Test Failed
Push — master ( 37d7fb...c02a6e )
by Daniel
07:38
created

amd._emd.network_simplex()   F

Complexity

Conditions 19

Size

Total Lines 186
Code Lines 107

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 107
dl 0
loc 186
rs 0.4199
c 0
b 0
f 0
cc 19
nop 3

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like amd._emd.network_simplex() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""An implementation of the Wasserstein metric (Earth Mover's distance)
2
between two weighted distributions. Aka Earth Mover's distance, the
3
appropriate metric for comparing two PDDs.
4
5
Copyright (C) 2020 Cameron Hargreaves. This code is adapted from the
6
Element Movers Distance project https://github.com/lrcfmd/ElMD.
7
"""
8
9
from typing import Tuple
10
11
import numba
12
import numpy as np
13
import numpy.typing as npt
14
15
16
@numba.njit(cache=True)
17
def network_simplex(
18
        source_demands: npt.NDArray,
19
        sink_demands: npt.NDArray,
20
        network_costs: npt.NDArray
21
) -> Tuple[float, npt.NDArray[np.float64]]:
22
    """Calculate the Earth mover's distance (Wasserstien metric) between
23
    two weighted distributions given by two sets of weights and a cost
24
    matrix.
25
26
    This is a port of the network simplex algorithm implented by Loïc
27
    Séguin-C for the networkx package to allow acceleration with numba.
28
    Copyright (C) 2010 Loïc Séguin-C. [email protected]. All rights
29
    reserved. BSD license.
30
31
    Parameters
32
    ----------
33
    source_demands : :class:`numpy.ndarray`
34
        Weights of the first distribution.
35
    sink_demands : :class:`numpy.ndarray`
36
        Weights of the second distribution.
37
    network_costs : :class:`numpy.ndarray`
38
        Cost matrix of distances between elements of the two
39
        distributions. Shape (len(source_demands), len(sink_demands)).
40
41
    Returns
42
    -------
43
    (emd, plan) : Tuple[float, :class:`numpy.ndarray`]
44
        A tuple of the Earth mover's distance and the optimal matching.
45
46
    References
47
    ----------
48
    [1] Z. Kiraly, P. Kovacs.
49
        Efficient implementation of minimum-cost flow algorithms.
50
        Acta Universitatis Sapientiae, Informatica 4(1), 67--118
51
        (2012).
52
    [2] R. Barr, F. Glover, D. Klingman.
53
        Enhancement of spanning tree labeling procedures for network
54
        optimization. INFOR 17(1), 16--34 (1979).
55
    """
56
57
    n_sources, n_sinks = source_demands.shape[0], sink_demands.shape[0]
58
    network_costs = network_costs.ravel()
59
    n = n_sources + n_sinks
60
    e = n_sources * n_sinks
61
    B = np.int64(np.ceil(np.sqrt(e)))
62
    fp_multiplier = np.int64(1e+6)
63
64
    # Add one additional node for a dummy source and sink
65
    source_d_int = (source_demands * fp_multiplier).astype(np.int64)
66
    sink_d_int = (sink_demands * fp_multiplier).astype(np.int64)
67
    sink_source_sum_diff = np.sum(sink_d_int) - np.sum(source_d_int)
68
69
    if sink_source_sum_diff > 0:
70
        source_d_int[np.argmax(source_d_int)] += sink_source_sum_diff
71
    elif sink_source_sum_diff < 0:
72
        sink_d_int[np.argmax(sink_d_int)] -= sink_source_sum_diff
73
74
    demands = np.empty(n, dtype=np.int64)
75
    demands[:n_sources] = -source_d_int
76
    demands[n_sources:] = sink_d_int
77
    tails = np.empty(e + n, dtype=np.int64)
78
    heads = np.empty(e + n, dtype=np.int64)
79
80
    ind = 0
81
    for i in range(n_sources):
82
        for j in range(n_sinks):
83
            tails[ind] = i
84
            heads[ind] = n_sources + j
85
            ind += 1
86
87
    for i, demand in enumerate(demands):
88
        ind = e + i
89
        if demand > 0:
90
            tails[ind] = -1
91
            heads[ind] = -1
92
        else:
93
            tails[ind] = i
94
            heads[ind] = i
95
96
    # Create costs and capacities for the arcs between nodes
97
    network_costs = network_costs * fp_multiplier
98
    network_capac = np.empty(shape=(e, ), dtype=np.float64)
99
    ind = 0
100
    for i in range(n_sources):
101
        for j in range(n_sinks):
102
            network_capac[ind] = min(source_demands[i], sink_demands[j])
103
            ind += 1
104
    network_capac *= fp_multiplier
105
106
    faux_inf = 3 * max(
107
        np.sum(network_capac), np.sum(np.abs(network_costs)),
108
        np.amax(source_d_int), np.amax(sink_d_int)
109
    )
110
111
    costs = np.empty(e + n, dtype=np.int64)
112
    costs[:e] = network_costs
113
    costs[e:] = faux_inf
114
115
    capac = np.empty(e + n, dtype=np.int64)
116
    capac[:e] = network_capac
117
    capac[e:] = fp_multiplier
118
119
    flows = np.empty(e + n, dtype=np.int64)
120
    flows[:e] = 0
121
    flows[e:e+n_sources] = source_d_int
122
    flows[e+n_sources:] = sink_d_int
123
124
    potentials = np.empty(n, dtype=np.int64)
125
    demands_neg_mask = demands <= 0
126
    potentials[demands_neg_mask] = faux_inf
127
    potentials[~demands_neg_mask] = -faux_inf
128
129
    parent = np.full(shape=(n + 1, ), fill_value=-1, dtype=np.int64)
130
    parent[-1] = -2
131
132
    size = np.full(shape=(n + 1, ), fill_value=1, dtype=np.int64)
133
    size[-1] = n + 1
134
135
    next_node = np.arange(1, n + 2, dtype=np.int64)
136
    next_node[-2] = -1
137
    next_node[-1] = 0
138
139
    last_node = np.arange(n + 1, dtype=np.int64)
140
    last_node[-1] = n - 1
141
142
    prev_node = np.arange(-1, n, dtype=np.int64)
143
    edge = np.arange(e, e + n, dtype=np.int64)
144
145
    # Pivot loop
146
147
    f = 0
148
    while True:
149
        i, p, q, f = find_entering_edges(
150
            B, e, f, tails, heads, costs, potentials, flows
151
        )
152
        # If no entering edges then the optimal score is found
153
        if p == -1:
154
            break
155
156
        cycle_nodes, cycle_edges = find_cycle(i, p, q, size, edge, parent)
157
        j, s, t = find_leaving_edge(
158
            cycle_nodes, cycle_edges, capac, flows, tails, heads
159
        )
160
        res_cap = capac[j] - flows[j] if tails[j] == s else flows[j]
161
162
        # Augment flows
163
        for i_, p_ in zip(cycle_edges, cycle_nodes):
164
            if tails[i_] == p_:
165
                flows[i_] += res_cap
166
            else:
167
                flows[i_] -= res_cap
168
169
        # Do nothing more if the entering edge is the same as the
170
        # leaving edge.
171
        if i != j:
172
            if parent[t] != s:
173
                # Ensure that s is the parent of t.
174
                s, t = t, s
175
176
            # Ensure that q is in the subtree rooted at t.
177
            for val in cycle_edges:
178
                if val == j:
179
                    p, q = q, p
180
                    break
181
                if val == i:
182
                    break
183
184
            remove_edge(
185
                s, t, size, prev_node, last_node, next_node, parent, edge
186
            )
187
            make_root(
188
                q, parent, size, last_node, prev_node, next_node, edge
189
            )
190
            add_edge(
191
                i, p, q, next_node, prev_node, last_node, size, parent, edge
192
            )
193
            update_potentials(
194
                i, p, q, heads, potentials, costs, last_node, next_node
195
            )
196
197
    final_flows = flows[:e] / fp_multiplier
198
    edge_costs = costs[:e] / fp_multiplier
199
    emd = final_flows.dot(edge_costs)
200
201
    return emd, final_flows.reshape((n_sources, n_sinks))
202
203
204
@numba.njit(cache=True)
205
def reduced_cost(i, costs, potentials, tails, heads, flows):
206
    """Return the reduced cost of an edge i."""
207
    c = costs[i] - potentials[tails[i]] + potentials[heads[i]]
208
    if flows[i] == 0:
209
        return c
210
    return -c
211
212
213
@numba.njit(cache=True)
214
def find_entering_edges(B, e, f, tails, heads, costs, potentials, flows):
215
    """Yield entering edges until none can be found. Entering edges are
216
    found by combining Dantzig's rule and Bland's rule. The edges are
217
    cyclically grouped into blocks of size B. Within each block,
218
    Dantzig's rule is applied to find an entering edge. The blocks to
219
    search is determined following Bland's rule.
220
    """
221
222
    m = 0
223
224
    while m < (e + B - 1) // B:
225
        # Determine the next block of edges.
226
        l = f + B
227
        if l <= e:
228
            edge_inds = np.arange(f, l)
229
        else:
230
            l -= e
231
            edge_inds = np.empty(e - f + l, dtype=np.int64)
232
            for i, v in enumerate(range(f, e)):
233
                edge_inds[i] = v
234
            for i in range(l):
235
                edge_inds[e-f+i] = i
236
237
        # Find the first edge with the lowest reduced cost.
238
        f = l
239
        i = edge_inds[0]
240
        c = reduced_cost(i, costs, potentials, tails, heads, flows)
241
242
        for j in edge_inds[1:]:
243
            cost = reduced_cost(j, costs, potentials, tails, heads, flows)
244
            if cost < c:
245
                c = cost
246
                i = j
247
248
        p = q = -1
249
        if c >= 0:
250
            m += 1
251
252
        # Entering edge found
253
        else:
254
            if flows[i] == 0:
255
                p = tails[i]
256
                q = heads[i]
257
            else:
258
                p = heads[i]
259
                q = tails[i]
260
261
            return i, p, q, f
262
263
    # All edges have nonnegative reduced costs, the flow is optimal
264
    return -1, -1, -1, -1
265
266
267
@numba.njit(cache=True)
268
def find_apex(p, q, size, parent):
269
    """Find the lowest common ancestor of nodes p and q in the spanning
270
    tree.
271
    """
272
273
    size_p = size[p]
274
    size_q = size[q]
275
276
    while True:
277
        while size_p < size_q:
278
            p = parent[p]
279
            size_p = size[p]
280
        while size_p > size_q:
281
            q = parent[q]
282
            size_q = size[q]
283
        if size_p == size_q:
284
            if p != q:
285
                p = parent[p]
286
                size_p = size[p]
287
                q = parent[q]
288
                size_q = size[q]
289
            else:
290
                return p
291
292
293
@numba.njit(cache=True)
294
def trace_path(p, w, edge, parent):
295
    """Return the nodes and edges on the path from node p to its
296
    ancestor w.
297
    """
298
299
    cycle_nodes = [p]
300
    cycle_edges = []
301
302
    while p != w:
303
        cycle_edges.append(edge[p])
304
        p = parent[p]
305
        cycle_nodes.append(p)
306
307
    return cycle_nodes, cycle_edges
308
309
310
@numba.njit(cache=True)
311
def find_cycle(i, p, q, size, edge, parent):
312
    """Return the nodes and edges on the cycle containing edge
313
    i == (p, q) when the latter is added to the spanning tree. The cycle
314
    is oriented in the direction from p to q.
315
    """
316
317
    w = find_apex(p, q, size, parent)
318
    cycle_nodes, cycle_edges = trace_path(p, w, edge, parent)
319
    cycle_nodes_rev, cycle_edges_rev = trace_path(q, w, edge, parent)
320
    len_cycle_nodes = len(cycle_nodes)
321
    add_to_c_nodes = max(len(cycle_nodes_rev) - 1, 0)
322
    cycle_nodes_ = np.empty(len_cycle_nodes + add_to_c_nodes, dtype=np.int64)
323
324
    for j in range(len_cycle_nodes):
325
        cycle_nodes_[j] = cycle_nodes[-(j+1)]
326
    for j in range(add_to_c_nodes):
327
        cycle_nodes_[len_cycle_nodes+j] = cycle_nodes_rev[j]
328
329
    len_cycle_edges = len(cycle_edges)
330
    len_cycle_edges_ = len_cycle_edges + len(cycle_edges_rev)
331
    if len_cycle_edges < 1 or cycle_edges[-1] != i:
332
        cycle_edges_ = np.empty(len_cycle_edges_ + 1, dtype=np.int64)
333
        cycle_edges_[len_cycle_edges] = i
334
    else:
335
        cycle_edges_ = np.empty(len_cycle_edges_, dtype=np.int64)
336
337
    for j in range(len_cycle_edges):
338
        cycle_edges_[j] = cycle_edges[-(j+1)]
339
    for j in range(1, len(cycle_edges_rev) + 1):
340
        cycle_edges_[-j] = cycle_edges_rev[-j]
341
342
    return cycle_nodes_, cycle_edges_
343
344
345
@numba.njit(cache=True)
346
def find_leaving_edge(cycle_nodes, cycle_edges, capac, flows, tails, heads):
347
    """Return the leaving edge in a cycle represented by cycle_nodes and
348
    cycle_edges.
349
    """
350
351
    j, s = cycle_edges[0], cycle_nodes[0]
352
    res_caps_min = capac[j] - flows[j] if tails[j] == s else flows[j]
353
354
    for ind in range(1, cycle_edges.shape[0]):
355
        j_, s_ = cycle_edges[ind], cycle_nodes[ind]
356
        res_cap = capac[j_] - flows[j_] if tails[j_] == s_ else flows[j_]
357
        if res_cap < res_caps_min:
358
            res_caps_min = res_cap
359
            j, s = j_, s_
360
361
    t = heads[j] if tails[j] == s else tails[j]
362
    return j, s, t
363
364
365
@numba.njit(cache=True)
366
def remove_edge(s, t, size, prev, last, next_node, parent, edge):
367
    """Remove an edge (s, t) where parent[t] == s from the spanning
368
    tree.
369
    """
370
371
    size_t = size[t]
372
    prev_t = prev[t]
373
    last_t = last[t]
374
    next_last_t = next_node[last_t]
375
    # Remove (s, t)
376
    parent[t] = -2
377
    edge[t] = -2
378
    # Remove the subtree rooted at t from the depth-first thread
379
    next_node[prev_t] = next_last_t
380
    prev[next_last_t] = prev_t
381
    next_node[last_t] = t
382
    prev[t] = last_t
383
384
    # Update the subtree sizes & last descendants of the (old) ancestors of t
385
    while s != -2:
386
        size[s] -= size_t
387
        if last[s] == last_t:
388
            last[s] = prev_t
389
        s = parent[s]
390
391
392
@numba.njit(cache=True)
393
def make_root(q, parent, size, last, prev, next_node, edge):
394
    """Make a node q the root of its containing subtree."""
395
396
    ancestors = []
397
    # -2 means node is checked
398
    while q != -2:
399
        ancestors.insert(0, q)
400
        q = parent[q]
401
402
    for i in range(len(ancestors) - 1):
403
        p = ancestors[i]
404
        q = ancestors[i+1]
405
        size_p = size[p]
406
        last_p = last[p]
407
        prev_q = prev[q]
408
        last_q = last[q]
409
        next_last_q = next_node[last_q]
410
411
        # Make p a child of q
412
        parent[p] = q
413
        parent[q] = -2
414
        edge[p] = edge[q]
415
        edge[q] = -2
416
        size[p] = size_p - size[q]
417
        size[q] = size_p
418
419
        # Remove the subtree rooted at q from the depth-first thread
420
        next_node[prev_q] = next_last_q
421
        prev[next_last_q] = prev_q
422
        next_node[last_q] = q
423
        prev[q] = last_q
424
425
        if last_p == last_q:
426
            last[p] = prev_q
427
            last_p = prev_q
428
429
        # Add the remaining parts of the subtree rooted at p as a subtree of q
430
        # in the depth-first thread
431
        prev[p] = last_q
432
        next_node[last_q] = p
433
        next_node[last_p] = q
434
        prev[q] = last_p
435
        last[q] = last_p
436
437
438
@numba.njit(cache=True)
439
def add_edge(i, p, q, next_node, prev, last, size, parent, edge):
440
    """Add an edge (p, q) to the spanning tree where q is the root of a
441
    subtree.
442
    """
443
444
    last_p = last[p]
445
    next_last_p = next_node[last_p]
446
    size_q = size[q]
447
    last_q = last[q]
448
    # Make q a child of p
449
    parent[q] = p
450
    edge[q] = i
451
    # Insert the subtree rooted at q into the depth-first thread
452
    next_node[last_p] = q
453
    prev[q] = last_p
454
    prev[next_last_p] = last_q
455
    next_node[last_q] = next_last_p
456
457
    # Update the subtree sizes and last descendants of the (new) ancestors of q
458
    while p != -2:
459
        size[p] += size_q
460
        if last[p] == last_p:
461
            last[p] = last_q
462
        p = parent[p]
463
464
465
@numba.njit(cache=True)
466
def update_potentials(i, p, q, heads, potentials, costs, last_node, next_node):
467
    """Update the potentials of the nodes in the subtree rooted at a
468
    node q connected to its parent p by an edge i.
469
    """
470
471
    if q == heads[i]:
472
        d = potentials[p] - costs[i] - potentials[q]
473
    else:
474
        d = potentials[p] + costs[i] - potentials[q]
475
476
    potentials[q] += d
477
    l = last_node[q]
478
    while q != l:
479
        q = next_node[q]
480
        potentials[q] += d
481