您的当前位置:首页正文

Spring Security中,多线程操作导致安全上下文丢失(附CountDownLatch的用法)

2024-11-29 来源:个人技术集锦

一、问题描述

之前做项目的时候,遇到的这个问题。

1. 前景描述

该项目应用的是Spring Security + JWT的安全框架,用户在登录时会携带有Authorization信息,Spring Security会对其进行认证,并在成功后,将当前登录的用户信息存储到安全上下文,然后在更新或插入数据库数据时,会从安全上下文中取出当前登录用户信息,作为这条数据的最后更新人。

2. 问题出现

某个功能因涉及的表比较多,数据量比较大,导致效率很慢,所以决定将其改为异步操作,使用多线程来实现。但是在功能实现完后,问题出来了。。。
最后更新人获取错误。
明明是我操作的,但是最后更新人却是另外一个人。

二、解决方案

通过上网搜索得知,在主线程中,启用另外的线程执行之后操作的时候,异步线程中的安全上下文会丢失。因为Spring Security的安全上下文默认是存储在ThreadLocal(也就是线程本地)的,启动其他线程执行的时候,就会丢失掉上下文信息。
知道了问题所在,接下来就简单了,既然在异步线程中,安全上下文会丢失,那么我只要把主线程中的安全上下文带到异步线程中去,不就好了嘛。

    // 1. 在主线程中获取安全上下文。
    SecurityContext securityContext = SecurityContextHolder.getContext();
    threadTaskExecutor.execute(() -> {
		try {
			// 2. 将主线程中的安全上下文设置到子线程中的ThreadLocal中。
        	SecurityContextHolder.setContext(securityContext);
        	// 业务代码
	    } catch (Exception e) {
	        // 异常信息捕获
	    } finally {
	        // 清除操作
	        // 3. 将调用者中的安全上下文设置到当前业务子线程中的ThreadLocal中。
	        SecurityContextHolder.clearContext();
	    }
    });

这里有一点需要注意的是,第三步的清除操作必不可少,不然会导致异步线程的安全上下文传播到线程池中,而如果该线程为线程池的核心线程,下次该线程执行时又没有设置安全上下文,则会获取到错误的登陆者信息(也就是这次设置的安全上下文信息)。

三、附加

1. 线程池

附上一个简单的线程池的配置类吧:

@Configuration
public class ThreadPoolConfig {
    @Bean
    public ThreadPoolTaskExecutor threadTaskExecutor() {
        ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
        executor.setCorePoolSize(10);
        executor.setKeepAliveSeconds(200);
        executor.setMaxPoolSize(20);
        executor.setQueueCapacity(20);
        executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
        return executor;
    }
}

然后只需要在用到的地方,注入一下就可以了:

	@Autowired
    private ThreadPoolTaskExecutor threadTaskExecutor;

然后就在需要开线程的地方,调用其execute或者submit即可:

threadTaskExecutor.execute(() -> {
	//业务代码
});

2. CountDownLatch

如果后续操作需要基于所有的线程执行完,那么可以使用CountDownLatch。

1. CountDownLatch概念
2. CountDownLatch用法

某一线程在开始运行前等待n个线程执行完毕。

// 将CountDownLatch的计数器初始化
final CountDownLatch latch = new CountDownLatch(projectIds.size());
projectIds.forEach(it -> {
	threadTaskExecutor.execute(() -> {
		try {
        	// 业务代码
	    } catch (Exception e) {
	        // 异常信息捕获
	    } finally {
	    	// 将计数器减1
	        latch.countDown();
	    }
	});
});
// 等待计算器的值变为0
latch.await();
// 基于所有线程执行完后的代码
显示全文