I have this small generic path search library. It's not perfect at all, so I need some comments as to have a chance to improve it.
com.stackexchange.codereview.graph.model:
AbstractHeuristicFunction.java:
package com.stackexchange.codereview.graph.model;
public interface AbstractHeuristicFunction<T extends AbstractNode<T>> {
public void setTarget(final T target);
public void setLayout(final PlaneLayout layout);
public double h(final T node);
}
AbstractNode.java:
package com.stackexchange.codereview.graph.model;
public abstract class AbstractNode<T extends AbstractNode<T>>
implements Iterable<T> {
protected final String id;
protected AbstractNode(final String id) {
if (id == null) {
throw new IllegalArgumentException("The ID string is null.");
}
this.id = id;
}
public abstract boolean connectTo(final T node);
public abstract boolean disconnectFrom(final T node);
public abstract boolean isConnectedTo(final T node);
public abstract Iterable<T> parents();
@Override
public int hashCode() {
return id.hashCode();
}
@Override
public boolean equals(final Object o) {
if (!(o instanceof AbstractNode)) {
return false;
}
return (((AbstractNode<T>) o).id.equals(this.id));
}
}
AbstractWeightFunction.java:
package com.stackexchange.codereview.graph.model;
import java.util.HashMap;
import java.util.Map;
public abstract class AbstractWeightFunction<T extends AbstractNode<T>> {
protected final Map<T, Map<T, Double>> map;
protected AbstractWeightFunction() {
this.map = new HashMap<>();
}
public abstract Double put(final T node1,
final T node2,
final double weight);
public abstract double get(final T node1, final T node2);
}
AbstractPathFinder.java:
package com.stackexchange.codereview.graph.model;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
public abstract class AbstractPathFinder<T extends AbstractNode<T>> {
public abstract List<T> search(final T source, final T target);
protected List<T> constructPath(final T middleNode,
final Map<T, T> parentMapA,
final Map<T, T> parentMapB) {
final List<T> path = new ArrayList<>();
T current = middleNode;
while (current != null) {
path.add(current);
current = parentMapA.get(current);
}
Collections.<T>reverse(path);
if (parentMapB != null) {
current = parentMapB.get(middleNode);
while (current != null) {
path.add(current);
current = parentMapB.get(current);
}
}
return path;
}
protected List<T> constructPath(final T target, final Map<T, T> parentMap) {
return constructPath(target, parentMap, null);
}
}
PlaneLayout.java:
package com.stackexchange.codereview.graph.model;
import java.awt.geom.Point2D;
import java.util.HashMap;
import java.util.Map;
public class PlaneLayout<T extends AbstractNode<T>> {
private final Map<T, Point2D.Double> map;
public PlaneLayout() {
this.map = new HashMap<>();
}
public Point2D.Double put(final T node, final Point2D.Double location) {
return map.put(node, location);
}
public Point2D.Double get(final T node) {
return map.get(node);
}
}
com.stackexchange.codereview.graph.model.support:
AStarPathFinder.java:
package com.stackexchange.codereview.graph.model.support;
import static com.stackexchange.codereview.graph.Utils.checkNotNull;
import com.stackexchange.codereview.graph.model.AbstractHeuristicFunction;
import com.stackexchange.codereview.graph.model.AbstractNode;
import com.stackexchange.codereview.graph.model.AbstractWeightFunction;
import com.stackexchange.codereview.graph.model.AbstractPathFinder;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
public class AStarPathFinder<T extends AbstractNode<T>>
extends AbstractPathFinder<T> {
private AbstractHeuristicFunction<T> heuristicFunction;
private AbstractWeightFunction<T> weightFunction;
private final Map<T, T> PARENTS;
private final Map<T, Double> DISTANCE;
private final Set<T> CLOSED;
private PriorityQueue<T> OPEN;
public AStarPathFinder() {
this.PARENTS = new HashMap<>();
this.DISTANCE = new HashMap<>();
this.CLOSED = new HashSet<>();
}
@Override
public List<T> search(T source, T target) {
checkNotNull(heuristicFunction, "Heuristic function is null.");
checkNotNull(weightFunction, "Weight function is null.");
clearState();
heuristicFunction.setTarget(target);
OPEN.add(source);
PARENTS.put(source, null);
DISTANCE.put(source, 0.0);
while (OPEN.size() > 0) {
final T current = OPEN.poll();
if (current.equals(target)) {
return constructPath(target, PARENTS);
}
CLOSED.add(current);
for (final T child : current) {
if (CLOSED.contains(child)) {
continue;
}
final double w = g(current) + w(current, child);
if (!PARENTS.containsKey(child)) {
PARENTS.put(child, current);
DISTANCE.put(child, w);
// DISTANCE updated, implicitly used by OPEN.add.
OPEN.add(child);
} else if (w < g(child)) {
PARENTS.put(child, current);
DISTANCE.put(child, w);
// Reinsert as to decrease the priority.
OPEN.remove(child);
OPEN.add(child);
}
}
}
// Empty list denotes that target is not reachable from source.
return Collections.<T>emptyList();
}
public AStarPathFinder<T>
setWeightFunction(final AbstractWeightFunction<T> function) {
this.weightFunction = function;
return this;
}
public AStarPathFinder<T>
setHeuristicFunction(final AbstractHeuristicFunction<T> function) {
this.heuristicFunction = function;
this.OPEN = new PriorityQueue<>(
new FValueComparator(DISTANCE, function));
return this;
}
private double h(final T node) {
return heuristicFunction.h(node);
}
private double w(final T tail, final T head) {
return weightFunction.get(tail, head);
}
private double g(final T node) {
return DISTANCE.get(node);
}
private void clearState() {
PARENTS.clear();
DISTANCE.clear();
CLOSED.clear();
OPEN.clear();
}
private class FValueComparator implements Comparator<T> {
private final Map<T, Double> DISTANCE;
private final AbstractHeuristicFunction<T> function;
FValueComparator(final Map<T, Double> DISTANCE,
final AbstractHeuristicFunction<T> function) {
this.DISTANCE = DISTANCE;
this.function = function;
}
@Override
public int compare(final T o1, final T o2) {
final double f1 = DISTANCE.get(o1) + function.h(o1);
final double f2 = DISTANCE.get(o2) + function.h(o2);
return Double.compare(f1, f2);
}
}
}
DijkstraHeuristicFunction.java:
package com.stackexchange.codereview.graph.model.support;
import com.stackexchange.codereview.graph.model.AbstractHeuristicFunction;
import com.stackexchange.codereview.graph.model.AbstractNode;
import com.stackexchange.codereview.graph.model.PlaneLayout;
public class DijkstraHeuristicFunction<T extends AbstractNode<T>>
implements AbstractHeuristicFunction<T> {
@Override
public double h(T node) {
return 0.0;
}
@Override
public void setTarget(T target) {
}
@Override
public void setLayout(PlaneLayout layout) {
}
}
DirectedGraphNode.java:
package com.stackexchange.codereview.graph.model.support;
import com.stackexchange.codereview.graph.model.AbstractNode;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.Set;
public class DirectedGraphNode extends AbstractNode<DirectedGraphNode> {
private final Set<DirectedGraphNode> in;
private final Set<DirectedGraphNode> out;
public DirectedGraphNode(final String id) {
super(id);
// LinkedHashSet iterates way faster than HashSet.
this.in = new LinkedHashSet<>();
this.out = new LinkedHashSet<>();
}
@Override
public boolean connectTo(DirectedGraphNode node) {
if (out.contains(node)) {
return false;
}
out.add(node);
node.in.add(this);
return true;
}
@Override
public boolean disconnectFrom(DirectedGraphNode node) {
if (!out.contains(node)) {
return false;
}
out.remove(node);
node.in.remove(this);
return true;
}
@Override
public boolean isConnectedTo(DirectedGraphNode node) {
return out.contains(node);
}
@Override
public Iterable<DirectedGraphNode> parents() {
return new Iterable<DirectedGraphNode>() {
@Override
public Iterator<DirectedGraphNode> iterator() {
return new IteratorProxy<>(in.iterator());
}
};
}
@Override
public Iterator<DirectedGraphNode> iterator() {
return new IteratorProxy<>(out.iterator());
}
@Override
public String toString() {
return "[DirectedGraphNode " + id + "]";
}
}
DirectedGraphWeightFunction.java:
package com.stackexchange.codereview.graph.model.support;
import com.stackexchange.codereview.graph.model.AbstractWeightFunction;
import java.util.HashMap;
public class DirectedGraphWeightFunction
extends AbstractWeightFunction<DirectedGraphNode> {
public DirectedGraphWeightFunction() {
super();
}
@Override
public Double put(final DirectedGraphNode node1,
final DirectedGraphNode node2,
final double weight) {
if (!map.containsKey(node1)) {
map.put(node1, new HashMap<>());
}
final Double old = map.get(node1).get(node2);
map.get(node1).put(node2, weight);
return old;
}
@Override
public double get(final DirectedGraphNode node1,
final DirectedGraphNode node2) {
return map.get(node1).get(node2);
}
}
IteratorProxy.java:
package com.stackexchange.codereview.graph.model.support;
import java.util.Iterator;
public class IteratorProxy<T> implements Iterator<T> {
private final Iterator<T> iterator;
protected IteratorProxy(final Iterator<T> iterator) {
this.iterator = iterator;
}
@Override
public boolean hasNext() {
return iterator.hasNext();
}
@Override
public T next() {
return iterator.next();
}
}
PlaneHeuristicFunction.java:
package com.stackexchange.codereview.graph.model.support;
import com.stackexchange.codereview.graph.model.AbstractHeuristicFunction;
import com.stackexchange.codereview.graph.model.AbstractNode;
import com.stackexchange.codereview.graph.model.PlaneLayout;
import java.awt.geom.Point2D;
public class PlaneHeuristicFunction<T extends AbstractNode<T>>
implements AbstractHeuristicFunction<T> {
private T target;
private PlaneLayout<T> layout;
private Point2D.Double targetLocation;
public PlaneHeuristicFunction(final PlaneLayout<T> layout,
final T target) {
this.layout = layout;
this.targetLocation = layout.get(target);
}
@Override
public void setLayout(PlaneLayout layout) {
this.layout = layout;
this.targetLocation = layout.get(target);
}
@Override
public void setTarget(T target) {
this.target = target;
this.targetLocation = layout.get(target);
}
@Override
public double h(final T node) {
return targetLocation.distance(layout.get(node));
}
}
com.stackexchange.codereview.graph:
Utils.java:
package com.stackexchange.codereview.graph;
import com.stackexchange.codereview.graph.model.PlaneLayout;
import com.stackexchange.codereview.graph.model.support.DirectedGraphNode;
import com.stackexchange.codereview.graph.model.support.DirectedGraphWeightFunction;
import java.awt.geom.Point2D;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
public class Utils {
public static class Triple<F, S, T> {
private final F first;
private final S second;
private final T third;
public Triple(final F first, final S second, final T third) {
this.first = first;
this.second = second;
this.third = third;
}
public F first() {
return first;
}
public S second() {
return second;
}
public T third() {
return third;
}
}
public static Triple<List<DirectedGraphNode>,
DirectedGraphWeightFunction,
PlaneLayout>
createRandomDigraph(final int nodeAmount,
float edgeLoadFactor,
final double width,
final double height,
final double maxDistance,
double weightFactor,
final Random rnd) {
final List<DirectedGraphNode> graph = new ArrayList<>(nodeAmount);
final PlaneLayout layout = new PlaneLayout();
final DirectedGraphWeightFunction weightFunction =
new DirectedGraphWeightFunction();
for (int i = 0; i < nodeAmount; ++i) {
final DirectedGraphNode node = new DirectedGraphNode("" + i);
layout.put(node, new Point2D.Double(width * rnd.nextDouble(),
height * rnd.nextDouble()));
graph.add(node);
}
weightFactor = Math.max(weightFactor, 1.05);
edgeLoadFactor = Math.min(edgeLoadFactor, 0.8f);
int edges = (int)(edgeLoadFactor * nodeAmount * nodeAmount);
while (edges > 0) {
final DirectedGraphNode tail = choose(graph, rnd);
final DirectedGraphNode head = choose(graph, rnd);
final Point2D.Double tailPoint = layout.get(tail);
final Point2D.Double headPoint = layout.get(head);
final double distance = tailPoint.distance(headPoint);
if (distance <= maxDistance) {
tail.connectTo(head);
weightFunction.put(tail, head, weightFactor * distance);
--edges;
}
}
return new Triple<>(graph, weightFunction, layout);
}
public static <E> E choose(final List<E> list, final Random rnd) {
if (list.isEmpty()) {
return null;
}
return list.get(rnd.nextInt(list.size()));
}
public static void checkNotNull(final Object reference,
final String message) {
if (reference == null) {
throw new NullPointerException(message);
}
}
public static <E> boolean listsAreSame(final List<E> list1,
final List<E> list2) {
if (list1.size() != list2.size()) {
return false;
}
for (int i = 0; i < list1.size(); ++i) {
if (!list1.get(i).equals(list2.get(i))) {
return false;
}
}
return true;
}
}
Demo.java:
package com.stackexchange.codereview.graph;
import com.stackexchange.codereview.graph.Utils.Triple;
import static com.stackexchange.codereview.graph.Utils.choose;
import static com.stackexchange.codereview.graph.Utils.listsAreSame;
import com.stackexchange.codereview.graph.model.PlaneLayout;
import com.stackexchange.codereview.graph.model.support.AStarPathFinder;
import com.stackexchange.codereview.graph.model.support.DijkstraHeuristicFunction;
import com.stackexchange.codereview.graph.model.support.DirectedGraphNode;
import com.stackexchange.codereview.graph.model.support.DirectedGraphWeightFunction;
import com.stackexchange.codereview.graph.model.support.PlaneHeuristicFunction;
import java.util.List;
import java.util.Random;
public class Demo {
public static final int GRAPH_SIZE = 100000;
public static final float EDGE_LOAD_FACTOR = 4.0f / GRAPH_SIZE;
public static final double WIDTH = 2000.0;
public static final double HEIGHT = 1000.0;
public static final double MAX_DISTANCE = 100.0;
public static final double WEIGHT_FACTOR = 1.1;
public static void main(final String... args) {
final long seed = System.currentTimeMillis();
System.out.println("Seed: " + seed);
final Random rnd = new Random(seed);
Triple<List<DirectedGraphNode>,
DirectedGraphWeightFunction,
PlaneLayout> data =
Utils.createRandomDigraph(GRAPH_SIZE,
EDGE_LOAD_FACTOR,
WIDTH,
HEIGHT,
MAX_DISTANCE,
WEIGHT_FACTOR,
rnd);
final DirectedGraphNode source = choose(data.first(), rnd);
final DirectedGraphNode target = choose(data.first(), rnd);
System.out.println("Source: " + source);
System.out.println("Target: " + target);
final AStarPathFinder<DirectedGraphNode> finder =
new AStarPathFinder<>()
.setHeuristicFunction(
new PlaneHeuristicFunction<>(data.third(), target))
.setWeightFunction(data.second());
long ta = System.currentTimeMillis();
final List<DirectedGraphNode> path1 = finder.search(source, target);
long tb = System.currentTimeMillis();
System.out.println("A* in " + (tb - ta) + " ms.");
for (final DirectedGraphNode node : path1) {
System.out.println(node);
}
System.out.println();
finder.setHeuristicFunction(new DijkstraHeuristicFunction<>());
ta = System.currentTimeMillis();
final List<DirectedGraphNode> path2 = finder.search(source, target);
tb = System.currentTimeMillis();
System.out.println("Dijkstra's algorithm in " + (tb - ta) + " ms.");
for (final DirectedGraphNode node : path2) {
System.out.println(node);
}
System.out.println();
System.out.println("Paths are same: " + listsAreSame(path1, path2));
}
}