package com.example.forkjoin;

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;

public class SumTask extends RecursiveTask<Integer> {

	private static final int THRESHOLD = 10;
	private final int[] array;

	private final int start;
	private final int end;

	public SumTask(int[] array, int start, int end) {
		this.array = array;
		this.start = start;
		this.end = end;
	}

	@Override
	protected Integer compute() {
		if (end - start <= THRESHOLD) {
			int sum = 0;
			for (int i = start; i < end; i++) {
				System.out.println(Thread.currentThread().getName() + " is computing...");
				sum += array[i];
			}
			return sum;
		} else {
			int mid = (start + end) / 2;
			SumTask leftTask = new SumTask(array, start, mid);
			SumTask rightTask = new SumTask(array, mid, end);
			leftTask.fork(); // divide/map
			int rightResult = rightTask.compute(); // execute the task
			int leftResult = leftTask.join(); // conquer/reduce
			return leftResult + rightResult;
		}
	}

	public static void main(String[] args) {
		ForkJoinPool pool = new ForkJoinPool();
		int[] array = new int[1000];
		for (int i = 0; i < array.length; i++) {
			array[i] = i + 1;
		}
		SumTask task = new SumTask(array, 0, array.length);
		int sum = pool.invoke(task);
		System.out.println("Sum: " + sum);
	}
}
