Introduction
In this post, I will present a parallel sorting algorithm for sorting primitive integer arrays.
Treesort
Treesort is an algorithm which iterates over the input array and constructs a binary search tree from the array components. As soon as the input range is processed, it traverses the nodes in-order and dumps them into the input array. The algorithm I am presenting here relies on treesort. However, instead of building a balanced binary search tree, I shuffle the input range in linear time, and build an unbalanced tree, which on average should have a logarithmic height (according to a lemma in Introduction to Algorithms book). Also, I am making two optimisations:
- In each node, I cache the number of occurrences of its key.
- I maintain a hashtable mapping each key to its corresponding tree node.
The above arrangement allows me to reduce the total tree construction phase to time \$\mathcal{O}(k \log k)\$ (where \$k\$ is the number of distinct integers), assuming that shuffling was not bad.
The algorithm
The actual algorithm computes some "reasonable" amount of threads \$T\$, splits the input range into \$T\$ contiguous subsequences, conquers them, and finally merges them.
Performance
The resulting sort is not comparable to Arrays.parallelSort
on average, yet is more efficient on arrays of size around 2 000 000 elements with relatively small set of distinct integers. For example, I get the following performance figures on arrays of 2 million elements and 11 000 distinct values in the array:
[STATUS] Warming up... [STATUS] Warming up done. Seed = 168244858505017 ParallelTreesort in 76 milliseconds. Arrays.parallelSort in 240 milliseconds. Algorithms agree: true
Implementation
The code snippet follows:
ParallelTreesort.java:
package net.coderodde.util;
import java.util.Arrays;
import java.util.Random;
/**
* This class implements a parallel tree sort for primitive integer arrays. The
* algorithm splits the input range into a particular number of substrings,
* sorts each in its own thread using an unbalanced tree sort algorithm, and
* merges the resulting sorted substrings.
*
* @author Rodion "rodde" Efremov
* @version 1.6 (May 23, 2016)
*/
public class ParallelTreesort {
private static final int MINIMUM_THREAD_WORK_LOAD = 1;
private ParallelTreesort() {}
public static void sort(final int[] array) {
sort(array, 0, array.length);
}
public static void sort(final int[] array,
final int fromIndex,
final int toIndex) {
final int rangeLength = toIndex - fromIndex;
if (rangeLength < 2) {
// Trivially sorted.
return;
}
final int cores = Runtime.getRuntime().availableProcessors();
final int commitThreads =
Math.max(1, rangeLength / MINIMUM_THREAD_WORK_LOAD);
int tentativeThreads = Math.min(cores, commitThreads);
// Make sure that the number of threads is a power of two after all.
tentativeThreads = fixThreadCount(tentativeThreads);
final TreesortThread[] sorterThreads =
new TreesortThread[tentativeThreads];
int tmpFromIndex = fromIndex;
final int basicChunkSize = rangeLength / tentativeThreads;
for (int i = 0; i < sorterThreads.length - 1; ++i) {
sorterThreads[i] =
new TreesortThread(array,
tmpFromIndex,
tmpFromIndex += basicChunkSize);
sorterThreads[i].start();
}
sorterThreads[sorterThreads.length - 1] =
new TreesortThread(array, tmpFromIndex, toIndex);
// Run the last sorter thread in this thread.
sorterThreads[sorterThreads.length - 1].run();
for (int i = 0; i < sorterThreads.length - 1; ++i) {
try {
sorterThreads[i].join();
} catch (final InterruptedException ex) {
throw new IllegalStateException(
"" + sorterThreads[i].getClass().getSimpleName() +
" \"" + sorterThreads[i].getName() + "\" " +
"threw an " + ex.getClass().getSimpleName(), ex);
}
}
if (sorterThreads.length == 1) {
// Single threaded sorting; no need to merge sort results from
// multiple threads.
return;
}
final int[] aux = Arrays.copyOfRange(array, fromIndex, toIndex);
new MergerThread(aux,
array,
0,
fromIndex,
sorterThreads,
0,
sorterThreads.length).run();
}
private static int getNumberOfMergePasses(final int threads) {
return 32 - Integer.numberOfLeadingZeros(threads - 1);
}
private static int fixThreadCount(final int threads) {
int ret = 1;
while (ret < threads) {
ret <<= 1;
}
return ret;
}
private static final class TreesortThread extends Thread {
private final int[] array;
private final int fromIndex;
private final int toIndex;
private TreeNode root;
private final HashTableEntry[] hashtable;
private final int mask;
private final int rangeLength;
TreesortThread(final int[] array,
final int fromIndex,
final int toIndex) {
this.array = array;
this.fromIndex = fromIndex;
this.toIndex = toIndex;
this.rangeLength = toIndex - fromIndex;
final int tableCapacity = fixCapacity(rangeLength);
this.mask = tableCapacity - 1;
this.hashtable = new HashTableEntry[tableCapacity];
}
int getFromIndex() {
return fromIndex;
}
int getToIndex() {
return toIndex;
}
int getRunLength() {
return toIndex - fromIndex;
}
@Override
public void run() {
shuffle();
constructTree();
dump();
}
private void shuffle() {
final Random random = new Random();
final int rangeLength = toIndex - fromIndex;
final int to = fromIndex + rangeLength / 2;
for (int i = fromIndex; i < to; ++i) {
final int randomIndex = fromIndex + random.nextInt(rangeLength);
swap(i, randomIndex);
}
}
private void constructTree() {
final int initialKey = array[fromIndex];
root = new TreeNode(initialKey);
hashtable[getHashTableIndex(initialKey)] =
new HashTableEntry(initialKey, root, null);
for (int i = fromIndex + 1; i < toIndex; ++i) {
final int currentArrayComponent = array[i];
final int hashtableIndex =
getHashTableIndex(currentArrayComponent);
final HashTableEntry entry =
findEntry(currentArrayComponent,
hashtable[hashtableIndex]);
if (entry != null) {
entry.treeNode.count++;
} else {
final TreeNode newnode =
new TreeNode(currentArrayComponent);
hashtable[hashtableIndex] =
new HashTableEntry(currentArrayComponent,
newnode,
hashtable[hashtableIndex]);
insertTreeNode(newnode);
}
}
}
private void dump() {
int index = fromIndex;
TreeNode node = root.getMinimum();
while (node != null) {
final int count = node.count;
final int key = node.key;
for (int i = 0; i < count; ++i) {
array[index++] = key;
}
node = node.getSuccessor();
}
}
private void insertTreeNode(final TreeNode node) {
final int key = node.key;
TreeNode current = root;
TreeNode parentOfCurrent = null;
while (current != null) {
parentOfCurrent = current;
if (key < current.key) {
current = current.left;
} else {
// We don't check 'key > current.key' as there is no risk
// of duplicate keys in the tree.
current = current.right;
}
}
if (key < parentOfCurrent.key) {
parentOfCurrent.left = node;
} else {
parentOfCurrent.right = node;
}
node.parent = parentOfCurrent;
}
private HashTableEntry
findEntry(final int key, final HashTableEntry collisionChainHead) {
HashTableEntry currentEntry = collisionChainHead;
while (currentEntry != null && currentEntry.key != key) {
currentEntry = currentEntry.next;
}
return currentEntry;
}
private int fixCapacity(final int capacity) {
int ret = 1;
while (ret < capacity) {
ret <<= 1;
}
return ret;
}
private int getHashTableIndex(final int key) {
return key & mask;
}
private void swap(final int index1, final int index2) {
final int tmp = array[index1];
array[index1] = array[index2];
array[index2] = tmp;
}
private static final class TreeNode {
TreeNode left;
TreeNode right;
TreeNode parent;
final int key;
int count = 1;
TreeNode(final int key) {
this.key = key;
}
TreeNode getMinimum() {
TreeNode minimumNode = this;
while (minimumNode.left != null) {
minimumNode = minimumNode.left;
}
return minimumNode;
}
TreeNode getSuccessor() {
if (this.right != null) {
return this.right.getMinimum();
}
TreeNode parentNode = this.parent;
TreeNode currentNode = this;
while (parentNode != null && parentNode.right == currentNode) {
currentNode = parentNode;
parentNode = parentNode.parent;
}
return parentNode;
}
}
private static final class HashTableEntry {
int key;
TreeNode treeNode;
HashTableEntry next;
HashTableEntry(final int key,
final TreeNode treeNode,
final HashTableEntry next) {
this.key = key;
this.treeNode = treeNode;
this.next = next;
}
}
}
private static final class MergerThread extends Thread {
private final int[] source;
private final int[] target;
private final int sourceOffset;
private final int targetOffset;
private final TreesortThread[] threads;
private final int threadStartIndex;
private final int threadEndIndex;
MergerThread(final int[] source,
final int[] target,
final int sourceOffset,
final int targetOffset,
final TreesortThread[] threads,
final int threadStartIndex,
final int threadEndIndex) {
this.source = source;
this.target = target;
this.sourceOffset = sourceOffset;
this.targetOffset = targetOffset;
this.threads = threads;
this.threadStartIndex = threadStartIndex;
this.threadEndIndex = threadEndIndex;
}
@Override
public void run() {
final int threadCount = threadEndIndex - threadStartIndex;
if (threadCount == 1) {
return;
}
final MergerThread leftMergerThread =
new MergerThread(target,
source,
targetOffset,
sourceOffset,
threads,
threadStartIndex,
threadStartIndex + threadCount / 2);
final int leftMergerThreadCoverage =
leftMergerThread.getCoverageLength();
final MergerThread rightMergerThread =
new MergerThread(target,
source,
targetOffset + leftMergerThreadCoverage,
sourceOffset + leftMergerThreadCoverage,
threads,
threadStartIndex + threadCount / 2,
threadEndIndex);
rightMergerThread.start();
leftMergerThread.run();
try {
rightMergerThread.join();
} catch (final InterruptedException ex) {
throw new IllegalStateException(
"" + rightMergerThread.getClass().getSimpleName() +
" \"" + rightMergerThread.getName() + "\" " +
"threw an " + ex.getClass().getSimpleName(), ex);
}
final int leftRunLength = leftMergerThread.getCoverageLength();
final int rightRunLength = rightMergerThread.getCoverageLength();
int left = sourceOffset;
int right = sourceOffset + leftRunLength;
final int leftEnd = right;
final int rightEnd = right + rightRunLength;
int targetIndex = targetOffset;
while (left < leftEnd && right < rightEnd) {
target[targetIndex++] =
source[right] < source[left] ?
source[right++] :
source[left++];
}
System.arraycopy(source,
left,
target,
targetIndex,
leftEnd - left);
System.arraycopy(source,
right,
target,
targetIndex,
rightEnd - right);
}
private int getCoverageLength() {
return threads[threadEndIndex - 1].getToIndex() -
threads[threadStartIndex].getFromIndex();
}
}
private static final class Warmup {
static final int ITERATIONS = 100;
static final int ARRAY_LENGTH = 500_000;
static final int MINIMUM_VALUE = -100;
static final int MAXIMUM_VALUE = 100;
}
private static final class Demo {
static final int ARRAY_LENGTH = 2_000_000;
static final int MINIMUM_VALUE = -1000;
static final int MAXIMUM_VALUE = 10_000;
static final int FROM_INDEX = 10;
static final int TO_INDEX = ARRAY_LENGTH - 15;
}
public static void main(final String... args) {
final long seed = System.nanoTime();
final Random random = new Random(seed);
final int[] array1 = random.ints(Demo.ARRAY_LENGTH,
Demo.MINIMUM_VALUE,
Demo.MAXIMUM_VALUE).toArray();
final int[] array2 = array1.clone();
System.out.println("[STATUS] Warming up...");
warmup(random);
System.out.println("[STATUS] Warming up done.");
System.out.println("Seed = " + seed);
long startTime = System.nanoTime();
sort(array1, Demo.FROM_INDEX, Demo.TO_INDEX);
long endTime = System.nanoTime();
System.out.printf("ParallelTreesort in %.0f milliseconds.\n",
(endTime - startTime) / 1e6);
startTime = System.nanoTime();
Arrays.parallelSort(array2, Demo.FROM_INDEX, Demo.TO_INDEX);
endTime = System.nanoTime();
System.out.printf("Arrays.parallelSort in %.0f milliseconds.\n",
(endTime - startTime) / 1e6);
System.out.println("Algorithms agree: " + Arrays.equals(array1,
array2));
}
private static void warmup(final Random random) {
for (int i = 0; i < Warmup.ITERATIONS; ++i) {
final int[] array1 = random.ints(Warmup.ARRAY_LENGTH,
Warmup.MINIMUM_VALUE,
Warmup.MAXIMUM_VALUE).toArray();
final int[] array2 = array1.clone();
ParallelTreesort.sort(array1);
Arrays.parallelSort(array2);
}
}
}
Critique request
I would like to receive critique on naming conventions, coding style, API design, and, especially, optimization opportunities.