目录
- 1、背景
- 2、CountDownLatch 入门
- 2.1、概念
- 2.2、案例
- 3、CountDownLatch 源码分析
- 3.1、类结构
- 3.2、`await()` 方法 —— CountDownLatch
- 3.2.1、`acquireSharedInterruptibly()` 方法 —— AQS
- 3.2.1.1、`tryAcquireShared()` 方法 —— CountDownLatch.Sync
- 3.2.1.2、`doAcquireSharedInterruptibly()` 方法 —— AQS
- 3.2.1.2.1、`setHeadAndPropagate()` 方法 —— AQS
- 3.2.1.2.1.1、`doReleaseShared()` 方法 —— AQS
- 3.3、`countDown()` 方法 —— CountDownLatch
- 3.3.1、`releaseShared()` 方法 —— AQS
- 3.3.1.1、`tryReleaseShared()` 方法 —— CountDownLatch.Sync
- 3.3.2、`doReleaseShared()` 方法 —— AQS
- 4、应用案例
- 5、总结
1、背景
先看一个常见的面试题:
如何实现让主线程等所有子线程执行完了后,主线程再继续执行?即:如何实现一个线程等其他线程执行完了后再继续执行?
这里我们可以使用 Thread#join()
方法实现。
Thread#join()
方法的实现原理:在 join()
方法内部,通常有一个循环结构,循环条件为 targetThread.isAlive()
,即:目标线程是否仍然存活。当目标线程尚未结束时,当前线程会进入循环体内部调用 wait()
方法进行等待(释放锁);当目标线程在其 run()
方法执行完毕后,其生命周期状态变为已终止(TERMINATED),并自动调用 notifyAll()
方法【JVM 底层】,会唤醒所有因为调用 wait()
而在目标线程对象上等待的线程,包括通过 join()
方法暂停的当前线程
public static void main(String[] args) throws InterruptedException {Runnable task = () -> {Random random = new Random();try {Thread.sleep(random.nextInt(10000) + 1000);} catch (InterruptedException e) {e.printStackTrace();}};Thread thread1 = new Thread(task, "线程1");Thread thread2 = new Thread(task, "线程2");Thread thread3 = new Thread(task, "线程3");thread1.start();thread2.start();thread3.start();// 启动了3个线程,然后让四个线程一直检测自己是否已经结束thread1.join();thread2.join();thread3.join();System.out.println("主线程继续执行...");
}
这种方式虽然能够解决问题,但是有些不尽人意的地方:每个线程都得调用 join()
方法。有没有更好的方法呢?
这个时候并发工具类
CountDownLatch
来了
2、CountDownLatch 入门
2.1、概念
CountDownLatch
:JDK1.5 提供的一个同步工具,基于 AQS 构建同步器【共享模式】。它可以让一个或多个线程等待,一直等到其他线程中执行完成一组操作;适用于在多线程的场景需要等待所有子线程全部执行完毕之后再做操作的场景
CountDownLatch 可以理解为并发计数器:当一个任务被拆分成多个子任务时,需要等待子任务全部完成后再操作,不然会阻塞线程(当前线程),每完成一个任务计数器会 -1,直到没有。
【注意】:一般用作多线程倒计时计数器,强制它们等待其他一组任务,计数器的减法是一个不可逆的过程。即:计数器值递减到 0 的时候,不能再复原。
接下来用
CountDownLatch
完成上述案例
2.2、案例
public static void main(String[] args) throws InterruptedException {int threadCount = 3;CountDownLatch countDownLatch = new CountDownLatch(threadCount);Runnable task = () -> {System.out.println(Thread.currentThread().getName() + " 线程开始");Random random = new Random();try {Thread.sleep(random.nextInt(10000) + 1000);} catch (InterruptedException e) {e.printStackTrace();}System.out.println( Thread.currentThread().getName() + " 线程执行完毕");countDownLatch.countDown();};for (int i = 0; i < threadCount; i++) {new Thread(task, "线程" + i).start();}countDownLatch.await();System.out.println("主线程继续执行...");
}
3、CountDownLatch 源码分析
3.1、类结构
public class CountDownLatch {private final Sync sync;public CountDownLatch(int count) {if (count < 0) throw new IllegalArgumentException("count < 0");this.sync = new Sync(count);}public void await() throws InterruptedException {sync.acquireSharedInterruptibly(1);}public void countDown() {sync.releaseShared(1);}// 内部类:使用了 state 计数private static final class Sync extends AbstractQueuedSynchronizer {Sync(int count) {setState(count);}protected boolean tryReleaseShared(int releases) {//...}protected int tryAcquireShared(int acquires) {//...}}
}
3.2、await()
方法 —— CountDownLatch
public void await() throws InterruptedException {sync.acquireSharedInterruptibly(1);
}
3.2.1、acquireSharedInterruptibly()
方法 —— AQS
public final void acquireSharedInterruptibly(int arg) throws InterruptedException {if (Thread.interrupted()) {throw new InterruptedException();}// 判断计数器 state 是否等于 0if (tryAcquireShared(arg) < 0) {// 如果 state > 0 ,则添加到同步等待队列中doAcquireSharedInterruptibly(arg);}
}
acquireSharedInterruptibly()
方法:共享模式下可中断地获取锁方法。如果计数器 state 为 0,则跳过逻辑【调用者不用阻塞,可继续往下执行】;否则,将此线程添加到同步等待队列中
3.2.1.1、tryAcquireShared()
方法 —— CountDownLatch.Sync
protected int tryAcquireShared(int acquires) {return (getState() == 0) ? 1 : -1;
}
3.2.1.2、doAcquireSharedInterruptibly()
方法 —— AQS
private void doAcquireSharedInterruptibly(int arg) throws InterruptedException {// 【共享模式】节点入同步队列final Node node = addWaiter(Node.SHARED);boolean failed = true;try {// 自旋for (;;) {// 获取 node 的前驱final Node p = node.predecessor();if (p == head) {// 再次获取 state:如果 state == 0,则 r = 1;否则 r = -1int r = tryAcquireShared(arg);if (r >= 0) {// state == 0,将 node 设置为头节点setHeadAndPropagate(node, r);p.next = null;failed = false;return;}}// 自旋两次后,阻塞线程//第一次,waitStatus 默认为 0,shouldParkAfterFailedAcquire() 方法将 waitStatus 赋值为 SIGNAL并返回 false;//第二次 for 循环,shouldParkAfterFailedAcquire() 方法返回 true,通过调用 parkAndCheckInterrupt() 将自己阻塞if (shouldParkAfterFailedAcquire(p, node) && parkAndCheckInterrupt()) {throw new InterruptedException(); }}} finally {if (failed)cancelAcquire(node);}
}
假如:现在有 3 个线程 A、B、C 调用了
await()
方法,此时 state 的值还不为 0,所以这三个线程都会加入到 AQS 队列中。并且三个线程都处于阻塞状态
如下图:
线程 A、B、C 自旋两次,通过 shouldParkAfterFailedAcquire()
方法将 waitStatus
由 0 修改为 SIGNAL
,并通过 parkAndCheckInterrupt()
方法进行阻塞起来;
它们现在都不会去调用 setHeadAndPropagate()
方法,只有等到 countdown()
方法使得 state=0 的时候才会被唤醒
3.2.1.2.1、setHeadAndPropagate()
方法 —— AQS
看完下面的
countDown()
方法再来看此方法
private void setHeadAndPropagate(Node node, int propagate) {// 旧 head 节点Node h = head;// 将当前节点设置为 head 节点setHead(node);// propagate 大于 0(一般情况下都会这样)或者 存在可唤醒的线程if (propagate > 0 || h == null || h.waitStatus < 0 || (h = head) == null || h.waitStatus < 0) {Node s = node.next;// 只有一个节点或者存在多个节点且是共享模式,则释放所有等待的线程,各自尝试抢占锁if (s == null || s.isShared()) {doReleaseShared(); }}
}
3.2.1.2.1.1、doReleaseShared()
方法 —— AQS
private void doReleaseShared() {for (;;) {Node h = head;if (h != null && h != tail) {int ws = h.waitStatus;if (ws == Node.SIGNAL) {if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0)) {continue;}// 唤醒后继节点【一个】unparkSuccessor(h);} else if (ws == 0 && !compareAndSetWaitStatus(h, 0, Node.PROPAGATE)) {continue;}}if (h == head) {break;}}
}
循环唤醒后续节点
3.3、countDown()
方法 —— CountDownLatch
public void countDown() {// 递减锁重入次数,当 state == 0 时,唤醒所有阻塞的线程sync.releaseShared(1);
}
3.3.1、releaseShared()
方法 —— AQS
public final boolean releaseShared(int arg) {if (tryReleaseShared(arg)) {doReleaseShared();return true;}return false;
}
只有当 state 减为 0 的时候,tryReleaseShared()
才返回 true,继而会调用 doReleaseShared()
方法来唤醒处于 await 状态下的线程;否则,只是简单的 state = state - 1
3.3.1.1、tryReleaseShared()
方法 —— CountDownLatch.Sync
protected boolean tryReleaseShared(int releases) {for (;;) {int c = getState();if (c == 0) {return false;}int nextc = c-1;// 共享模式:CAS 操作(存在多个线程)if (compareAndSetState(c, nextc)) {// 只有最后一个计数器减完才为 0,返回truereturn nextc == 0;}}
}
【共享模式】:存在多个线程,所以需要自旋 + CAS 操作
tryReleaseShared()
方法:自旋,对计数器进行 CAS 操作 -1,如果计数器减到 0【需要唤醒阻塞的线程】,返回 true;否则,返回 false
3.3.2、doReleaseShared()
方法 —— AQS
private void doReleaseShared() {// 自旋for (;;) {// 记录旧 head 节点Node h = head;if (h != null && h != tail) {int ws = h.waitStatus;// 前驱节点状态为 SIGNAL,后继节点需要被唤醒if (ws == Node.SIGNAL) {// 将头结点的 waitstatue 设置为0,以后就不会再次唤醒该后继节点了,这一步是为了解决并发问题,保证只 unpark()一次,不成功就继续if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0)) {// 如果 CAS 失败,则继续执行continue; }// 唤醒头节点的一个后继节点unparkSuccessor(h);// ws == 0:head 节点刚入队列,未调用 shouldParkAfterFailedAcquire() 方法【将 waitStatus 由 0 修改为 SIGNAL】// CAS 操作:将 head 节点状态设置为 PROPAGATE,表示要向下传播,依次唤醒// CAS 操作失败场景:// 1.这时,刚好有节点入队列,且已调用了 shouldParkAfterFailedAcquire() 方法,修改为了 SIGNAL 状态// 2.有其它线程尝试将其设置为 PROPAGATE 状态} else if (ws == 0 && !compareAndSetWaitStatus(h, 0, Node.PROPAGATE)) {// CAS 操作失败,继续尝试continue;}}// 判断 head 节点是否是原 head 节点// 如果是:说明了之前唤醒的线程还未唤醒 | 就没唤醒过线程【执行 else-if 逻辑】,跳出循环// 如果不是:说明了之前唤醒的线程已唤醒【线程A】,跳过当前循环,继续在 for 循环中执行第二次if (h == head) {break;}}
}
一旦线程 A 被唤醒,代码又会继续回到 doAcquireSharedInterruptibly()
中来执行。如果当前 state 满足 ==0 的条件,则会执行 setHeadAndPropagate()
方法
对于下面这块代码:在 CountDownLatch
的实现中,头节点状态为 PROPAGATE
的情况并不常见
else if (ws == 0 && !compareAndSetWaitStatus(h, 0, Node.PROPAGATE)) {// CAS 操作失败,继续尝试continue;
}
因为 CountDownLatch
的唤醒逻辑并不依赖于节点状态为 PROPAGATE
。通常情况下,当计数器归零时,CountDownLatch
会直接一次性唤醒所有等待线程,而不会特别处理节点状态为 PROPAGATE
的情况
对于
CountDownLatch
来说,其核心逻辑相对简单:当计数器递减至 0 时,意味着所有等待的线程已完成其预定任务。此时,doReleaseShared()
方法的主要任务是确保所有等待在CountDownLatch
上的线程都能被唤醒,而不是传播某种释放信号
ROPAGATE
状态主要出现在其他基于 AbstractQueuedSynchronizer(AQS)
构建的同步组件(如Semaphore
、ReentrantReadWriteLock
的读锁等)中,用于表示释放操作应当继续向下传播,唤醒更多等待的线程。在这些组件中,当某个节点释放资源后,可能需要将释放操作传播到队列中的其他节点,此时会将节点状态设置为 PROPAGATE
,以便后续逻辑处理
4、应用案例
等顾客们来齐了,服务员再来上菜,吃饭,人不齐不能动筷子,大家都坐那等着
public static void main(String[] args) throws InterruptedException {// 5 个顾客final int customerCount = 5;// 7 道菜,需要 7 个服务员final int waitressCount = 7;CountDownLatch customerCountDownLatch = new CountDownLatch(customerCount);CountDownLatch countDownLatch = new CountDownLatch(waitressCount);Runnable customerTask = () -> {try {SimpleDateFormat sdf = new SimpleDateFormat("HH:mm:ss");Random random = new Random();System.out.println(sdf.format(new Date()) + " " + Thread.currentThread().getName() + "出发去饭店");Thread.sleep((long) (random.nextDouble() * 3000) + 1000);System.out.println(sdf.format(new Date()) + " " + Thread.currentThread().getName() + "到了饭店");customerCountDownLatch.countDown();} catch (InterruptedException e) {throw new RuntimeException(e);}};Runnable waitressTask = () -> {try {SimpleDateFormat sdf = new SimpleDateFormat("HH:mm:ss");System.out.println(sdf.format(new Date()) + " " + Thread.currentThread().getName() + "等待顾客");customerCountDownLatch.await();System.out.println(sdf.format(new Date()) + " " + Thread.currentThread().getName() + "人齐了,开始上菜");Random random = new Random();Thread.sleep((long) (random.nextDouble() * 3000) + 1000);countDownLatch.countDown();System.out.println(Thread.currentThread().getName() + " 完成上菜,还差 " + countDownLatch.getCount() + " 个菜没上");} catch (Exception e) {e.printStackTrace();}};for (int i = 0; i < customerCount; i++) {new Thread(customerTask, "customer" + i).start();}for (int i = 0; i < waitressCount; i++) {new Thread(waitressTask, "waitress" + i).start();}countDownLatch.await();System.out.println("菜都上完了,可以吃了");
}
运行结果:
14:19:28 customer4出发去饭店
14:19:28 waitress3等待顾客
14:19:28 customer3出发去饭店
14:19:28 waitress6等待顾客
14:19:28 customer1出发去饭店
14:19:28 waitress0等待顾客
14:19:28 waitress2等待顾客
14:19:28 waitress5等待顾客
14:19:28 waitress1等待顾客
14:19:28 customer2出发去饭店
14:19:28 customer0出发去饭店
14:19:28 waitress4等待顾客
14:19:29 customer3到了饭店
14:19:29 customer4到了饭店
14:19:29 customer2到了饭店
14:19:31 customer0到了饭店
14:19:31 customer1到了饭店
14:19:31 waitress3人齐了,开始上菜
14:19:31 waitress2人齐了,开始上菜
14:19:31 waitress5人齐了,开始上菜
14:19:31 waitress0人齐了,开始上菜
14:19:31 waitress6人齐了,开始上菜
14:19:31 waitress1人齐了,开始上菜
14:19:31 waitress4人齐了,开始上菜
waitress3 完成上菜,还差 6 个菜没上
waitress4 完成上菜,还差 5 个菜没上
waitress5 完成上菜,还差 4 个菜没上
waitress1 完成上菜,还差 3 个菜没上
waitress0 完成上菜,还差 2 个菜没上
waitress2 完成上菜,还差 1 个菜没上
waitress6 完成上菜,还差 0 个菜没上
菜都上完了,可以吃了
5、总结
- 通过构造方法初始化
CountDownLatch
:设置 AQS 中的 state 的值 - 调用
countDown()
方法:调用 AQS 的释放同步状态的方法,每调用一次,state 就自减 1,直至为 0 - 调用
await()
方法:如果 state 不为 0,则阻塞线程并入队列。当 state 为 0 后,唤醒其它所有阻塞的线程