|
1
|
|
|
"""An implementation of the Wasserstein metric (Earth Mover's distance) |
|
2
|
|
|
between two weighted distributions. Aka Earth Mover's distance, the |
|
3
|
|
|
appropriate metric for comparing two PDDs. |
|
4
|
|
|
|
|
5
|
|
|
Copyright (C) 2020 Cameron Hargreaves. This code is adapted from the |
|
6
|
|
|
Element Movers Distance project https://github.com/lrcfmd/ElMD. |
|
7
|
|
|
""" |
|
8
|
|
|
|
|
9
|
|
|
from typing import Tuple |
|
10
|
|
|
|
|
11
|
|
|
import numba |
|
12
|
|
|
import numpy as np |
|
13
|
|
|
import numpy.typing as npt |
|
14
|
|
|
|
|
15
|
|
|
|
|
16
|
|
|
@numba.njit(cache=True) |
|
17
|
|
|
def network_simplex( |
|
18
|
|
|
source_demands: npt.NDArray, |
|
19
|
|
|
sink_demands: npt.NDArray, |
|
20
|
|
|
network_costs: npt.NDArray |
|
21
|
|
|
) -> Tuple[float, npt.NDArray[np.float64]]: |
|
22
|
|
|
"""Calculate the Earth mover's distance (Wasserstien metric) between |
|
23
|
|
|
two weighted distributions given by two sets of weights and a cost |
|
24
|
|
|
matrix. |
|
25
|
|
|
|
|
26
|
|
|
This is a port of the network simplex algorithm implented by Loïc |
|
27
|
|
|
Séguin-C for the networkx package to allow acceleration with numba. |
|
28
|
|
|
Copyright (C) 2010 Loïc Séguin-C. [email protected]. All rights |
|
29
|
|
|
reserved. BSD license. |
|
30
|
|
|
|
|
31
|
|
|
Parameters |
|
32
|
|
|
---------- |
|
33
|
|
|
source_demands : :class:`numpy.ndarray` |
|
34
|
|
|
Weights of the first distribution. |
|
35
|
|
|
sink_demands : :class:`numpy.ndarray` |
|
36
|
|
|
Weights of the second distribution. |
|
37
|
|
|
network_costs : :class:`numpy.ndarray` |
|
38
|
|
|
Cost matrix of distances between elements of the two |
|
39
|
|
|
distributions. Shape (len(source_demands), len(sink_demands)). |
|
40
|
|
|
|
|
41
|
|
|
Returns |
|
42
|
|
|
------- |
|
43
|
|
|
(emd, plan) : Tuple[float, :class:`numpy.ndarray`] |
|
44
|
|
|
A tuple of the Earth mover's distance and the optimal matching. |
|
45
|
|
|
|
|
46
|
|
|
References |
|
47
|
|
|
---------- |
|
48
|
|
|
[1] Z. Kiraly, P. Kovacs. |
|
49
|
|
|
Efficient implementation of minimum-cost flow algorithms. |
|
50
|
|
|
Acta Universitatis Sapientiae, Informatica 4(1), 67--118 |
|
51
|
|
|
(2012). |
|
52
|
|
|
[2] R. Barr, F. Glover, D. Klingman. |
|
53
|
|
|
Enhancement of spanning tree labeling procedures for network |
|
54
|
|
|
optimization. INFOR 17(1), 16--34 (1979). |
|
55
|
|
|
""" |
|
56
|
|
|
|
|
57
|
|
|
n_sources, n_sinks = source_demands.shape[0], sink_demands.shape[0] |
|
58
|
|
|
network_costs = network_costs.ravel() |
|
59
|
|
|
n = n_sources + n_sinks |
|
60
|
|
|
e = n_sources * n_sinks |
|
61
|
|
|
B = np.int64(np.ceil(np.sqrt(e))) |
|
62
|
|
|
fp_multiplier = np.int64(1e+6) |
|
63
|
|
|
|
|
64
|
|
|
# Add one additional node for a dummy source and sink |
|
65
|
|
|
source_d_int = (source_demands * fp_multiplier).astype(np.int64) |
|
66
|
|
|
sink_d_int = (sink_demands * fp_multiplier).astype(np.int64) |
|
67
|
|
|
sink_source_sum_diff = np.sum(sink_d_int) - np.sum(source_d_int) |
|
68
|
|
|
|
|
69
|
|
|
if sink_source_sum_diff > 0: |
|
70
|
|
|
source_d_int[np.argmax(source_d_int)] += sink_source_sum_diff |
|
71
|
|
|
elif sink_source_sum_diff < 0: |
|
72
|
|
|
sink_d_int[np.argmax(sink_d_int)] -= sink_source_sum_diff |
|
73
|
|
|
|
|
74
|
|
|
demands = np.empty(n, dtype=np.int64) |
|
75
|
|
|
demands[:n_sources] = -source_d_int |
|
76
|
|
|
demands[n_sources:] = sink_d_int |
|
77
|
|
|
tails = np.empty(e + n, dtype=np.int64) |
|
78
|
|
|
heads = np.empty(e + n, dtype=np.int64) |
|
79
|
|
|
|
|
80
|
|
|
ind = 0 |
|
81
|
|
|
for i in range(n_sources): |
|
82
|
|
|
for j in range(n_sinks): |
|
83
|
|
|
tails[ind] = i |
|
84
|
|
|
heads[ind] = n_sources + j |
|
85
|
|
|
ind += 1 |
|
86
|
|
|
|
|
87
|
|
|
for i, demand in enumerate(demands): |
|
88
|
|
|
ind = e + i |
|
89
|
|
|
if demand > 0: |
|
90
|
|
|
tails[ind] = -1 |
|
91
|
|
|
heads[ind] = -1 |
|
92
|
|
|
else: |
|
93
|
|
|
tails[ind] = i |
|
94
|
|
|
heads[ind] = i |
|
95
|
|
|
|
|
96
|
|
|
# Create costs and capacities for the arcs between nodes |
|
97
|
|
|
network_costs = network_costs * fp_multiplier |
|
98
|
|
|
network_capac = np.empty(shape=(e, ), dtype=np.float64) |
|
99
|
|
|
ind = 0 |
|
100
|
|
|
for i in range(n_sources): |
|
101
|
|
|
for j in range(n_sinks): |
|
102
|
|
|
network_capac[ind] = min(source_demands[i], sink_demands[j]) |
|
103
|
|
|
ind += 1 |
|
104
|
|
|
network_capac *= fp_multiplier |
|
105
|
|
|
|
|
106
|
|
|
faux_inf = 3 * max( |
|
107
|
|
|
np.sum(network_capac), np.sum(np.abs(network_costs)), |
|
108
|
|
|
np.amax(source_d_int), np.amax(sink_d_int) |
|
109
|
|
|
) |
|
110
|
|
|
|
|
111
|
|
|
costs = np.empty(e + n, dtype=np.int64) |
|
112
|
|
|
costs[:e] = network_costs |
|
113
|
|
|
costs[e:] = faux_inf |
|
114
|
|
|
|
|
115
|
|
|
capac = np.empty(e + n, dtype=np.int64) |
|
116
|
|
|
capac[:e] = network_capac |
|
117
|
|
|
capac[e:] = fp_multiplier |
|
118
|
|
|
|
|
119
|
|
|
flows = np.empty(e + n, dtype=np.int64) |
|
120
|
|
|
flows[:e] = 0 |
|
121
|
|
|
flows[e:e+n_sources] = source_d_int |
|
122
|
|
|
flows[e+n_sources:] = sink_d_int |
|
123
|
|
|
|
|
124
|
|
|
potentials = np.empty(n, dtype=np.int64) |
|
125
|
|
|
demands_neg_mask = demands <= 0 |
|
126
|
|
|
potentials[demands_neg_mask] = faux_inf |
|
127
|
|
|
potentials[~demands_neg_mask] = -faux_inf |
|
128
|
|
|
|
|
129
|
|
|
parent = np.full(shape=(n + 1, ), fill_value=-1, dtype=np.int64) |
|
130
|
|
|
parent[-1] = -2 |
|
131
|
|
|
|
|
132
|
|
|
size = np.full(shape=(n + 1, ), fill_value=1, dtype=np.int64) |
|
133
|
|
|
size[-1] = n + 1 |
|
134
|
|
|
|
|
135
|
|
|
next_node = np.arange(1, n + 2, dtype=np.int64) |
|
136
|
|
|
next_node[-2] = -1 |
|
137
|
|
|
next_node[-1] = 0 |
|
138
|
|
|
|
|
139
|
|
|
last_node = np.arange(n + 1, dtype=np.int64) |
|
140
|
|
|
last_node[-1] = n - 1 |
|
141
|
|
|
|
|
142
|
|
|
prev_node = np.arange(-1, n, dtype=np.int64) |
|
143
|
|
|
edge = np.arange(e, e + n, dtype=np.int64) |
|
144
|
|
|
|
|
145
|
|
|
# Pivot loop |
|
146
|
|
|
|
|
147
|
|
|
f = 0 |
|
148
|
|
|
while True: |
|
149
|
|
|
i, p, q, f = find_entering_edges( |
|
150
|
|
|
B, e, f, tails, heads, costs, potentials, flows |
|
151
|
|
|
) |
|
152
|
|
|
# If no entering edges then the optimal score is found |
|
153
|
|
|
if p == -1: |
|
154
|
|
|
break |
|
155
|
|
|
|
|
156
|
|
|
cycle_nodes, cycle_edges = find_cycle(i, p, q, size, edge, parent) |
|
157
|
|
|
j, s, t = find_leaving_edge( |
|
158
|
|
|
cycle_nodes, cycle_edges, capac, flows, tails, heads |
|
159
|
|
|
) |
|
160
|
|
|
res_cap = capac[j] - flows[j] if tails[j] == s else flows[j] |
|
161
|
|
|
|
|
162
|
|
|
# Augment flows |
|
163
|
|
|
for i_, p_ in zip(cycle_edges, cycle_nodes): |
|
164
|
|
|
if tails[i_] == p_: |
|
165
|
|
|
flows[i_] += res_cap |
|
166
|
|
|
else: |
|
167
|
|
|
flows[i_] -= res_cap |
|
168
|
|
|
|
|
169
|
|
|
# Do nothing more if the entering edge is the same as the |
|
170
|
|
|
# leaving edge. |
|
171
|
|
|
if i != j: |
|
172
|
|
|
if parent[t] != s: |
|
173
|
|
|
# Ensure that s is the parent of t. |
|
174
|
|
|
s, t = t, s |
|
175
|
|
|
|
|
176
|
|
|
# Ensure that q is in the subtree rooted at t. |
|
177
|
|
|
for val in cycle_edges: |
|
178
|
|
|
if val == j: |
|
179
|
|
|
p, q = q, p |
|
180
|
|
|
break |
|
181
|
|
|
if val == i: |
|
182
|
|
|
break |
|
183
|
|
|
|
|
184
|
|
|
remove_edge( |
|
185
|
|
|
s, t, size, prev_node, last_node, next_node, parent, edge |
|
186
|
|
|
) |
|
187
|
|
|
make_root( |
|
188
|
|
|
q, parent, size, last_node, prev_node, next_node, edge |
|
189
|
|
|
) |
|
190
|
|
|
add_edge( |
|
191
|
|
|
i, p, q, next_node, prev_node, last_node, size, parent, edge |
|
192
|
|
|
) |
|
193
|
|
|
update_potentials( |
|
194
|
|
|
i, p, q, heads, potentials, costs, last_node, next_node |
|
195
|
|
|
) |
|
196
|
|
|
|
|
197
|
|
|
final_flows = flows[:e] / fp_multiplier |
|
198
|
|
|
edge_costs = costs[:e] / fp_multiplier |
|
199
|
|
|
emd = final_flows.dot(edge_costs) |
|
200
|
|
|
|
|
201
|
|
|
return emd, final_flows.reshape((n_sources, n_sinks)) |
|
202
|
|
|
|
|
203
|
|
|
|
|
204
|
|
|
@numba.njit(cache=True) |
|
205
|
|
|
def reduced_cost(i, costs, potentials, tails, heads, flows): |
|
206
|
|
|
"""Return the reduced cost of an edge i.""" |
|
207
|
|
|
c = costs[i] - potentials[tails[i]] + potentials[heads[i]] |
|
208
|
|
|
if flows[i] == 0: |
|
209
|
|
|
return c |
|
210
|
|
|
return -c |
|
211
|
|
|
|
|
212
|
|
|
|
|
213
|
|
|
@numba.njit(cache=True) |
|
214
|
|
|
def find_entering_edges(B, e, f, tails, heads, costs, potentials, flows): |
|
215
|
|
|
"""Yield entering edges until none can be found. Entering edges are |
|
216
|
|
|
found by combining Dantzig's rule and Bland's rule. The edges are |
|
217
|
|
|
cyclically grouped into blocks of size B. Within each block, |
|
218
|
|
|
Dantzig's rule is applied to find an entering edge. The blocks to |
|
219
|
|
|
search is determined following Bland's rule. |
|
220
|
|
|
""" |
|
221
|
|
|
|
|
222
|
|
|
m = 0 |
|
223
|
|
|
|
|
224
|
|
|
while m < (e + B - 1) // B: |
|
225
|
|
|
# Determine the next block of edges. |
|
226
|
|
|
l = f + B |
|
227
|
|
|
if l <= e: |
|
228
|
|
|
edge_inds = np.arange(f, l) |
|
229
|
|
|
else: |
|
230
|
|
|
l -= e |
|
231
|
|
|
edge_inds = np.empty(e - f + l, dtype=np.int64) |
|
232
|
|
|
for i, v in enumerate(range(f, e)): |
|
233
|
|
|
edge_inds[i] = v |
|
234
|
|
|
for i in range(l): |
|
235
|
|
|
edge_inds[e-f+i] = i |
|
236
|
|
|
|
|
237
|
|
|
# Find the first edge with the lowest reduced cost. |
|
238
|
|
|
f = l |
|
239
|
|
|
i = edge_inds[0] |
|
240
|
|
|
c = reduced_cost(i, costs, potentials, tails, heads, flows) |
|
241
|
|
|
|
|
242
|
|
|
for j in edge_inds[1:]: |
|
243
|
|
|
cost = reduced_cost(j, costs, potentials, tails, heads, flows) |
|
244
|
|
|
if cost < c: |
|
245
|
|
|
c = cost |
|
246
|
|
|
i = j |
|
247
|
|
|
|
|
248
|
|
|
p = q = -1 |
|
249
|
|
|
if c >= 0: |
|
250
|
|
|
m += 1 |
|
251
|
|
|
|
|
252
|
|
|
# Entering edge found |
|
253
|
|
|
else: |
|
254
|
|
|
if flows[i] == 0: |
|
255
|
|
|
p = tails[i] |
|
256
|
|
|
q = heads[i] |
|
257
|
|
|
else: |
|
258
|
|
|
p = heads[i] |
|
259
|
|
|
q = tails[i] |
|
260
|
|
|
|
|
261
|
|
|
return i, p, q, f |
|
262
|
|
|
|
|
263
|
|
|
# All edges have nonnegative reduced costs, the flow is optimal |
|
264
|
|
|
return -1, -1, -1, -1 |
|
265
|
|
|
|
|
266
|
|
|
|
|
267
|
|
|
@numba.njit(cache=True) |
|
268
|
|
|
def find_apex(p, q, size, parent): |
|
269
|
|
|
"""Find the lowest common ancestor of nodes p and q in the spanning |
|
270
|
|
|
tree. |
|
271
|
|
|
""" |
|
272
|
|
|
|
|
273
|
|
|
size_p = size[p] |
|
274
|
|
|
size_q = size[q] |
|
275
|
|
|
|
|
276
|
|
|
while True: |
|
277
|
|
|
while size_p < size_q: |
|
278
|
|
|
p = parent[p] |
|
279
|
|
|
size_p = size[p] |
|
280
|
|
|
while size_p > size_q: |
|
281
|
|
|
q = parent[q] |
|
282
|
|
|
size_q = size[q] |
|
283
|
|
|
if size_p == size_q: |
|
284
|
|
|
if p != q: |
|
285
|
|
|
p = parent[p] |
|
286
|
|
|
size_p = size[p] |
|
287
|
|
|
q = parent[q] |
|
288
|
|
|
size_q = size[q] |
|
289
|
|
|
else: |
|
290
|
|
|
return p |
|
291
|
|
|
|
|
292
|
|
|
|
|
293
|
|
|
@numba.njit(cache=True) |
|
294
|
|
|
def trace_path(p, w, edge, parent): |
|
295
|
|
|
"""Return the nodes and edges on the path from node p to its |
|
296
|
|
|
ancestor w. |
|
297
|
|
|
""" |
|
298
|
|
|
|
|
299
|
|
|
cycle_nodes = [p] |
|
300
|
|
|
cycle_edges = [] |
|
301
|
|
|
|
|
302
|
|
|
while p != w: |
|
303
|
|
|
cycle_edges.append(edge[p]) |
|
304
|
|
|
p = parent[p] |
|
305
|
|
|
cycle_nodes.append(p) |
|
306
|
|
|
|
|
307
|
|
|
return cycle_nodes, cycle_edges |
|
308
|
|
|
|
|
309
|
|
|
|
|
310
|
|
|
@numba.njit(cache=True) |
|
311
|
|
|
def find_cycle(i, p, q, size, edge, parent): |
|
312
|
|
|
"""Return the nodes and edges on the cycle containing edge |
|
313
|
|
|
i == (p, q) when the latter is added to the spanning tree. The cycle |
|
314
|
|
|
is oriented in the direction from p to q. |
|
315
|
|
|
""" |
|
316
|
|
|
|
|
317
|
|
|
w = find_apex(p, q, size, parent) |
|
318
|
|
|
cycle_nodes, cycle_edges = trace_path(p, w, edge, parent) |
|
319
|
|
|
cycle_nodes_rev, cycle_edges_rev = trace_path(q, w, edge, parent) |
|
320
|
|
|
len_cycle_nodes = len(cycle_nodes) |
|
321
|
|
|
add_to_c_nodes = max(len(cycle_nodes_rev) - 1, 0) |
|
322
|
|
|
cycle_nodes_ = np.empty(len_cycle_nodes + add_to_c_nodes, dtype=np.int64) |
|
323
|
|
|
|
|
324
|
|
|
for j in range(len_cycle_nodes): |
|
325
|
|
|
cycle_nodes_[j] = cycle_nodes[-(j+1)] |
|
326
|
|
|
for j in range(add_to_c_nodes): |
|
327
|
|
|
cycle_nodes_[len_cycle_nodes+j] = cycle_nodes_rev[j] |
|
328
|
|
|
|
|
329
|
|
|
len_cycle_edges = len(cycle_edges) |
|
330
|
|
|
len_cycle_edges_ = len_cycle_edges + len(cycle_edges_rev) |
|
331
|
|
|
if len_cycle_edges < 1 or cycle_edges[-1] != i: |
|
332
|
|
|
cycle_edges_ = np.empty(len_cycle_edges_ + 1, dtype=np.int64) |
|
333
|
|
|
cycle_edges_[len_cycle_edges] = i |
|
334
|
|
|
else: |
|
335
|
|
|
cycle_edges_ = np.empty(len_cycle_edges_, dtype=np.int64) |
|
336
|
|
|
|
|
337
|
|
|
for j in range(len_cycle_edges): |
|
338
|
|
|
cycle_edges_[j] = cycle_edges[-(j+1)] |
|
339
|
|
|
for j in range(1, len(cycle_edges_rev) + 1): |
|
340
|
|
|
cycle_edges_[-j] = cycle_edges_rev[-j] |
|
341
|
|
|
|
|
342
|
|
|
return cycle_nodes_, cycle_edges_ |
|
343
|
|
|
|
|
344
|
|
|
|
|
345
|
|
|
@numba.njit(cache=True) |
|
346
|
|
|
def find_leaving_edge(cycle_nodes, cycle_edges, capac, flows, tails, heads): |
|
347
|
|
|
"""Return the leaving edge in a cycle represented by cycle_nodes and |
|
348
|
|
|
cycle_edges. |
|
349
|
|
|
""" |
|
350
|
|
|
|
|
351
|
|
|
j, s = cycle_edges[0], cycle_nodes[0] |
|
352
|
|
|
res_caps_min = capac[j] - flows[j] if tails[j] == s else flows[j] |
|
353
|
|
|
|
|
354
|
|
|
for ind in range(1, cycle_edges.shape[0]): |
|
355
|
|
|
j_, s_ = cycle_edges[ind], cycle_nodes[ind] |
|
356
|
|
|
res_cap = capac[j_] - flows[j_] if tails[j_] == s_ else flows[j_] |
|
357
|
|
|
if res_cap < res_caps_min: |
|
358
|
|
|
res_caps_min = res_cap |
|
359
|
|
|
j, s = j_, s_ |
|
360
|
|
|
|
|
361
|
|
|
t = heads[j] if tails[j] == s else tails[j] |
|
362
|
|
|
return j, s, t |
|
363
|
|
|
|
|
364
|
|
|
|
|
365
|
|
|
@numba.njit(cache=True) |
|
366
|
|
|
def remove_edge(s, t, size, prev, last, next_node, parent, edge): |
|
367
|
|
|
"""Remove an edge (s, t) where parent[t] == s from the spanning |
|
368
|
|
|
tree. |
|
369
|
|
|
""" |
|
370
|
|
|
|
|
371
|
|
|
size_t = size[t] |
|
372
|
|
|
prev_t = prev[t] |
|
373
|
|
|
last_t = last[t] |
|
374
|
|
|
next_last_t = next_node[last_t] |
|
375
|
|
|
# Remove (s, t) |
|
376
|
|
|
parent[t] = -2 |
|
377
|
|
|
edge[t] = -2 |
|
378
|
|
|
# Remove the subtree rooted at t from the depth-first thread |
|
379
|
|
|
next_node[prev_t] = next_last_t |
|
380
|
|
|
prev[next_last_t] = prev_t |
|
381
|
|
|
next_node[last_t] = t |
|
382
|
|
|
prev[t] = last_t |
|
383
|
|
|
|
|
384
|
|
|
# Update the subtree sizes & last descendants of the (old) ancestors of t |
|
385
|
|
|
while s != -2: |
|
386
|
|
|
size[s] -= size_t |
|
387
|
|
|
if last[s] == last_t: |
|
388
|
|
|
last[s] = prev_t |
|
389
|
|
|
s = parent[s] |
|
390
|
|
|
|
|
391
|
|
|
|
|
392
|
|
|
@numba.njit(cache=True) |
|
393
|
|
|
def make_root(q, parent, size, last, prev, next_node, edge): |
|
394
|
|
|
"""Make a node q the root of its containing subtree.""" |
|
395
|
|
|
|
|
396
|
|
|
ancestors = [] |
|
397
|
|
|
# -2 means node is checked |
|
398
|
|
|
while q != -2: |
|
399
|
|
|
ancestors.insert(0, q) |
|
400
|
|
|
q = parent[q] |
|
401
|
|
|
|
|
402
|
|
|
for i in range(len(ancestors) - 1): |
|
403
|
|
|
p = ancestors[i] |
|
404
|
|
|
q = ancestors[i+1] |
|
405
|
|
|
size_p = size[p] |
|
406
|
|
|
last_p = last[p] |
|
407
|
|
|
prev_q = prev[q] |
|
408
|
|
|
last_q = last[q] |
|
409
|
|
|
next_last_q = next_node[last_q] |
|
410
|
|
|
|
|
411
|
|
|
# Make p a child of q |
|
412
|
|
|
parent[p] = q |
|
413
|
|
|
parent[q] = -2 |
|
414
|
|
|
edge[p] = edge[q] |
|
415
|
|
|
edge[q] = -2 |
|
416
|
|
|
size[p] = size_p - size[q] |
|
417
|
|
|
size[q] = size_p |
|
418
|
|
|
|
|
419
|
|
|
# Remove the subtree rooted at q from the depth-first thread |
|
420
|
|
|
next_node[prev_q] = next_last_q |
|
421
|
|
|
prev[next_last_q] = prev_q |
|
422
|
|
|
next_node[last_q] = q |
|
423
|
|
|
prev[q] = last_q |
|
424
|
|
|
|
|
425
|
|
|
if last_p == last_q: |
|
426
|
|
|
last[p] = prev_q |
|
427
|
|
|
last_p = prev_q |
|
428
|
|
|
|
|
429
|
|
|
# Add the remaining parts of the subtree rooted at p as a subtree of q |
|
430
|
|
|
# in the depth-first thread |
|
431
|
|
|
prev[p] = last_q |
|
432
|
|
|
next_node[last_q] = p |
|
433
|
|
|
next_node[last_p] = q |
|
434
|
|
|
prev[q] = last_p |
|
435
|
|
|
last[q] = last_p |
|
436
|
|
|
|
|
437
|
|
|
|
|
438
|
|
|
@numba.njit(cache=True) |
|
439
|
|
|
def add_edge(i, p, q, next_node, prev, last, size, parent, edge): |
|
440
|
|
|
"""Add an edge (p, q) to the spanning tree where q is the root of a |
|
441
|
|
|
subtree. |
|
442
|
|
|
""" |
|
443
|
|
|
|
|
444
|
|
|
last_p = last[p] |
|
445
|
|
|
next_last_p = next_node[last_p] |
|
446
|
|
|
size_q = size[q] |
|
447
|
|
|
last_q = last[q] |
|
448
|
|
|
# Make q a child of p |
|
449
|
|
|
parent[q] = p |
|
450
|
|
|
edge[q] = i |
|
451
|
|
|
# Insert the subtree rooted at q into the depth-first thread |
|
452
|
|
|
next_node[last_p] = q |
|
453
|
|
|
prev[q] = last_p |
|
454
|
|
|
prev[next_last_p] = last_q |
|
455
|
|
|
next_node[last_q] = next_last_p |
|
456
|
|
|
|
|
457
|
|
|
# Update the subtree sizes and last descendants of the (new) ancestors of q |
|
458
|
|
|
while p != -2: |
|
459
|
|
|
size[p] += size_q |
|
460
|
|
|
if last[p] == last_p: |
|
461
|
|
|
last[p] = last_q |
|
462
|
|
|
p = parent[p] |
|
463
|
|
|
|
|
464
|
|
|
|
|
465
|
|
|
@numba.njit(cache=True) |
|
466
|
|
|
def update_potentials(i, p, q, heads, potentials, costs, last_node, next_node): |
|
467
|
|
|
"""Update the potentials of the nodes in the subtree rooted at a |
|
468
|
|
|
node q connected to its parent p by an edge i. |
|
469
|
|
|
""" |
|
470
|
|
|
|
|
471
|
|
|
if q == heads[i]: |
|
472
|
|
|
d = potentials[p] - costs[i] - potentials[q] |
|
473
|
|
|
else: |
|
474
|
|
|
d = potentials[p] + costs[i] - potentials[q] |
|
475
|
|
|
|
|
476
|
|
|
potentials[q] += d |
|
477
|
|
|
l = last_node[q] |
|
478
|
|
|
while q != l: |
|
479
|
|
|
q = next_node[q] |
|
480
|
|
|
potentials[q] += d |
|
481
|
|
|
|