写点什么

ThreadLocal 源码分析及避坑指南

作者:喝水不抬头
  • 2023-02-12
    上海
  • 本文字数:4356 字

    阅读完需:约 14 分钟

ThreadLocal 可以为每个线程保存一份变量的副本,防止在多线程情况下,属于某个线程的变量被其他线程修改。下面从源码角度分析其实现原理。观察最常使用的 get()和 set()方法可以看出:


  • get()操作需要获取当前线程对应的 ThreadLocalMap,再根据当前 ThreadLoca 变量的引用,获取当前线程的变量副本。

  • set(T value)操作同样需要获取当前线程对应的 ThreadLocalMap,再根据当前 ThreadLoca 变量的引用,设置当前线程的变量副本。


  public T get() {        Thread t = Thread.currentThread();        ThreadLocalMap map = getMap(t);        if (map != null) {            ThreadLocalMap.Entry e = map.getEntry(this);            if (e != null) {                @SuppressWarnings("unchecked")                T result = (T)e.value;                return result;            }        }        return setInitialValue();    }
复制代码


    public void set(T value) {        Thread t = Thread.currentThread();        ThreadLocalMap map = getMap(t);        if (map != null)            map.set(this, value);        else            createMap(t, value);    }
复制代码


我们来看一下 ThreadLocal、ThreadLocalMap 和 Thread 的关系。ThreadLocalMap 是 ThreadLocal 的静态内部类,而每一个 Thread 对象都包含一个 ThreadLocalMap。


public class ThreadLocal<T> {    ......       static class ThreadLocalMap {        ......    }    ......}
复制代码


public class Thread implements Runnable {
...... /* ThreadLocal values pertaining to this thread. This map is maintained * by the ThreadLocal class. */ ThreadLocal.ThreadLocalMap threadLocals = null; ....}
复制代码


既然每个线程都维护一个 ThreadLocalMap,那么为什么不设计 Map<Thread,T>这种形式,一个线程对应一个存储对象,而“托管”给 ThreadLocal 来保存每个线程的变量副本呢?ThreadLocal 这样设计的目的主要有两个:


  • 一是可以保证当前线程结束时相关对象能尽快被回收;

  • 二是 ThreadLocalMap 中的元素会大大减少,我们都知道 map 过大更容易造成哈希冲突而导致性能变差。


下面我们着重看下 ThreadLocalMap 这个数据结构。


static class ThreadLocalMap {
static class Entry extends WeakReference<ThreadLocal<?>> { /** The value associated with this ThreadLocal. */ Object value;
Entry(ThreadLocal<?> k, Object v) { super(k); value = v; } } ......}
复制代码


ThreadLocalMap 中的 key 是 ThreadLocal<?>对象,value 值当前线程的变量副本。这里需要注意的是,ThreadLocalMap 的 Entry 是继承 WeakReference,和 HashMap 很大的区别是,Entry 中没有 next 字段,所以就不存在链表的情况了。那么 ThreadLocalMap 在 set 和 get 时是如何解决 hash 冲突的呢,接下来进行介绍。

hash 冲突

private void set(ThreadLocal<?> key, Object value) {
// We don't use a fast path as with get() because it is at // least as common to use set() to create new entries as // it is to replace existing ones, in which case, a fast // path would fail more often than not.
Entry[] tab = table; int len = tab.length; int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) { ThreadLocal<?> k = e.get();
if (k == key) { e.value = value; return; }
if (k == null) { replaceStaleEntry(key, value, i); return; } }
tab[i] = new Entry(key, value); int sz = ++size; if (!cleanSomeSlots(i, sz) && sz >= threshold) rehash(); }
复制代码


在往 ThreadLocalMap 中 put 元素时,首先计算索引


  • 如果该索引出没有 Entry,则退出循环,构造一个新的 Entry 插入;

  • 如果该索引处已插入 Entry,并且对应的 key 正好为当前的 ThreadLocal<?>对象,则直接进行 value 的替换,

  • 如果该索引处已插入 Entry,并且对应的 key 不是当前的 ThreadLocal<?>对象,则计算下一个索引。计算下一个索引的方式其实就是当前索引加 1,若超过数组长度,则索引为 0。


   /**     * Increment i modulo len.   */   private static int nextIndex(int i, int len) {       return ((i + 1 < len) ? i + 1 : 0);   }
复制代码


在从 ThreadLocalMap 中 get 元素时,首先计算索引


private Entry getEntry(ThreadLocal<?> key) {            int i = key.threadLocalHashCode & (table.length - 1);            Entry e = table[i];            if (e != null && e.get() == key)                return e;            else                return getEntryAfterMiss(key, i, e);        }
复制代码


private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {            Entry[] tab = table;            int len = tab.length;
while (e != null) { ThreadLocal<?> k = e.get(); if (k == key) return e; if (k == null) expungeStaleEntry(i); else i = nextIndex(i, len); e = tab[i]; } return null; }
复制代码


  • 若索引处有 Entry,并且对应的 key 正好为当前的 ThreadLocal<?>对象,则返回对应的 value

  • 若索引处没有 Entry,则按照与 set 方法相似的过程,计算下一个索引,直到找到某个 Entry,对应的 key 正好为当前的 ThreadLocal<?>对象,如果找不到,最终返回 null

常见的坑

由于 ThreadLocal 其内部条目为弱引用,当 key 为 null 时,该条目就变成“废弃条目”,相关“value”的回收,往往依赖于几个关键点,即 set、remove、rehash。下面是 set 示例:


private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table; int len = tab.length; int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) { ThreadLocal<?> k = e.get();
if (k == key) { e.value = value; return; }
if (k == null) { // 替换废弃条目 replaceStaleEntry(key, value, i); return; } }
tab[i] = new Entry(key, value); int sz = ++size; // 扫描并清理发现的废弃条目,并检查容量是否超限 if (!cleanSomeSlots(i, sz) && sz >= threshold) // 清理废弃条目,如果仍然超限,则扩容 rehash(); }
复制代码


具体的清理逻辑是在 cleanSomeSlots 和 expungeStaleEntry 中。可以看出,废弃项目的回收依赖于显示的触发,否则就要等待线程结束,进而回收相应的 ThreadLocalMap!这就是很多 OOM 的来源,所以通常建议:

  1. 应用一定要自己负责 remove

  2. 不要和线程池配合,因为 worker 线程往往是不会退出的


下面举一个例子说明,ThreadLocal 在线程池中使用的坑使用 SpringBoot 创建一个 Web 应用程序,使用 ThreadLocal 存放一个 Integer 的值,来暂且代表需要在线程中保存的用户 ID,这个值初始时 null,在业务逻辑中,会把外部传入的用户 ID 设置到 ThreadLocal 中,示例代码如下


@RestControllerpublic class WrongDemoController {    private static final ThreadLocal<Integer> currentUser = ThreadLocal.withInitial(() -> null);
@GetMapping("/wrong") public Map wrong(@RequestParam(value = "userId") Integer userId) { String before = Thread.currentThread().getName() + ":" + currentUser.get(); currentUser.set(userId); String after = Thread.currentThread().getName() + ":" + currentUser.get(); Map result = new HashMap(); result.put("before", before); result.put("after", after); return result; }}
复制代码


线程池会重用固定的几个线程,为了更快地重现问题,在配置文件中设置一下 tomcat 的参数,把工作线程池最大线程数设置为 1,这样始终是同一个线程在处理请求:


server.tomcat.max-threads=1
复制代码


在浏览器中依次输入 userId=1 和 userId=2,可以看出:


  • 当 userId=1 时,设置 ThreadLocal 之前和之后,从 ThreadLocal 中拿到的值分别为 null 和 1


  • 当 userId=2 时,设置 ThreadLocal 之前和之后,从 ThreadLocal 中拿到的值分别为 1 和 2



问题出现了,为什么当 userId=2 时,从 ThreadLocal 拿到的初始值是 1 呢?原因是 tomact 的工作线程被重用了(在我们的例子中只有一个工作线程),那么很可能从 ThreadLocal 中拿到的值是别的用户的请求遗留的值(真实生产环境可能会导致用户信息错乱)。解决方案:ThreadLocal 工具用来存放一些数据时,需要特别注意在代码运行完后,显示地去清空设置的数据。比如在上面的案例中,可以再 finally 代码块中显示清除 ThreadLocal 中的数据。


@RestControllerpublic class WrongDemoController {    private static final ThreadLocal<Integer> currentUser = ThreadLocal.withInitial(() -> null);
@GetMapping("/wrong") public Map wrong(@RequestParam(value = "userId") Integer userId) { try{ String before = Thread.currentThread().getName() + ":" + currentUser.get(); currentUser.set(userId); String after = Thread.currentThread().getName() + ":" + currentUser.get(); Map result = new HashMap(); result.put("before", before); result.put("after", after); return result; }finally { // 显示清除ThreadLocal中的数据 currentUser.remove(); }
}}
复制代码


发布于: 2023-02-12阅读数: 21
用户头像

还未添加个人签名 2018-05-15 加入

还未添加个人简介

评论

发布
暂无评论
ThreadLocal源码分析及避坑指南_喝水不抬头_InfoQ写作社区