Passed
Push — master ( 8cfb03...7c5bc8 )
by Daniel
07:13
created

amd._network_simplex.add_edge()   A

Complexity

Conditions 3

Size

Total Lines 25
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 17
dl 0
loc 25
rs 9.55
c 0
b 0
f 0
cc 3
nop 9

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""An implementation of the Wasserstein metric between two compositional vectors.
2
Aka Earth Mover's distance, an appropriate metric for comparing two PDDs.
3
4
Copyright (C) 2020 Cameron Hargreaves. This code is part of the Element Movers
5
Distance project https://github.com/lrcfmd/ElMD.
6
"""
7
8
import numba
9
import numpy as np
10
11
12
@numba.njit(cache=True)
13
def network_simplex(source_demands, sink_demands, network_costs):
14
    """
15
    This is a port of the network simplex algorithm implented by Loïc Séguin-C
16
    for the networkx package to allow acceleration via the numba package.
17
    Copyright (C) 2010 Loïc Séguin-C. [email protected]
18
    All rights reserved.
19
    BSD license.
20
21
    References
22
    ----------
23
    [1] Z. Kiraly, P. Kovacs.
24
        Efficient implementation of minimum-cost flow algorithms.
25
        Acta Universitatis Sapientiae, Informatica 4(1):67--118. 2012.
26
    [2] R. Barr, F. Glover, D. Klingman.
27
        Enhancement of spanning tree labeling procedures for network
28
        optimization.
29
        INFOR 17(1):16--34. 1979.
30
    """
31
32
    n_sources, n_sinks = source_demands.shape[0], sink_demands.shape[0]
33
    network_costs = network_costs.ravel()
34
    n = n_sources + n_sinks
35
    e = n_sources * n_sinks
36
    B = np.int64(np.ceil(np.sqrt(e)))
37
    fp_multiplier = np.int64(1_000_000)
38
39
    # Add one additional node for a dummy source and sink
40
    source_d_int = (source_demands * fp_multiplier).astype(np.int64)
41
    sink_d_int = (sink_demands * fp_multiplier).astype(np.int64)
42
43
    # FP conversion error correction
44
    source_sum = np.sum(source_d_int)
45
    sink_sum = np.sum(sink_d_int)
46
    sink_source_sum_diff = sink_sum - source_sum
47
48
    if sink_source_sum_diff > 0:
49
        source_d_int[np.argmax(source_d_int)] += sink_source_sum_diff
50
    elif sink_source_sum_diff < 0:
51
        sink_d_int[np.argmax(sink_d_int)] -= sink_source_sum_diff
52
53
    demands = np.empty(n, dtype=np.int64)
54
    demands[:n_sources] = -source_d_int
55
    demands[n_sources:] = sink_d_int
56
57
    tails = np.empty(e + n, dtype=np.int64)
58
    heads = np.empty(e + n, dtype=np.int64)
59
60
    for i in range(n_sources):
61
        for j in range(n_sinks):
62
            ind = i * n_sinks + j
63
            tails[ind] = i
64
            heads[ind] = j + n_sources
65
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
66
    for i, demand in enumerate(demands):
67
        if demand > 0:
68
            tails[e + i] = -1
69
            heads[e + i] = -1
70
        else:
71
            tails[e + i] = i
72
            heads[e + i] = i
73
74
    # Create costs and capacities for the arcs between nodes
75
    network_costs = network_costs * fp_multiplier
76
77
    network_capac = np.array([min(source_demands[i], sink_demands[j])
78
                              for i in range(n_sources)
79
                              for j in range(n_sinks)],
80
                             dtype=np.float64) * fp_multiplier
81
82
    faux_inf = 3 * np.max(np.array((
83
                            np.sum(network_capac),
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation (remove 20 spaces).
Loading history...
84
                            np.sum(np.abs(network_costs)),
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation (remove 20 spaces).
Loading history...
85
                            np.amax(source_d_int),
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation (remove 20 spaces).
Loading history...
86
                            np.amax(sink_d_int)),
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation (remove 20 spaces).
Loading history...
87
                          dtype=np.int64))
0 ignored issues
show
Coding Style introduced by
Wrong continued indentation (add 9 spaces).
Loading history...
88
89
    # allocate arrays
90
    costs = np.empty(e + n, dtype=np.int64)
91
    costs[:e] = network_costs
92
    costs[e:] = faux_inf
93
94
    capac = np.empty(e + n, dtype=np.int64)
95
    capac[:e] = network_capac
96
    capac[e:] = fp_multiplier
97
98
    flows = np.empty(e + n, dtype=np.int64)
99
    flows[:e] = 0
100
    flows[e:e+n_sources] = source_d_int
101
    flows[e+n_sources:] = sink_d_int
102
103
    potentials = np.empty(n, dtype=np.int64)
104
    demands_neg_mask = demands <= 0
105
    potentials[demands_neg_mask] = faux_inf
106
    potentials[~demands_neg_mask] = -faux_inf
107
108
    parent = np.empty(n + 1, dtype=np.int64)
109
    parent[:-1] = -1
110
    parent[-1] = -2
111
112
    size = np.empty(n + 1, dtype=np.int64)
113
    size[:-1] = 1
114
    size[-1] = n + 1
115
116
    next_node = np.arange(1, n + 2, dtype=np.int64)
117
    next_node[-2] = -1
118
    next_node[-1] = 0
119
120
    last_node = np.arange(n + 1, dtype=np.int64)
121
    last_node[-1] = n - 1
122
123
    prev_node = np.arange(-1, n, dtype=np.int64)
124
    edge = np.arange(e, e + n, dtype=np.int64)
125
126
    ### Pivot loop ###
127
128
    f = 0
129
    while True:
130
        i, p, q, f = find_entering_edges(B, e, f, tails, heads, costs, potentials, flows)
131
        if p == -1: # If no entering edges then the optimal score is found
132
            break
133
134
        cycle_nodes, cycle_edges = find_cycle(i, p, q, size, edge, parent)
135
        j, s, t = find_leaving_edge(cycle_nodes, cycle_edges, capac, flows, tails, heads)
136
        augment_flow(cycle_nodes, cycle_edges, residual_capacity(j, s, capac, flows, tails), tails, flows)
137
138
        if i != j:  # Do nothing more if the entering edge is the same as the leaving edge.
139
            if parent[t] != s:
140
                # Ensure that s is the parent of t.
141
                s, t = t, s
142
143
            # Ensure that q is in the subtree rooted at t.
144
            for val in cycle_edges:
145
                if val == j:
0 ignored issues
show
Unused Code introduced by
Unnecessary "elif" after "break"
Loading history...
146
                    p, q = q, p
147
                    break
148
                elif val == i:
149
                    break
150
151
            remove_edge(s, t, size, prev_node, last_node, next_node, parent, edge)
152
            make_root(q, parent, size, last_node, prev_node, next_node, edge)
153
            add_edge(i, p, q, next_node, prev_node, last_node, size, parent, edge)
154
            update_potentials(i, p, q, heads, potentials, costs, last_node, next_node)
155
156
    final_flows = flows[:e] / fp_multiplier
157
    edge_costs = costs[:e] / fp_multiplier
158
    final = final_flows.dot(edge_costs)
159
160
    return final, final_flows.reshape((n_sources, n_sinks))
161
162
163
@numba.njit(cache=True)
164
def reduced_cost(i, costs, potentials, tails, heads, flows):
0 ignored issues
show
best-practice introduced by
Too many arguments (6/5)
Loading history...
165
    """Return the reduced cost of an edge i.
166
    """
167
    c = costs[i] - potentials[tails[i]] + potentials[heads[i]]
168
169
    if flows[i] == 0:
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
170
        return c
171
    else:
172
        return -c
173
174
175
@numba.njit(cache=True)
176
def find_entering_edges(B, e, f, tails, heads, costs, potentials, flows):
0 ignored issues
show
best-practice introduced by
Too many arguments (8/5)
Loading history...
177
    """Yield entering edges until none can be found.
178
    """
179
    # Entering edges are found by combining Dantzig's rule and Bland's
180
    # rule. The edges are cyclically grouped into blocks of size B. Within
181
    # each block, Dantzig's rule is applied to find an entering edge. The
182
    # blocks to search is determined following Bland's rule.
183
184
    M = (e + B - 1) // B    # number of blocks needed to cover all edges
185
    m = 0
186
187
    while m < M:
188
        # Determine the next block of edges.
189
        l = f + B
190
        if l <= e:
191
            edge_inds = np.arange(f, l)
192
        else:
193
            l -= e
194
            edge_inds = np.empty(e - f + l, dtype=np.int64)
195
            for i, v in enumerate(range(f, e)):
196
                edge_inds[i] = v
197
            for i in range(l):
198
                edge_inds[e - f + i] = i
199
200
        f = l
201
202
        # Find the first edge with the lowest reduced cost.
203
        i = edge_inds[0]
204
        c = reduced_cost(i, costs, potentials, tails, heads, flows)
205
        for ind in edge_inds[1:]:
206
            cost = reduced_cost(ind, costs, potentials, tails, heads, flows)
207
            if cost < c:
208
                c = cost
209
                i = ind
210
211
        p = q = -1
212
213
        if c >= 0:
214
            m += 1
215
216
        # Entering edge found.
217
        else:
218
            if flows[i] == 0:
219
                p = tails[i]
220
                q = heads[i]
221
            else:
222
                p = heads[i]
223
                q = tails[i]
224
225
            return i, p, q, f
226
227
    # All edges have nonnegative reduced costs. The flow is optimal.
228
    return -1, -1, -1, -1
229
230
231
@numba.njit(cache=True)
232
def find_apex(p, q, size, parent):
233
    """Find the lowest common ancestor of nodes p and q in the spanning tree.
234
    """
235
    size_p = size[p]
236
    size_q = size[q]
237
238
    while True:
239
        while size_p < size_q:
240
            p = parent[p]
241
            size_p = size[p]
242
        while size_p > size_q:
243
            q = parent[q]
244
            size_q = size[q]
245
        if size_p == size_q:
246
            if p != q:
247
                p = parent[p]
248
                size_p = size[p]
249
                q = parent[q]
250
                size_q = size[q]
251
            else:
252
                return p
253
254
255
@numba.njit(cache=True)
256
def trace_path(p, w, edge, parent):
257
    """Return the nodes and edges on the path from node p to its ancestor w.
258
    """
259
    cycle_nodes = [p]
260
    cycle_edges = []
261
262
    while p != w:
263
        cycle_edges.append(edge[p])
264
        p = parent[p]
265
        cycle_nodes.append(p)
266
267
    return cycle_nodes, cycle_edges
268
269
270
@numba.njit(cache=True)
271
def find_cycle(i, p, q, size, edge, parent):
0 ignored issues
show
best-practice introduced by
Too many arguments (6/5)
Loading history...
272
    """Return the nodes and edges on the cycle containing edge i == (p, q)
273
    when the latter is added to the spanning tree.
274
    The cycle is oriented in the direction from p to q.
275
    """
276
277
    w = find_apex(p, q, size, parent)
278
    cycle_nodes, cycle_edges = trace_path(p, w, edge, parent)
279
    cycle_nodes_rev, cycle_edges_rev = trace_path(q, w, edge, parent)
280
    append_i_to_edges = (len(cycle_edges) < 1 or cycle_edges[-1] != i)
281
    add_to_c_nodes = max(len(cycle_nodes_rev) - 1, 0)
282
    cycle_nodes_ = np.empty(len(cycle_nodes) + add_to_c_nodes, dtype=np.int64)
283
284
    for j in range(len(cycle_nodes)):
285
        cycle_nodes_[j] = cycle_nodes[-(j+1)]
286
    for j in range(add_to_c_nodes):
287
        cycle_nodes_[len(cycle_nodes) + j] = cycle_nodes_rev[j]
288
289
    if append_i_to_edges:
290
        cycle_edges_ = np.empty(len(cycle_edges) + len(cycle_edges_rev) + 1, dtype=np.int64)
291
        cycle_edges_[len(cycle_edges)] = i
292
    else:
293
        cycle_edges_ = np.empty(len(cycle_edges) + len(cycle_edges_rev), dtype=np.int64)
294
295
    for j in range(len(cycle_edges)):
296
        cycle_edges_[j] = cycle_edges[-(j+1)]
297
    for j in range(1, len(cycle_edges_rev) + 1):
298
        cycle_edges_[-j] = cycle_edges_rev[-j]
299
300
    return cycle_nodes_, cycle_edges_
301
302
303
@numba.njit(cache=True)
304
def residual_capacity(i, p, capac, flows, tails):
305
    """Return the residual capacity of an edge i in the direction away
306
    from its endpoint p.
307
    """
308
    if tails[i] == p:
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
309
        return capac[i] - flows[i]
310
    else:
311
        return flows[i]
312
313
314
@numba.njit(cache=True)
315
def find_leaving_edge(cycle_nodes, cycle_edges, capac, flows, tails, heads):
0 ignored issues
show
best-practice introduced by
Too many arguments (6/5)
Loading history...
316
    """Return the leaving edge in a cycle represented by cycle_nodes and
317
    cycle_edges.
318
    """
319
    j, s = cycle_edges[0], cycle_nodes[0]
320
    res_caps_min = residual_capacity(j, s, capac, flows, tails)
321
    for ind in range(1, cycle_edges.shape[0]):
322
        j_, s_ = cycle_edges[ind], cycle_nodes[ind]
323
        res_cap = residual_capacity(j_, s_, capac, flows, tails)
324
        if res_cap < res_caps_min:
325
            res_caps_min = res_cap
326
            j, s = j_, s_
327
328
    t = heads[j] if tails[j] == s else tails[j]
329
    return j, s, t
330
331
332
@numba.njit(cache=True)
333
def augment_flow(cycle_nodes, cycle_edges, f, tails, flows):
334
    """Augment f units of flow along a cycle representing Wn with cycle_edges.
335
    """
336
    for i, p in zip(cycle_edges, cycle_nodes):
337
        if tails[i] == p:
338
            flows[i] += f
339
        else:
340
            flows[i] -= f
341
342
343
@numba.njit(cache=True)
344
def remove_edge(s, t, size, prev, last, next_node, parent, edge):
0 ignored issues
show
best-practice introduced by
Too many arguments (8/5)
Loading history...
345
    """Remove an edge (s, t) where parent[t] == s from the spanning tree.
346
    """
347
    size_t = size[t]
348
    prev_t = prev[t]
349
    last_t = last[t]
350
    next_last_t = next_node[last_t]
351
    # Remove (s, t).
352
    parent[t] = -2
353
    edge[t] = -2
354
    # Remove the subtree rooted at t from the depth-first thread.
355
    next_node[prev_t] = next_last_t
356
    prev[next_last_t] = prev_t
357
    next_node[last_t] = t
358
    prev[t] = last_t
359
360
    # Update the subtree sizes and last descendants of the (old) ancestors
361
    # of t.
362
    while s != -2:
363
        size[s] -= size_t
364
        if last[s] == last_t:
365
            last[s] = prev_t
366
        s = parent[s]
367
368
369
@numba.njit(cache=True)
370
def make_root(q, parent, size, last, prev, next_node, edge):
0 ignored issues
show
best-practice introduced by
Too many arguments (7/5)
Loading history...
371
    """Make a node q the root of its containing subtree.
372
    """
373
    ancestors = []
374
    # -2 means node is checked
375
    while q != -2:
376
        ancestors.insert(0, q)
377
        q = parent[q]
378
379
    for p, q in zip(ancestors[:-1], ancestors[1:]):
0 ignored issues
show
unused-code introduced by
Redefining argument with the local name 'q'
Loading history...
380
        size_p = size[p]
381
        last_p = last[p]
382
        prev_q = prev[q]
383
        last_q = last[q]
384
        next_last_q = next_node[last_q]
385
386
        # Make p a child of q.
387
        parent[p] = q
388
        parent[q] = -2
389
        edge[p] = edge[q]
390
        edge[q] = -2
391
        size[p] = size_p - size[q]
392
        size[q] = size_p
393
394
        # Remove the subtree rooted at q from the depth-first thread.
395
        next_node[prev_q] = next_last_q
396
        prev[next_last_q] = prev_q
397
        next_node[last_q] = q
398
        prev[q] = last_q
399
400
        if last_p == last_q:
401
            last[p] = prev_q
402
            last_p = prev_q
403
404
        # Add the remaining parts of the subtree rooted at p as a subtree
405
        # of q in the depth-first thread.
406
        prev[p] = last_q
407
        next_node[last_q] = p
408
        next_node[last_p] = q
409
        prev[q] = last_p
410
        last[q] = last_p
411
412
413
@numba.njit(cache=True)
414
def add_edge(i, p, q, next_node, prev, last, size, parent, edge):
0 ignored issues
show
best-practice introduced by
Too many arguments (9/5)
Loading history...
415
    """Add an edge (p, q) to the spanning tree where q is the root of a
416
    subtree.
417
    """
418
    last_p = last[p]
419
    next_last_p = next_node[last_p]
420
    size_q = size[q]
421
    last_q = last[q]
422
    # Make q a child of p.
423
    parent[q] = p
424
    edge[q] = i
425
    # Insert the subtree rooted at q into the depth-first thread.
426
    next_node[last_p] = q
427
    prev[q] = last_p
428
    prev[next_last_p] = last_q
429
    next_node[last_q] = next_last_p
430
431
    # Update the subtree sizes and last descendants of the (new) ancestors
432
    # of q.
433
    while p != -2:
434
        size[p] += size_q
435
        if last[p] == last_p:
436
            last[p] = last_q
437
        p = parent[p]
438
439
440
@numba.njit(cache=True)
441
def update_potentials(i, p, q, heads, potentials, costs, last_node, next_node):
0 ignored issues
show
best-practice introduced by
Too many arguments (8/5)
Loading history...
442
    """Update the potentials of the nodes in the subtree rooted at a node
443
    q connected to its parent p by an edge i.
444
    """
445
    if q == heads[i]:
446
        d = potentials[p] - costs[i] - potentials[q]
447
    else:
448
        d = potentials[p] + costs[i] - potentials[q]
449
450
    potentials[q] += d
451
    l = last_node[q]
452
    while q != l:
453
        q = next_node[q]
454
        potentials[q] += d
455