Completed
Pull Request — master (#225)
by Chris
09:15
created

_TokenDistance._union()   A

Complexity

Conditions 2

Size

Total Lines 9
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 5
CRAP Score 2

Importance

Changes 0
Metric Value
eloc 5
dl 0
loc 9
ccs 5
cts 5
cp 1
rs 10
c 0
b 0
f 0
cc 2
nop 1
crap 2
1
# -*- coding: utf-8 -*-
2
3
# Copyright 2018-2019 by Christopher C. Little.
4
# This file is part of Abydos.
5
#
6
# Abydos is free software: you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation, either version 3 of the License, or
9
# (at your option) any later version.
10
#
11
# Abydos is distributed in the hope that it will be useful,
12
# but WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
# GNU General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with Abydos. If not, see <http://www.gnu.org/licenses/>.
18
19 1
"""abydos.distance._token_distance.
20
21
The distance._token_distance._TokenDistance module implements abstract class
22
_TokenDistance.
23
"""
24
25 1
from __future__ import (
26
    absolute_import,
27
    division,
28
    print_function,
29
    unicode_literals,
30
)
31
32 1
from collections import Counter, OrderedDict
33 1
from itertools import product
34 1
from math import exp, log1p
35
36 1
from numpy import copy as np_copy
37 1
from numpy import zeros as np_zeros
38
39 1
try:
40 1
    from scipy.optimize import linear_sum_assignment
41
except ImportError:  # pragma: no cover
42
    # If the system lacks the scipy library, we'll fall back to our
43
    # Python+Numpy implementation of the Hungarian algorithm
44
    linear_sum_assignment = None
45
46 1
from ._damerau_levenshtein import DamerauLevenshtein
47 1
from ._distance import _Distance
48 1
from ._lcprefix import LCPrefix
49 1
from ._levenshtein import Levenshtein
50 1
from ..stats import ConfusionTable
51 1
from ..tokenizer import QGrams, QSkipgrams, WhitespaceTokenizer
52
53 1
__all__ = ['_TokenDistance']
54
55
56 1
class _TokenDistance(_Distance):
57
    r"""Abstract Token Distance class.
58
59
    .. _confusion_table:
60
61
    +----------------+--------------+-----------------+-------+
62
    |                | |in| ``tar`` | |notin| ``tar`` |       |
63
    +----------------+--------------+-----------------+-------+
64
    | |in| ``src``   | |a|          | |b|             | |a+b| |
65
    +----------------+--------------+-----------------+-------+
66
    | |notin| ``src``| |c|          | |d|             | |c+d| |
67
    +----------------+--------------+-----------------+-------+
68
    |                | |a+c|        | |b+d|           | |n|   |
69
    +----------------+--------------+-----------------+-------+
70
71
    .. |in| replace:: :math:`x \in`
72
    .. |notin| replace:: :math:`x \notin`
73
74
    .. |a| replace:: :math:`a = |X \cap Y|`
75
    .. |b| replace:: :math:`b = |X\setminus Y|`
76
    .. |c| replace:: :math:`c = |Y \setminus X|`
77
    .. |d| replace:: :math:`d = |(N\setminus X)\setminus Y|`
78
    .. |n| replace:: :math:`n = |N|`
79
    .. |a+b| replace:: :math:`p_1 = a+b = |X|`
80
    .. |a+c| replace:: :math:`p_2 = a+c = |Y|`
81
    .. |c+d| replace:: :math:`q_1 = c+d = |N\setminus X|`
82
    .. |b+d| replace:: :math:`q_2 = b+d = |N\setminus Y|`
83
84
    .. versionadded:: 0.3.6
85
    """
86
87 1
    def __init__(self, tokenizer=None, intersection_type='crisp', **kwargs):
88
        r"""Initialize _TokenDistance instance.
89
90
        .. _intersection_type:
91
92
        Parameters
93
        ----------
94
        tokenizer : _Tokenizer
95
            A tokenizer instance from the :py:mod:`abydos.tokenizer` package
96
        intersection_type : str
97
            Specifies the intersection type, and set type as a result:
98
99
                - 'crisp': Ordinary intersection, wherein items are entirely
100
                  members or non-members of the intersection. (Default)
101
                - ``fuzzy``: Fuzzy intersection, defined by :cite:`Wang:2014`,
102
                  wherein items can be partially members of the intersection
103
                  if their similarity meets or exceeds a threshold value. This
104
                  also takes `metric` (by default :class:`Levenshtein()`) and
105
                  `threshold` (by default 0.8) parameters.
106
                - ``soft``: Soft intersection, defined by :cite:`Russ:2014`,
107
                  wherein items can be partially members of the intersection
108
                  depending on their similarity. This also takes a `metric`
109
                  (by default :class:`DamerauLevenshtein()`) parameter.
110
                - ``linkage``: Group linkage, defined by :cite:`On:2007`. Like
111
                  the soft intersection, items can be partially members of the
112
                  intersection, but the method of pairing similar members is
113
                  somewhat more complex. See the cited paper for details. This
114
                  also takes `metric`
115
                  (by default :class:`DamerauLevenshtein()`) and `threshold`
116
                  (by default 0.1) parameters.
117
        **kwargs
118
            Arbitrary keyword arguments
119
120
121
        .. _alphabet:
122
123
        Other Parameters
124
        ----------------
125
        qval : int
126
            The length of each q-gram. Using this parameter and tokenizer=None
127
            will cause the instance to use the QGram tokenizer with this
128
            q value.
129
        metric : _Distance
130
            A string distance measure class for use in the ``soft`` and
131
            ``fuzzy`` variants.
132
        threshold : float
133
            A threshold value, similarities above which are counted as
134
            members of the intersection for the ``fuzzy`` variant.
135
        alphabet : Counter, collection, int, or None
136
            This represents the alphabet of possible tokens.
137
138
                - If a Counter is supplied, it is used directly in computing
139
                  the complement of the tokens in both sets.
140
                - If a collection is supplied, it is converted to a Counter
141
                  and used directly. In the case of a single string being
142
                  supplied and the QGram tokenizer being used, the full
143
                  alphabet is inferred (i.e.
144
                  :math:`len(set(alphabet+QGrams.start\_stop))^{QGrams.qval}`
145
                  is used as the cardinality of the full alphabet.
146
                - If an int is supplied, it is used as the cardinality of the
147
                  full alphabet.
148
                - If None is supplied, the cardinality of the full alphabet
149
                  is inferred if QGram of QSkipgrams tokenization is used (i.e.
150
                  :math:`28^{QGrams.qval}` is used as the cardinality of the
151
                  full alphabet or :math:`26` if QGrams.qval is 1, which
152
                  assumes the strings are English language strings and only
153
                  contain letters of a single case). Otherwise, the cardinality
154
                  of the complement of the total will be 0.
155
        normalizer : str
156
            This represents the normalization applied to the values in the
157
            2x2 contingency table prior to any of the cardinality (\*_card)
158
            methods returning a value. By default, no normalization is applied,
159
            but the following values are supported:
160
161
                - ``proportional`` : :math:`\frac{x}{n}`, where n is the total
162
                  population
163
                - ``log`` : :math:`log(1+x)`
164
                - ``exp`` : :math:`e^x`
165
                - ``laplace`` : :math:`x+1`
166
                - ``inverse`` : :math:`\frac{1}{x}`
167
                - ``complement`` : :math:`n-x`, where n is the total population
168
        internal_assignment_problem : bool
169
            When using ``linkage`` as the intersection type (i.e. group
170
            linkage), this forces use of the internal implementation to solve
171
            the assignment problem, rather than scipy's linear_sum_assignment.
172
173
        .. versionadded:: 0.4.0
174
175
        """
176 1
        super(_TokenDistance, self).__init__(
177
            intersection_type=intersection_type, **kwargs
178
        )
179
180 1
        qval = 2 if 'qval' not in self.params else self.params['qval']
181 1
        self.params['tokenizer'] = (
182
            tokenizer
183
            if tokenizer is not None
184
            else WhitespaceTokenizer()
185
            if qval == 0
186
            else QGrams(qval=qval, start_stop='$#', skip=0, scaler=None)
187
        )
188
189 1
        if hasattr(self.params['tokenizer'], 'qval'):
190 1
            if isinstance(self.params['tokenizer'].qval, int):
191 1
                qvals = [self.params['tokenizer'].qval]
192
            else:
193 1
                qvals = list(self.params['tokenizer'].qval)
194
        else:
195 1
            qvals = []
196
197 1
        if 'alphabet' in self.params:
198 1
            if isinstance(self.params['alphabet'], str):
199 1
                self.params['alphabet'] = set(self.params['alphabet'])
200 1
                if isinstance(self.params['tokenizer'], (QGrams, QSkipgrams)):
201 1
                    self.params['alphabet'] |= set(
202
                        self.params['tokenizer'].start_stop
203
                    )
204 1
                    self.params['alphabet'] = sum(
205
                        len(self.params['alphabet']) ** qval for qval in qvals
206
                    )
207 1
            if hasattr(self.params['alphabet'], '__len__') and not isinstance(
208
                self.params['alphabet'], Counter
209
            ):
210 1
                self.params['alphabet'] = len(self.params['alphabet'])
211 1
            elif self.params['alphabet'] is None and isinstance(
212
                self.params['tokenizer'], (QGrams, QSkipgrams)
213
            ):
214 1
                self.params['alphabet'] = sum(
215
                    28 ** qval if qval > 1 else 26 for qval in qvals
216
                )
217
        else:
218 1
            if isinstance(self.params['tokenizer'], (QGrams, QSkipgrams)):
219 1
                self.params['alphabet'] = sum(
220
                    28 ** qval if qval > 1 else 26 for qval in qvals
221
                )
222
            else:
223 1
                self.params['alphabet'] = None
224
225 1
        if intersection_type == 'soft':
226 1
            if 'metric' not in self.params or self.params['metric'] is None:
227 1
                self.params['metric'] = DamerauLevenshtein()
228 1
            self._lcprefix = LCPrefix()
229 1
            self._intersection = self._soft_intersection
230 1
        elif intersection_type == 'fuzzy':
231 1
            if 'metric' not in self.params or self.params['metric'] is None:
232 1
                self.params['metric'] = Levenshtein()
233 1
            if 'threshold' not in self.params:
234 1
                self.params['threshold'] = 0.8
235 1
            self._intersection = self._fuzzy_intersection
236 1
        elif intersection_type == 'linkage':
237 1
            if 'metric' not in self.params or self.params['metric'] is None:
238 1
                self.params['metric'] = DamerauLevenshtein()
239 1
            if 'threshold' not in self.params:
240 1
                self.params['threshold'] = 0.1
241 1
            self._intersection = self._group_linkage_intersection
242
        else:
243 1
            self._intersection = self._crisp_intersection
244
245 1
        self._src_tokens = Counter()
246 1
        self._tar_tokens = Counter()
247 1
        self._population_card_value = 0
248
249
        # initialize normalizer
250 1
        self.normalizer = self._norm_none
251
252 1
        self._norm_dict = {
253
            'proportional': self._norm_proportional,
254
            'log': self._norm_log,
255
            'exp': self._norm_exp,
256
            'laplace': self._norm_laplace,
257
            'inverse': self._norm_inverse,
258
            'complement': self._norm_complement,
259
        }
260
261 1
    def _norm_none(self, x, _squares, _pop):
262 1
        return x
263
264 1
    def _norm_proportional(self, x, _squares, pop):
265 1
        return x / max(1, pop)
266
267 1
    def _norm_log(self, x, _squares, _pop):
268 1
        return log1p(x)
269
270 1
    def _norm_exp(self, x, _squares, _pop):
271 1
        return exp(x)
272
273 1
    def _norm_laplace(self, x, squares, _pop):
274 1
        return x + squares
275
276 1
    def _norm_inverse(self, x, _squares, pop):
277 1
        return 1 / x if x else pop
278
279 1
    def _norm_complement(self, x, _squares, pop):
280 1
        return pop - x
281
282 1
    def _tokenize(self, src, tar):
283
        """Return the Q-Grams in src & tar.
284
285
        Parameters
286
        ----------
287
        src : str
288
            Source string (or QGrams/Counter objects) for comparison
289
        tar : str
290
            Target string (or QGrams/Counter objects) for comparison
291
292
        Returns
293
        -------
294
        tuple of Counters
295
            Q-Grams
296
297
        Examples
298
        --------
299
        >>> pe = _TokenDistance()
300
        >>> pe._tokenize('AT', 'TT')._get_tokens()
301
        (Counter({'$A': 1, 'AT': 1, 'T#': 1}),
302
         Counter({'$T': 1, 'TT': 1, 'T#': 1}))
303
304
305
        .. versionadded:: 0.1.0
306
        .. versionchanged:: 0.3.6
307
            Encapsulated in class
308
309
        """
310 1
        self._src_orig = src
311 1
        self._tar_orig = tar
312
313 1
        if isinstance(src, Counter):
314 1
            self._src_tokens = src
315
        else:
316 1
            self._src_tokens = (
317
                self.params['tokenizer'].tokenize(src).get_counter()
318
            )
319 1
        if isinstance(src, Counter):
320 1
            self._tar_tokens = tar
321
        else:
322 1
            self._tar_tokens = (
323
                self.params['tokenizer'].tokenize(tar).get_counter()
324
            )
325
326 1
        self._population_card_value = self._calc_population_card()
327
328
        # Set up the normalizer, a function of two variables:
329
        # x is the value in the contingency table square(s)
330
        # n is the number of squares that x represents
331 1
        if (
332
            'normalizer' in self.params
333
            and self.params['normalizer'] in self._norm_dict
334
        ):
335 1
            self.normalizer = self._norm_dict[self.params['normalizer']]
336
337 1
        return self
338
339 1
    def _get_tokens(self):
340
        """Return the src and tar tokens as a tuple."""
341 1
        return self._src_tokens, self._tar_tokens
342
343 1
    def _src_card(self):
344
        r"""Return the cardinality of the tokens in the source set."""
345 1
        return self.normalizer(
346
            sum(abs(val) for val in self._src_tokens.values()),
347
            2,
348
            self._population_card_value,
349
        )
350
351 1
    def _src_only(self):
352
        r"""Return the src tokens minus the tar tokens.
353
354
        For (multi-)sets S and T, this is :math:`S \setminus T`.
355
        """
356 1
        src_only = self._src_tokens - self._intersection()
357 1
        if self.params['intersection_type'] != 'crisp':
358 1
            src_only -= self._intersection() - self._crisp_intersection()
359 1
        return src_only
360
361 1
    def _src_only_card(self):
362
        """Return the cardinality of the tokens only in the source set."""
363 1
        return self.normalizer(
364
            sum(abs(val) for val in self._src_only().values()),
365
            1,
366
            self._population_card_value,
367
        )
368
369 1
    def _tar_card(self):
370
        r"""Return the cardinality of the tokens in the target set."""
371 1
        return self.normalizer(
372
            sum(abs(val) for val in self._tar_tokens.values()),
373
            2,
374
            self._population_card_value,
375
        )
376
377 1
    def _tar_only(self):
378
        r"""Return the tar tokens minus the src tokens.
379
380
        For (multi-)sets S and T, this is :math:`T \setminus S`.
381
        """
382 1
        tar_only = self._tar_tokens - self._intersection()
383 1
        if self.params['intersection_type'] != 'crisp':
384 1
            tar_only -= self._intersection() - self._crisp_intersection()
385 1
        return tar_only
386
387 1
    def _tar_only_card(self):
388
        """Return the cardinality of the tokens only in the target set."""
389 1
        return self.normalizer(
390
            sum(abs(val) for val in self._tar_only().values()),
391
            1,
392
            self._population_card_value,
393
        )
394
395 1
    def _symmetric_difference(self):
396
        r"""Return the symmetric difference of tokens from src and tar.
397
398
        For (multi-)sets S and T, this is :math:`S \triangle T`.
399
        """
400 1
        return self._src_only() + self._tar_only()
401
402 1
    def _symmetric_difference_card(self):
403
        """Return the cardinality of the symmetric difference."""
404 1
        return self.normalizer(
405
            sum(abs(val) for val in self._symmetric_difference().values()),
406
            2,
407
            self._population_card_value,
408
        )
409
410 1
    def _total(self):
411
        """Return the sum of the sets.
412
413
        For (multi-)sets S and T, this is :math:`S + T`.
414
415
        In the case of multisets, this counts values in the interesection
416
        twice. In the case of sets, this is identical to the union.
417
        """
418 1
        return self._src_tokens + self._tar_tokens
419
420 1
    def _total_card(self):
421
        """Return the cardinality of the complement of the total."""
422 1
        return self.normalizer(
423
            sum(abs(val) for val in self._total().values()),
424
            3,
425
            self._population_card_value,
426
        )
427
428 1
    def _total_complement_card(self):
429
        """Return the cardinality of the complement of the total."""
430 1
        if self.params['alphabet'] is None:
431 1
            return self.normalizer(0, 1, self._population_card_value)
432 1
        elif isinstance(self.params['alphabet'], Counter):
433 1
            return self.normalizer(
434
                max(
435
                    0,
436
                    sum(
437
                        abs(val)
438
                        for val in (
439
                            self.params['alphabet'] - self._total()
440
                        ).values()
441
                    ),
442
                ),
443
                1,
444
                self._population_card_value,
445
            )
446 1
        return self.normalizer(
447
            max(0, self.params['alphabet'] - len(self._total().values())),
448
            1,
449
            self._population_card_value,
450
        )
451
452 1
    def _calc_population_card(self):
453
        """Return the cardinality of the population."""
454 1
        save_normalizer = self.normalizer
455 1
        self.normalizer = self._norm_none
456 1
        pop = self._total_card() + self._total_complement_card()
457 1
        self.normalizer = save_normalizer
458 1
        return pop
459
460 1
    def _population_card(self):
461
        """Return the cardinality of the population."""
462 1
        return self.normalizer(
463
            self._population_card_value, 4, self._population_card_value
464
        )
465
466 1
    def _population_unique_card(self):
467
        """Return the cardinality of the population minus the intersection."""
468 1
        return self.normalizer(
469
            self._population_card_value - self._intersection_card(),
470
            4,
471
            self._population_card_value,
472
        )
473
474 1
    def _union(self):
475
        r"""Return the union of tokens from src and tar.
476
477
        For (multi-)sets S and T, this is :math:`S \cup T`.
478
        """
479 1
        union = self._total() - self._intersection()
480 1
        if self.params['intersection_type'] != 'crisp':
481 1
            union -= self._intersection() - self._crisp_intersection()
482 1
        return union
483
484 1
    def _union_card(self):
485
        """Return the cardinality of the union."""
486 1
        return self.normalizer(
487
            sum(abs(val) for val in self._union().values()),
488
            3,
489
            self._population_card_value,
490
        )
491
492 1
    def _difference(self):
493
        """Return the difference of the tokens, supporting negative values."""
494 1
        _src_copy = Counter(self._src_tokens)
495 1
        _src_copy.subtract(self._tar_tokens)
496 1
        return _src_copy
497
498 1
    def _crisp_intersection(self):
499
        r"""Return the intersection of tokens from src and tar.
500
501
        For (multi-)sets S and T, this is :math:`S \cap T`.
502
        """
503 1
        return self._src_tokens & self._tar_tokens
504
505 1
    def _soft_intersection(self):
506
        """Return the soft intersection of the tokens in src and tar.
507
508
        This implements the soft intersection defined by :cite:`Russ:2014`.
509
        """
510 1
        intersection = self._crisp_intersection()
511 1
        src_only = self._src_tokens - self._tar_tokens
512 1
        tar_only = self._tar_tokens - self._src_tokens
513
514 1
        def _membership(src, tar):
515 1
            greater_length = max(len(src), len(tar))
516 1
            return (
517
                max(
518
                    greater_length - self.params['metric'].dist_abs(src, tar),
519
                    self._lcprefix.dist_abs(src, tar),
520
                )
521
                / greater_length
522
            )
523
524
        # Dictionary ordering is important for reproducibility, so insertion
525
        # order needs to be controlled and retained.
526 1
        memberships = OrderedDict(
527
            ((src, tar), _membership(src, tar))
528
            for src, tar in sorted(product(src_only, tar_only))
529
        )
530
531 1
        while memberships:
532 1
            src_tok, tar_tok = max(memberships, key=memberships.get)
533 1
            if memberships[src_tok, tar_tok] > 0.0:
534 1
                pairings = min(src_only[src_tok], tar_only[tar_tok])
535 1
                if pairings:
536 1
                    intersection[src_tok] += (
537
                        memberships[src_tok, tar_tok] * pairings / 2
538
                    )
539 1
                    intersection[tar_tok] += (
540
                        memberships[src_tok, tar_tok] * pairings / 2
541
                    )
542 1
                    src_only[src_tok] -= pairings
543 1
                    tar_only[tar_tok] -= pairings
544 1
            del memberships[src_tok, tar_tok]
545
546 1
        return intersection
547
548 1
    def _fuzzy_intersection(self):
549
        r"""Return the fuzzy intersection of the tokens in src and tar.
550
551
        This implements the fuzzy intersection defined by :cite:`Wang:2014`.
552
553
        For two sets X and Y, the intersection :cite:`Wang:2014` is the sum of
554
        similarities of all tokens in the two sets that are greater than or
555
        equal to some threshold value (:math:`\delta`).
556
557
        The lower bound of on this intersection and the value when
558
        :math:`\delta = 1.0`, is the crisp intersection. Tokens shorter than
559
        :math:`\frac{\delta}{1-\delta}`, 4 in the case of the default threshold
560
        :math:`\delta = 0.8`, must match exactly to be included in the
561
        intersection.
562
563
564
        .. versionadded:: 0.4.0
565
566
        """
567 1
        intersection = self._crisp_intersection()
568 1
        src_only = self._src_tokens - self._tar_tokens
569 1
        tar_only = self._tar_tokens - self._src_tokens
570
571 1
        pair = {}
572 1
        for src_tok in sorted(src_only):
573 1
            for tar_tok in sorted(tar_only):
574 1
                sim = self.params['metric'].sim(src_tok, tar_tok)
575 1
                if sim >= self.params['threshold']:
576 1
                    pair[(src_tok, tar_tok)] = sim
577
578 1
        for src_tok, tar_tok in sorted(pair, key=pair.get, reverse=True):
579 1
            pairings = min(src_only[src_tok], tar_only[tar_tok])
580 1
            if pairings:
581 1
                sim = pair[(src_tok, tar_tok)]
582
583 1
                intersection[src_tok] += sim / 2 * pairings
584 1
                intersection[tar_tok] += sim / 2 * pairings
585
586 1
                src_only[src_tok] -= pairings
587 1
                tar_only[tar_tok] -= pairings
588
589
        """
590
        # Here is a slightly different optimization method, which is even
591
        # greedier than the above.
592
        # ordered by sim*pairings rather than just sim
593
594
        pair = {}
595
        for src_tok in sorted(src_only):
596
            for tar_tok in sorted(tar_only):
597
                sim = self.params['metric'].sim(src_tok, tar_tok)
598
                if sim >= self.params['threshold']:
599
                    pairings = min(src_only[src_tok], tar_only[tar_tok])
600
                    pair[(src_tok, tar_tok)] = sim*pairings
601
602
        for src_tok, tar_tok in sorted(pair, key=pair.get, reverse=True):
603
            pairings = min(src_only[src_tok], tar_only[tar_tok])
604
            if pairings:
605
                sim = pair[(src_tok, tar_tok)]
606
607
                intersection[src_tok] += sim / 2
608
                intersection[tar_tok] += sim / 2
609
610
                src_only[src_tok] -= pairings
611
                tar_only[tar_tok] -= pairings
612
        """
613
614 1
        return intersection
615
616 1
    def _group_linkage_intersection(self):
617
        r"""Return the group linkage intersection of the tokens in src and tar.
618
619
        This is based on group linkage, as defined by :cite:`On:2007`.
620
621
        Most of this method is concerned with solving the assignment problem,
622
        in order to find the weight of the maximum weight bipartite matching.
623
        If the system has SciPy installed, we use it's linear_sum_assignment
624
        function to get the assignments. Otherwise, we use the Hungarian
625
        algorithm of Munkres :cite:`Munkres:1957`, implemented in Python &
626
        Numpy.
627
628
        .. versionadded:: 0.4.0
629
630
        """
631 1
        intersection = self._crisp_intersection()
632 1
        src_only = sorted(self._src_tokens - self._tar_tokens)
633 1
        tar_only = sorted(self._tar_tokens - self._src_tokens)
634
635 1
        if linear_sum_assignment and not (
636
            'internal_assignment_problem' in self.params
637
            and self.params['internal_assignment_problem']
638
        ):
639 1
            arr = np_zeros((len(tar_only), len(src_only)))
640
641 1
            for col in range(len(src_only)):
642 1
                for row in range(len(tar_only)):
643 1
                    arr[row, col] = self.params['metric'].dist(
644
                        src_only[col], tar_only[row]
645
                    )
646
647 1
            for row, col in zip(*linear_sum_assignment(arr)):
648 1
                sim = 1.0 - arr[row, col]
649 1
                if sim >= self.params['threshold']:
650 1
                    intersection[src_only[col]] += (sim / 2) * (
651
                        self._src_tokens - self._tar_tokens
652
                    )[src_only[col]]
653 1
                    intersection[tar_only[row]] += (sim / 2) * (
654
                        self._tar_tokens - self._src_tokens
655
                    )[tar_only[row]]
656
        else:
657 1
            n = max(len(tar_only), len(src_only))
658 1
            arr = np_zeros((n, n), dtype=float)
659
660 1
            for col in range(len(src_only)):
661 1
                for row in range(len(tar_only)):
662 1
                    arr[row, col] = self.params['metric'].dist(
663
                        src_only[col], tar_only[row]
664
                    )
665
666 1
            src_only += [''] * (n - len(src_only))
667 1
            tar_only += [''] * (n - len(tar_only))
668
669 1
            orig_sim = 1 - np_copy(arr)
670
671
            # Step 1
672 1
            for row in range(n):
673 1
                arr[row, :] -= arr[row, :].min()
674
            # Step 2
675 1
            for col in range(n):
676 1
                arr[:, col] -= arr[:, col].min()
677
678 1
            while True:
679
                # Step 3
680 1
                assignments = {}
681
682 1
                allocated_cols = set()
683 1
                allocated_rows = set()
684 1
                assigned_rows = set()
685 1
                assigned_cols = set()
686
687 1
                for row in range(n):
688 1
                    if (arr[row, :] == 0.0).sum() == 1:
689 1
                        col = arr[row, :].argmin()
690 1
                        if col not in allocated_cols:
691 1
                            assignments[row, col] = orig_sim[row, col]
692 1
                            allocated_cols.add(col)
693 1
                            assigned_rows.add(row)
694 1
                            assigned_cols.add(col)
695
696 1
                for col in range(n):
697 1
                    if (arr[:, col] == 0.0).sum() == 1:
698 1
                        row = arr[:, col].argmin()
699 1
                        if row not in allocated_rows:
700 1
                            assignments[row, col] = orig_sim[row, col]
701 1
                            allocated_rows.add(row)
702 1
                            assigned_rows.add(row)
703 1
                            assigned_cols.add(col)
704
705 1
                if len(assignments) == n:
706 1
                    break
707
708 1
                marked_rows = {_ for _ in range(n) if _ not in assigned_rows}
709 1
                marked_cols = set()
710 1
                for row in sorted(set(marked_rows)):
711 1
                    for col, mark in enumerate(arr[row, :] == 0.0):
712 1
                        if mark:
713 1
                            marked_cols.add(col)
714 1
                            for row2 in range(n):
715 1
                                if (row2, col) in assignments:
716 1
                                    marked_rows.add(row2)
717
718 1
                if n - len(marked_rows) + len(marked_cols) == n:
719
                    # We have sufficient lines
720 1
                    for col in range(n):
721 1
                        row = arr[:, col].argmin()
722 1
                        assignments[row, col] = orig_sim[row, col]
723 1
                    break
724
725
                # Step 4
726 1
                min_val = arr[tuple(marked_rows), :][
727
                    :, sorted(set(range(n)) - marked_cols)
728
                ].min()
729 1
                for row in range(n):
730 1
                    for col in range(n):
731 1
                        if row in marked_rows and col not in marked_cols:
732 1
                            arr[row, col] -= min_val
733 1
                        elif row not in marked_rows and col in marked_cols:
734 1
                            arr[row, col] += min_val
735
736 1
            for row, col in assignments.keys():
737 1
                sim = orig_sim[row, col]
738 1
                if sim >= self.params['threshold']:
739 1
                    intersection[src_only[col]] += (sim / 2) * (
740
                        self._src_tokens - self._tar_tokens
741
                    )[src_only[col]]
742 1
                    intersection[tar_only[row]] += (sim / 2) * (
743
                        self._tar_tokens - self._src_tokens
744
                    )[tar_only[row]]
745
746 1
        return intersection
747
748 1
    def _intersection_card(self):
749
        """Return the cardinality of the intersection."""
750 1
        return self.normalizer(
751
            sum(abs(val) for val in self._intersection().values()),
752
            1,
753
            self._population_card_value,
754
        )
755
756 1
    def _intersection(self):
757
        """Return the intersection.
758
759
        This function may be overridden by setting the intersection_type during
760
        initialization.
761
        """
762
        return self._crisp_intersection()  # pragma: no cover
763
764 1
    def _get_confusion_table(self):
765
        """Return the token counts as a ConfusionTable object."""
766 1
        return ConfusionTable(
767
            self._intersection_card(),
768
            self._total_complement_card(),
769
            self._src_only_card(),
770
            self._tar_only_card(),
771
        )
772
773
774
if __name__ == '__main__':
775
    import doctest
776
777
    doctest.testmod()
778