Conditions | 19 |
Total Lines | 186 |
Code Lines | 107 |
Lines | 0 |
Ratio | 0 % |
Changes | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like amd._emd.network_simplex() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | """An implementation of the Wasserstein metric (Earth Mover's distance) |
||
17 | @numba.njit(cache=True) |
||
18 | def network_simplex( |
||
19 | source_demands: npt.NDArray, |
||
20 | sink_demands: npt.NDArray, |
||
21 | network_costs: npt.NDArray |
||
22 | ) -> Tuple[float, npt.NDArray[np.float64]]: |
||
23 | """Calculate the Earth mover's distance (Wasserstien metric) between |
||
24 | two weighted distributions given by two sets of weights and a cost |
||
25 | matrix. |
||
26 | |||
27 | This is a port of the network simplex algorithm implented by Loïc |
||
28 | Séguin-C for the networkx package to allow acceleration with numba. |
||
29 | Copyright (C) 2010 Loïc Séguin-C. [email protected]. All rights |
||
30 | reserved. BSD license. |
||
31 | |||
32 | Parameters |
||
33 | ---------- |
||
34 | source_demands : :class:`numpy.ndarray` |
||
35 | Weights of the first distribution. |
||
36 | sink_demands : :class:`numpy.ndarray` |
||
37 | Weights of the second distribution. |
||
38 | network_costs : :class:`numpy.ndarray` |
||
39 | Cost matrix of distances between elements of the two |
||
40 | distributions. Shape (len(source_demands), len(sink_demands)). |
||
41 | |||
42 | Returns |
||
43 | ------- |
||
44 | (emd, plan) : Tuple[float, :class:`numpy.ndarray`] |
||
45 | A tuple of the Earth mover's distance and the optimal matching. |
||
46 | |||
47 | References |
||
48 | ---------- |
||
49 | [1] Z. Kiraly, P. Kovacs. |
||
50 | Efficient implementation of minimum-cost flow algorithms. |
||
51 | Acta Universitatis Sapientiae, Informatica 4(1), 67--118 |
||
52 | (2012). |
||
53 | [2] R. Barr, F. Glover, D. Klingman. |
||
54 | Enhancement of spanning tree labeling procedures for network |
||
55 | optimization. INFOR 17(1), 16--34 (1979). |
||
56 | """ |
||
57 | |||
58 | n_sources, n_sinks = source_demands.shape[0], sink_demands.shape[0] |
||
59 | network_costs = network_costs.ravel() |
||
60 | n = n_sources + n_sinks |
||
61 | e = n_sources * n_sinks |
||
62 | B = np.int64(np.ceil(np.sqrt(e))) |
||
63 | fp_multiplier = np.int64(1e+6) |
||
64 | |||
65 | # Add one additional node for a dummy source and sink |
||
66 | source_d_int = (source_demands * fp_multiplier).astype(np.int64) |
||
67 | sink_d_int = (sink_demands * fp_multiplier).astype(np.int64) |
||
68 | sink_source_sum_diff = np.sum(sink_d_int) - np.sum(source_d_int) |
||
69 | |||
70 | if sink_source_sum_diff > 0: |
||
71 | source_d_int[np.argmax(source_d_int)] += sink_source_sum_diff |
||
72 | elif sink_source_sum_diff < 0: |
||
73 | sink_d_int[np.argmax(sink_d_int)] -= sink_source_sum_diff |
||
74 | |||
75 | demands = np.empty(n, dtype=np.int64) |
||
76 | demands[:n_sources] = -source_d_int |
||
77 | demands[n_sources:] = sink_d_int |
||
78 | tails = np.empty(e + n, dtype=np.int64) |
||
79 | heads = np.empty(e + n, dtype=np.int64) |
||
80 | |||
81 | ind = 0 |
||
82 | for i in range(n_sources): |
||
83 | for j in range(n_sinks): |
||
84 | tails[ind] = i |
||
85 | heads[ind] = n_sources + j |
||
86 | ind += 1 |
||
87 | |||
88 | for i, demand in enumerate(demands): |
||
89 | ind = e + i |
||
90 | if demand > 0: |
||
91 | tails[ind] = -1 |
||
92 | heads[ind] = -1 |
||
93 | else: |
||
94 | tails[ind] = i |
||
95 | heads[ind] = i |
||
96 | |||
97 | # Create costs and capacities for the arcs between nodes |
||
98 | network_costs = network_costs * fp_multiplier |
||
99 | network_capac = np.empty(shape=(e, ), dtype=np.float64) |
||
100 | ind = 0 |
||
101 | for i in range(n_sources): |
||
102 | for j in range(n_sinks): |
||
103 | network_capac[ind] = min(source_demands[i], sink_demands[j]) |
||
104 | ind += 1 |
||
105 | network_capac *= fp_multiplier |
||
106 | |||
107 | faux_inf = 3 * max( |
||
108 | np.sum(network_capac), np.sum(np.abs(network_costs)), |
||
109 | np.amax(source_d_int), np.amax(sink_d_int) |
||
110 | ) |
||
111 | |||
112 | costs = np.empty(e + n, dtype=np.int64) |
||
113 | costs[:e] = network_costs |
||
114 | costs[e:] = faux_inf |
||
115 | |||
116 | capac = np.empty(e + n, dtype=np.int64) |
||
117 | capac[:e] = network_capac |
||
118 | capac[e:] = fp_multiplier |
||
119 | |||
120 | flows = np.empty(e + n, dtype=np.int64) |
||
121 | flows[:e] = 0 |
||
122 | flows[e:e+n_sources] = source_d_int |
||
123 | flows[e+n_sources:] = sink_d_int |
||
124 | |||
125 | potentials = np.empty(n, dtype=np.int64) |
||
126 | demands_neg_mask = demands <= 0 |
||
127 | potentials[demands_neg_mask] = faux_inf |
||
128 | potentials[~demands_neg_mask] = -faux_inf |
||
129 | |||
130 | parent = np.full(shape=(n + 1, ), fill_value=-1, dtype=np.int64) |
||
131 | parent[-1] = -2 |
||
132 | |||
133 | size = np.full(shape=(n + 1, ), fill_value=1, dtype=np.int64) |
||
134 | size[-1] = n + 1 |
||
135 | |||
136 | next_node = np.arange(1, n + 2, dtype=np.int64) |
||
137 | next_node[-2] = -1 |
||
138 | next_node[-1] = 0 |
||
139 | |||
140 | last_node = np.arange(n + 1, dtype=np.int64) |
||
141 | last_node[-1] = n - 1 |
||
142 | |||
143 | prev_node = np.arange(-1, n, dtype=np.int64) |
||
144 | edge = np.arange(e, e + n, dtype=np.int64) |
||
145 | |||
146 | # Pivot loop |
||
147 | |||
148 | f = 0 |
||
149 | while True: |
||
150 | i, p, q, f = _find_entering_edges( |
||
151 | B, e, f, tails, heads, costs, potentials, flows |
||
152 | ) |
||
153 | # If no entering edges then the optimal score is found |
||
154 | if p == -1: |
||
155 | break |
||
156 | |||
157 | cycle_nodes, cycle_edges = _find_cycle(i, p, q, size, edge, parent) |
||
158 | j, s, t = _find_leaving_edge( |
||
159 | cycle_nodes, cycle_edges, capac, flows, tails, heads |
||
160 | ) |
||
161 | res_cap = capac[j] - flows[j] if tails[j] == s else flows[j] |
||
162 | |||
163 | # Augment flows |
||
164 | for i_, p_ in zip(cycle_edges, cycle_nodes): |
||
165 | if tails[i_] == p_: |
||
166 | flows[i_] += res_cap |
||
167 | else: |
||
168 | flows[i_] -= res_cap |
||
169 | |||
170 | # Do nothing more if the entering edge is the same as the |
||
171 | # leaving edge. |
||
172 | if i != j: |
||
173 | if parent[t] != s: |
||
174 | # Ensure that s is the parent of t. |
||
175 | s, t = t, s |
||
176 | |||
177 | # Ensure that q is in the subtree rooted at t. |
||
178 | for val in cycle_edges: |
||
179 | if val == j: |
||
180 | p, q = q, p |
||
181 | break |
||
182 | if val == i: |
||
183 | break |
||
184 | |||
185 | _remove_edge( |
||
186 | s, t, size, prev_node, last_node, next_node, parent, edge |
||
187 | ) |
||
188 | _make_root( |
||
189 | q, parent, size, last_node, prev_node, next_node, edge |
||
190 | ) |
||
191 | _add_edge( |
||
192 | i, p, q, next_node, prev_node, last_node, size, parent, edge |
||
193 | ) |
||
194 | _update_potentials( |
||
195 | i, p, q, heads, potentials, costs, last_node, next_node |
||
196 | ) |
||
197 | |||
198 | final_flows = flows[:e] / fp_multiplier |
||
199 | edge_costs = costs[:e] / fp_multiplier |
||
200 | emd = final_flows.dot(edge_costs) |
||
201 | |||
202 | return emd, final_flows.reshape((n_sources, n_sinks)) |
||
203 | |||
484 |