Sorting in Parallel in Java using Executors and Arrays.sort

Below is HeavySort, a parallel sorting algorithm in Java. The algorithm takes the following steps:

Input: An list A of length n and to be sorted with m threads.

  1. Divide A into m pieces, and sort each piece onto a new list Bi.
  2. Build a list C with first the smallest element of A, then m-1 points evenly spaced from the interior of B1, then the largest element of A.
  3. Compute D[i], the number of elements of A that are between C[i] and C[i+1] (use binary searches).
  4. For each i = 1,...,m, merge the values of the sorted lists B1,...,Bm between C[i] and C[i+1]. Copy these values back onto A starting at D[1] + · · · + D[i-1]

Here is a brief analysis of the the running time. Step 1 takes O(n/m \log(n/m)) for each of m sorts which can be accomplished in parallel. Step 2 can be accomplished in O(m) time as we can find the largest and smallest elements of A by checking the end points of each Bi. For step 3, we can compute each D[i] in parallel. For j = 1,\ldots,m, we do two binary searches of Bj for C[i] and C[i+1], taking O(m \log(n/m)) in each thread. For step 4, as should be about n/m elements of A with values between c[i] and c[i+1], and for each element we need to spend O(m) time to determine which of the Bj is next, we can complete step 4 in O(n) using m threads. Thus the total running time is O(n/m\log(n/m) + n). Note that in the final step, performance could be improved by using a binary heap to determine which Bj, but m would have to be quite large for this to be practical.

Implicit in this analysis is the assumption that the values in B1 are representative of A. While this assumption would usually be true if A began in uniformly random order, it would be violated if A began sorted or nearly sorted. When the assumption is violated, the entire merging would be done by a single thread, making step 4 take O(mn).

The algorithm is called HeavySort as it uses twice the memory required to store A. The extra memory allocation is only necessary as we try and merge in parallel, as is step 3. Since for small m, the sorting in step 1 will be dominant, it would be interesting to see if the cost of making the memory allocation and the binary searches in step 3 are actually saved by parallelizing the merging.

My objective while designing this algorithm was that it would be easy to implement the sorting with Java’s built in Arrays.sort(), and the parallelization with Java’s ExecutorService. The algorithm is implemented in a single class that can be dropped into any Java project, version 1.5 or greater. Below is the source code, JUnit tests, and an example. In practice, I had a 3-4x speedup using 6 threads.

You can download the code for HeavySort here.





package parallelSorting;

import java.util.Arrays;
import java.util.Collections;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import parallelSorting.HeavySort.ArrayFactory;

public class HeavySortMain {
	
	public static void main(String [] args){
		int  problemSize = 200000000;
		int numThreads = 6;
		Random random = new Random();
		Integer[] sortArray = new Integer[problemSize];
		for (int i=0; i< problemSize ;i++){
			sortArray[i] = random.nextInt(Integer.MAX_VALUE );
		}
		long startTime2 = System.currentTimeMillis();
		Arrays.sort(sortArray);
		System.out.println("Single Threaded Sort: time taken " + 
				(System.currentTimeMillis() - startTime2));
		Collections.shuffle(Arrays.asList(sortArray));
		final ExecutorService executor = Executors.newFixedThreadPool(numThreads);
		long startTime = System.currentTimeMillis();
		ArrayFactory<Integer> factory = new ArrayFactory<Integer>(){

			@Override
			public Integer[] buildArray(int length) {
				return new Integer[length];
			}

		};
		HeavySort.sort(sortArray,executor,numThreads,factory);
		System.out.println("Multi-Threaded sort: time taken " + 
				(System.currentTimeMillis() - startTime));
		

		for (int i=0; i<sortArray.length-1; i++){
			if(sortArray[i] > sortArray[i+1]){
				System.err.println("Error: element at " + i 
						+ " : " + sortArray[i]  );
				System.err.println("Error: element at " + 
						(i+1) + " : " + sortArray[i+1]  );
			}

		}
		executor.shutdown();
	}

}






package parallelSorting;


import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

public class HeavySort {

	private static boolean noisy = true;
	
	public static <T extends Comparable<T>> void sort(T[] data, 
			ExecutorService service, int numThreads, 
			ArrayFactory<T> arrayFactory){
		if(data.length <= 1){
			return;
		}
		List<Callable<Boolean>> tasks = new ArrayList<Callable<Boolean>>();
		int[] startingPoints = new int[numThreads];
		for(int i = 0; i < numThreads;i++){
			int lo = data.length*i/numThreads;
			startingPoints[i] = lo;
			int hi = data.length*(i+1)/numThreads;
			tasks.add(new SortSubsequence<T>(data,lo,hi));
		}
		long sortingTime = System.currentTimeMillis();
		try {
			List<Future<Boolean>> results = service.invokeAll(tasks);
			for(Future<Boolean> result: results){
				if(!result.get().booleanValue()){
					throw new RuntimeException();
				}
			}
		} catch (InterruptedException e) {
			throw new RuntimeException(e);
		} catch (ExecutionException e) {
			throw new RuntimeException(e);
		}
		if(noisy){
			System.out.println("Sorting Time " + 
					(System.currentTimeMillis() - sortingTime));
		}
		List<T> dividers = new ArrayList<T>();
		int hi = data.length/numThreads;
		dividers.add(null);
		for(int i = 1; i < numThreads; i++){
			dividers.add(data[(hi*i)/numThreads]);
		}
		dividers.add(null);
		List<Callable<T[]>> merges = new ArrayList<Callable<T[]>>();
		for(int i = 0 ; i < numThreads; i++){
			merges.add(new MergeSubsequences<T>(dividers.get(i),
					dividers.get(i+1),data,startingPoints,arrayFactory));
		}
		List<T[]> resultsCollected = new ArrayList<T[]>();
		long mergingTime = System.currentTimeMillis();
		try {
			List<Future<T[]>> results = service.invokeAll(merges);
			for(Future<T[]> result: results){
				resultsCollected.add(result.get());
			}
		} catch (InterruptedException e) {
			throw new RuntimeException(e);
		} catch (ExecutionException e) {
			throw new RuntimeException(e);
		}
		if(noisy){
			System.out.println("Merging Time " + 
					(System.currentTimeMillis() - mergingTime));
		}
		List<Callable<Boolean>> pastes = new ArrayList<Callable<Boolean>>();
		int startingPoint = 0;
		for(int i = 0 ; i < numThreads; i++){
			pastes.add(new Paste<T>(startingPoint, 
					data,resultsCollected.get(i)));
			startingPoint+= resultsCollected.get(i).length;
		}
		long pastingTime = System.currentTimeMillis();
		try {
			List<Future<Boolean>> pastesResults = service.invokeAll(pastes);
			for(Future<Boolean> result: pastesResults){
				if(!result.get().booleanValue()){
					throw new RuntimeException();
				}
			}
		} catch (InterruptedException e) {
			throw new RuntimeException(e);
		} catch (ExecutionException e) {
			throw new RuntimeException(e);
		}
		if(noisy){
			System.out.println("Pasting Time " + 
					(System.currentTimeMillis() - pastingTime));
		}

	}

	public static interface ArrayFactory<T extends Comparable<T>>{
		public T[] buildArray(int length);
	}


	private static class Paste<T extends Comparable<T>> implements Callable<Boolean>{

		private int lo;
		private T[] data;
		private T[] source;




		public Paste(int lo, T[] data, T[] source) {
			super();
			this.lo = lo;
			this.data = data;
			this.source = source;
		}




		@Override
		public Boolean call() throws Exception {
			System.arraycopy(source, 0, data, lo, source.length);
			return Boolean.valueOf(true);
		}

	}

	private static class MergeSubsequences<T extends Comparable<T>> implements Callable<T[]>{

		private T lo;
		private T hi;
		private T[] data;
		private int[] startingPoints;
		private int[] endPoints;
		private ArrayFactory<T> arrayFactory;



		public MergeSubsequences(T lo, T hi, T[] data, 
				int[] startingPoints, ArrayFactory<T> arrayFactory) {
			super();
			this.arrayFactory = arrayFactory;
			this.lo = lo;
			this.hi = hi;
			this.data = data;
			this.startingPoints = startingPoints;
			this.endPoints = new int[startingPoints.length];
			for(int i = 0; i < startingPoints.length-1; i++){
				endPoints[i] = startingPoints[i+1];
			}
			endPoints[endPoints.length-1] = data.length;

		}



		@Override
		public T[] call() throws Exception {

			int[] currentLocationBySection = Arrays.copyOf(
					startingPoints, startingPoints.length);
			int[] upperBoundsBySection = Arrays.copyOf(
					endPoints, endPoints.length);

			if(lo != null){
				for(int i = 0; i < currentLocationBySection.length; i++){
					currentLocationBySection[i] = Arrays.binarySearch(
							data, startingPoints[i], endPoints[i], lo);
					if(currentLocationBySection[i] < 0){
						currentLocationBySection[i] = 
							-currentLocationBySection[i] - 1;
					}
				}
			}
			if(hi != null){
				for(int i = 0; i < upperBoundsBySection.length; i++){
					upperBoundsBySection[i] = Arrays.binarySearch(
							data, startingPoints[i], endPoints[i], hi);
					if(upperBoundsBySection[i] < 0){
						upperBoundsBySection[i] = 
							-upperBoundsBySection[i] - 1;
					}
				}
			}
			boolean[] sectionsInBounds = 
				new boolean[currentLocationBySection.length];
			Arrays.fill(sectionsInBounds, true);
			int numSectionsInBounds = sectionsInBounds.length;
			int totalItems = 0;
			for(int i = 0; i < sectionsInBounds.length; i++){
				if(currentLocationBySection[i] >= upperBoundsBySection[i]){
					sectionsInBounds[i] = false;
					numSectionsInBounds--;
				}
				else{
					totalItems += upperBoundsBySection[i] -
					currentLocationBySection[i];
				}		
			}
			T[] ans = arrayFactory.buildArray(totalItems);
			int ansInd = 0;
			while(numSectionsInBounds > 0){
				int bestSection = -1;
				T best = null;
				for(int i = 0; i < sectionsInBounds.length; i++){
					if(sectionsInBounds[i]){
						if(best == null || 
								data[currentLocationBySection[i]].compareTo(best)
								< 0){
							bestSection = i;
							best = data[currentLocationBySection[i]];
						}
					}				
				}				
				ans[ansInd] = best;
				ansInd++;
				currentLocationBySection[bestSection]++;
				if(currentLocationBySection[bestSection] 
				                            >= upperBoundsBySection[bestSection]){
					sectionsInBounds[bestSection] = false;
					numSectionsInBounds--;
				}
			}
			return ans;
		}

	}


	private static class SortSubsequence<T extends Comparable<T>> implements Callable<Boolean>{

		private T[] data;
		private int lo;
		private int hi;



		public SortSubsequence(T[] data, int lo, int hi) {
			super();
			this.data = data;
			this.lo = lo;
			this.hi = hi;
		}



		@Override
		public Boolean call() throws Exception {
			Arrays.sort(data,lo,hi);
			return Boolean.valueOf(true);
		}

	}

}






package parallelSorting;

import static org.junit.Assert.*;

import java.util.Arrays;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import org.junit.Test;

import parallelSorting.HeavySort.ArrayFactory;

public class HeavySortTest {
	
	private static ArrayFactory<Integer> integerArrayFactory = 
		new ArrayFactory<Integer>(){

		@Override
		public Integer[] buildArray(int length) {
			return new Integer[length];
		}
    	
    };
    
    private static Integer[] zeroArray(){
    	return new Integer[0];
    }
    
    private static Integer[] oneArray(){
    	return new Integer[]{Integer.valueOf(-12)};
    }
    
    private static Integer[] twoArray(){
    	return new Integer[]{Integer.valueOf(300), Integer.valueOf(100)};
    }
    
    private static Integer[] threeArray(){
    	return new Integer[]{Integer.valueOf(-10), 
    			Integer.valueOf(-5), Integer.valueOf(-1)};
    }    
    
    private static Integer[] nineArray(){
    	return new Integer[]{14, 4 ,100,140,-4,8,30,4,-20 };
    }
    
    

	@Test
	public void test() {
		for(int i : new int[]{1,3,4,5,8,10}){
			ExecutorService exec = Executors.newFixedThreadPool(i);
			{
				Integer[] zero = zeroArray();
				HeavySort.sort(zero, exec, i, integerArrayFactory);
				Integer[] ans = zeroArray();
				Arrays.sort(ans);
				assertArrayEquals(ans,zero);
			}
			{
				Integer[] one = oneArray();
				HeavySort.sort(one, exec, i, integerArrayFactory);
				Integer[] ans = oneArray();
				Arrays.sort(ans);
				assertArrayEquals(ans,one);
			}
			{
				Integer[] two = twoArray();
				HeavySort.sort(two, exec, i, integerArrayFactory);
				Integer[] ans = twoArray();
				Arrays.sort(ans);
				assertArrayEquals(ans,two);
			}
			{
				Integer[] three = threeArray();
				HeavySort.sort(three, exec, i, integerArrayFactory);
				Integer[] ans = threeArray();
				Arrays.sort(ans);
				assertArrayEquals(ans,three);
			}
			{
				Integer[] nine = nineArray();
				HeavySort.sort(nine, exec, i, integerArrayFactory);
				Integer[] ans = nineArray();
				Arrays.sort(ans);
				assertArrayEquals(ans,nine);
			}
		}
	}

}


2 Comments

  1. cool, thanks. it’s nice that it’s just a class and not a huge package with it’s own dependencies.
    it seems like the sorting algorithm isn’t stable though. is there an obvious way to fix that?

    • Actually I think the algorithm should be stable. I believe that Arrays.sort() is stable. Then when we merge, as seen in call() from MergeSubsequences, in the event of a tie we pick the element from the B_i with the lowest index, as we iterate through the B_i from low to high and only pick the a new element if there is a strict improvement.

      I haven’t looked at this in a quite a while, so its possible I could be overlooking something. Did you have a counterexample?

Leave a Reply