使用 RecursiveTask
除了 RecursiveAction,Fork/Join 框架还提供了其他 ForkJoinTask 子类:带有返回值的 RecursiveTask,使用 finish() 方法显式中止的 AsyncAction 和 LinkedAsyncAction,以及可使用 TaskBarrier 为每个任务设置不同中止条件的 CyclicAction。
从 RecursiveTask 继承的子类同样需要重载 protected void compute() 方法。与 RecursiveAction 稍有不同的是,它可使用泛型指定一个返回值的类型。下面,我们来看看如何使用 RecursiveTask 的子类。
清单 4. RecursiveTask 的子类
2 final int n;
3
4 Fibonacci(int n) {
5 this.n = n;
6 }
7
8 private int compute(int small) {
9 final int[] results = { 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89 };
10 return results[small];
11 }
12
13 public Integer compute() {
14 if (n <= 10) {
15 return compute(n);
16 }
17 Fibonacci f1 = new Fibonacci(n - 1);
18 Fibonacci f2 = new Fibonacci(n - 2);
19 f1.fork();
20 f2.fork();
21 return f1.join() + f2.join();
22 }
23 }
24
在 清单 4 中, Fibonacci 的返回值为 Integer 类型。其 compute() 函数首先建立两个子任务,启动子任务执行,阻塞以等待子任务的结果返回,相加后得到最终结果。同样,当子任务足够小时,通过查表得到其结果,以减小因过多地分割任务引起的性能降低。其中,我们用到了 RecursiveTask 提供的方法 fork() 和 join()。它们分别表示:子任务的异步执行和阻塞等待结果完成。
现在剩下的工作就是将 Fibonacci 提交到 ForkJoinPool 了,我们在一个 JUnit 的 test 方法中作了如下处理:
清单 5. 将 Fibonacci 提交到 ForkJoinPool
2 public void testFibonacci() throws InterruptedException, ExecutionException {
3 ForkJoinTask<Integer> fjt = new Fibonacci(45);
4 ForkJoinPool fjpool = new ForkJoinPool();
5 Future<Integer> result = fjpool.submit(fjt);
6
7 // do something
8 System.out.println(result.get());
9 }
10
使用 CyclicAction 来处理循环任务
CyclicAction 的用法稍微复杂一些。如果一个复杂任务需要几个线程协作完成,并且线程之间需要在某个点等待所有其他线程到达,那么我们就能方便的用 CyclicAction 和 TaskBarrier 来完成。图 2 描述了使用 CyclicAction 和 TaskBarrier 的一个典型场景。
图 2. 使用 CyclicAction 和 TaskBarrier 执行多线程任务

继承自 CyclicAction 的子类需要 TaskBarrier 为每个任务设置不同的中止条件。从 CyclicAction 继承的子类需要重载 protected void compute() 方法,定义在 barrier 的每个步骤需要执行的动作。compute() 方法将被反复执行直到 barrier 的 isTerminated() 方法返回 True。TaskBarrier 的行为类似于 CyclicBarrier。下面,我们来看看如何使用 CyclicAction 的子类。
清单 6. 使用 CyclicAction 的子类
2 protected void compute() {
3 TaskBarrier b = new TaskBarrier() {
4 protected boolean terminate(int cycle, int registeredParties) {
5 System.out.println("Cycle is " + cycle + ";"
6 + registeredParties + " parties");
7 return cycle >= 10;
8 }
9 };
10 int n = 3;
11 CyclicAction[] actions = new CyclicAction[n];
12 for (int i = 0; i < n; ++i) {
13 final int index = i;
14 actions[i] = new CyclicAction(b) {
15 protected void compute() {
16 System.out.println("I'm working " + getCycle() + " "
17 + index);
18 try {
19 Thread.sleep(500);
20 } catch (InterruptedException e) {
21 e.printStackTrace();
22 }
23 }
24 };
25 }
26 for (int i = 0; i < n; ++i)
27 actions[i].fork();
28 for (int i = 0; i < n; ++i)
29 actions[i].join();
30 }
31 }
32
在 清单 6 中,CyclicAction[] 数组建立了三个任务,打印各自的工作次数和序号。而在 b.terminate() 方法中,我们设置的中止条件表示重复 10 次计算后中止。现在剩下的工作就是将 ConcurrentPrint 提交到 ForkJoinPool 了。我们可以在 ForkJoinPool 的构造函数中指定需要的线程数目,例如 ForkJoinPool(4) 就表明线程池包含 4 个线程。我们在一个 JUnit 的 test 方法中运行 ConcurrentPrint 的这个循环任务:
清单 7. 运行 ConcurrentPrint 循环任务
2 public void testBarrier () throws InterruptedException, ExecutionException {
3 ForkJoinTask fjt = new ConcurrentPrint();
4 ForkJoinPool fjpool = new ForkJoinPool(4);
5 fjpool.submit(fjt);
6 fjpool.shutdown();
7 }
8
RecursiveTask 和 CyclicAction 两个例子的完整代码如下所示:
清单 8. RecursiveTask 和 CyclicAction 两个例子的完整代码
2
3 import java.util.concurrent.ExecutionException;
4 import java.util.concurrent.Future;
5
6 import jsr166y.forkjoin.CyclicAction;
7 import jsr166y.forkjoin.ForkJoinPool;
8 import jsr166y.forkjoin.ForkJoinTask;
9 import jsr166y.forkjoin.RecursiveAction;
10 import jsr166y.forkjoin.RecursiveTask;
11 import jsr166y.forkjoin.TaskBarrier;
12
13 import org.junit.Test;
14
15 class Fibonacci extends RecursiveTask<Integer> {
16 final int n;
17
18 Fibonacci(int n) {
19 this.n = n;
20 }
21
22 private int compute(int small) {
23 final int[] results = { 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89 };
24 return results[small];
25 }
26
27 public Integer compute() {
28 if (n <= 10) {
29 return compute(n);
30 }
31 Fibonacci f1 = new Fibonacci(n - 1);
32 Fibonacci f2 = new Fibonacci(n - 2);
33 System.out.println("fork new thread for " + (n - 1));
34 f1.fork();
35 System.out.println("fork new thread for " + (n - 2));
36 f2.fork();
37 return f1.join() + f2.join();
38 }
39 }
40
41 class ConcurrentPrint extends RecursiveAction {
42 protected void compute() {
43 TaskBarrier b = new TaskBarrier() {
44 protected boolean terminate(int cycle, int registeredParties) {
45 System.out.println("Cycle is " + cycle + ";"
46 + registeredParties + " parties");
47 return cycle >= 10;
48 }
49 };
50 int n = 3;
51 CyclicAction[] actions = new CyclicAction[n];
52 for (int i = 0; i < n; ++i) {
53 final int index = i;
54 actions[i] = new CyclicAction(b) {
55 protected void compute() {
56 System.out.println("I'm working " + getCycle() + " "
57 + index);
58 try {
59 Thread.sleep(500);
60 } catch (InterruptedException e) {
61 e.printStackTrace();
62 }
63 }
64 };
65 }
66 for (int i = 0; i < n; ++i)
67 actions[i].fork();
68 for (int i = 0; i < n; ++i)
69 actions[i].join();
70 }
71 }
72
73 public class TestForkJoin {
74 @Test
75 public void testBarrier () throws InterruptedException, ExecutionException {
76 System.out.println("\ntesting Task Barrier ...");
77 ForkJoinTask fjt = new ConcurrentPrint();
78 ForkJoinPool fjpool = new ForkJoinPool(4);
79 fjpool.submit(fjt);
80 fjpool.shutdown();
81 }
82
83 @Test
84 public void testFibonacci () throws InterruptedException, ExecutionException {
85 System.out.println("\ntesting Fibonacci ...");
86 final int num = 14; //For demo only
87 ForkJoinTask<Integer> fjt = new Fibonacci(num);
88 ForkJoinPool fjpool = new ForkJoinPool();
89 Future<Integer> result = fjpool.submit(fjt);
90
91 // do something
92 System.out.println("Fibonacci(" + num + ") = " + result.get());
93 }
94 }
95
运行以上代码,我们可以得到以下结果:
2 I'm working 0 2
3 I'm working 0 0
4 I'm working 0 1
5 Cycle is 0; 3 parties
6 I'm working 1 2
7 I'm working 1 0
8 I'm working 1 1
9 Cycle is 1; 3 parties
10 I'm working 2 0
11 I'm working 2 1
12 I'm working 2 2
13 Cycle is 2; 3 parties
14 I'm working 3 0
15 I'm working 3 2
16 I'm working 3 1
17 Cycle is 3; 3 parties
18 I'm working 4 2
19 I'm working 4 0
20 I'm working 4 1
21 Cycle is 4; 3 parties
22 I'm working 5 1
23 I'm working 5 0
24 I'm working 5 2
25 Cycle is 5; 3 parties
26 I'm working 6 0
27 I'm working 6 2
28 I'm working 6 1
29 Cycle is 6; 3 parties
30 I'm working 7 2
31 I'm working 7 0
32 I'm working 7 1
33 Cycle is 7; 3 parties
34 I'm working 8 1
35 I'm working 8 0
36 I'm working 8 2
37 Cycle is 8; 3 parties
38 I'm working 9 0
39 I'm working 9 2
40
41 testing Fibonacci ...
42 fork new thread for 13
43 fork new thread for 12
44 fork new thread for 11
45 fork new thread for 10
46 fork new thread for 12
47 fork new thread for 11
48 fork new thread for 10
49 fork new thread for 9
50 fork new thread for 10
51 fork new thread for 9
52 fork new thread for 11
53 fork new thread for 10
54 fork new thread for 10
55 fork new thread for 9
56 Fibonacci(14) = 610
57
结论
从以上的例子中可以看到,通过使用 Fork/Join 模式,软件开发人员能够方便地利用多核平台的计算能力。尽管还没有做到对软件开发人员完全透明,Fork/Join 模式已经极大地简化了编写并发程序的琐碎工作。对于符合 Fork/Join 模式的应用,软件开发人员不再需要处理各种并行相关事务,例如同步、通信等,以难以调试而闻名的死锁和 data race 等错误也就不会出现,提升了思考问题的层次。你可以把 Fork/Join 模式看作并行版本的 Divide and Conquer 策略,仅仅关注如何划分任务和组合中间结果,将剩下的事情丢给 Fork/Join 框架。
在实际工作中利用 Fork/Join 模式,可以充分享受多核平台为应用带来的免费午餐。