写点什么

多线程如何实现事务回滚?一招帮你搞定!

作者:Java你猿哥
  • 2023-04-23
    湖南
  • 本文字数:8506 字

    阅读完需:约 28 分钟

多线程如何实现事务回滚?一招帮你搞定!

特别说明 CountDownLatch

CountDownLatch 是一个类 springboot 自带的类,可以直接用,变量 AtomicBoolean 也是可以直接使用

CountDownLatch 的用法

CountDownLatch 典型用法:

1、某一线程在开始运行前等待 n 个线程执行完毕。 将 CountDownLatch 的计数器初始化为 new CountDownLatch(n),每当一个任务线程执行完毕,就将计数器减 1 countdownLatch.countDown(),当计数器的值变为 0 时,在 CountDownLatch 上 await()的线程就会被唤醒。一个典型应用场景就是启动一个服务时,主线程需要等待多个组件加载完毕,之后再继续执行。

2、实现多个线程开始执行任务的最大并行性。 注意是并行性,不是并发,强调的是多个线程在某一时刻同时开始执行。类似于赛跑,将多个线程放到起点,等待发令枪响,然后同时开跑。做法是初始化一个共享的 CountDownLatch(1),将其计算器初始化为 1,多个线程在开始执行任务前首先 countdownlatch.await(),当主线程调用 countDown()时,计数器变为 0,多个线程同时被唤醒。

CountDownLatch(num) 简单说明

new 一个 CountDownLatch(num) 对象

建立对象的时候 num 代表的是需要等待 num 个线程

// 建立对象的时候 num 代表的是需要等待 num 个线程//主线程CountDownLatch mainThreadLatch = new CountDownLatch(num);//子线程CountDownLatch rollBackLatch  = new CountDownLatch(1);
复制代码

主线程:mainThreadLatch.await() 和 mainThreadLatch.countDown()

新建对象

CountDownLatch mainThreadLatch = new CountDownLatch(num);
复制代码

卡住主线程,让其等待子线程,代码 mainThreadLatch.await(),放在主线程里

mainThreadLatch.await();
复制代码


代码 mainThreadLatch.countDown(),放在子线程里,每一个子线程运行一到这个代码,意味着 CountDownLatch(num),里面的 num-1(自动减一)

mainThreadLatch.countDown();
复制代码

CountDownLatch(num)里面的 num 减到 0,也就是 CountDownLatch(0),被卡住的主线程 mainThreadLatch.await(),就会往下执行


子线程:rollBackLatch.await() 和 rollBackLatch.countDown()

新建对象,特别注意:子线程这个 num 就是 1(关于只能为 1 的解答在后面)

CountDownLatch rollBackLatch  = new CountDownLatch(1);
复制代码

卡住子线程,阻止每一个子线程的事务提交和回滚

rollBackLatch.await();
复制代码



代码 rollBackLatch.countDown();放在主线程里,而且是放在主线程的等待代码 mainThreadLatch.await();后面。

rollBackLatch.countDown();
复制代码



为什么所有的子线程会在一瞬间就被所有都释放了?


事务的回滚是怎么结合进去的?

假设总共 20 个子线程,那么其中一个线程报错了怎么实现所有线程回滚。

引入变量

AtomicBoolean rollbackFlag = new AtomicBoolean(false)
复制代码

和字面意思是一样的:根据 rollbackFlag 的 true 或者 false 判断子线程里面,是否回滚。

首先我们确定的一点:rollbackFlag 是所有的子线程都用着这一个判断



主线程类 Entry

package org.apache.dolphinscheduler.api.utils;
import com.alibaba.fastjson.JSONArray;import com.alibaba.fastjson.JSONObject;import org.apache.dolphinscheduler.api.controller.WorkThread;import org.apache.dolphinscheduler.common.enums.DbType;import org.springframework.web.bind.annotation.*;
import java.text.SimpleDateFormat;import java.util.ArrayList;import java.util.Date;import java.util.List;import java.util.TimeZone;import java.util.concurrent.CountDownLatch;import java.util.concurrent.atomic.AtomicBoolean;

@RestController@RequestMapping("importDatabase")public class Entry {
/** * @param dbid 数据库的id * @param tablename 表名 * @param sftpFileName 文件名称 * @param head 是否有头文件 * @param splitSign 分隔符 * @param type 数据库类型 */ private static String SFTP_HOST = "192.168.1.92"; private static int SFTP_PORT = 22; private static String SFTP_USERNAME = "root"; private static String SFTP_PASSWORD = "rootroot"; private static String SFTP_BASEPATH = "/opt/testSFTP/"; @PostMapping("/thread") @ResponseBody public static JSONObject importDatabase(@RequestParam("dbid") int dbid ,@RequestParam("tablename") String tablename ,@RequestParam("sftpFileName") String sftpFileName ,@RequestParam("head") String head ,@RequestParam("splitSign") String splitSign ,@RequestParam("type") DbType type ,@RequestParam("heads") String heads ,@RequestParam("scolumns") String scolumns ,@RequestParam("tcolumns") String tcolumns ) throws Exception { JSONObject obForRetrun = new JSONObject();
try {
JSONArray jsonArray = JSONArray.parseArray(tcolumns); JSONArray scolumnArray = JSONArray.parseArray(scolumns); JSONArray headsArray = JSONArray.parseArray(heads); List<Integer> listInteger = getRrightDataNum(headsArray,scolumnArray); JSONArray bodys = SFTPUtils.getSftpContent(SFTP_HOST,SFTP_PORT,SFTP_USERNAME,SFTP_PASSWORD,SFTP_BASEPATH,sftpFileName,head,splitSign); int total = bodys.size(); int num = 20; //一个批次的数据有多少 int count = total/num;//周期 int lastNum =total- count*num;//余数
List<Thread> list = new ArrayList<Thread>(); SimpleDateFormat sdf = new SimpleDateFormat("HH:mm:ss:SS"); TimeZone t = sdf.getTimeZone(); t.setRawOffset(0); sdf.setTimeZone(t); Long startTime=System.currentTimeMillis();

int countForCountDownLatch = 0; if(lastNum==0){//整除 countForCountDownLatch= count; }else{ countForCountDownLatch= count + 1; } //子线程 CountDownLatch rollBackLatch = new CountDownLatch(1); //主线程 CountDownLatch mainThreadLatch = new CountDownLatch(countForCountDownLatch);
AtomicBoolean rollbackFlag = new AtomicBoolean(false); StringBuffer message = new StringBuffer(); message.append("报错信息:");
//子线程 for(int i=0;i<count;i++) {//这里的count代表有几个线程 Thread g = new Thread(new WorkThread(i,num,tablename,jsonArray,dbid,type,bodys,listInteger,mainThreadLatch,rollBackLatch,rollbackFlag,message )); g.start(); list.add(g); }
if(lastNum!=0){//有小数的情况下 Thread g = new Thread(new WorkThread(0,lastNum,tablename,jsonArray,dbid,type,bodys,listInteger,mainThreadLatch,rollBackLatch,rollbackFlag,message )); g.start(); list.add(g); }
// for(Thread thread:list){// System.out.println(thread.getState());// thread.join();//是等待这个线程结束;// }
mainThreadLatch.await(); //所有等待的子线程全部放开 rollBackLatch.countDown();
//是主线程等待子线程的终止。也就是说主线程的代码块中,如果碰到了t.join()方法,此时主线程需要等待(阻塞),等待子线程结束了(Waits for this thread to die.),才能继续执行t.join()之后的代码块。

Long endTime=System.currentTimeMillis(); System.out.println("总共用时: "+sdf.format(new Date(endTime-startTime)));
if(rollbackFlag.get()){ obForRetrun.put("code",500); obForRetrun.put("msg",message); }else{ obForRetrun.put("code",200); obForRetrun.put("msg","提交成功!"); } obForRetrun.put("data",null); }catch (InterruptedException e){ e.printStackTrace(); obForRetrun.put("code",500); obForRetrun.put("msg",e.getMessage()); obForRetrun.put("data",null); } return obForRetrun;
}
/** * 文件里第几列被作为导出列 * @param headsArray * @param scolumnArray * @return */ public static List<Integer> getRrightDataNum(JSONArray headsArray, JSONArray scolumnArray){ List<Integer> list = new ArrayList<Integer>(); String arrayA [] = new String[headsArray.size()]; for(int i=0;i<headsArray.size();i++){ JSONObject ob = (JSONObject)headsArray.get(i); arrayA[i] =String.valueOf(ob.get("title")); }
String arrayB [] = new String[scolumnArray.size()]; for(int i=0;i<scolumnArray.size();i++){ JSONObject ob = (JSONObject)scolumnArray.get(i); arrayB[i] =String.valueOf(ob.get("columnName")); }
for(int i =0;i<arrayA.length;i++){ for(int j=0;j<arrayB.length;j++){ if(arrayA[i].equals(arrayB[j])){ list.add(i); break; } } }
return list; }}
复制代码


子线程类 WorkThread

package org.apache.dolphinscheduler.api.controller;
import com.alibaba.fastjson.JSONArray;import com.alibaba.fastjson.JSONObject;import org.apache.dolphinscheduler.api.service.DataSourceService;import org.apache.dolphinscheduler.common.enums.DbType;import org.apache.dolphinscheduler.dao.entity.DataSource;import org.apache.dolphinscheduler.dao.mapper.DataSourceMapper;import org.apache.dolphinscheduler.service.bean.SpringApplicationContext;import org.springframework.transaction.PlatformTransactionManager;
import java.sql.Connection;import java.sql.PreparedStatement;import java.sql.SQLException;import java.text.ParseException;import java.text.SimpleDateFormat;import java.util.Date;import java.util.List;import java.util.TimeZone;import java.util.concurrent.CountDownLatch;import java.util.concurrent.atomic.AtomicBoolean;

/** * 多线程 */public class WorkThread implements Runnable{ //建立线程的两种方法 1 实现Runnable 接口 2 继承 Thread 类
private DataSourceService dataSourceService;
private DataSourceMapper dataSourceMapper;
private Integer begin; private Integer end; private String tableName; private JSONArray columnArray; private Integer dbid; private DbType type; private JSONArray bodys; private List<Integer> listInteger; private PlatformTransactionManager transactionManager; private CountDownLatch mainThreadLatch; private CountDownLatch rollBackLatch; private AtomicBoolean rollbackFlag; private StringBuffer message;


/** * @param i * @param num * @param tableFrom * @param columnArrayFrom * @param dbidFrom * @param typeFrom */ public WorkThread(int i, int num, String tableFrom, JSONArray columnArrayFrom, int dbidFrom , DbType typeFrom, JSONArray bodysFrom, List<Integer> listIntegerFrom ,CountDownLatch mainThreadLatch,CountDownLatch rollBackLatch,AtomicBoolean rollbackFlag ,StringBuffer messageFrom) { begin=i*num; end=begin+num; tableName = tableFrom; columnArray = columnArrayFrom; dbid = dbidFrom; type = typeFrom; bodys = bodysFrom; listInteger = listIntegerFrom; this.dataSourceMapper = SpringApplicationContext.getBean(DataSourceMapper.class); this.dataSourceService = SpringApplicationContext.getBean(DataSourceService.class); this.transactionManager = SpringApplicationContext.getBean(PlatformTransactionManager.class); this.mainThreadLatch = mainThreadLatch; this.rollBackLatch = rollBackLatch; this.rollbackFlag = rollbackFlag; this.message = messageFrom; }
public void run() {
DataSource dataSource = dataSourceMapper.queryDataSourceByID(dbid); String cp = dataSource.getConnectionParams(); Connection con=null; con = dataSourceService.getConnection(type,cp); if(con!=null) { SimpleDateFormat sdf = new SimpleDateFormat("HH:mm:ss:SS"); TimeZone t = sdf.getTimeZone(); t.setRawOffset(0); sdf.setTimeZone(t); Long startTime = System.currentTimeMillis(); try { con.setAutoCommit(false);
//---------------------------- 获取字段和类型 String columnString = null;//活动的字段 int intForType = 0; String type[] = new String[columnArray.size()];//类型集合 for(int i=0;i<columnArray.size();i++){ JSONObject ob = (JSONObject)columnArray.get(i); if(columnString==null){ columnString = String.valueOf(ob.get("name")); }else{ columnString = columnString + ","+ String.valueOf(ob.get("name")); } type[intForType] = String.valueOf(ob.get("type")); intForType = intForType + 1; } intForType = 0;
//这一步是为了形成 insert into "+tableName+"(id,name,age) values (?,?,?); String dataString = null; for(int i=0;i<columnArray.size();i++){ if(dataString==null){ dataString = "?"; }else{ dataString = dataString +","+"?"; } }
//--------------------------------
StringBuffer sql = new StringBuffer(); sql = sql.append("insert into "+tableName+"("+columnString+") values ("+dataString+")") ; PreparedStatement pst= (PreparedStatement)con.prepareStatement(sql.toString()); for(int i=begin;i<end;i++) { JSONObject ob = (JSONObject)bodys.get(i); if(ob!=null){ String [] array = ob.get(i).toString().split("\\,"); String [] arrayFinal = getFinalData(listInteger,array); for(int j=0;j<type.length;j++){ String typeString = type[j].toLowerCase(); int z = j+1; if("string".equals(typeString)||"varchar".equals(typeString)){ pst.setString(z,arrayFinal[j]);//这里的第一个参数 是指 替换第几个? }else if("int".equals(typeString)||"bigint".equals(typeString)){ pst.setInt(z,Integer.valueOf(arrayFinal[j]));//这里的第一个参数 是指 替换第几个? }else if("long".equals(typeString)){ pst.setLong(z,Long.valueOf(arrayFinal[j]));//这里的第一个参数 是指 替换第几个? }else if("double".equals(typeString)){ pst.setDouble(z,Double.parseDouble(arrayFinal[j])); }else if("date".equals(typeString)||"datetime".equals(typeString)){ pst.setDate(z, setDateback(arrayFinal[j])); }else if("Timestamp".equals(typeString)){ pst.setTimestamp(z, setTimestampback(arrayFinal[j])); } } } pst.addBatch(); } pst.executeBatch();
mainThreadLatch.countDown(); rollBackLatch.await();
if(rollbackFlag.get()){ con.rollback();//表示回滚事务; }else{ con.commit();//事务提交 } con.close(); } catch (Exception e) { System.out.println(e.getMessage()); message = message.append(e.getMessage()); rollbackFlag.set(true); mainThreadLatch.countDown(); try { con.close(); } catch (SQLException throwables) { throwables.printStackTrace(); } } Long endTime = System.currentTimeMillis(); System.out.println(Thread.currentThread().getName()+":startTime= "+sdf.format(new Date(startTime))+",endTime= "+sdf.format(new Date(endTime)) +" 用时:"+sdf.format(new Date(endTime - startTime)));
} }

public java.sql.Date setDateback(String dateString) throws ParseException { SimpleDateFormat sdf = new SimpleDateFormat( "yyyy-MM-dd HH:mm:ss" ); java.util.Date date = sdf.parse( "2015-5-6 10:30:00" ); long lg = date.getTime();// 日期 转 时间戳 return new java.sql.Date( lg ); }
public java.sql.Timestamp setTimestampback(String dateString) throws ParseException { SimpleDateFormat sdf = new SimpleDateFormat( "yyyy-MM-dd HH:mm:ss" ); java.util.Date date = sdf.parse( "2015-5-6 10:30:00" ); long lg = date.getTime();// 日期 转 时间戳 return new java.sql.Timestamp( lg ); }
public String [] getFinalData(List<Integer> listInteger,String[] array){ String [] arrayFinal = new String [listInteger.size()]; for(int i=0;i<listInteger.size();i++){ int a = listInteger.get(i); arrayFinal[i] = array[a]; } return arrayFinal; }}
复制代码

代码实际运用踩坑!!!!


还记得这里有个一批次处理多少数据么,我这边设置了 20,实际到运用中的时候客户给了个 20W 的数据,我批次设置为 20,那就有 1W 个子线程!!!!

这还不是最糟糕的,最糟糕的是每个子线程都会创建一个数据库连接,数据库直接被我搞炸了


所以这里需要把:

int num = 20; //一个批次的数据有多少
复制代码

改成:

int num = 20000; //一个批次的数据有多少
复制代码


用户头像

Java你猿哥

关注

一只在编程路上渐行渐远的程序猿 2023-03-09 加入

关注我,了解更多Java、架构、Spring等知识

评论

发布
暂无评论
多线程如何实现事务回滚?一招帮你搞定!_Java_Java你猿哥_InfoQ写作社区