- Fork/Join:线程池的实现,体现是分治思想,适用于能够进行任务拆分的 CPU 密集型运算,用于并行计算。
任务拆分:将一个大任务拆分为算法上相同的小任务,直至不能拆分可以直接求解。跟递归相关的一些计算,如归并排序、斐波那契数列都可以用分治思想进行求解。
Fork/Join 在分治的基础上加入了多线程,把每个任务的分解和合并交给不同的线程来完成,提升了运算效率。
ForkJoin 使用
ForkJoinPool
来启动,ForkJoinPool
是一个特殊的线程池,默认会创建与 CPU 核心数大小相同的线程池。如果要得到任务的返回值,那么任务就继承
RecursiveTask
,否则可以选择继承RecursiveAction
。
ForkJoin线程池
Fork/Join:线程池的实现,体现是分治思想,适用于能够进行任务拆分的 CPU 密集型运算,用于并行计算。
任务拆分:将一个大任务拆分为算法上相同的小任务,直至不能拆分可以直接求解。跟递归相关的一些计算,如归并排序、斐波那契数列都可以用分治思想进行求解。
Fork/Join 在分治的基础上加入了多线程,把每个任务的分解和合并交给不同的线程来完成,提升了运算效率
ForkJoin 使用
ForkJoinPool
来启动,ForkJoinPool
是一个特殊的线程池,默认会创建与 CPU 核心数大小相同的线程池。如果要得到任务的返回值,那么任务就继承
RecursiveTask
,否则可以选择继承RecursiveAction
。
考虑这么一个场景:要对 1000 万个数字进行求和,该怎么尽快地计算出结果?
- 对于这种 CPU 密集型的任务,就可以考虑使用 Fork/Join 线程池进行并行计算了。难点在于如何对任务进行拆分。
任务拆分思路:
任务拆分的核心在于将大任务递归地分割成若干小任务,直到每个小任务足够简单,然后汇总计算结果。对于 1000 万个数字求和,以下是具体步骤:
确定分割的阈值:通常我们设置一个任务分割的阈值(例如每个任务处理 1,000 个数字或更少),当任务量达到这个阈值时,直接进行处理(例如通过 for 循环求和),而不再进一步分割。
任务拆分:递归地将任务二分,每次将数组从中间分成两半,分别交给两个子任务去处理,直到数组的长度小于或等于阈值。
使用 Fork/Join 代码进行实现:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57import java.util.concurrent.RecursiveTask;
import java.util.concurrent.ForkJoinPool;
public class ForkJoinSumTask extends RecursiveTask<Long> {
private static final int THRESHOLD = 1000; // 阈值
private int[] arr;
private int start;
private int end;
public ForkJoinSumTask(int[] arr, int start, int end) {
this.arr = arr;
this.start = start;
this.end = end;
}
protected Long compute() {
// 如果任务小于阈值,直接计算
if (end - start <= THRESHOLD) {
long sum = 0;
for (int i = start; i <= end; i++) {
sum += arr[i];
}
return sum;
} else {
// 任务大于阈值,拆分成两个子任务
int middle = (start + end) / 2;
ForkJoinSum leftTask = new ForkJoinSum(arr, start, middle);
ForkJoinSum rightTask = new ForkJoinSum(arr, middle + 1, end);
// 提交子任务
leftTask.fork();
rightTask.fork();
// 合并结果
long leftResult = leftTask.join();
long rightResult = rightTask.join();
return leftResult + rightResult;
}
}
public static void main(String[] args) {
// 初始化数据
int[] arr = new int[10_000_000];
for (int i = 0; i < arr.length; i++) {
arr[i] = i + 1; // 示例数据
}
// 创建ForkJoinPool线程池
ForkJoinPool pool = new ForkJoinPool();
// 创建初始任务
ForkJoinSum task = new ForkJoinSum(arr, 0, arr.length - 1);
// 执行任务并获取结果
long result = pool.invoke(task);
System.out.println("Sum: " + result);
}
}