package com.example.forkjoin;

import java.util.Arrays;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;

public class ParallelMergeSort extends RecursiveTask<int[]> {

	private static final int THRESHOLD = 10;
	private final int[] array;
	private final int start;
	private final int end;

	public ParallelMergeSort(int[] array, int start, int end) {
		this.array = array;
		this.start = start;
		this.end = end;
	}

	@Override
	protected int[] compute() {
		if (end - start <= THRESHOLD) {
			System.out.println(Thread.currentThread().getName() + " is running compute1...");
			int[] result = Arrays.copyOfRange(array, start, end);
			Arrays.sort(result);
			return result;
		} else {
			System.out.println(Thread.currentThread().getName() + " is running compute2...");
			int mid = (start + end) / 2;
			ParallelMergeSort leftTask = new ParallelMergeSort(array, start, mid);
			ParallelMergeSort rightTask = new ParallelMergeSort(array, mid, end);
			invokeAll(leftTask, rightTask);
			return merge(leftTask.join(), rightTask.join());
		}
	}

	private int[] merge(int[] left, int[] right) {
		System.out.println(Thread.currentThread().getName() + " is running merge...");
		int[] result = new int[left.length + right.length];
		int i = 0, j = 0, k = 0;
		while (i < left.length && j < right.length) {
			result[k++] = (left[i] < right[j]) ? left[i++] : right[j++];
		}
		while (i < left.length) {
			result[k++] = left[i++];
		}
		while (j < right.length) {
			result[k++] = right[j++];
		}
		return result;
	}

	public static void main(String[] args) {
		ForkJoinPool pool = new ForkJoinPool();
		int[] array = { 5, 3, 8, 6, 2, 7, 4, 1, 5, 3, 8, 6, 2, 7, 4, 1, 5, 3, 8, 6, 2, 7, 4, 1, 5, 3, 8, 6, 2, 7, 4,
				1 };
		ParallelMergeSort task = new ParallelMergeSort(array, 0, array.length);
		int[] sortedArray = pool.invoke(task);
		System.out.println(Arrays.toString(sortedArray));
	}
}
