amd.calculate.PDD()   A
last analyzed

Complexity

Conditions 3

Size

Total Lines 88
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 18
dl 0
loc 88
rs 9.5
c 0
b 0
f 0
cc 3
nop 6

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""Calculation of isometry invariants from periodic sets."""
2
3
import warnings
4
import collections
5
from typing import Tuple, Union
6
import itertools
7
8
import numpy as np
9
import numpy.typing as npt
10
import numba
11
from scipy.spatial.distance import pdist, squareform
12
13
from ._types import FloatArray
14
from ._nearest_neighbors import nearest_neighbors, nearest_neighbors_minval
15
from .periodicset import PeriodicSet
16
from .utils import diameter
17
from .globals_ import MAX_DISORDER_CONFIGS
18
19
20
__all__ = [
21
    "PDD",
22
    "AMD",
23
    "ADA",
24
    "PDA",
25
    "PDD_to_AMD",
26
    "AMD_finite",
27
    "PDD_finite",
28
    "PDD_reconstructable",
29
    "AMD_estimate",
30
]
31
32
33
def PDD(
34
    pset: PeriodicSet,
35
    k: int,
36
    lexsort: bool = True,
37
    collapse: bool = True,
38
    collapse_tol: float = 1e-4,
39
    return_row_data: bool = False,
40
) -> Union[FloatArray, Tuple[FloatArray, list]]:
41
    """Return the pointwise distance distribution (PDD) of a periodic
42
    set (usually representing a crystal).
43
44
    The PDD is a geometry based descriptor independent of choice of
45
    motif and unit cell. It is a matrix with each row corresponding to a
46
    point in the motif, starting with a weight followed by distances to
47
    the k nearest neighbors of the point.
48
49
    Parameters
50
    ----------
51
    pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
52
        A periodic set (crystal).
53
    k : int
54
        Number of neighbors considered for each point in a unit cell.
55
        The output has k + 1 columns with the first column containing
56
        weights.
57
    lexsort : bool, default True
58
        Lexicographically order rows.
59
    collapse: bool, default True
60
        Collapse duplicate rows (within ``collapse_tol`` in the
61
        Chebyshev metric).
62
    collapse_tol: float, default 1e-4
63
        If two rows are closer than ``collapse_tol`` in the Chebyshev
64
        metric, they are merged and weights are given to rows in
65
        proportion to their frequency.
66
    return_row_data: bool, default False
67
        Return a tuple ``(pdd, groups)`` where ``groups`` contains
68
        information about which rows in ``pdd`` correspond to which
69
        points. If ``pset.asym_unit`` is None, then ``groups[i]``
70
        contains indices of points in ``pset.motif`` corresponding to
71
        ``pdd[i]``. Otherwise, PDD rows correspond to points in the
72
        asymmetric unit, and ``groups[i]`` contains indices pointing to
73
        ``pset.asym_unit``.
74
75
    Returns
76
    -------
77
    pdd : :class:`numpy.ndarray`
78
        The PDD of ``pset``, a :class:`numpy.ndarray` with ``k+1``
79
        columns. If ``return_row_data`` is True, returns a tuple
80
        (:class:`numpy.ndarray`, list).
81
82
    Examples
83
    --------
84
    Make list of PDDs with ``k=100`` for crystals in data.cif::
85
86
        pdds = []
87
        for periodic_set in amd.CifReader('data.cif'):
88
            pdd = amd.PDD(periodic_set, 100)
89
            pdds.append(pdd)
90
91
    Make list of PDDs with ``k=10`` for crystals in these CSD refcode
92
    families (requires csd-python-api)::
93
94
        pdds = []
95
        for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True):
96
            pdds.append(amd.PDD(periodic_set, 10))
97
98
    Manually create a periodic set as a tuple (motif, cell)::
99
100
        # simple cubic lattice
101
        motif = np.array([[0,0,0]])
102
        cell = np.array([[1,0,0], [0,1,0], [0,0,1]])
103
        periodic_set = amd.PeriodicSet(motif, cell)
104
        cubic_pdd = amd.PDD(periodic_set, 100)
105
    """
106
107
    if not isinstance(pset, PeriodicSet):
108
        raise ValueError(
109
            f"Expected {PeriodicSet.__name__}, got {pset.__class__.__name__}"
110
        )
111
112
    weights, dists, groups = _PDD(
113
        pset, k, lexsort=lexsort, collapse=collapse, collapse_tol=collapse_tol
114
    )
115
    pdd = np.empty(shape=(len(dists), k + 1), dtype=np.float64)
116
    pdd[:, 0] = weights
117
    pdd[:, 1:] = dists
118
    if return_row_data:
119
        return pdd, groups
120
    return pdd
121
122
123
def _PDD(
124
    pset: PeriodicSet,
125
    k: int,
126
    lexsort: bool = True,
127
    collapse: bool = True,
128
    collapse_tol: float = 1e-4,
129
) -> Tuple[FloatArray, FloatArray, list[list[int]]]:
130
    """See PDD() for documentation. This core function always returns a
131
    tuple (weights, dists, groups), with weights and dists to be merged
132
    by PDD() and groups to be optionally returned.
133
    """
134
135
    asym_unit = pset.motif[pset.asym_unit]
136
    weights = pset.multiplicities / pset.motif.shape[0]
137
138
    # Disordered structures
139
    subs_disorder_info = {}  # i: [inds masked] where i is sub disordered
140
141
    if pset.disorder:
142
        # Gather which disorder assemblies must be considered
143
        _asym_mask = np.full((asym_unit.shape[0], ), fill_value=True)
144
        asm_sizes = {}
145
        for i, asm in enumerate(pset.disorder):
146
            grps = asm.groups
147
148
            # Ignore assmeblies with 1 group
149
            if len(grps) < 2:
150
                continue
151
152
            # For substitutional disorder, mask all but one atom
153
            elif asm.is_substitutional:
154
                mask_inds = [grps[j].indices[0] for j in range(1, len(grps))]
155
                keep = grps[0].indices[0]
156
                subs_disorder_info[keep] = mask_inds
157
                _asym_mask[mask_inds] = False
158
159
            else:
160
                asm_sizes[i] = len(grps)
161
162
        asm_sizes_arr = np.array(list(asm_sizes.values()))
163
        if _array_product_exceeds(asm_sizes_arr, MAX_DISORDER_CONFIGS):
164
            warnings.warn(
165
                f"Disorder configs exceeds limit "
166
                f"amd.globals_.MAX_DISORDER_CONFIGS={MAX_DISORDER_CONFIGS}, "
167
                "defaulting to majority occupancy config"
168
            )
169
            configs = [[]]
170
            for asm in pset.disorder:
171
                i, _ = max(enumerate(asm.groups), key=lambda g: g[1].occupancy)
172
                configs[0].append(i)
173
        else:
174
            configs = itertools.product(*(range(t) for t in asm_sizes.values()))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable t does not seem to be defined.
Loading history...
175
176
        # One PDD for each disorder configuration
177
        dists_list, inds_list = [], []
178
        for config_inds in configs:
179
180
            # Mask groups not selected
181
            asym_mask = _asym_mask.copy()
182
            motif_mask = np.full((pset.motif.shape[0], ), fill_value=True)
183
            for i, asm_ind in enumerate(asm_sizes.keys()):
184
                for j, grp in enumerate(pset.disorder[asm_ind].groups):
185
                    if j != config_inds[i]:
186
                        for t in grp.indices:
187
                            asym_mask[t] = False
188
                            m_i = pset.asym_unit[t]
189
                            mul = pset.multiplicities[t]
190
                            motif_mask[m_i : m_i + mul] = False
191
192
            dists = nearest_neighbors(
193
                pset.motif[motif_mask], pset.cell, asym_unit[asym_mask], k + 1
194
            )
195
            dists_list.append(dists[:, 1:])
196
            inds_list.append(np.where(asym_mask)[0])
197
198
        dists = np.vstack(dists_list)
199
        inds = list(np.concatenate(inds_list))
200
        weights = np.concatenate([weights[i] for i in inds_list])
201
        weights /= np.sum(weights)
202
203
    else:
204
        dists = nearest_neighbors(pset.motif, pset.cell, asym_unit, k + 1)
205
        dists = dists[:, 1:]
206
        inds = list(range(len(dists)))
207
208
    # Collapse rows within tolerance
209
    groups = None
210
    if collapse:
211
        weights, dists, group_labs = _merge_pdd_rows(weights, dists, collapse_tol)
212
        if dists.shape[0] != len(group_labs):
213
            groups = [[] for _ in range(weights.shape[0])]
214
            for old_ind, new_ind in enumerate(group_labs):
215
                groups[new_ind].append(int(inds[old_ind]))
216
217
    if groups is None:
218
        groups = [[int(i)] for i in inds]
219
220
    # Add back substitutionally disordered sites to group info
221
    if subs_disorder_info:
222
        for i, masked_inds in subs_disorder_info.items():
223
            for grp in groups:
224
                if i in grp:
225
                    grp.extend(masked_inds)
226
227
    if lexsort:
228
        lex_ordering = np.lexsort(dists.T[::-1])
229
        weights = weights[lex_ordering]
230
        dists = dists[lex_ordering]
231
        groups = [groups[i] for i in lex_ordering]
232
233
    return weights, dists, groups
234
235
236
def AMD(pset: PeriodicSet, k: int) -> FloatArray:
237
    """Return the average minimum distance (AMD) of a periodic set
238
    (usually representing a crystal).
239
240
    The AMD is the centroid or average of the PDD (pointwise distance
241
    distribution) and hence is also a independent of choice of motif and
242
    unit cell. It is a vector containing average distances from points
243
    to k neighbouring points.
244
245
    Parameters
246
    ----------
247
    pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
248
        A periodic set (crystal).
249
    k : int
250
        Number of neighbors considered for each point in a unit cell.
251
252
    Returns
253
    -------
254
    :class:`numpy.ndarray`
255
        The AMD of ``pset``, a :class:`numpy.ndarray` shape ``(k, )``.
256
257
    Examples
258
    --------
259
    Make list of AMDs with k = 100 for crystals in data.cif::
260
261
        amds = []
262
        for periodic_set in amd.CifReader('data.cif'):
263
            amds.append(amd.AMD(periodic_set, 100))
264
265
    Make list of AMDs with k = 10 for crystals in these CSD refcode families::
266
267
        amds = []
268
        for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True):
269
            amds.append(amd.AMD(periodic_set, 10))
270
271
    Manually create a periodic set as a tuple (motif, cell)::
272
273
        # simple cubic lattice
274
        motif = np.array([[0,0,0]])
275
        cell = np.array([[1,0,0], [0,1,0], [0,0,1]])
276
        periodic_set = amd.PeriodicSet(motif, cell)
277
        cubic_amd = amd.AMD(periodic_set, 100)
278
    """
279
    weights, dists, _ = _PDD(pset, k, lexsort=False, collapse=False)
280
    return np.average(dists, weights=weights, axis=0)
281
282
283
@numba.njit(cache=True, fastmath=True)
284
def PDD_to_AMD(pdd: FloatArray) -> FloatArray:
285
    """Calculate an AMD from a PDD, faster than computing both from
286
    scratch.
287
288
    Parameters
289
    ----------
290
    pdd : :class:`numpy.ndarray`
291
        The PDD of a periodic set as given by :class:`PDD() <.PDD>`.
292
    Returns
293
    -------
294
    :class:`numpy.ndarray`
295
        The AMD of the periodic set, so that
296
        ``amd.PDD_to_AMD(amd.PDD(pset)) == amd.AMD(pset)``
297
    """
298
299
    amd_ = np.empty((pdd.shape[-1] - 1,), dtype=np.float64)
300
    for col in range(amd_.shape[0]):
301
        v = 0
302
        for row in range(pdd.shape[0]):
303
            v += pdd[row, 0] * pdd[row, col + 1]
304
        amd_[col] = v
305
    return amd_
306
307
308
def AMD_finite(motif: FloatArray) -> FloatArray:
309
    """Return the AMD of a finite m-point set up to k = m - 1.
310
311
    Parameters
312
    ----------
313
    motif : :class:`numpy.ndarray`
314
        Collection of points.
315
316
    Returns
317
    -------
318
    :class:`numpy.ndarray`
319
        The AMD of ``motif``, a vector shape ``(motif.shape[0] - 1, )``.
320
321
    Examples
322
    --------
323
    The (L-infinity) AMD distance between finite trapezium and kite
324
    point sets, which have the same list of inter-point distances::
325
326
        trapezium = np.array([[0,0],[1,1],[3,1],[4,0]])
327
        kite      = np.array([[0,0],[1,1],[1,-1],[4,0]])
328
329
        trap_amd = amd.AMD_finite(trapezium)
330
        kite_amd = amd.AMD_finite(kite)
331
332
        l_inf_dist = np.amax(np.abs(trap_amd - kite_amd))
333
    """
334
335
    dm = np.sort(squareform(pdist(motif)), axis=-1)[:, 1:]
336
    return np.average(dm, axis=0)
337
338
339
def PDD_finite(
340
    motif: FloatArray,
341
    lexsort: bool = True,
342
    collapse: bool = True,
343
    collapse_tol: float = 1e-4,
344
    return_row_data: bool = False,
345
) -> Union[FloatArray, Tuple[FloatArray, list]]:
346
    """Return the PDD of a finite m-point set up to k = m - 1.
347
348
    Parameters
349
    ----------
350
    motif : :class:`numpy.ndarray`
351
        Collection of points.
352
    lexsort : bool, default True
353
        Lexicographically order rows.
354
    collapse: bool, default True
355
        Collapse duplicate rows (within ``collapse_tol`` in the
356
        Chebyshev metric).
357
    collapse_tol: float, default 1e-4
358
        If two rows are closer than ``collapse_tol`` in the Chebyshev
359
        metric, they are merged and weights are given to rows in
360
        proportion to their frequency.
361
    return_row_data: bool, default False
362
        If True, return a tuple ``(pdd, groups)`` where ``groups[i]``
363
        contains indices of points in ``motif`` corresponding to
364
        ``pdd[i]``.
365
366
    Returns
367
    -------
368
    pdd : :class:`numpy.ndarray`
369
        The PDD of ``motif``, a :class:`numpy.ndarray` with ``k+1``
370
        columns. If ``return_row_data`` is True, returns a tuple
371
        (:class:`numpy.ndarray`, list).
372
373
    Examples
374
    --------
375
    The PDD distance between finite trapezium and kite point sets, which
376
    have the same list of inter-point distances::
377
378
        trapezium = np.array([[0,0],[1,1],[3,1],[4,0]])
379
        kite      = np.array([[0,0],[1,1],[1,-1],[4,0]])
380
381
        trap_pdd = amd.PDD_finite(trapezium)
382
        kite_pdd = amd.PDD_finite(kite)
383
384
        dist = amd.EMD(trap_pdd, kite_pdd)
385
    """
386
387
    m = motif.shape[0]
388
    dists = np.sort(squareform(pdist(motif)), axis=-1)[:, 1:]
389
    weights = np.full((m,), 1 / m)
390
    groups = [[i] for i in range(len(dists))]
391
392
    # TODO: use _merge_pdd_rows
393
    if collapse:
394
        overlapping = pdist(dists, metric="chebyshev") <= collapse_tol
395
        if overlapping.any():
396
            groups = _collapse_into_groups(overlapping)
397
            weights = np.array([np.sum(weights[group]) for group in groups])
398
            dists = np.array(
399
                [np.average(dists[group], axis=0) for group in groups], dtype=np.float64
400
            )
401
402
    pdd = np.empty(shape=(len(weights), m), dtype=np.float64)
403
404
    if lexsort:
405
        lex_ordering = np.lexsort(np.rot90(dists))
406
        pdd[:, 0] = weights[lex_ordering]
407
        pdd[:, 1:] = dists[lex_ordering]
408
        if return_row_data:
409
            groups = [groups[i] for i in lex_ordering]
410
    else:
411
        pdd[:, 0] = weights
412
        pdd[:, 1:] = dists
413
414
    if return_row_data:
415
        return pdd, groups
416
    return pdd
417
418
419
def PDD_reconstructable(pset: PeriodicSet, lexsort: bool = True) -> FloatArray:
420
    """Return the PDD of a periodic set with ``k`` (number of columns)
421
    large enough such that the periodic set can be reconstructed from
422
    the PDD with :func:`amd.reconstruct.reconstruct`. Does NOT return
423
    weights or collapse rows.
424
425
    Parameters
426
    ----------
427
    pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
428
        A periodic set (crystal).
429
    lexsort : bool, default True
430
        Lexicographically order rows.
431
432
    Returns
433
    -------
434
    pdd : :class:`numpy.ndarray`
435
        The PDD of ``pset`` with enough columns to reconstruct ``pset``
436
        using :func:`amd.reconstruct.reconstruct`.
437
    """
438
439
    if not isinstance(pset, PeriodicSet):
440
        raise ValueError(
441
            f"Expected {PeriodicSet.__name__}, got {pset.__class__.__name__}"
442
        )
443
444
    if pset.ndim not in (2, 3):
445
        raise ValueError(
446
            "Reconstructing from PDD is only possible for 2 and 3 dimensions."
447
        )
448
    min_val = diameter(pset.cell) * 2
449
    pdd, _, _ = nearest_neighbors_minval(pset.motif, pset.cell, min_val)
450
    if lexsort:
451
        lex_ordering = np.lexsort(pdd.T[::-1])
452
        pdd = pdd[lex_ordering]
453
    return pdd
454
455
456
def AMD_estimate(pset: PeriodicSet, k: int) -> FloatArray:
457
    r"""Calculate an estimate of :class:`AMD <.AMD>` based on the
458
    :class:`PPC <.periodicset.PeriodicSet.PPC>` of ``pset``.
459
460
    Parameters
461
    ----------
462
    pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
463
        A periodic set (crystal).
464
465
    Returns
466
    -------
467
    amd_est : :class:`numpy.ndarray`
468
        An array shape (k, ), where ``amd_est[i]``
469
        :math:`= \text{PPC} \sqrt[n]{k}` in n dimensions, whose ratio
470
        with AMD has been shown to converge to 1.
471
    """
472
473
    if not isinstance(pset, PeriodicSet):
474
        raise ValueError(
475
            f"Expected {PeriodicSet.__name__}, got {pset.__class__.__name__}"
476
        )
477
    arange = np.arange(1, k + 1, dtype=np.float64)
478
    return pset.PPC() * np.power(arange, 1.0 / pset.ndim)
479
480
481
def PDA(
482
    pset: PeriodicSet,
483
    k: int,
484
    lexsort: bool = True,
485
    collapse: bool = True,
486
    collapse_tol: float = 1e-4,
487
    return_row_data: bool = False,
488
) -> Union[FloatArray, Tuple[FloatArray, list]]:
489
    """Return the pointwise deviation from asymptotic distribution,
490
    essentially a normalisation of the pointwise distance distribution
491
    of ``pset``. The PDA records how much the distances in the PDD
492
    deviate from what is expected based on the asymptotic estimate.
493
494
    The PDD of ``pset`` is a geometry based descriptor independent of
495
    choice of motif and unit cell. Its asymptotic behaviour is well
496
    understood and depends on the point density of the periodic set.
497
    The PDA is the difference between the PDD and its asymptotic curve.
498
499
    Parameters
500
    ----------
501
    pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
502
        A periodic set (crystal).
503
    k : int
504
        Number of neighbors considered for each point in a unit cell.
505
        The output has k + 1 columns with the first column containing
506
        weights.
507
    lexsort : bool, default True
508
        Lexicographically order rows.
509
    collapse: bool, default True
510
        Collapse duplicate rows (within ``collapse_tol`` in the
511
        Chebyshev metric).
512
    collapse_tol: float, default 1e-4
513
        If two rows are closer than ``collapse_tol`` in the Chebyshev
514
        metric, they are merged and weights are given to rows in
515
        proportion to their frequency.
516
    return_row_data: bool, default False
517
        Return a tuple ``(pda, groups)`` where ``groups`` contains
518
        information about which rows in ``pda`` correspond to which
519
        points. If ``pset.asym_unit`` is None, then ``groups[i]``
520
        contains indices of points in ``pset.motif`` corresponding to
521
        ``pda[i]``. Otherwise, PDA rows correspond to points in the
522
        asymmetric unit, and ``groups[i]`` contains indices pointing to
523
        ``pset.asym_unit``.
524
525
    Returns
526
    -------
527
    pda : :class:`numpy.ndarray`
528
        The PDA of ``pset``, a :class:`numpy.ndarray` with ``k+1``
529
        columns. If ``return_row_data`` is True, returns a tuple
530
        (:class:`numpy.ndarray`, list).
531
    """
532
    pdd, grps = PDD(
533
        pset,
534
        k,
535
        collapse=collapse,
536
        collapse_tol=collapse_tol,
537
        lexsort=lexsort,
538
        return_row_data=True,
539
    )
540
    pdd[:, 1:] -= AMD_estimate(pset, k)
541
    if return_row_data:
542
        return pdd, grps
543
    return pdd
544
545
546
def ADA(pset: PeriodicSet, k: int) -> FloatArray:
547
    """Return the average deviation from asymptotic, essentially a
548
    normalisation of the average minimum distance of ``pset``. The ADA
549
    records how much the distances in the AMD deviate from what is
550
    expected based on the asymptotic estimate.
551
552
    The AMD of ``pset`` is a geometry based descriptor independent of
553
    choice of motif and unit cell. Its asymptotic behaviour is well
554
    understood and depends on the point density of the periodic set.
555
    The ADA is the difference between the AMD and its asymptotic curve.
556
557
    Parameters
558
    ----------
559
    pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
560
        A periodic set (crystal).
561
    k : int
562
        Number of neighbors considered for each point in a unit cell.
563
564
    Returns
565
    -------
566
    :class:`numpy.ndarray`
567
        The ADA of ``pset``, a :class:`numpy.ndarray` shape ``(k, )``.
568
    """
569
    return AMD(pset, k) - AMD_estimate(pset, k)
570
571
572
@numba.njit(cache=True, fastmath=True)
573
def _array_product_exceeds(values, limit):
574
    """Returns False if np.prod(values) > limit."""
575
    tot = 1
576
    for i in range(len(values)):
577
        tot *= values[i]
578
        if tot > limit:
579
            return True
580
    return False
581
582
583
@numba.njit(cache=True, fastmath=True)
584
def _merge_pdd_rows(weights, dists, collapse_tol):
585
    """Collpases weights & rows of a PDD, and return an array of group
586
    labels (new indices of old rows)."""
587
588
    n, k = dists.shape
589
    group_labels = np.empty((n,), dtype=np.int64)
590
    done = set()
591
    group = 0
592
593
    for i in range(n):
594
        if i in done:
595
            continue
596
597
        group_labels[i] = group
598
599
        for j in range(i + 1, n):
600
            if j in done:
601
                continue
602
603
            grouped = True
604
            for i_ in range(k):
605
                v = np.abs(dists[i, i_] - dists[j, i_])
606
                if v > collapse_tol:
607
                    grouped = False
608
                    break
609
610
            if grouped:
611
                group_labels[j] = group
612
                done.add(j)
613
614
        group += 1
615
616
    if group == n:
617
        return weights, dists, group_labels
618
619
    weights_ = np.zeros((group,), dtype=np.float64)
620
    dists_ = np.zeros((group, k), dtype=np.float64)
621
    group_counts = np.zeros((group,), dtype=np.int64)
622
623
    for i in range(n):
624
        row = group_labels[i]
625
        weights_[row] += weights[i]
626
        dists_[row] += dists[i]
627
        group_counts[row] += 1
628
629
    for i in range(group):
630
        dists_[i] /= group_counts[i]
631
632
    return weights_, dists_, group_labels
633
634
635
def _collapse_into_groups(overlapping: npt.NDArray[np.bool_]) -> list:
636
    """Return a list of groups of indices where all indices in the same
637
    group overlap. ``overlapping`` indicates for each pair of items in a
638
    set whether or not the items overlap, in the shape of a condensed
639
    distance matrix.
640
    """
641
642
    overlapping = squareform(overlapping)
643
    group_nums = {}
644
    group = 0
645
    for i, row in enumerate(overlapping):
646
        if i not in group_nums:
647
            group_nums[i] = group
648
            group += 1
649
            for j in np.argwhere(row).T[0]:
650
                if j not in group_nums:
651
                    group_nums[j] = group_nums[i]
652
653
    groups = collections.defaultdict(list)
654
    for row_ind, group_num in sorted(group_nums.items()):
655
        groups[group_num].append(row_ind)
656
657
    return list(groups.values())
658