Take the 2-minute tour ×
Code Review Stack Exchange is a question and answer site for peer programmer code reviews. It's 100% free, no registration required.

The Edmonds-Karp algorithm relies on breadth-first search in order to find an augmenting path in the residual flow network. This version of the algorithm uses bidirectional breadth-first search in order to speed up the entire algorithm. My profiling reports consistently speed ups of at least 3.

The files DirectedGraphNode.java, EdgeWeightFunction.java, AbstractMaximumFlowFinder.java and EdmondsKarpMaximumFlowFinder.java may be found at Edmonds-Karp algorithm for maximum flow in Java

The new and modified files are:

BidirectionalEdmondsKarpMaximumFlowFinder.java:

package net.coderodde.graph.model.flow.support;

import java.util.ArrayDeque;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import net.coderodde.graph.model.DirectedGraphNode;
import net.coderodde.graph.model.EdgeWeightFunction;
import net.coderodde.graph.model.flow.AbstractMaximumFlowFinder;

public class BidirectionalEdmondsKarpMaximumFlowFinder 
extends AbstractMaximumFlowFinder {

    private final Queue<DirectedGraphNode> queueA;
    private final Queue<DirectedGraphNode> queueB;
    private final Map<DirectedGraphNode, DirectedGraphNode> parentMapA;
    private final Map<DirectedGraphNode, DirectedGraphNode> parentMapB;
    private final Map<DirectedGraphNode, Integer> distanceMapA;
    private final Map<DirectedGraphNode, Integer> distanceMapB;

    public BidirectionalEdmondsKarpMaximumFlowFinder() {
        this.queueA = new ArrayDeque<>();
        this.queueB = new ArrayDeque<>();
        this.parentMapA = new HashMap<>();
        this.parentMapB = new HashMap<>();
        this.distanceMapA = new HashMap<>();
        this.distanceMapB = new HashMap<>();
    }

    @Override
    public FlowData findMaximumFlow(final DirectedGraphNode source, 
                                    final DirectedGraphNode sink, 
                                    final EdgeWeightFunction capacityFunction) {
        if (source.equals(sink)) {
            throw new IllegalArgumentException(
                    "The source node is the same as the sink node.");
        }

        // Proceed to the actual computation.
        double flow = 0.0;
        final EdgeWeightFunction f = new EdgeWeightFunction();
        final EdgeWeightFunction c = capacityFunction; // Just rename.
        final List<DirectedGraphNode> graph = 
                getConnectedComponent(source, sink);

        if (graph == null) {
            // Sink not reachable from source.
            return new FlowData(f, 0.0);
        }

        List<DirectedGraphNode> path = null;

        while ((path = findAugmentingPath(source, sink, f, c)).size() > 1) {
            flow += findMinimumEdgeAndRemove(path, f, c);
        }

        return new FlowData(f, flow);
    }

    /**
     * Finds a (shortest in terms of edge amount) path of which each edge has
     * nonnegative residual weight.
     * 
     * @param source the source node.
     * @param sink   the sink node.
     * @param c      the capacity function.
     * @param f      the flow function.
     * 
     * @return       a list of nodes of a path or <code>null</code> if there is
     *               no path anymore.
     */
    private List<DirectedGraphNode> 
        findAugmentingPath(final DirectedGraphNode source,
                           final DirectedGraphNode sink,
                           final EdgeWeightFunction f,
                           final EdgeWeightFunction c) {
        queueA.clear();
        queueB.clear();
        parentMapA.clear();
        parentMapB.clear();
        distanceMapA.clear();
        distanceMapB.clear();

        queueA.add(source);
        queueB.add(sink);
        parentMapA.put(source, null);
        parentMapB.put(sink, null);
        distanceMapA.put(source, 0);
        distanceMapB.put(sink, 0);

        DirectedGraphNode touchNode = null;
        int bestDistance = Integer.MAX_VALUE;

        while (queueA.size() > 0 && queueB.size() > 0) {
            if (touchNode != null 
                    && distanceMapA.get(queueA.peek()) +
                       distanceMapB.get(queueB.peek()) >= bestDistance) {
                return tracebackPath(touchNode, parentMapA, parentMapB);
            }

            DirectedGraphNode current = queueA.remove();

            if (parentMapB.containsKey(current)
                    && bestDistance > distanceMapA.get(current) + 
                                      distanceMapB.get(current)) {
                bestDistance = distanceMapA.get(current) + 
                               distanceMapB.get(current);
                touchNode = current;
            }

            for (final DirectedGraphNode neighbor : current.all()) {
                if (parentMapA.containsKey(neighbor)) {
                    continue;
                }

                if (residualEdgeWeight(current, neighbor, f, c) > 0) {
                    parentMapA.put(neighbor, current);
                    distanceMapA.put(neighbor, distanceMapA.get(current) + 1);
                    queueA.add(neighbor);
                }
            }

            current = queueB.remove();

            if (parentMapA.containsKey(current)
                    && bestDistance > distanceMapA.get(current) + 
                                      distanceMapB.get(current)) {
                bestDistance = distanceMapA.get(current) + 
                               distanceMapB.get(current);
                touchNode = current;
            }

            for (final DirectedGraphNode neighbor : current.all()) {
                if (parentMapB.containsKey(neighbor)) {
                    continue;
                }

                if (residualEdgeWeight(neighbor, current, f, c) > 0) {
                    parentMapB.put(neighbor, current);
                    distanceMapB.put(neighbor, distanceMapB.get(current) + 1);
                    queueB.add(neighbor);
                }
            }
        }

        return Collections.<DirectedGraphNode>emptyList();
    }
}

Demo.java:

package net.coderodde.graph;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import net.coderodde.graph.model.DirectedGraphNode;
import net.coderodde.graph.model.EdgeWeightFunction;
import net.coderodde.graph.model.flow.AbstractMaximumFlowFinder.FlowData;
import net.coderodde.graph.model.flow.support.BidirectionalEdmondsKarpMaximumFlowFinder;
import net.coderodde.graph.model.flow.support.EdmondsKarpMaximumFlowFinder;

public class Demo {

    private static final int MAX_CAPACITY = 25;
    private static final int SIZE = 2000;
    private static final float EDGE_LOAD_FACTOR = 0.03f;

    public static void main(final String... args) {
        final long seed = System.currentTimeMillis();
        final Random rnd = new Random(seed);
        System.out.println("Seed: " + seed);

        smallDemo();
        largeDemo(rnd);
    }

    // This sample graph is from Introduction to Algorithms, 3rd edition,
    // page 727.
    public static void smallDemo() {
        final List<DirectedGraphNode> network = new ArrayList<>(7);

        final DirectedGraphNode s = new DirectedGraphNode("s");
        final DirectedGraphNode v1 = new DirectedGraphNode("v1");
        final DirectedGraphNode v2 = new DirectedGraphNode("v2");
        final DirectedGraphNode v3 = new DirectedGraphNode("v3");
        final DirectedGraphNode v4 = new DirectedGraphNode("v4");
        final DirectedGraphNode t = new DirectedGraphNode("t");

        s.connectTo(v1);
        s.connectTo(v2);

        v1.connectTo(v3);

        v2.connectTo(v1);
        v2.connectTo(v4);

        v3.connectTo(v2);
        v3.connectTo(t);

        v4.connectTo(v3);
        v4.connectTo(t);

        v3.connectTo(v3);
        v4.connectTo(v4);
        v2.connectTo(v2);

        final EdgeWeightFunction capacities = new EdgeWeightFunction();

        capacities.put(s, v1, 16);
        capacities.put(s, v2, 13);
        capacities.put(v1, v3, 12);
        capacities.put(v2, v1, 4);
        capacities.put(v2, v4, 14);
        capacities.put(v3, v2, 9);
        capacities.put(v3, t, 20);
        capacities.put(v4, v3, 7);
        capacities.put(v4, t, 4);

        final FlowData fd = 
                new EdmondsKarpMaximumFlowFinder()
                .findMaximumFlow(s, t, capacities);

        System.out.println("Total flow: " + fd.flowValue);

        System.out.println(fd.flowFunction.getOrDefault(s,  v1, 0));
        System.out.println(fd.flowFunction.getOrDefault(s,  v2, 0));
        System.out.println(fd.flowFunction.getOrDefault(v1, v3, 0));
        System.out.println(fd.flowFunction.getOrDefault(v2, v1, 0));
        System.out.println(fd.flowFunction.getOrDefault(v2, v4, 0));
        System.out.println(fd.flowFunction.getOrDefault(v3, v2, 0));
        System.out.println(fd.flowFunction.getOrDefault(v3, t,  0));
        System.out.println(fd.flowFunction.getOrDefault(v4, v3, 0));
        System.out.println(fd.flowFunction.getOrDefault(v4, t,  0));
    }

    public static void largeDemo(final Random rnd) {
        final FlowNetwork flowNetwork = createRandomNetwork(SIZE,
                                                             EDGE_LOAD_FACTOR,
                                                             rnd);

        final DirectedGraphNode source = choose(flowNetwork.network, rnd);
        final DirectedGraphNode sink =   choose(flowNetwork.network, rnd);

        long ta = System.currentTimeMillis();
        final FlowData flowData1 = new EdmondsKarpMaximumFlowFinder()
                .findMaximumFlow(source, sink, flowNetwork.capacityFunction);
        long tb = System.currentTimeMillis();

        System.out.println("Time: " + (tb - ta) + " ms.");
        System.out.println("Flow: " + flowData1.flowValue);

        final boolean sourceCorrect = verifySource(source, 
                                                   flowData1.flowFunction,
                                                   flowData1.flowValue);

        final boolean sinkCorrect = verifySink(sink, 
                                               flowData1.flowFunction,
                                               flowData1.flowValue);

        System.out.println("Source check: " + sourceCorrect);
        System.out.println("Sink check:   " + sinkCorrect);
        System.out.println("----");
        System.out.println();

        // Bidirectional Edmonds-Karp.
        ta = System.currentTimeMillis();
        final FlowData flowData2 = 
                new BidirectionalEdmondsKarpMaximumFlowFinder()
                .findMaximumFlow(source, sink, flowNetwork.capacityFunction);
        tb = System.currentTimeMillis();

        System.out.println("Time: " + (tb - ta) + " ms.");
        System.out.println("Flow: " + flowData2.flowValue);

        final boolean sourceCorrect2 = verifySource(source, 
                                                    flowData2.flowFunction,
                                                    flowData2.flowValue);

        final boolean sinkCorrect2 = verifySink(sink, 
                                                flowData2.flowFunction,
                                                flowData2.flowValue);

        System.out.println("Source check: " + sourceCorrect2);
        System.out.println("Sink check:   " + sinkCorrect2);
    }

    public static class FlowNetwork {
        public final List<DirectedGraphNode> network;
        public final EdgeWeightFunction capacityFunction;

        public FlowNetwork(final List<DirectedGraphNode> network,
                           final EdgeWeightFunction capacityFunction) {
            this.network = network;
            this.capacityFunction = capacityFunction;
        }
    }

    public static FlowNetwork createRandomNetwork(final int size,
                                                   final float edgeLoadFactor,
                                                   final Random rnd) {
        final List<DirectedGraphNode> network = new ArrayList<>(size);
        final EdgeWeightFunction capacityFunction = new EdgeWeightFunction();

        for (int i = 0; i < size; ++i) {
            network.add(new DirectedGraphNode("" + i));
        }

        int edges = (int)(Math.min(0.95f, edgeLoadFactor) * size * size);

        while (edges > 0) {
            final DirectedGraphNode tail = choose(network, rnd);
            final DirectedGraphNode head = choose(network, rnd);
            tail.connectTo(head);
            capacityFunction.put(tail, head, rnd.nextInt(MAX_CAPACITY));
            --edges;
        }

        return new FlowNetwork(network, capacityFunction);
    }

    private static <E> E choose(final List<E> list, final Random rnd) {
        if (list.isEmpty()) {
            return null;
        }

        return list.get(rnd.nextInt(list.size()));
    }

    private static boolean verifySource(final DirectedGraphNode source,
                                        final EdgeWeightFunction f,
                                        final double actualMaximumFlow) {
        double flow = 0.0;

        for (final DirectedGraphNode child : source) {
            flow += f.getOrDefault(source, child, 0);
        }

        for (final DirectedGraphNode parent : source.parents()) {
            flow -= f.getOrDefault(parent, source, 0);
        }

        return flow == actualMaximumFlow;
    }

    private static boolean verifySink(final DirectedGraphNode sink,
                                      final EdgeWeightFunction f,
                                      final double actualMaximumFlow) {
        double flow = 0.0;

        for (final DirectedGraphNode child : sink) {
            flow -= f.getOrDefault(sink, child, 0);
        }

        for (final DirectedGraphNode parent : sink.parents()) {
            flow += f.getOrDefault(parent, sink, 0);
        }

        return flow == actualMaximumFlow;
    }
}
share|improve this question
    
This is being discussed on meta: meta.codereview.stackexchange.com/questions/5072/… –  rolfl Feb 17 at 18:22

Your Answer

 
discard

By posting your answer, you agree to the privacy policy and terms of service.

Browse other questions tagged or ask your own question.