之前只针对一个实例进行插入保存,详情移步:
现在我们需要的是公用的batchSave方法——当然是使用泛型啦,在原有的基础上进行稍加改造:
一、新增批量导入接口BatchSaveRepository
isSave: true - save, false - update
package com.easemob.oa.persistence.jpa;
import org.springframework.data.repository.NoRepositoryBean;
import java.util.List;
@NoRepositoryBean
public interface BatchSaveRepository<T> {
<S extends T> List<S> batchSave(Iterable<S> entities,Boolean isSave);
}
二、新增接口实现类BatchSaveRepositoryImpl
这里需要注意EntityManager是线程不安全的,多线程使用需要注意;
那么如何获取线程安全的EntityManager(简称em)?
由于EntityManagerFactory(简称emf)是线程安全的,在创建线程时通过emf为每个线程单独获取em即可。
那么如何获取EntityManagerFactory?
一般来说获取EntityManagerFactory需要通过读取配置文件中指定的persistence-unit-name,来动态获取;
指定persistence-unit-name有两种方式:*配置Persistence.xml、*配置JpaConfig.java;
META-INF/persistence.xml
<?xml version="1.0" encoding="UTF-8"?>
<persistence xmlns="http://java.sun.com/xml/ns/persistence" version="2.0">
<!--需要配置persistence-unit节点
持久化单元:
name:持久化单元名称
transaction-type:事务管理的方式
JTA:分布式事务管理(不同的表分不到不同的数据库,使用分布式事务管理)
RESOURCE_LOCAL:本地事务管理
-->
<persistence-unit name="turnfly" transaction-type="RESOURCE_LOCAL">
<!--jpa的实现方式 -->
<provider>org.hibernate.jpa.HibernatePersistenceProvider</provider>
<!--可选配置:配置jpa实现方的配置信息-->
<properties>
<!-- 数据库信息
用户名,javax.persistence.jdbc.user
密码, javax.persistence.jdbc.password
驱动, javax.persistence.jdbc.driver
数据库地址 javax.persistence.jdbc.url
-->
<property name="javax.persistence.jdbc.user" value="oatransfer"/>
<property name="javax.persistence.jdbc.password" value="qwert"/>
<property name="javax.persistence.jdbc.driver" value="com.oscar.Driver"/>
<property name="javax.persistence.jdbc.url" value="jdbc:oscar://x.x.x.x:2003/OSRDB?useSSL=false"/>
<!--配置jpa实现方(hibernate)的配置信息
显示sql : false|true
自动创建数据库表 : hibernate.hbm2ddl.auto
create : 程序运行时创建数据库表(如果有表,先删除表再创建)
update :程序运行时创建表(如果有表,不会创建表)
none :不会创建表
-->
<!--显示sql-->
<property name="hibernate.show_sql" value="true" />
<property name="hibernate.dialect" value="org.hibernate.dialect.OracleDialect" />
<!--自动创建数据库表-->
<property name="hibernate.hbm2ddl.auto" value="update" />
</properties>
</persistence-unit>
</persistence>
JpaConfig.java
package com.easemob.oa.persistence.config;
import org.springframework.boot.autoconfigure.jdbc.DataSourceBuilder;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.jdbc.datasource.DriverManagerDataSource;
import org.springframework.orm.jpa.LocalContainerEntityManagerFactoryBean;
import org.springframework.orm.jpa.vendor.HibernateJpaVendorAdapter;
import javax.sql.DataSource;
import java.util.Properties;
@Configuration
public class JpaConfig {
/**
* DateSource Config
* */
@Bean
public DataSource getDataSource() {
DriverManagerDataSource dataSource = new DriverManagerDataSource();
dataSource.setDriverClassName("com.oscar.Driver");
dataSource.setUrl("jdbc:oscar://127.0.0.1:2003/OSRDB_TF?useSSL=false&rewriteBatchedStatements=TRUE");
dataSource.setUsername("oatransfer");
dataSource.setPassword("0098");
return dataSource;
}
@Bean
@ConfigurationProperties(prefix="oa-server.second-datasource")
public DataSource getSecondDataSource() {
return DataSourceBuilder.create().build();
}
@Bean
public LocalContainerEntityManagerFactoryBean entityManagerFactory(DataSource dataSource) {
// 设置是否生成DDL被已经被初始化后,创建/更新所有相关表
HibernateJpaVendorAdapter vendorAdapter = new HibernateJpaVendorAdapter();
LocalContainerEntityManagerFactoryBean factoryBean = new LocalContainerEntityManagerFactoryBean();
factoryBean.setJpaVendorAdapter(vendorAdapter);
Properties properties = new Properties();
properties.setProperty("hibernate.dialect", "org.hibernate.dialect.MySQLDialect");
properties.setProperty("hibernate.show_sql", "true");
properties.setProperty("hibernate.hbm2ddl.auto", "update");
properties.setProperty("hibernate.format_sql", "false");
factoryBean.setJpaProperties(properties);
//扫描实体类所在的包
factoryBean.setPackagesToScan("com.easemob.oa.models.entity");
factoryBean.setDataSource(getSecondDataSource());
factoryBean.setPersistenceUnitName("turnfly");
return factoryBean;
}
}
Springboot替我们节省了这些配置操作,启动时会默认生成一个persistence-unit-name —— [default]
然而,若我们采用读取配置文件的方式读取该默认name ,
emf = Persistence.createEntityManagerFactory("default");
会发现取了个寂寞。。
仔细想想我们并没有配置文件,通过读取配置文件persistence-unit-name肯定不可取;而springboot已经默认注入了一个persistence-unit-name为default的bean到单例池中,我们直接取就ok;
正确的方式是新增一个工具类,类整体添加一个@Repository注解表示项目启动时被扫描,然后使用注解@PersistenceUnit注入EntityManagerFactory;
有关EntityManager、EntityManagerFactory相关请参考我另一篇文章:EntityManager、EntityManagerFactory详解
1、创建工具类EntityManagerHelper
ThreaLoad保证变量线程私有、通过set注入的形式给emf赋值;
package com.easemob.oa.persistence.config;
import org.springframework.stereotype.Repository;
import javax.persistence.*;
import java.util.logging.Level;
@Repository
public class EntityManagerHelper {
// 实体化私有静态实体管理器变量emf
//private static final EntityManagerFactory emf;
private static EntityManagerFactory emf;
// 实体化私有静态本地线程变量threadLocal
private static final ThreadLocal<EntityManager> threadLocal;
// 用来给两个变量赋初值的静态块
static {
//emf = Persistence.createEntityManagerFactory("zcxx");
threadLocal = new ThreadLocal<EntityManager>();
}
@PersistenceUnit
public void setEntityManager(EntityManagerFactory emf){
EntityManagerHelper.emf = emf;
}
// 得到实体管理器的方法
public static EntityManager getEntityManager() {
EntityManager manager = threadLocal.get();
if (manager == null || !manager.isOpen()) {
manager = emf.createEntityManager();
threadLocal.set(manager);
}
return manager;
}
// 关闭实体管理器的方法
public static void closeEntityManager() {
EntityManager em = threadLocal.get();
threadLocal.set(null);
if (em != null)
em.close();
}
// 开始事务的方法
public static void beginTransaction() {
getEntityManager().getTransaction().begin();
}
// 提交事务的方法
public static void commitTransaction() {
getEntityManager().getTransaction().commit();
}
// 回滚事务的方法
public static void rollback() {
getEntityManager().getTransaction().rollback();
}
// 生成查找的方法
public static Query createQuery(String query) {
return getEntityManager().createQuery(query);
}
public static void log(String string, Level info, Object object){
// TODO Auto-generated method stub
}
}
2、BatchSaveRepositoryImpl.java、CallableResultVo.java
继承SimpleJpaRepository是注入em的一般用法,spring官网在自定义Repository接口时有介绍,在多线程时由于em是线程单独从emf获取的,可以不用继承SimpleJpaRepository(使用时可以自行取掉);
此外为了接收子List执行完毕带有id的返回值,子线程应实现Callable接口,可以构造一个CallableResultVo实体来接收;
CountDownLatch保证所有线程执行完毕才进行下一步操作;
CallableResultVo.java
package com.easemob.oa.models.callable;
import lombok.Data;
import java.util.List;
@Data
public class CallableResultVo<T> {
List<T> result;
}
BatchSaveRepositoryImpl.java
package com.easemob.oa.persistence.jpa.impl;
import com.easemob.oa.models.callable.CallableResultVo;
import com.easemob.oa.persistence.jpa.BatchSaveRepository;
import com.google.common.collect.Lists;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.jpa.repository.support.JpaEntityInformation;
import org.springframework.data.jpa.repository.support.SimpleJpaRepository;
import org.springframework.data.repository.NoRepositoryBean;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;
import javax.persistence.*;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.*;
/**
* @Author turnflys
* @Date 1/12/21 10:51 PM
*/
@NoRepositoryBean
@Slf4j
public class BatchSaveRepositoryImpl<T,ID extends Serializable> extends SimpleJpaRepository<T, ID> implements BatchSaveRepository<T> {
//每个线程分的数据量
private final Integer BATCH_SIZE = 1500;
//最大线程数
private final Integer MAX_THREAD = 12;
private static EntityManager em = null;
public BatchSaveRepositoryImpl(JpaEntityInformation entityInformation, EntityManager entityManager) {
super(entityInformation, entityManager);
this.em = entityManager;
}
//@Async
@Override
@Transactional(propagation = Propagation.REQUIRES_NEW)
public <S extends T> List<S> batchSave(Iterable<S> entities,Boolean isSave){
//返回list数组,需要带id
List<S> result = new ArrayList<>();
List<S> lists = Lists.newArrayList();
List<S> listsTmp;
entities.forEach(lists::add);
Integer listSize = lists.size();
//构造线程池 - 默认最大MaxThread个线程同时执行,每个线程执行数据量BATCH_SIZE
ExecutorService executorService = Executors.newFixedThreadPool(MAX_THREAD);
//需要循环测次数,最后一次大概率不满足一个BATCH_SIZE
Integer loopCount = listSize/BATCH_SIZE+1;
//倒计时门闩 - await() 让线程等待,用countDown()消初始化数量。当数量等于0时线程唤醒
CountDownLatch cdl = new CountDownLatch(loopCount);//使用计数器
//创建FutureList,存储每一个线程返回的结果
List<Future> futureSaveList = new ArrayList<>();
List<Future> futureUpdateList = new ArrayList<>();
//一共循环threadNum次
for(int i = 0; i < loopCount; i++){
if(i == loopCount-1){
//走到头但不足一整次的部分
log.info("------------------------------------------------拆分数据最后一部分下标范围:start - {}, end - {}.",i*BATCH_SIZE,listSize);
listsTmp = lists.subList(i*BATCH_SIZE,listSize);
PartSaveCallable<S> psc = new PartSaveCallable<>(listsTmp,cdl,isSave);
if(isSave){
futureSaveList.add(executorService.submit(psc));
}else{
futureUpdateList.add(executorService.submit(psc));
}
}else{
log.info("------------------------------------------------拆分数据下标范围:start - {}, end - {}.",i*BATCH_SIZE,(i+1)*BATCH_SIZE);
listsTmp = lists.subList(i*BATCH_SIZE,(i+1)*BATCH_SIZE);
PartSaveCallable<S> psc = new PartSaveCallable<>(listsTmp,cdl,isSave);
if(isSave){
futureSaveList.add(executorService.submit(psc));
}else{
futureUpdateList.add(executorService.submit(psc));
}
}
}
try {
//确保线程执行完
cdl.await();
List<Future> tempFutureList = isSave?futureSaveList:futureUpdateList;
for(Future future : tempFutureList){
//线程到这儿必定执行完了
try {
Object res = future.get();
if(res != null){
CallableResultVo<S> crv = (CallableResultVo<S>) res;
result.addAll(crv.getResult());
}
} catch (ExecutionException e) {
e.printStackTrace();
}
}
} catch (InterruptedException e) {
e.printStackTrace();
}finally {
//执行完关闭线程池
executorService.shutdown();
}
return result;
}
static <S> List<S> partBatchSave(Iterable<S> entities) {
Iterator<S> iterator = entities.iterator();
int index = 0;
while (iterator.hasNext()){
em.persist(iterator.next());
index++;
if (index % 500 == 0){
em.flush();
em.clear();
}
}
if (index % 500 != 0){
em.flush();
em.clear();
}
List<S> lists = Lists.newArrayList();
entities.forEach(lists::add);
return lists;
}
}
三、子线程实现类PartSaveCallable
注意Callable与Runnale的区别:
Callable可以在线程执行完毕时返回指定的值 且 可以向上抛出异常;
注意需要为每个线程新创建事务:@Transactional(propagation = Propagation.REQUIRES_NEW)
package com.easemob.oa.persistence.jpa.impl;
import com.easemob.oa.models.callable.CallableResultVo;
import com.easemob.oa.persistence.config.EntityManagerHelper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.transaction.annotation.Transactional;
import javax.persistence.EntityManager;
import javax.persistence.EntityTransaction;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
@Slf4j
public class PartSaveCallable<S> implements Callable {
private List<S> lists;
private CountDownLatch cdl;
private Boolean isSave;
public PartSaveCallable(List<S> lists, CountDownLatch cdl,Boolean isSave){
this.lists = lists;
this.cdl = cdl;
this.isSave = isSave;
}
@Override
@Transactional
public CallableResultVo call() {
CallableResultVo<S> result = new CallableResultVo<>();
List<S> resultList = new ArrayList<>();
String su = isSave?"插入":"更新";
log.info("--------------线程"+Thread.currentThread()+"开始执行" + su + "操作!,当前cdl is -------:" + cdl.getCount() + " -----------------------------");
//当前线程em
EntityManager em = EntityManagerHelper.getEntityManager();
EntityTransaction entityTransaction = em.getTransaction();
try {
log.info("-------------current EntityManager is :{} -------------",em);
EntityManagerHelper.beginTransaction();
//log.info("-------------list Size is :{} -------------",lists.size());
for(S s : lists){
if(isSave){
em.persist(s);
}else{
em.merge(s);
}
resultList.add(s);
}
EntityManagerHelper.commitTransaction();
EntityManagerHelper.closeEntityManager();
result.setResult(resultList);
log.info("-------------线程"+Thread.currentThread()+su+"完成,当前cdl(--) is -------:" + cdl.getCount() + " -------------");
cdl.countDown();
}catch (RuntimeException e){
if (entityTransaction.isActive()) {
entityTransaction.rollback();
}
log.error("发生错误 :{}", e);
}
return result;
}
}
效果: