技术开发 频道

JDK 7 中的 Fork/Join 模式

   【IT168 技术文章】

         介绍

  随着多核芯片逐渐成为主流,大多数软件开发人员不可避免地需要了解并行编程的知识。而同时,主流程序语言正在将越来越多的并行特性合并到标准库或者语言本身之中。我们可以看到,JDK 在这方面同样走在潮流的前方。在 JDK 标准版 5 中,由 Doug Lea 提供的并行框架成为了标准库的一部分(JSR-166)。随后,在 JDK 6 中,一些新的并行特性,例如并行 collection 框架,合并到了标准库中(JSR-166x)。直到今天,尽管 Java SE 7 还没有正式发布,一些并行相关的新特性已经出现在 JSR-166y 中:

  Fork/Join 模式;

  TransferQueue,它继承自 BlockingQueue 并能在队列满时阻塞“生产者”;

  ArrayTasks/ListTasks,用于并行执行某些数组/列表相关任务的类;

  IntTasks/LongTasks/DoubleTasks,用于并行处理数字类型数组的工具类,提供了排序、查找、求和、求最小值、求最大值等功能;

  其中,对 Fork/Join 模式的支持可能是对开发并行软件来说最通用的新特性。在 JSR-166y 中,Doug Lea 实现 ArrayTasks/ListTasks/IntTasks/LongTasks/DoubleTasks 时就大量的用到了 Fork/Join 模式。读者还需要注意一点,因为 JDK 7 还没有正式发布,因此本文涉及到的功能和发布版本有可能不一样。

  Fork/Join 模式有自己的适用范围。如果一个应用能被分解成多个子任务,并且组合多个子任务的结果就能够获得最终的答案,那么这个应用就适合用 Fork/Join 模式来解决。图 1 给出了一个 Fork/Join 模式的示意图,位于图上部的 Task 依赖于位于其下的 Task 的执行,只有当所有的子任务都完成之后,调用者才能获得 Task 0 的返回结果。

  图 1. Fork/Join 模式示意图

  可以说,Fork/Join 模式能够解决很多种类的并行问题。通过使用 Doug Lea 提供的 Fork/Join 框架,软件开发人员只需要关注任务的划分和中间结果的组合就能充分利用并行平台的优良性能。其他和并行相关的诸多难于处理的问题,例如负载平衡、同步等,都可以由框架采用统一的方式解决。这样,我们就能够轻松地获得并行的好处而避免了并行编程的困难且容易出错的缺点。

  使用 Fork/Join 模式

  在开始尝试 Fork/Join 模式之前,我们需要从 Doug Lea 主持的 Concurrency JSR-166 Interest Site 上下载 JSR-166y 的源代码,并且我们还需要安装最新版本的 JDK 6(下载网址请参阅 参考资源)。Fork/Join 模式的使用方式非常直观。首先,我们需要编写一个 ForkJoinTask 来完成子任务的分割、中间结果的合并等工作。随后,我们将这个 ForkJoinTask 交给 ForkJoinPool 来完成应用的执行。

  通常我们并不直接继承 ForkJoinTask,它包含了太多的抽象方法。针对特定的问题,我们可以选择 ForkJoinTask 的不同子类来完成任务。RecursiveAction 是 ForkJoinTask 的一个子类,它代表了一类最简单的 ForkJoinTask:不需要返回值,当子任务都执行完毕之后,不需要进行中间结果的组合。如果我们从 RecursiveAction 开始继承,那么我们只需要重载 protected void compute() 方法。下面,我们来看看怎么为快速排序算法建立一个 ForkJoinTask 的子类:

  清单 1. ForkJoinTask 的子类

1 class SortTask extends RecursiveAction {
2     final long[] array;
3     final int lo;
4     final int hi;
5     private int THRESHOLD = 30;
6
7     public SortTask(long[] array) {
8         this.array = array;
9         this.lo = 0;
10         this.hi = array.length - 1;
11     }
12
13     public SortTask(long[] array, int lo, int hi) {
14         this.array = array;
15         this.lo = lo;
16         this.hi = hi;
17     }
18
19     protected void compute() {
20         if (hi - lo < THRESHOLD)
21             sequentiallySort(array, lo, hi);
22         else {
23             int pivot = partition(array, lo, hi);
24             coInvoke(new SortTask(array, lo, pivot - 1), new SortTask(array,
25                 pivot + 1, hi));
26         }
27     }
28
29     private int partition(long[] array, int lo, int hi) {
30         long x = array[hi];
31         int i = lo - 1;
32         for (int j = lo; j < hi; j++) {
33             if (array[j] <= x) {
34                 i++;
35                 swap(array, i, j);
36             }
37         }
38         swap(array, i + 1, hi);
39         return i + 1;
40     }
41
42     private void swap(long[] array, int i, int j) {
43         if (i != j) {
44             long temp = array[i];
45             array[i] = array[j];
46             array[j] = temp;
47         }
48     }
49
50     private void sequentiallySort(long[] array, int lo, int hi) {
51         Arrays.sort(array, lo, hi + 1);
52     }
53 }
54

  在 清单 1 中,SortTask 首先通过 partition() 方法将数组分成两个部分。随后,两个子任务将被生成并分别排序数组的两个部分。当子任务足够小时,再将其分割为更小的任务反而引起性能的降低。因此,这里我们使用一个 THRESHOLD,限定在子任务规模较小时,使用直接排序,而不是再将其分割成为更小的任务。其中,我们用到了 RecursiveAction 提供的方法 coInvoke()。它表示:启动所有的任务,并在所有任务都正常结束后返回。如果其中一个任务出现异常,则其它所有的任务都取消。coInvoke() 的参数还可以是任务的数组。

  现在剩下的工作就是将 SortTask 提交到 ForkJoinPool 了。ForkJoinPool() 默认建立具有与 CPU 可使用线程数相等线程个数的线程池。我们在一个 JUnit 的 test 方法中将 SortTask 提交给一个新建的 ForkJoinPool:

  清单 2. 新建的 ForkJoinPool

1 @Test
2 public void testSort() throws Exception {
3     ForkJoinTask sort = new SortTask(array);
4     ForkJoinPool fjpool = new ForkJoinPool();
5     fjpool.submit(sort);
6     fjpool.shutdown();
7
8     fjpool.awaitTermination(30, TimeUnit.SECONDS);
9
10     assertTrue(checkSorted(array));
11 }
12

  在上面的代码中,我们用到了 ForkJoinPool 提供的如下函数:

  submit():将 ForkJoinTask 类的对象提交给 ForkJoinPool,ForkJoinPool 将立刻开始执行 ForkJoinTask。

  shutdown():执行此方法之后,ForkJoinPool 不再接受新的任务,但是已经提交的任务可以继续执行。如果希望立刻停止所有的任务,可以尝试 shutdownNow() 方法。

  awaitTermination():阻塞当前线程直到 ForkJoinPool 中所有的任务都执行结束。

  并行快速排序的完整代码如下所示:

  清单 3. 并行快速排序的完整代码

  1 package tests;
  2
  3 import static org.junit.Assert.*;
  4
  5 import java.util.Arrays;
  6 import java.util.Random;
  7 import java.util.concurrent.TimeUnit;
  8
  9 import jsr166y.forkjoin.ForkJoinPool;
10 import jsr166y.forkjoin.ForkJoinTask;
11 import jsr166y.forkjoin.RecursiveAction;
12
13 import org.junit.Before;
14 import org.junit.Test;
15
16 class SortTask extends RecursiveAction {
17     final long[] array;
18     final int lo;
19     final int hi;
20     private int THRESHOLD = 0; //For demo only
21
22     public SortTask(long[] array) {
23         this.array = array;
24         this.lo = 0;
25         this.hi = array.length - 1;
26     }
27
28     public SortTask(long[] array, int lo, int hi) {
29         this.array = array;
30         this.lo = lo;
31         this.hi = hi;
32     }
33
34     protected void compute() {
35         if (hi - lo < THRESHOLD)
36             sequentiallySort(array, lo, hi);
37         else {
38             int pivot = partition(array, lo, hi);
39             System.out.println("\npivot = " + pivot + ", low = " + lo + ", high = " + hi);
40             System.out.println("array" + Arrays.toString(array));
41             coInvoke(new SortTask(array, lo, pivot - 1), new SortTask(array,
42                     pivot + 1, hi));
43         }
44     }
45
46     private int partition(long[] array, int lo, int hi) {
47         long x = array[hi];
48         int i = lo - 1;
49         for (int j = lo; j < hi; j++) {
50             if (array[j] <= x) {
51                 i++;
52                 swap(array, i, j);
53             }
54         }
55         swap(array, i + 1, hi);
56         return i + 1;
57     }
58
59     private void swap(long[] array, int i, int j) {
60         if (i != j) {
61             long temp = array[i];
62             array[i] = array[j];
63             array[j] = temp;
64         }
65     }
66
67     private void sequentiallySort(long[] array, int lo, int hi) {
68         Arrays.sort(array, lo, hi + 1);
69     }
70 }
71
72 public class TestForkJoinSimple {
73     private static final int NARRAY = 16; //For demo only
74     long[] array = new long[NARRAY];
75     Random rand = new Random();
76
77     @Before
78     public void setUp() {
79         for (int i = 0; i < array.length; i++) {
80             array[i] = rand.nextLong()%100; //For demo only
81         }
82         System.out.println("Initial Array: " + Arrays.toString(array));
83     }
84
85     @Test
86     public void testSort() throws Exception {
87         ForkJoinTask sort = new SortTask(array);
88         ForkJoinPool fjpool = new ForkJoinPool();
89         fjpool.submit(sort);
90         fjpool.shutdown();
91
92         fjpool.awaitTermination(30, TimeUnit.SECONDS);
93
94         assertTrue(checkSorted(array));
95     }
96
97     boolean checkSorted(long[] a) {
98         for (int i = 0; i < a.length - 1; i++) {
99             if (a[i] > (a[i + 1])) {
100                 return false;
101             }
102         }
103         return true;
104     }
105 }
106

  运行以上代码,我们可以得到以下结果:

1 Initial Array: [46, -12, 74, -67, 76, -13, -91, -96]
2
3 pivot = 0, low = 0, high = 7
4 array[-96, -12, 74, -67, 76, -13, -91, 46]
5
6 pivot = 5, low = 1, high = 7
7 array[-96, -12, -67, -13, -91, 46, 76, 74]
8
9 pivot = 1, low = 1, high = 4
10 array[-96, -91, -67, -13, -12, 46, 74, 76]
11
12 pivot = 4, low = 2, high = 4
13 array[-96, -91, -67, -13, -12, 46, 74, 76]
14
15 pivot = 3, low = 2, high = 3
16 array[-96, -91, -67, -13, -12, 46, 74, 76]
17
18 pivot = 2, low = 2, high = 2
19 array[-96, -91, -67, -13, -12, 46, 74, 76]
20
21 pivot = 6, low = 6, high = 7
22 array[-96, -91, -67, -13, -12, 46, 74, 76]
23
24 pivot = 7, low = 7, high = 7
25 array[-96, -91, -67, -13, -12, 46, 74, 76]
26
0
相关文章