跳到主要内容

ThreadLocal 线程内的容器

学习编程的过程中,总会不断地发现新的事物,就像我今天认识的 ThreadLocal

ThreadLocal 看名字翻译过来是“线程本地”?这个名字乍一看有些懵,但我认为还挺贴切,线程内的本地变量(或者说本地容器)。为什么说它是线程内的变量,这家伙又是干嘛用的呢?咱们接着往下瞧。

假设有这样一个需求,要求搞一个计数器,每调用一次加1,并且在并发环境下,这个计数器不能受到其它线程的影响。

举例,先来写一个计数器:

public class Counter {

private static int count = 0;

public static int getCount() {
return ++count;
}
}

创建3个线程来模拟下使用场景。

public class MainTest {

public static void main(String[] args) {
// 假设一个线程调用3次计数器
Runnable task = new Runnable() {
@Override
public void run() {
for (int i = 0; i < 3; i++) {
System.out.printf("%s => %s\n", Thread.currentThread().getName(), Counter.getCount());
}
}
};

new Thread(task).start();
new Thread(task).start();
new Thread(task).start();
}
}

结果:

Thread-0 => 1
Thread-0 => 4
Thread-0 => 5
Thread-2 => 2
Thread-2 => 6
Thread-2 => 7
Thread-1 => 3
Thread-1 => 8
Thread-1 => 9

看样子这个计数器还不能满足需求,由于计数器中计数是存在 static 变量中,这个计数器被所有线程共享了,导致所有线程都共享了这一个计数变量。看来是时候请上今天的主角 ThreadLocal 了,现在用 ThreadLocal 改造下这个计数器。

public class Counter {

private static ThreadLocal<Integer> counter = new ThreadLocal<Integer>() {
@Override
protected Integer initialValue() {
return 0;
}
};

public static int getCount() {
counter.set(counter.get() + 1);
return counter.get();
}
}

这里把计数器中的计数存到到 ThreadLocal 中,再看一下效果:

public class MainTest {

public static void main(String[] args) {
// 假设一个线程调用3次计数器
Runnable task = new Runnable() {
@Override
public void run() {
for (int i = 0; i < 3; i++) {
System.out.printf("%s => %s\n", Thread.currentThread().getName(), Counter.getCount());
}
}
};

new Thread(task).start();
new Thread(task).start();
new Thread(task).start();
}
}

执行结果:

Thread-0 => 1
Thread-0 => 2
Thread-0 => 3
Thread-2 => 1
Thread-1 => 1
Thread-1 => 2
Thread-1 => 3
Thread-2 => 2
Thread-2 => 3

使用 ThreadLocal 改造后的计数器满足需求,那么 ThreadLocal 内部做了什么,才使得产生了 “线程隔离” 的效果呢?下面,我来自定义一个 MyThreadLocal 试试。

public class MyThreadLocal<T> {

/**
* 定义一个线程安全的 Map 用来存放所有数据,Thread 对象作为 Map 的 Key,其实看到这里,应该就应该明白了 ThreadLocal 的原理
*/
private final Map<Thread, T> container = Collections.synchronizedMap(new HashMap<Thread, T>());

public void set(T value) {
Thread thread = Thread.currentThread();
container.put(thread, value);
}

public T get() {
Thread thread = Thread.currentThread();
T value = container.get(thread);
// 如果 map 中取值为 null,并且 map 不包含当前的key,才取初始值。可能存在的情况是,当前的key存的值就是null,所以不应该取初始值。
if (value == null && !container.containsKey(thread)) {
value = initialValue();
}
return value;
}

public void remove() {
Thread thread = Thread.currentThread();
container.remove(thread);
}

/**
* 用来重写设置初始值
* @return
*/
protected T initialValue() {
return null;
}
}

用自定义的 MyThreadLocal 改造之前的计数器,看看是否同样能实现需求。

public class Counter {

private static MyThreadLocal<Integer> counter = new MyThreadLocal<Integer>() {
@Override
protected Integer initialValue() {
return 0;
}
};

public static int getCount() {
counter.set(counter.get() + 1);
return counter.get();
}
}
public class MainTest {

public static void main(String[] args) {
// 假设一个线程调用3次计数器
Runnable task = new Runnable() {
@Override
public void run() {
for (int i = 0; i < 3; i++) {
System.out.printf("%s => %s\n", Thread.currentThread().getName(), Counter.getCount());
}
}
};

new Thread(task).start();
new Thread(task).start();
new Thread(task).start();
}
}

执行结果:

Thread-0 => 1
Thread-0 => 2
Thread-0 => 3
Thread-2 => 1
Thread-2 => 2
Thread-2 => 3
Thread-1 => 1
Thread-1 => 2
Thread-1 => 3

看结果,自定义的 MyThreadLocal 是符合预期的。自定义的 MyThreadLocal 中定义了一个 Map 用来存储所有的数据,用 Thread 对象来作为 Map 的 key,这样就保证当前线程的操作都是当前线程的值。 实际 ThreadLocal 也只有这几个方法,使用非常简单。在需要实现 “线程隔离” 这种效果的时候,就可以使用 ThreadLocal。

注意:在实际开发中使用 ThreadLocal ,用完要及时 remove() 掉数据,否则会出现意想不到的问题。

End.