juc · 2022-05-31 0

ForkJoin 使用

一、概述

fork/join 框架是 Java 7 中引入的 ,它是一个工具,通过“分而治之”的方法尝试将所有可用的处理器内核使用起来帮助加速并行处理

fork/join 框架使用了一个名为 ForkJoinPool 的线程池,用于管理 ForkJoinWorkerThread 类型的工作线程

ForkJoinPool 线程池并不会为每个子任务创建一个单独的线程,相反,池中的每个线程都有自己的双端队列用于存储任务 ( double-ended queue )( 或 deque,发音 deck )

这种架构使用了一种名为工作窃取( work-stealing )算法来平衡线程的工作负载

工作窃取( work-stealing )算法:简单来说,就是 空闲的线程试图从繁忙线程的 deques 中 窃取 工作

ForkJoinTask 是 ForkJoinPool 线程之中执行的任务的基本类型。一般不直接使用 ForkJoinTask ,而是扩展它的两个子类中的任意一个。RecursiveActionRecursiveTask <V>

二、使用

ForkJoinTask代表运行在ForkJoinPool中的任务

ForkJoinTask 有方法:

  • fork() 在当前线程运行的线程池中安排一个异步执行。简单的理解就是再创建一个子任务
  • join() 当任务完成的时候返回计算结果
  • invoke() 开始执行任务,如果必要,等待计算完成

1. RecursiveAction

public class NumberRecursiveAction extends RecursiveAction {
    /**
     * 每个"小任务"最多只打印5个数
     */
    private static final int MAX = 5;

    private int start;
    private int end;

    public NumberRecursiveAction(int start, int end) {
        this.start = start;
        this.end = end;
    }

    @Override
    protected void compute() {
        // 当end-start的值小于MAX时,开始打印
        if ((end - start) < MAX) {
            for (int i = start; i < end; i++) {
                System.out.println(String.format("threadName: %s i: %s", Thread.currentThread().getName(), i));
            }
        } else {
            // 将大任务分解成两个小任务
            int middle = (start + end) / 2;
            NumberRecursiveAction left = new NumberRecursiveAction(start, middle);
            NumberRecursiveAction right = new NumberRecursiveAction(middle, end);
            left.fork();
            right.fork();
        }
    }
}
@Test
public void testRecursiveAction() throws Exception {
    // 创建包含Runtime.getRuntime().availableProcessors()返回值作为个数的并行线程的ForkJoinPool
    ForkJoinPool forkJoinPool = new ForkJoinPool();

    // 提交可分解的PrintTask任务
    forkJoinPool.submit(new NumberRecursiveAction(0, 20));

    // 阻塞当前线程直到 ForkJoinPool 中所有的任务都执行结束
    forkJoinPool.awaitTermination(2, TimeUnit.SECONDS);

    // 关闭线程池
    forkJoinPool.shutdown();
}

结果:

threadName: ForkJoinPool-1-worker-7 i: 0
threadName: ForkJoinPool-1-worker-7 i: 1
threadName: ForkJoinPool-1-worker-13 i: 2
threadName: ForkJoinPool-1-worker-13 i: 3
threadName: ForkJoinPool-1-worker-13 i: 4
threadName: ForkJoinPool-1-worker-3 i: 17
threadName: ForkJoinPool-1-worker-11 i: 5
threadName: ForkJoinPool-1-worker-13 i: 12
threadName: ForkJoinPool-1-worker-5 i: 7
threadName: ForkJoinPool-1-worker-11 i: 6
threadName: ForkJoinPool-1-worker-13 i: 13
threadName: ForkJoinPool-1-worker-5 i: 8
threadName: ForkJoinPool-1-worker-13 i: 14
threadName: ForkJoinPool-1-worker-7 i: 10
threadName: ForkJoinPool-1-worker-3 i: 18
threadName: ForkJoinPool-1-worker-7 i: 11
threadName: ForkJoinPool-1-worker-3 i: 19
threadName: ForkJoinPool-1-worker-5 i: 9
threadName: ForkJoinPool-1-worker-15 i: 15
threadName: ForkJoinPool-1-worker-15 i: 16

2.RecursiveTask

public class NumberRecursiveTask extends RecursiveTask<Integer> {
    /**
     *  每个"小任务"最多计算5个数
     */
    private static final int MAX = 5;

    private int arr[];
    private int start;
    private int end;

    public NumberRecursiveTask(int[] arr, int start, int end) {
        this.arr = arr;
        this.start = start;
        this.end = end;
    }

    @Override
    protected Integer compute() {
        int sum = 0;
        // 当end-start的值小于MAX时候,开始打印
        if((end - start) < MAX) {
            for (int i = start; i < end; i++) {
                sum += arr[i];
            }
            return sum;
        }else {
            // 将大任务分解成两个小任务
            int middle = (start + end) / 2;
            NumberRecursiveTask left = new NumberRecursiveTask(arr, start, middle);
            NumberRecursiveTask right = new NumberRecursiveTask(arr, middle, end);
            // 并行执行两个小任务
            left.fork();
            right.fork();
            // 把两个小任务累加的结果合并起来
            return left.join() + right.join();
        }
    }
}
@Test
public void testRecursiveTask() throws Exception {
    int expectedSum = 0;

    int[] arr = new int[100];
    for (int i = 0; i < arr.length; i++) {
        arr[i] = i;
        expectedSum += i;
    }
    System.out.println("expected sum: " + expectedSum);

    // 创建包含Runtime.getRuntime().availableProcessors()返回值作为个数的并行线程的ForkJoinPool
    ForkJoinPool forkJoinPool = new ForkJoinPool();
    NumberRecursiveTask numberRecursiveTask = new NumberRecursiveTask(arr, 0, arr.length);

    // 1.execute join
    // forkJoinPool.execute(numberRecursiveTask);
    // System.out.println("actual sum: " + numberRecursiveTask.join());

    // 2.submit get
    // ForkJoinTask<Integer> future = forkJoinPool.submit(numberRecursiveTask);
    // System.out.println("actual sum: " + future.get());

    // 3.invoke
    Integer actualSum = forkJoinPool.invoke(numberRecursiveTask);
    System.out.println("actual sum: " + actualSum);

    // 关闭线程池
    forkJoinPool.shutdown();
}

结果:

expected sum: 4950
actual sum: 4950

三、原理

对于 new ForkJoinPool(),parallelism 为 Runtime.getRuntime().availableProcessors();
对于 ForkJoinPool.commonPool(),parallelism 为 Runtime.getRuntime().availableProcessors() - 1。