Passed
Push — master ( 8c12c2...4daa36 )
by Daniel
07:46
created

amd._emd   F

Complexity

Total Complexity 64

Size/Duplication

Total Lines 479
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 64
eloc 286
dl 0
loc 479
rs 3.28
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like amd._emd often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""An implementation of the Wasserstein metric (Earth Mover's distance)
2
between two weighted distributions, used as the metric for comparing two
3
pointwise distance distributions (PDDs), see
4
:func:`amd.PDD <.calculate.PDD>`
5
6
Copyright (C) 2020 Cameron Hargreaves. This code is adapted from the
7
Element Movers Distance project https://github.com/lrcfmd/ElMD.
8
"""
9
10
from typing import Tuple
11
12
import numba
13
import numpy as np
14
15
__all__ = ['network_simplex']
16
17
18
@numba.njit(cache=True)
19
def network_simplex(
20
        source_demands: np.ndarray,
21
        sink_demands: np.ndarray,
22
        network_costs: np.ndarray
23
) -> Tuple[float, np.ndarray]:
24
    """Calculate the Earth mover's distance (Wasserstien metric) between
25
    two weighted distributions given by two sets of weights and a cost
26
    matrix.
27
28
    This is a port of the network simplex algorithm implented by Loïc
29
    Séguin-C for the networkx package to allow acceleration with numba.
30
    Copyright (C) 2010 Loïc Séguin-C. [email protected]. All rights
31
    reserved. BSD license.
32
33
    Parameters
34
    ----------
35
    source_demands : :class:`numpy.ndarray`
36
        Weights of the first distribution.
37
    sink_demands : :class:`numpy.ndarray`
38
        Weights of the second distribution.
39
    network_costs : :class:`numpy.ndarray`
40
        Cost matrix of distances between elements of the two
41
        distributions. Shape (len(source_demands), len(sink_demands)).
42
43
    Returns
44
    -------
45
    (emd, plan) : Tuple[float, :class:`numpy.ndarray`]
46
        A tuple of the Earth mover's distance and the optimal matching.
47
48
    References
49
    ----------
50
    [1] Z. Kiraly, P. Kovacs.
51
        Efficient implementation of minimum-cost flow algorithms.
52
        Acta Universitatis Sapientiae, Informatica 4(1), 67--118
53
        (2012).
54
    [2] R. Barr, F. Glover, D. Klingman.
55
        Enhancement of spanning tree labeling procedures for network
56
        optimization. INFOR 17(1), 16--34 (1979).
57
    """
58
59
    n_sources, n_sinks = source_demands.shape[0], sink_demands.shape[0]
60
    n = n_sources + n_sinks
61
    e = n_sources * n_sinks
62
    B = np.int64(np.ceil(np.sqrt(e)))
63
    fp_multiplier = np.float64(1_000_000)
64
65
    # Add one additional node for a dummy source and sink
66
    source_d_int = (source_demands * fp_multiplier).astype(np.int64)
67
    sink_d_int = (sink_demands * fp_multiplier).astype(np.int64)
68
    sink_source_sum_diff = np.sum(sink_d_int) - np.sum(source_d_int)
69
70
    if sink_source_sum_diff > 0:
71
        source_d_int[np.argmax(source_d_int)] += sink_source_sum_diff
72
    elif sink_source_sum_diff < 0:
73
        sink_d_int[np.argmax(sink_d_int)] -= sink_source_sum_diff
74
75
    demands = np.empty(n, dtype=np.int64)
76
    demands[:n_sources] = -source_d_int
77
    demands[n_sources:] = sink_d_int
78
    tails = np.empty(e + n, dtype=np.int64)
79
    heads = np.empty(e + n, dtype=np.int64)
80
81
    ind = 0
82
    for i in range(n_sources):
83
        for j in range(n_sinks):
84
            tails[ind] = i
85
            heads[ind] = n_sources + j
86
            ind += 1
87
88
    for i, demand in enumerate(demands):
89
        ind = e + i
90
        if demand > 0:
91
            tails[ind] = -1
92
            heads[ind] = -1
93
        else:
94
            tails[ind] = i
95
            heads[ind] = i
96
97
    # Create costs and capacities for the arcs between nodes
98
    network_costs = (network_costs.ravel() * fp_multiplier).astype(np.int64)
99
    network_capac = np.empty(shape=(e, ), dtype=np.int64)
100
    ind = 0
101
    for i in range(n_sources):
102
        for j in range(n_sinks):
103
            network_capac[ind] = np.int64(
104
                min(source_demands[i], sink_demands[j]) * fp_multiplier
105
            )
106
            ind += 1
107
108
    # In amd network_costs are always positive, otherwise take abs here
109
    faux_inf = np.int64(3 * max(
110
        np.sum(network_costs),
111
        np.sum(network_capac),
112
        np.amax(source_d_int),
113
        np.amax(sink_d_int)
114
    ))
115
116
    costs = np.empty(e + n, dtype=np.int64)
117
    costs[:e] = network_costs
118
    costs[e:] = faux_inf
119
120
    capac = np.empty(e + n, dtype=np.int64)
121
    capac[:e] = network_capac
122
    capac[e:] = fp_multiplier
123
124
    flows = np.empty(e + n, dtype=np.int64)
125
    flows[:e] = 0
126
    flows[e:e+n_sources] = source_d_int
127
    flows[e+n_sources:] = sink_d_int
128
129
    potentials = np.empty(n, dtype=np.int64)
130
    demands_neg_mask = demands <= 0
131
    potentials[demands_neg_mask] = faux_inf
132
    potentials[~demands_neg_mask] = -faux_inf
133
134
    parent = np.full(shape=(n + 1, ), fill_value=-1, dtype=np.int64)
135
    parent[-1] = -2
136
137
    size = np.full(shape=(n + 1, ), fill_value=1, dtype=np.int64)
138
    size[-1] = n + 1
139
140
    next_node = np.arange(1, n + 2, dtype=np.int64)
141
    next_node[-2] = -1
142
    next_node[-1] = 0
143
144
    last_node = np.arange(n + 1, dtype=np.int64)
145
    last_node[-1] = n - 1
146
147
    prev_node = np.arange(-1, n, dtype=np.int64)
148
    edge = np.arange(e, e + n, dtype=np.int64)
149
150
    # Pivot loop
151
    f = 0
152
    while True:
153
        i, p, q, f = _find_entering_edges(
154
            B, e, f, tails, heads, costs, potentials, flows
155
        )
156
        # If no entering edges then the optimal score is found
157
        if p == -1:
158
            break
159
160
        cycle_nodes, cycle_edges = _find_cycle(i, p, q, size, edge, parent)
161
        j, s, t = _find_leaving_edge(
162
            cycle_nodes, cycle_edges, capac, flows, tails, heads
163
        )
164
        res_cap = capac[j] - flows[j] if tails[j] == s else flows[j]
165
166
        # Augment flows
167
        for i_, p_ in zip(cycle_edges, cycle_nodes):
168
            if tails[i_] == p_:
169
                flows[i_] += res_cap
170
            else:
171
                flows[i_] -= res_cap
172
173
        # Do nothing more if the entering edge is the same as the leaving edge.
174
        if i != j:
175
            if parent[t] != s:
176
                # Ensure that s is the parent of t.
177
                s, t = t, s
178
179
            # Ensure that q is in the subtree rooted at t.
180
            for val in cycle_edges:
181
                if val == j:
182
                    p, q = q, p
183
                    break
184
                if val == i:
185
                    break
186
187
            _remove_edge(s, t, size, prev_node, last_node, next_node, parent, edge)
188
            _make_root(q, parent, size, last_node, prev_node, next_node, edge)
189
            _add_edge(i, p, q, next_node, prev_node, last_node, size, parent, edge)
190
            _update_potentials(i, p, q, heads, potentials, costs, last_node, next_node)
191
192
    final_flows = flows[:e] / fp_multiplier
193
    edge_costs = costs[:e] / fp_multiplier
194
    emd = final_flows.dot(edge_costs)
195
196
    return emd, final_flows.reshape((n_sources, n_sinks))
197
198
199
@numba.njit(cache=True)
200
def _reduced_cost(i, costs, potentials, tails, heads, flows):
201
    """Return the reduced cost of an edge i."""
202
    c = costs[i] - potentials[tails[i]] + potentials[heads[i]]
203
    if flows[i] == 0:
204
        return c
205
    return -c
206
207
208
@numba.njit(cache=True)
209
def _find_entering_edges(B, e, f, tails, heads, costs, potentials, flows):
210
    """Yield entering edges until none can be found. Entering edges are
211
    found by combining Dantzig's rule and Bland's rule. The edges are
212
    cyclically grouped into blocks of size B. Within each block,
213
    Dantzig's rule is applied to find an entering edge. The blocks to
214
    search is determined following Bland's rule.
215
    """
216
217
    m = 0
218
    while m < (e + B - 1) // B:
219
        # Determine the next block of edges.
220
        l = f + B
221
        if l <= e:
222
            edge_inds = np.arange(f, l)
223
        else:
224
            l -= e
225
            edge_inds = np.empty(e - f + l, dtype=np.int64)
226
            for i, v in enumerate(range(f, e)):
227
                edge_inds[i] = v
228
            for i in range(l):
229
                edge_inds[e-f+i] = i
230
231
        # Find the first edge with the lowest reduced cost.
232
        f = l
233
        i = edge_inds[0]
234
        c = _reduced_cost(i, costs, potentials, tails, heads, flows)
235
236
        for j in edge_inds[1:]:
237
            cost = _reduced_cost(j, costs, potentials, tails, heads, flows)
238
            if cost < c:
239
                c = cost
240
                i = j
241
242
        p = q = -1
243
        if c >= 0:
244
            m += 1
245
246
        # Entering edge found
247
        else:
248
            if flows[i] == 0:
249
                p = tails[i]
250
                q = heads[i]
251
            else:
252
                p = heads[i]
253
                q = tails[i]
254
255
            return i, p, q, f
256
257
    # All edges have nonnegative reduced costs, the flow is optimal
258
    return -1, -1, -1, -1
259
260
261
@numba.njit(cache=True)
262
def _find_apex(p, q, size, parent):
263
    """Find the lowest common ancestor of nodes p and q in the spanning
264
    tree.
265
    """
266
267
    size_p = size[p]
268
    size_q = size[q]
269
270
    while True:
271
        while size_p < size_q:
272
            p = parent[p]
273
            size_p = size[p]
274
        while size_p > size_q:
275
            q = parent[q]
276
            size_q = size[q]
277
        if size_p == size_q:
278
            if p != q:
279
                p = parent[p]
280
                size_p = size[p]
281
                q = parent[q]
282
                size_q = size[q]
283
            else:
284
                return p
285
286
287
@numba.njit(cache=True)
288
def _trace_path(p, w, edge, parent):
289
    """Return the nodes and edges on the path from node p to its
290
    ancestor w.
291
    """
292
293
    cycle_nodes = [p]
294
    cycle_edges = []
295
296
    while p != w:
297
        cycle_edges.append(edge[p])
298
        p = parent[p]
299
        cycle_nodes.append(p)
300
301
    return cycle_nodes, cycle_edges
302
303
304
@numba.njit(cache=True)
305
def _find_cycle(i, p, q, size, edge, parent):
306
    """Return the nodes and edges on the cycle containing edge
307
    i == (p, q) when the latter is added to the spanning tree. The cycle
308
    is oriented in the direction from p to q.
309
    """
310
311
    w = _find_apex(p, q, size, parent)
312
    cycle_nodes, cycle_edges = _trace_path(p, w, edge, parent)
313
    cycle_nodes_rev, cycle_edges_rev = _trace_path(q, w, edge, parent)
314
    len_cycle_nodes = len(cycle_nodes)
315
    add_to_c_nodes = max(len(cycle_nodes_rev) - 1, 0)
316
    cycle_nodes_ = np.empty(len_cycle_nodes + add_to_c_nodes, dtype=np.int64)
317
318
    for j in range(len_cycle_nodes):
319
        cycle_nodes_[j] = cycle_nodes[-(j+1)]
320
    for j in range(add_to_c_nodes):
321
        cycle_nodes_[len_cycle_nodes+j] = cycle_nodes_rev[j]
322
323
    len_cycle_edges = len(cycle_edges)
324
    len_cycle_edges_ = len_cycle_edges + len(cycle_edges_rev)
325
    if len_cycle_edges < 1 or cycle_edges[-1] != i:
326
        cycle_edges_ = np.empty(len_cycle_edges_ + 1, dtype=np.int64)
327
        cycle_edges_[len_cycle_edges] = i
328
    else:
329
        cycle_edges_ = np.empty(len_cycle_edges_, dtype=np.int64)
330
331
    for j in range(len_cycle_edges):
332
        cycle_edges_[j] = cycle_edges[-(j+1)]
333
    for j in range(1, len(cycle_edges_rev) + 1):
334
        cycle_edges_[-j] = cycle_edges_rev[-j]
335
336
    return cycle_nodes_, cycle_edges_
337
338
339
@numba.njit(cache=True)
340
def _find_leaving_edge(cycle_nodes, cycle_edges, capac, flows, tails, heads):
341
    """Return the leaving edge in a cycle represented by cycle_nodes and
342
    cycle_edges.
343
    """
344
345
    j, s = cycle_edges[0], cycle_nodes[0]
346
    res_caps_min = capac[j] - flows[j] if tails[j] == s else flows[j]
347
348
    for ind in range(1, cycle_edges.shape[0]):
349
        j_, s_ = cycle_edges[ind], cycle_nodes[ind]
350
        res_cap = capac[j_] - flows[j_] if tails[j_] == s_ else flows[j_]
351
        if res_cap < res_caps_min:
352
            res_caps_min = res_cap
353
            j, s = j_, s_
354
355
    t = heads[j] if tails[j] == s else tails[j]
356
    return j, s, t
357
358
359
@numba.njit(cache=True)
360
def _remove_edge(s, t, size, prev, last, next_node, parent, edge):
361
    """Remove an edge (s, t) where parent[t] == s from the spanning
362
    tree.
363
    """
364
365
    size_t = size[t]
366
    prev_t = prev[t]
367
    last_t = last[t]
368
    next_last_t = next_node[last_t]
369
    # Remove (s, t)
370
    parent[t] = -2
371
    edge[t] = -2
372
    # Remove the subtree rooted at t from the depth-first thread
373
    next_node[prev_t] = next_last_t
374
    prev[next_last_t] = prev_t
375
    next_node[last_t] = t
376
    prev[t] = last_t
377
    # Update the subtree sizes & last descendants of the (old) ancestors of t
378
    while s != -2:
379
        size[s] -= size_t
380
        if last[s] == last_t:
381
            last[s] = prev_t
382
        s = parent[s]
383
384
385
@numba.njit(cache=True)
386
def _make_root(q, parent, size, last, prev, next_node, edge):
387
    """Make a node q the root of its containing subtree."""
388
389
    ancestors = []
390
    # -2 means node is checked
391
    while q != -2:
392
        ancestors.insert(0, q)
393
        q = parent[q]
394
395
    for i in range(len(ancestors) - 1):
396
        p = ancestors[i]
397
        q = ancestors[i+1]
398
        size_p = size[p]
399
        last_p = last[p]
400
        prev_q = prev[q]
401
        last_q = last[q]
402
        next_last_q = next_node[last_q]
403
        # Make p a child of q
404
        parent[p] = q
405
        parent[q] = -2
406
        edge[p] = edge[q]
407
        edge[q] = -2
408
        size[p] = size_p - size[q]
409
        size[q] = size_p
410
        # Remove the subtree rooted at q from the depth-first thread
411
        next_node[prev_q] = next_last_q
412
        prev[next_last_q] = prev_q
413
        next_node[last_q] = q
414
        prev[q] = last_q
415
        if last_p == last_q:
416
            last[p] = prev_q
417
            last_p = prev_q
418
        # Add the remaining parts of the subtree rooted at p as a subtree of q
419
        # in the depth-first thread
420
        prev[p] = last_q
421
        next_node[last_q] = p
422
        next_node[last_p] = q
423
        prev[q] = last_p
424
        last[q] = last_p
425
426
427
@numba.njit(cache=True)
428
def _add_edge(i, p, q, next_node, prev, last, size, parent, edge):
429
    """Add an edge (p, q) to the spanning tree where q is the root of a
430
    subtree.
431
    """
432
433
    last_p = last[p]
434
    next_last_p = next_node[last_p]
435
    size_q = size[q]
436
    last_q = last[q]
437
    # Make q a child of p
438
    parent[q] = p
439
    edge[q] = i
440
    # Insert the subtree rooted at q into the depth-first thread
441
    next_node[last_p] = q
442
    prev[q] = last_p
443
    prev[next_last_p] = last_q
444
    next_node[last_q] = next_last_p
445
    # Update the subtree sizes and last descendants of the (new) ancestors of q
446
    while p != -2:
447
        size[p] += size_q
448
        if last[p] == last_p:
449
            last[p] = last_q
450
        p = parent[p]
451
452
453
@numba.njit(cache=True)
454
def _update_potentials(
455
        i, p, q, heads, potentials, costs, last_node, next_node
456
):
457
    """Update the potentials of the nodes in the subtree rooted at a
458
    node q connected to its parent p by an edge i.
459
    """
460
461
    if q == heads[i]:
462
        d = potentials[p] - costs[i] - potentials[q]
463
    else:
464
        d = potentials[p] + costs[i] - potentials[q]
465
    potentials[q] += d
466
    l = last_node[q]
467
    while q != l:
468
        q = next_node[q]
469
        potentials[q] += d
470