Passed
Push — master ( 7c2d4e...5cae11 )
by Daniel
06:09
created

amd._emd._update_potentials()   A

Complexity

Conditions 3

Size

Total Lines 18
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 11
dl 0
loc 18
rs 9.85
c 0
b 0
f 0
cc 3
nop 8

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""An implementation of the Wasserstein metric (Earth Mover's distance)
2
between two weighted distributions, used as the metric for comparing two
3
pointwise distance distributions (PDDs), see
4
:func:`amd.PDD <.calculate.PDD>`
5
6
Copyright (C) 2020 Cameron Hargreaves. This code is adapted from the
7
Element Movers Distance project https://github.com/lrcfmd/ElMD.
8
"""
9
10
from typing import Tuple
11
12
import numba
13
import numpy as np
14
import numpy.typing as npt
15
16
17
@numba.njit(cache=True)
18
def network_simplex(
19
        source_demands: npt.NDArray,
20
        sink_demands: npt.NDArray,
21
        network_costs: npt.NDArray
22
) -> Tuple[float, npt.NDArray[np.float64]]:
23
    """Calculate the Earth mover's distance (Wasserstien metric) between
24
    two weighted distributions given by two sets of weights and a cost
25
    matrix.
26
27
    This is a port of the network simplex algorithm implented by Loïc
28
    Séguin-C for the networkx package to allow acceleration with numba.
29
    Copyright (C) 2010 Loïc Séguin-C. [email protected]. All rights
30
    reserved. BSD license.
31
32
    Parameters
33
    ----------
34
    source_demands : :class:`numpy.ndarray`
35
        Weights of the first distribution.
36
    sink_demands : :class:`numpy.ndarray`
37
        Weights of the second distribution.
38
    network_costs : :class:`numpy.ndarray`
39
        Cost matrix of distances between elements of the two
40
        distributions. Shape (len(source_demands), len(sink_demands)).
41
42
    Returns
43
    -------
44
    (emd, plan) : Tuple[float, :class:`numpy.ndarray`]
45
        A tuple of the Earth mover's distance and the optimal matching.
46
47
    References
48
    ----------
49
    [1] Z. Kiraly, P. Kovacs.
50
        Efficient implementation of minimum-cost flow algorithms.
51
        Acta Universitatis Sapientiae, Informatica 4(1), 67--118
52
        (2012).
53
    [2] R. Barr, F. Glover, D. Klingman.
54
        Enhancement of spanning tree labeling procedures for network
55
        optimization. INFOR 17(1), 16--34 (1979).
56
    """
57
58
    n_sources, n_sinks = source_demands.shape[0], sink_demands.shape[0]
59
    network_costs = network_costs.ravel()
60
    n = n_sources + n_sinks
61
    e = n_sources * n_sinks
62
    B = np.int64(np.ceil(np.sqrt(e)))
63
    fp_multiplier = np.int64(1e+6)
64
65
    # Add one additional node for a dummy source and sink
66
    source_d_int = (source_demands * fp_multiplier).astype(np.int64)
67
    sink_d_int = (sink_demands * fp_multiplier).astype(np.int64)
68
    sink_source_sum_diff = np.sum(sink_d_int) - np.sum(source_d_int)
69
70
    if sink_source_sum_diff > 0:
71
        source_d_int[np.argmax(source_d_int)] += sink_source_sum_diff
72
    elif sink_source_sum_diff < 0:
73
        sink_d_int[np.argmax(sink_d_int)] -= sink_source_sum_diff
74
75
    demands = np.empty(n, dtype=np.int64)
76
    demands[:n_sources] = -source_d_int
77
    demands[n_sources:] = sink_d_int
78
    tails = np.empty(e + n, dtype=np.int64)
79
    heads = np.empty(e + n, dtype=np.int64)
80
81
    ind = 0
82
    for i in range(n_sources):
83
        for j in range(n_sinks):
84
            tails[ind] = i
85
            heads[ind] = n_sources + j
86
            ind += 1
87
88
    for i, demand in enumerate(demands):
89
        ind = e + i
90
        if demand > 0:
91
            tails[ind] = -1
92
            heads[ind] = -1
93
        else:
94
            tails[ind] = i
95
            heads[ind] = i
96
97
    # Create costs and capacities for the arcs between nodes
98
    network_costs = network_costs * fp_multiplier
99
    network_capac = np.empty(shape=(e, ), dtype=np.float64)
100
    ind = 0
101
    for i in range(n_sources):
102
        for j in range(n_sinks):
103
            network_capac[ind] = min(source_demands[i], sink_demands[j])
104
            ind += 1
105
    network_capac *= fp_multiplier
106
107
    faux_inf = 3 * max(
108
        np.sum(network_capac), np.sum(np.abs(network_costs)),
109
        np.amax(source_d_int), np.amax(sink_d_int)
110
    )
111
112
    costs = np.empty(e + n, dtype=np.int64)
113
    costs[:e] = network_costs
114
    costs[e:] = faux_inf
115
116
    capac = np.empty(e + n, dtype=np.int64)
117
    capac[:e] = network_capac
118
    capac[e:] = fp_multiplier
119
120
    flows = np.empty(e + n, dtype=np.int64)
121
    flows[:e] = 0
122
    flows[e:e+n_sources] = source_d_int
123
    flows[e+n_sources:] = sink_d_int
124
125
    potentials = np.empty(n, dtype=np.int64)
126
    demands_neg_mask = demands <= 0
127
    potentials[demands_neg_mask] = faux_inf
128
    potentials[~demands_neg_mask] = -faux_inf
129
130
    parent = np.full(shape=(n + 1, ), fill_value=-1, dtype=np.int64)
131
    parent[-1] = -2
132
133
    size = np.full(shape=(n + 1, ), fill_value=1, dtype=np.int64)
134
    size[-1] = n + 1
135
136
    next_node = np.arange(1, n + 2, dtype=np.int64)
137
    next_node[-2] = -1
138
    next_node[-1] = 0
139
140
    last_node = np.arange(n + 1, dtype=np.int64)
141
    last_node[-1] = n - 1
142
143
    prev_node = np.arange(-1, n, dtype=np.int64)
144
    edge = np.arange(e, e + n, dtype=np.int64)
145
146
    # Pivot loop
147
148
    f = 0
149
    while True:
150
        i, p, q, f = _find_entering_edges(
151
            B, e, f, tails, heads, costs, potentials, flows
152
        )
153
        # If no entering edges then the optimal score is found
154
        if p == -1:
155
            break
156
157
        cycle_nodes, cycle_edges = _find_cycle(i, p, q, size, edge, parent)
158
        j, s, t = _find_leaving_edge(
159
            cycle_nodes, cycle_edges, capac, flows, tails, heads
160
        )
161
        res_cap = capac[j] - flows[j] if tails[j] == s else flows[j]
162
163
        # Augment flows
164
        for i_, p_ in zip(cycle_edges, cycle_nodes):
165
            if tails[i_] == p_:
166
                flows[i_] += res_cap
167
            else:
168
                flows[i_] -= res_cap
169
170
        # Do nothing more if the entering edge is the same as the
171
        # leaving edge.
172
        if i != j:
173
            if parent[t] != s:
174
                # Ensure that s is the parent of t.
175
                s, t = t, s
176
177
            # Ensure that q is in the subtree rooted at t.
178
            for val in cycle_edges:
179
                if val == j:
180
                    p, q = q, p
181
                    break
182
                if val == i:
183
                    break
184
185
            _remove_edge(
186
                s, t, size, prev_node, last_node, next_node, parent, edge
187
            )
188
            _make_root(
189
                q, parent, size, last_node, prev_node, next_node, edge
190
            )
191
            _add_edge(
192
                i, p, q, next_node, prev_node, last_node, size, parent, edge
193
            )
194
            _update_potentials(
195
                i, p, q, heads, potentials, costs, last_node, next_node
196
            )
197
198
    final_flows = flows[:e] / fp_multiplier
199
    edge_costs = costs[:e] / fp_multiplier
200
    emd = final_flows.dot(edge_costs)
201
202
    return emd, final_flows.reshape((n_sources, n_sinks))
203
204
205
@numba.njit(cache=True)
206
def _reduced_cost(i, costs, potentials, tails, heads, flows):
207
    """Return the reduced cost of an edge i."""
208
    c = costs[i] - potentials[tails[i]] + potentials[heads[i]]
209
    if flows[i] == 0:
210
        return c
211
    return -c
212
213
214
@numba.njit(cache=True)
215
def _find_entering_edges(B, e, f, tails, heads, costs, potentials, flows):
216
    """Yield entering edges until none can be found. Entering edges are
217
    found by combining Dantzig's rule and Bland's rule. The edges are
218
    cyclically grouped into blocks of size B. Within each block,
219
    Dantzig's rule is applied to find an entering edge. The blocks to
220
    search is determined following Bland's rule.
221
    """
222
223
    m = 0
224
225
    while m < (e + B - 1) // B:
226
        # Determine the next block of edges.
227
        l = f + B
228
        if l <= e:
229
            edge_inds = np.arange(f, l)
230
        else:
231
            l -= e
232
            edge_inds = np.empty(e - f + l, dtype=np.int64)
233
            for i, v in enumerate(range(f, e)):
234
                edge_inds[i] = v
235
            for i in range(l):
236
                edge_inds[e-f+i] = i
237
238
        # Find the first edge with the lowest reduced cost.
239
        f = l
240
        i = edge_inds[0]
241
        c = _reduced_cost(i, costs, potentials, tails, heads, flows)
242
243
        for j in edge_inds[1:]:
244
            cost = _reduced_cost(j, costs, potentials, tails, heads, flows)
245
            if cost < c:
246
                c = cost
247
                i = j
248
249
        p = q = -1
250
        if c >= 0:
251
            m += 1
252
253
        # Entering edge found
254
        else:
255
            if flows[i] == 0:
256
                p = tails[i]
257
                q = heads[i]
258
            else:
259
                p = heads[i]
260
                q = tails[i]
261
262
            return i, p, q, f
263
264
    # All edges have nonnegative reduced costs, the flow is optimal
265
    return -1, -1, -1, -1
266
267
268
@numba.njit(cache=True)
269
def _find_apex(p, q, size, parent):
270
    """Find the lowest common ancestor of nodes p and q in the spanning
271
    tree.
272
    """
273
274
    size_p = size[p]
275
    size_q = size[q]
276
277
    while True:
278
        while size_p < size_q:
279
            p = parent[p]
280
            size_p = size[p]
281
        while size_p > size_q:
282
            q = parent[q]
283
            size_q = size[q]
284
        if size_p == size_q:
285
            if p != q:
286
                p = parent[p]
287
                size_p = size[p]
288
                q = parent[q]
289
                size_q = size[q]
290
            else:
291
                return p
292
293
294
@numba.njit(cache=True)
295
def _trace_path(p, w, edge, parent):
296
    """Return the nodes and edges on the path from node p to its
297
    ancestor w.
298
    """
299
300
    cycle_nodes = [p]
301
    cycle_edges = []
302
303
    while p != w:
304
        cycle_edges.append(edge[p])
305
        p = parent[p]
306
        cycle_nodes.append(p)
307
308
    return cycle_nodes, cycle_edges
309
310
311
@numba.njit(cache=True)
312
def _find_cycle(i, p, q, size, edge, parent):
313
    """Return the nodes and edges on the cycle containing edge
314
    i == (p, q) when the latter is added to the spanning tree. The cycle
315
    is oriented in the direction from p to q.
316
    """
317
318
    w = _find_apex(p, q, size, parent)
319
    cycle_nodes, cycle_edges = _trace_path(p, w, edge, parent)
320
    cycle_nodes_rev, cycle_edges_rev = _trace_path(q, w, edge, parent)
321
    len_cycle_nodes = len(cycle_nodes)
322
    add_to_c_nodes = max(len(cycle_nodes_rev) - 1, 0)
323
    cycle_nodes_ = np.empty(len_cycle_nodes + add_to_c_nodes, dtype=np.int64)
324
325
    for j in range(len_cycle_nodes):
326
        cycle_nodes_[j] = cycle_nodes[-(j+1)]
327
    for j in range(add_to_c_nodes):
328
        cycle_nodes_[len_cycle_nodes+j] = cycle_nodes_rev[j]
329
330
    len_cycle_edges = len(cycle_edges)
331
    len_cycle_edges_ = len_cycle_edges + len(cycle_edges_rev)
332
    if len_cycle_edges < 1 or cycle_edges[-1] != i:
333
        cycle_edges_ = np.empty(len_cycle_edges_ + 1, dtype=np.int64)
334
        cycle_edges_[len_cycle_edges] = i
335
    else:
336
        cycle_edges_ = np.empty(len_cycle_edges_, dtype=np.int64)
337
338
    for j in range(len_cycle_edges):
339
        cycle_edges_[j] = cycle_edges[-(j+1)]
340
    for j in range(1, len(cycle_edges_rev) + 1):
341
        cycle_edges_[-j] = cycle_edges_rev[-j]
342
343
    return cycle_nodes_, cycle_edges_
344
345
346
@numba.njit(cache=True)
347
def _find_leaving_edge(cycle_nodes, cycle_edges, capac, flows, tails, heads):
348
    """Return the leaving edge in a cycle represented by cycle_nodes and
349
    cycle_edges.
350
    """
351
352
    j, s = cycle_edges[0], cycle_nodes[0]
353
    res_caps_min = capac[j] - flows[j] if tails[j] == s else flows[j]
354
355
    for ind in range(1, cycle_edges.shape[0]):
356
        j_, s_ = cycle_edges[ind], cycle_nodes[ind]
357
        res_cap = capac[j_] - flows[j_] if tails[j_] == s_ else flows[j_]
358
        if res_cap < res_caps_min:
359
            res_caps_min = res_cap
360
            j, s = j_, s_
361
362
    t = heads[j] if tails[j] == s else tails[j]
363
    return j, s, t
364
365
366
@numba.njit(cache=True)
367
def _remove_edge(s, t, size, prev, last, next_node, parent, edge):
368
    """Remove an edge (s, t) where parent[t] == s from the spanning
369
    tree.
370
    """
371
372
    size_t = size[t]
373
    prev_t = prev[t]
374
    last_t = last[t]
375
    next_last_t = next_node[last_t]
376
    # Remove (s, t)
377
    parent[t] = -2
378
    edge[t] = -2
379
    # Remove the subtree rooted at t from the depth-first thread
380
    next_node[prev_t] = next_last_t
381
    prev[next_last_t] = prev_t
382
    next_node[last_t] = t
383
    prev[t] = last_t
384
385
    # Update the subtree sizes & last descendants of the (old) ancestors of t
386
    while s != -2:
387
        size[s] -= size_t
388
        if last[s] == last_t:
389
            last[s] = prev_t
390
        s = parent[s]
391
392
393
@numba.njit(cache=True)
394
def _make_root(q, parent, size, last, prev, next_node, edge):
395
    """Make a node q the root of its containing subtree."""
396
397
    ancestors = []
398
    # -2 means node is checked
399
    while q != -2:
400
        ancestors.insert(0, q)
401
        q = parent[q]
402
403
    for i in range(len(ancestors) - 1):
404
        p = ancestors[i]
405
        q = ancestors[i+1]
406
        size_p = size[p]
407
        last_p = last[p]
408
        prev_q = prev[q]
409
        last_q = last[q]
410
        next_last_q = next_node[last_q]
411
412
        # Make p a child of q
413
        parent[p] = q
414
        parent[q] = -2
415
        edge[p] = edge[q]
416
        edge[q] = -2
417
        size[p] = size_p - size[q]
418
        size[q] = size_p
419
420
        # Remove the subtree rooted at q from the depth-first thread
421
        next_node[prev_q] = next_last_q
422
        prev[next_last_q] = prev_q
423
        next_node[last_q] = q
424
        prev[q] = last_q
425
426
        if last_p == last_q:
427
            last[p] = prev_q
428
            last_p = prev_q
429
430
        # Add the remaining parts of the subtree rooted at p as a subtree of q
431
        # in the depth-first thread
432
        prev[p] = last_q
433
        next_node[last_q] = p
434
        next_node[last_p] = q
435
        prev[q] = last_p
436
        last[q] = last_p
437
438
439
@numba.njit(cache=True)
440
def _add_edge(i, p, q, next_node, prev, last, size, parent, edge):
441
    """Add an edge (p, q) to the spanning tree where q is the root of a
442
    subtree.
443
    """
444
445
    last_p = last[p]
446
    next_last_p = next_node[last_p]
447
    size_q = size[q]
448
    last_q = last[q]
449
    # Make q a child of p
450
    parent[q] = p
451
    edge[q] = i
452
    # Insert the subtree rooted at q into the depth-first thread
453
    next_node[last_p] = q
454
    prev[q] = last_p
455
    prev[next_last_p] = last_q
456
    next_node[last_q] = next_last_p
457
458
    # Update the subtree sizes and last descendants of the (new) ancestors of q
459
    while p != -2:
460
        size[p] += size_q
461
        if last[p] == last_p:
462
            last[p] = last_q
463
        p = parent[p]
464
465
466
@numba.njit(cache=True)
467
def _update_potentials(
468
        i, p, q, heads, potentials, costs, last_node, next_node
469
):
470
    """Update the potentials of the nodes in the subtree rooted at a
471
    node q connected to its parent p by an edge i.
472
    """
473
474
    if q == heads[i]:
475
        d = potentials[p] - costs[i] - potentials[q]
476
    else:
477
        d = potentials[p] + costs[i] - potentials[q]
478
479
    potentials[q] += d
480
    l = last_node[q]
481
    while q != l:
482
        q = next_node[q]
483
        potentials[q] += d
484