Passed
Push — master ( 8cfb03...7c5bc8 )
by Daniel
07:13
created

amd._network_simplex.find_leaving_edge()   A

Complexity

Conditions 3

Size

Total Lines 21
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 14
dl 0
loc 21
rs 9.7
c 0
b 0
f 0
cc 3
nop 6
1
"""An implementation of the Wasserstein metric between two compositional vectors.
2
Aka Earth Mover's distance, an appropriate metric for comparing two PDDs.
3
4
Copyright (C) 2020 Cameron Hargreaves. This code is part of the Element Movers
5
Distance project https://github.com/lrcfmd/ElMD.
6
"""
7
8
import numba
9
import numpy as np
10
11
12
@numba.njit(cache=True)
13
def network_simplex(source_demands, sink_demands, network_costs):
14
    """
15
    This is a port of the network simplex algorithm implented by Loïc Séguin-C
16
    for the networkx package to allow acceleration via the numba package.
17
    Copyright (C) 2010 Loïc Séguin-C. [email protected]
18
    All rights reserved.
19
    BSD license.
20
21
    References
22
    ----------
23
    [1] Z. Kiraly, P. Kovacs.
24
        Efficient implementation of minimum-cost flow algorithms.
25
        Acta Universitatis Sapientiae, Informatica 4(1):67--118. 2012.
26
    [2] R. Barr, F. Glover, D. Klingman.
27
        Enhancement of spanning tree labeling procedures for network
28
        optimization.
29
        INFOR 17(1):16--34. 1979.
30
    """
31
32
    n_sources, n_sinks = source_demands.shape[0], sink_demands.shape[0]
33
    network_costs = network_costs.ravel()
34
    n = n_sources + n_sinks
35
    e = n_sources * n_sinks
36
    B = np.int64(np.ceil(np.sqrt(e)))
37
    fp_multiplier = np.int64(1_000_000)
38
39
    # Add one additional node for a dummy source and sink
40
    source_d_int = (source_demands * fp_multiplier).astype(np.int64)
41
    sink_d_int = (sink_demands * fp_multiplier).astype(np.int64)
42
43
    # FP conversion error correction
44
    source_sum = np.sum(source_d_int)
45
    sink_sum = np.sum(sink_d_int)
46
    sink_source_sum_diff = sink_sum - source_sum
47
48
    if sink_source_sum_diff > 0:
49
        source_d_int[np.argmax(source_d_int)] += sink_source_sum_diff
50
    elif sink_source_sum_diff < 0:
51
        sink_d_int[np.argmax(sink_d_int)] -= sink_source_sum_diff
52
53
    demands = np.empty(n, dtype=np.int64)
54
    demands[:n_sources] = -source_d_int
55
    demands[n_sources:] = sink_d_int
56
57
    tails = np.empty(e + n, dtype=np.int64)
58
    heads = np.empty(e + n, dtype=np.int64)
59
60
    for i in range(n_sources):
61
        for j in range(n_sinks):
62
            ind = i * n_sinks + j
63
            tails[ind] = i
64
            heads[ind] = j + n_sources
65
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
66
    for i, demand in enumerate(demands):
67
        if demand > 0:
68
            tails[e + i] = -1
69
            heads[e + i] = -1
70
        else:
71
            tails[e + i] = i
72
            heads[e + i] = i
73
74
    # Create costs and capacities for the arcs between nodes
75
    network_costs = network_costs * fp_multiplier
76
77
    network_capac = np.array([min(source_demands[i], sink_demands[j])
78
                              for i in range(n_sources)
79
                              for j in range(n_sinks)],
80
                             dtype=np.float64) * fp_multiplier
81
82
    faux_inf = 3 * np.max(np.array((
83
                            np.sum(network_capac),
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation (remove 20 spaces).
Loading history...
84
                            np.sum(np.abs(network_costs)),
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation (remove 20 spaces).
Loading history...
85
                            np.amax(source_d_int),
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation (remove 20 spaces).
Loading history...
86
                            np.amax(sink_d_int)),
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation (remove 20 spaces).
Loading history...
87
                          dtype=np.int64))
0 ignored issues
show
Coding Style introduced by
Wrong continued indentation (add 9 spaces).
Loading history...
88
89
    # allocate arrays
90
    costs = np.empty(e + n, dtype=np.int64)
91
    costs[:e] = network_costs
92
    costs[e:] = faux_inf
93
94
    capac = np.empty(e + n, dtype=np.int64)
95
    capac[:e] = network_capac
96
    capac[e:] = fp_multiplier
97
98
    flows = np.empty(e + n, dtype=np.int64)
99
    flows[:e] = 0
100
    flows[e:e+n_sources] = source_d_int
101
    flows[e+n_sources:] = sink_d_int
102
103
    potentials = np.empty(n, dtype=np.int64)
104
    demands_neg_mask = demands <= 0
105
    potentials[demands_neg_mask] = faux_inf
106
    potentials[~demands_neg_mask] = -faux_inf
107
108
    parent = np.empty(n + 1, dtype=np.int64)
109
    parent[:-1] = -1
110
    parent[-1] = -2
111
112
    size = np.empty(n + 1, dtype=np.int64)
113
    size[:-1] = 1
114
    size[-1] = n + 1
115
116
    next_node = np.arange(1, n + 2, dtype=np.int64)
117
    next_node[-2] = -1
118
    next_node[-1] = 0
119
120
    last_node = np.arange(n + 1, dtype=np.int64)
121
    last_node[-1] = n - 1
122
123
    prev_node = np.arange(-1, n, dtype=np.int64)
124
    edge = np.arange(e, e + n, dtype=np.int64)
125
126
    ### Pivot loop ###
127
128
    f = 0
129
    while True:
130
        i, p, q, f = find_entering_edges(B, e, f, tails, heads, costs, potentials, flows)
131
        if p == -1: # If no entering edges then the optimal score is found
132
            break
133
134
        cycle_nodes, cycle_edges = find_cycle(i, p, q, size, edge, parent)
135
        j, s, t = find_leaving_edge(cycle_nodes, cycle_edges, capac, flows, tails, heads)
136
        augment_flow(cycle_nodes, cycle_edges, residual_capacity(j, s, capac, flows, tails), tails, flows)
137
138
        if i != j:  # Do nothing more if the entering edge is the same as the leaving edge.
139
            if parent[t] != s:
140
                # Ensure that s is the parent of t.
141
                s, t = t, s
142
143
            # Ensure that q is in the subtree rooted at t.
144
            for val in cycle_edges:
145
                if val == j:
0 ignored issues
show
Unused Code introduced by
Unnecessary "elif" after "break"
Loading history...
146
                    p, q = q, p
147
                    break
148
                elif val == i:
149
                    break
150
151
            remove_edge(s, t, size, prev_node, last_node, next_node, parent, edge)
152
            make_root(q, parent, size, last_node, prev_node, next_node, edge)
153
            add_edge(i, p, q, next_node, prev_node, last_node, size, parent, edge)
154
            update_potentials(i, p, q, heads, potentials, costs, last_node, next_node)
155
156
    final_flows = flows[:e] / fp_multiplier
157
    edge_costs = costs[:e] / fp_multiplier
158
    final = final_flows.dot(edge_costs)
159
160
    return final, final_flows.reshape((n_sources, n_sinks))
161
162
163
@numba.njit(cache=True)
164
def reduced_cost(i, costs, potentials, tails, heads, flows):
0 ignored issues
show
best-practice introduced by
Too many arguments (6/5)
Loading history...
165
    """Return the reduced cost of an edge i.
166
    """
167
    c = costs[i] - potentials[tails[i]] + potentials[heads[i]]
168
169
    if flows[i] == 0:
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
170
        return c
171
    else:
172
        return -c
173
174
175
@numba.njit(cache=True)
176
def find_entering_edges(B, e, f, tails, heads, costs, potentials, flows):
0 ignored issues
show
best-practice introduced by
Too many arguments (8/5)
Loading history...
177
    """Yield entering edges until none can be found.
178
    """
179
    # Entering edges are found by combining Dantzig's rule and Bland's
180
    # rule. The edges are cyclically grouped into blocks of size B. Within
181
    # each block, Dantzig's rule is applied to find an entering edge. The
182
    # blocks to search is determined following Bland's rule.
183
184
    M = (e + B - 1) // B    # number of blocks needed to cover all edges
185
    m = 0
186
187
    while m < M:
188
        # Determine the next block of edges.
189
        l = f + B
190
        if l <= e:
191
            edge_inds = np.arange(f, l)
192
        else:
193
            l -= e
194
            edge_inds = np.empty(e - f + l, dtype=np.int64)
195
            for i, v in enumerate(range(f, e)):
196
                edge_inds[i] = v
197
            for i in range(l):
198
                edge_inds[e - f + i] = i
199
200
        f = l
201
202
        # Find the first edge with the lowest reduced cost.
203
        i = edge_inds[0]
204
        c = reduced_cost(i, costs, potentials, tails, heads, flows)
205
        for ind in edge_inds[1:]:
206
            cost = reduced_cost(ind, costs, potentials, tails, heads, flows)
207
            if cost < c:
208
                c = cost
209
                i = ind
210
211
        p = q = -1
212
213
        if c >= 0:
214
            m += 1
215
216
        # Entering edge found.
217
        else:
218
            if flows[i] == 0:
219
                p = tails[i]
220
                q = heads[i]
221
            else:
222
                p = heads[i]
223
                q = tails[i]
224
225
            return i, p, q, f
226
227
    # All edges have nonnegative reduced costs. The flow is optimal.
228
    return -1, -1, -1, -1
229
230
231
@numba.njit(cache=True)
232
def find_apex(p, q, size, parent):
233
    """Find the lowest common ancestor of nodes p and q in the spanning tree.
234
    """
235
    size_p = size[p]
236
    size_q = size[q]
237
238
    while True:
239
        while size_p < size_q:
240
            p = parent[p]
241
            size_p = size[p]
242
        while size_p > size_q:
243
            q = parent[q]
244
            size_q = size[q]
245
        if size_p == size_q:
246
            if p != q:
247
                p = parent[p]
248
                size_p = size[p]
249
                q = parent[q]
250
                size_q = size[q]
251
            else:
252
                return p
253
254
255
@numba.njit(cache=True)
256
def trace_path(p, w, edge, parent):
257
    """Return the nodes and edges on the path from node p to its ancestor w.
258
    """
259
    cycle_nodes = [p]
260
    cycle_edges = []
261
262
    while p != w:
263
        cycle_edges.append(edge[p])
264
        p = parent[p]
265
        cycle_nodes.append(p)
266
267
    return cycle_nodes, cycle_edges
268
269
270
@numba.njit(cache=True)
271
def find_cycle(i, p, q, size, edge, parent):
0 ignored issues
show
best-practice introduced by
Too many arguments (6/5)
Loading history...
272
    """Return the nodes and edges on the cycle containing edge i == (p, q)
273
    when the latter is added to the spanning tree.
274
    The cycle is oriented in the direction from p to q.
275
    """
276
277
    w = find_apex(p, q, size, parent)
278
    cycle_nodes, cycle_edges = trace_path(p, w, edge, parent)
279
    cycle_nodes_rev, cycle_edges_rev = trace_path(q, w, edge, parent)
280
    append_i_to_edges = (len(cycle_edges) < 1 or cycle_edges[-1] != i)
281
    add_to_c_nodes = max(len(cycle_nodes_rev) - 1, 0)
282
    cycle_nodes_ = np.empty(len(cycle_nodes) + add_to_c_nodes, dtype=np.int64)
283
284
    for j in range(len(cycle_nodes)):
285
        cycle_nodes_[j] = cycle_nodes[-(j+1)]
286
    for j in range(add_to_c_nodes):
287
        cycle_nodes_[len(cycle_nodes) + j] = cycle_nodes_rev[j]
288
289
    if append_i_to_edges:
290
        cycle_edges_ = np.empty(len(cycle_edges) + len(cycle_edges_rev) + 1, dtype=np.int64)
291
        cycle_edges_[len(cycle_edges)] = i
292
    else:
293
        cycle_edges_ = np.empty(len(cycle_edges) + len(cycle_edges_rev), dtype=np.int64)
294
295
    for j in range(len(cycle_edges)):
296
        cycle_edges_[j] = cycle_edges[-(j+1)]
297
    for j in range(1, len(cycle_edges_rev) + 1):
298
        cycle_edges_[-j] = cycle_edges_rev[-j]
299
300
    return cycle_nodes_, cycle_edges_
301
302
303
@numba.njit(cache=True)
304
def residual_capacity(i, p, capac, flows, tails):
305
    """Return the residual capacity of an edge i in the direction away
306
    from its endpoint p.
307
    """
308
    if tails[i] == p:
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
309
        return capac[i] - flows[i]
310
    else:
311
        return flows[i]
312
313
314
@numba.njit(cache=True)
315
def find_leaving_edge(cycle_nodes, cycle_edges, capac, flows, tails, heads):
0 ignored issues
show
best-practice introduced by
Too many arguments (6/5)
Loading history...
316
    """Return the leaving edge in a cycle represented by cycle_nodes and
317
    cycle_edges.
318
    """
319
    j, s = cycle_edges[0], cycle_nodes[0]
320
    res_caps_min = residual_capacity(j, s, capac, flows, tails)
321
    for ind in range(1, cycle_edges.shape[0]):
322
        j_, s_ = cycle_edges[ind], cycle_nodes[ind]
323
        res_cap = residual_capacity(j_, s_, capac, flows, tails)
324
        if res_cap < res_caps_min:
325
            res_caps_min = res_cap
326
            j, s = j_, s_
327
328
    t = heads[j] if tails[j] == s else tails[j]
329
    return j, s, t
330
331
332
@numba.njit(cache=True)
333
def augment_flow(cycle_nodes, cycle_edges, f, tails, flows):
334
    """Augment f units of flow along a cycle representing Wn with cycle_edges.
335
    """
336
    for i, p in zip(cycle_edges, cycle_nodes):
337
        if tails[i] == p:
338
            flows[i] += f
339
        else:
340
            flows[i] -= f
341
342
343
@numba.njit(cache=True)
344
def remove_edge(s, t, size, prev, last, next_node, parent, edge):
0 ignored issues
show
best-practice introduced by
Too many arguments (8/5)
Loading history...
345
    """Remove an edge (s, t) where parent[t] == s from the spanning tree.
346
    """
347
    size_t = size[t]
348
    prev_t = prev[t]
349
    last_t = last[t]
350
    next_last_t = next_node[last_t]
351
    # Remove (s, t).
352
    parent[t] = -2
353
    edge[t] = -2
354
    # Remove the subtree rooted at t from the depth-first thread.
355
    next_node[prev_t] = next_last_t
356
    prev[next_last_t] = prev_t
357
    next_node[last_t] = t
358
    prev[t] = last_t
359
360
    # Update the subtree sizes and last descendants of the (old) ancestors
361
    # of t.
362
    while s != -2:
363
        size[s] -= size_t
364
        if last[s] == last_t:
365
            last[s] = prev_t
366
        s = parent[s]
367
368
369
@numba.njit(cache=True)
370
def make_root(q, parent, size, last, prev, next_node, edge):
0 ignored issues
show
best-practice introduced by
Too many arguments (7/5)
Loading history...
371
    """Make a node q the root of its containing subtree.
372
    """
373
    ancestors = []
374
    # -2 means node is checked
375
    while q != -2:
376
        ancestors.insert(0, q)
377
        q = parent[q]
378
379
    for p, q in zip(ancestors[:-1], ancestors[1:]):
0 ignored issues
show
unused-code introduced by
Redefining argument with the local name 'q'
Loading history...
380
        size_p = size[p]
381
        last_p = last[p]
382
        prev_q = prev[q]
383
        last_q = last[q]
384
        next_last_q = next_node[last_q]
385
386
        # Make p a child of q.
387
        parent[p] = q
388
        parent[q] = -2
389
        edge[p] = edge[q]
390
        edge[q] = -2
391
        size[p] = size_p - size[q]
392
        size[q] = size_p
393
394
        # Remove the subtree rooted at q from the depth-first thread.
395
        next_node[prev_q] = next_last_q
396
        prev[next_last_q] = prev_q
397
        next_node[last_q] = q
398
        prev[q] = last_q
399
400
        if last_p == last_q:
401
            last[p] = prev_q
402
            last_p = prev_q
403
404
        # Add the remaining parts of the subtree rooted at p as a subtree
405
        # of q in the depth-first thread.
406
        prev[p] = last_q
407
        next_node[last_q] = p
408
        next_node[last_p] = q
409
        prev[q] = last_p
410
        last[q] = last_p
411
412
413
@numba.njit(cache=True)
414
def add_edge(i, p, q, next_node, prev, last, size, parent, edge):
0 ignored issues
show
best-practice introduced by
Too many arguments (9/5)
Loading history...
415
    """Add an edge (p, q) to the spanning tree where q is the root of a
416
    subtree.
417
    """
418
    last_p = last[p]
419
    next_last_p = next_node[last_p]
420
    size_q = size[q]
421
    last_q = last[q]
422
    # Make q a child of p.
423
    parent[q] = p
424
    edge[q] = i
425
    # Insert the subtree rooted at q into the depth-first thread.
426
    next_node[last_p] = q
427
    prev[q] = last_p
428
    prev[next_last_p] = last_q
429
    next_node[last_q] = next_last_p
430
431
    # Update the subtree sizes and last descendants of the (new) ancestors
432
    # of q.
433
    while p != -2:
434
        size[p] += size_q
435
        if last[p] == last_p:
436
            last[p] = last_q
437
        p = parent[p]
438
439
440
@numba.njit(cache=True)
441
def update_potentials(i, p, q, heads, potentials, costs, last_node, next_node):
0 ignored issues
show
best-practice introduced by
Too many arguments (8/5)
Loading history...
442
    """Update the potentials of the nodes in the subtree rooted at a node
443
    q connected to its parent p by an edge i.
444
    """
445
    if q == heads[i]:
446
        d = potentials[p] - costs[i] - potentials[q]
447
    else:
448
        d = potentials[p] + costs[i] - potentials[q]
449
450
    potentials[q] += d
451
    l = last_node[q]
452
    while q != l:
453
        q = next_node[q]
454
        potentials[q] += d
455