Passed
Push — master ( 7f7a74...2c655b )
by Daniel
05:50
created

amd.compare.emd()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 5
rs 10
c 0
b 0
f 0
cc 1
nop 3
1
"""Functions for comparing AMDs and PDDs of crystals.
2
"""
3
4
import inspect
5
from typing import List, Optional, Union, Tuple, Callable, Sequence
6
from functools import partial
7
from itertools import combinations
8
from pathlib import Path
9
10
import numpy as np
11
import numpy.typing as npt
12
import pandas as pd
13
from scipy.spatial.distance import cdist, pdist, squareform
14
from sklearn.neighbors import NearestNeighbors
15
from joblib import Parallel, delayed
16
import tqdm
17
18
from .io import CifReader, CSDReader
19
from .calculate import AMD, PDD
20
from ._emd import network_simplex
21
from .periodicset import PeriodicSet
22
23
FloatArray = npt.NDArray[np.floating]
24
IntArray = npt.NDArray[np.integer]
25
26
__all__ = [
27
    'compare',
28
    'EMD',
29
    'AMD_cdist',
30
    'AMD_pdist',
31
    'PDD_cdist',
32
    'PDD_pdist',
33
    'emd'
34
]
35
36
_SingleCompareInput = Union[PeriodicSet, str]
37
CompareInput = Union[_SingleCompareInput, List[_SingleCompareInput]]
38
39
40
def compare(
41
        crystals: CompareInput,
42
        crystals_: Optional[CompareInput] = None,
43
        by: str = 'AMD',
44
        k: int = 100,
45
        n_neighbors: Optional[int] = None,
46
        csd_refcodes: bool = False,
47
        verbose: bool = True,
48
        **kwargs
49
) -> pd.DataFrame:
50
    r"""Given one or two sets of crystals, compare by AMD or PDD and
51
    return a pandas DataFrame of the distance matrix.
52
53
    Given one or two paths to CIFs, periodic sets, CSD refcodes or lists
54
    thereof, compare by AMD or PDD and return a pandas DataFrame of the
55
    distance matrix. Default is to comapre by AMD with k = 100. Accepts
56
    any keyword arguments accepted by
57
    :class:`CifReader <.io.CifReader>`,
58
    :class:`CSDReader <.io.CSDReader>` and functions from
59
    :mod:`.compare`.
60
61
    Parameters
62
    ----------
63
    crystals : list of str or :class:`PeriodicSet <.periodicset.PeriodicSet>`
64
        A path, :class:`PeriodicSet <.periodicset.PeriodicSet>`, tuple
65
        or a list of those.
66
    crystals\_ : list of str or :class:`PeriodicSet <.periodicset.PeriodicSet>`, optional
67
        A path, :class:`PeriodicSet <.periodicset.PeriodicSet>`, tuple
68
        or a list of those.
69
    by : str, default 'AMD'
70
        Use AMD or PDD to compare crystals.
71
    k : int, default 100
72
        Parameter for AMD/PDD, the number of neighbor atoms to consider
73
        for each atom in a unit cell.
74
    n_neighbors : int, deafult None
75
        Find a number of nearest neighbors instead of a full distance
76
        matrix between crystals.
77
    csd_refcodes : bool, optional, csd-python-api only
78
        Interpret ``crystals`` and ``crystals_`` as CSD refcodes or
79
        lists thereof, rather than paths.
80
    verbose: bool, optional
81
        If True, prints a progress bar during reading, calculating and
82
        comparing items.
83
    **kwargs :
84
        Any keyword arguments accepted by the ``amd.CifReader``,
85
        ``amd.CSDReader``, ``amd.PDD`` and functions used to compare:
86
        ``reader``, ``remove_hydrogens``, ``disorder``,
87
        ``heaviest_component``, ``molecular_centres``,
88
        ``show_warnings``, (from class:`CifReader <.io.CifReader>`),
89
        ``refcode_families`` (from :class:`CSDReader <.io.CSDReader>`),
90
        ``collapse_tol`` (from :func:`PDD <.calculate.PDD>`),
91
        ``metric``, ``low_memory``
92
        (from :func:`AMD_pdist <.compare.AMD_pdist>`), ``metric``,
93
        ``backend``, ``n_jobs``, ``verbose``,
94
        (from :func:`PDD_pdist <.compare.PDD_pdist>`), ``algorithm``,
95
        ``leaf_size``, ``metric``, ``p``, ``metric_params``, ``n_jobs``
96
        (from :func:`_nearest_items <.compare._nearest_items>`).
97
98
    Returns
99
    -------
100
    df : :class:`pandas.DataFrame`
101
        DataFrame of the distance matrix for the given crystals compared
102
        by the chosen invariant.
103
104
    Raises
105
    ------
106
    ValueError
107
        If by is not 'AMD' or 'PDD', if either set given have no valid
108
        crystals to compare, or if crystals or crystals\_ are an invalid
109
        type.
110
111
    Examples
112
    --------
113
    Compare everything in a .cif (deafult, AMD with k=100)::
114
115
        df = amd.compare('data.cif')
116
117
    Compare everything in one cif with all crystals in all cifs in a
118
    directory (PDD, k=50)::
119
120
        df = amd.compare('data.cif', 'dir/to/cifs', by='PDD', k=50)
121
122
    **Examples (csd-python-api only)**
123
124
    Compare two crystals by CSD refcode (PDD, k=50)::
125
126
        df = amd.compare('DEBXIT01', 'DEBXIT02', csd_refcodes=True, by='PDD', k=50)
127
128
    Compare everything in a refcode family (AMD, k=100)::
129
130
        df = amd.compare('DEBXIT', csd_refcodes=True, families=True)
131
    """
132
133
    def _default_kwargs(func: Callable) -> dict:
134
        """Get the default keyword arguments from ``func``, if any
135
        arguments are in ``kwargs`` then replace with the value in
136
        ``kwargs`` instead of the default.
137
        """
138
        return {
139
            k: v.default for k, v in inspect.signature(func).parameters.items()
140
            if v.default is not inspect.Parameter.empty
141
        }
142
143
    def _unwrap_refcode_list(
144
            refcodes: List[str], **reader_kwargs
145
    ) -> List[PeriodicSet]:
146
        """Given string or list of strings, interpret as CSD refcodes
147
        and return a list of ``PeriodicSet`` objects.
148
        """
149
        if not all(isinstance(refcode, str) for refcode in refcodes):
150
            raise TypeError(
151
                f'amd.compare(csd_refcodes=True) expects a string or list of '
152
                'strings.'
153
            )
154
        return list(CSDReader(refcodes, **reader_kwargs))
155
156
    def _unwrap_pset_list(
157
            psets: List[Union[str, PeriodicSet]], **reader_kwargs
158
    ) -> List[PeriodicSet]:
159
        """Given a list of strings or ``PeriodicSet`` objects, interpret
160
        strings as paths and unwrap all items into one list of
161
        ``PeriodicSet``s.
162
        """
163
        ret = []
164
        for item in psets:
165
            if isinstance(item, PeriodicSet):
166
                ret.append(item)
167
            else:
168
                try:
169
                    path = Path(item)
170
                except TypeError:
171
                    raise ValueError(
172
                        'amd.compare() expects strings or amd.PeriodicSets, '
173
                        f'got {item.__class__.__name__}'
174
                    )
175
                ret.extend(CifReader(path, **reader_kwargs))
176
        return ret
177
178
    by = by.upper()
179
    if by not in ('AMD', 'PDD'):
180
        raise ValueError(
181
            "'by' parameter of amd.compare() must be 'AMD' or 'PDD' (passed "
182
            f"'{by}')"
183
        )
184
185
    # Sort out keyword arguments
186
    cifreader_kwargs = _default_kwargs(CifReader.__init__)
187
    csdreader_kwargs = _default_kwargs(CSDReader.__init__)
188
    csdreader_kwargs.pop('refcodes', None)
189
    pdd_kwargs = _default_kwargs(PDD)
190
    pdd_kwargs.pop('return_row_groups', None)
191
    compare_amds_kwargs = _default_kwargs(AMD_pdist)    
192
    compare_pdds_kwargs = _default_kwargs(PDD_pdist)
193
    nearest_items_kwargs = _default_kwargs(_nearest_items)
194
    nearest_items_kwargs.pop('XB', None)
195
    cifreader_kwargs['verbose'] = verbose
196
    csdreader_kwargs['verbose'] = verbose
197
    compare_pdds_kwargs['verbose'] = verbose
198
199
    for default_kwargs in (
200
        cifreader_kwargs, csdreader_kwargs, pdd_kwargs, compare_amds_kwargs,
201
        compare_pdds_kwargs, nearest_items_kwargs
202
    ):
203
        for kw in default_kwargs:
204
            if kw in kwargs:
205
                default_kwargs[kw] = kwargs[kw]
206
207
    # Get list of periodic sets from first input
208
    if not isinstance(crystals, list):
209
        crystals = [crystals]
210
    if csd_refcodes:
211
        crystals = _unwrap_refcode_list(crystals, **csdreader_kwargs)
212
    else:
213
        crystals = _unwrap_pset_list(crystals, **cifreader_kwargs)
214
    if not crystals:
215
        raise ValueError(
216
            'First argument passed to amd.compare() contains no valid '
217
            'crystals/periodic sets to compare.'
218
        )
219
    names = [s.name for s in crystals]
220
    if verbose:
221
        crystals = tqdm.tqdm(crystals, desc='Calculating', delay=1)
222
223
    # Get list of periodic sets from second input if given
224
    if crystals_ is None:
225
        names_ = names
226
    else:
227
        if not isinstance(crystals_, list):
228
            crystals_ = [crystals_]
229
        if csd_refcodes:
230
            crystals_ = _unwrap_refcode_list(crystals_, **csdreader_kwargs)
231
        else:
232
            crystals_ = _unwrap_pset_list(crystals_, **cifreader_kwargs)
233
        if not crystals_:
234
            raise ValueError(
235
                'Second argument passed to amd.compare() contains no '
236
                'valid crystals/periodic sets to compare.'
237
            )
238
        names_ = [s.name for s in crystals_]
239
        if verbose:
240
            crystals_ = tqdm.tqdm(crystals_, desc='Calculating', delay=1)
241
242
    if by == 'AMD':
243
244
        amds = np.empty((len(names), k), dtype=np.float64)
245
        for i, s in enumerate(crystals):
246
            amds[i] = AMD(s, k)
247
248
        if crystals_ is None:
249
            if n_neighbors is None:
250
                dm = squareform(AMD_pdist(amds, **compare_amds_kwargs))
251
                return pd.DataFrame(dm, index=names, columns=names_)
252
            else:
253
                nn_dm, inds = _nearest_items(
254
                    n_neighbors, amds, **nearest_items_kwargs
255
                )
256
                return _nearest_neighbors_dataframe(nn_dm, inds, names, names_)
257
        else:
258
            amds_ = np.empty((len(names_), k), dtype=np.float64)
259
            for i, s in enumerate(crystals_):
260
                amds_[i] = AMD(s, k)
261
262
            if n_neighbors is None:
263
                dm = AMD_cdist(amds, amds_, **compare_amds_kwargs)
264
                return pd.DataFrame(dm, index=names, columns=names_)
265
            else:
266
                nn_dm, inds = _nearest_items(
267
                    n_neighbors, amds, amds_, **nearest_items_kwargs
268
                )
269
                return _nearest_neighbors_dataframe(nn_dm, inds, names, names_)
270
271
    elif by == 'PDD':
272
273
        pdds = [PDD(s, k, **pdd_kwargs) for s in crystals]
274
275
        if crystals_ is None:
276
            dm = PDD_pdist(pdds, **compare_pdds_kwargs)
277
            if n_neighbors is None:
278
                dm = squareform(dm)
279
        else:
280
            pdds_ = [PDD(s, k, **pdd_kwargs) for s in crystals_]
281
            dm = PDD_cdist(pdds, pdds_, **compare_pdds_kwargs)
282
283
        if n_neighbors is None:
284
            return pd.DataFrame(dm, index=names, columns=names_)
285
        else:
286
            nn_dm, inds = _neighbors_from_distance_matrix(n_neighbors, dm)
287
            return _nearest_neighbors_dataframe(nn_dm, inds, names, names_)
288
289
290
def EMD(
291
        pdd: FloatArray,
292
        pdd_: FloatArray,
293
        metric: Optional[str] = 'chebyshev',
294
        return_transport: Optional[bool] = False,
295
        **kwargs
296
) -> Union[float, Tuple[float, FloatArray]]:
297
    r"""Calculate the Earth mover's distance (EMD) between two PDDs, aka
298
    the Wasserstein metric.
299
300
    Parameters
301
    ----------
302
    pdd : :class:`numpy.ndarray`
303
        PDD of a crystal.
304
    pdd\_ : :class:`numpy.ndarray`
305
        PDD of a crystal.
306
    metric : str or callable, default 'chebyshev'
307
        EMD between PDDs requires defining a distance between PDD rows.
308
        By default, Chebyshev (L-infinity) distance is chosen like with
309
        AMDs. Accepts any metric accepted by
310
        :func:`scipy.spatial.distance.cdist`.
311
    return_transport: bool, default False
312
        Instead return a tuple ``(emd, transport_plan)`` where
313
        transport_plan describes the optimal flow.
314
315
    Returns
316
    -------
317
    emd : float
318
        Earth mover's distance between two PDDs. If ``return_transport``
319
        is True, return a tuple (emd, transport_plan).
320
321
    Raises
322
    ------
323
    ValueError
324
        Thrown if ``pdd`` and ``pdd_`` do not have the same number of
325
        columns.
326
    """
327
328
    dm = cdist(pdd[:, 1:], pdd_[:, 1:], metric=metric, **kwargs)
329
    emd_dist, transport_plan = network_simplex(pdd[:, 0], pdd_[:, 0], dm)
330
331
    if return_transport:
332
        return emd_dist, transport_plan
333
    return emd_dist
334
335
336
def AMD_cdist(
337
        amds,
338
        amds_,
339
        metric: str = 'chebyshev',
340
        low_memory: bool = False,
341
        **kwargs
342
) -> FloatArray:
343
    r"""Compare two sets of AMDs with each other, returning a distance
344
    matrix. This function is essentially
345
    :func:`scipy.spatial.distance.cdist` with the default metric
346
    ``chebyshev`` and a low memory option.
347
348
    Parameters
349
    ----------
350
    amds : ArrayLike
351
        A list/array of AMDs.
352
    amds\_ : ArrayLike
353
        A list/array of AMDs.
354
    metric : str or callable, default 'chebyshev'
355
        Usually AMDs are compared with the Chebyshev (L-infinitys)
356
        distance. Accepts any metric accepted by
357
        :func:`scipy.spatial.distance.cdist`.
358
    low_memory : bool, default False
359
        Use a slower but more memory efficient method for large
360
        collections of AMDs (metric 'chebyshev' only).
361
    **kwargs :
362
        Extra arguments for ``metric``, passed to
363
        :func:`scipy.spatial.distance.cdist`.
364
365
    Returns
366
    -------
367
    dm : :class:`numpy.ndarray`
368
        A distance matrix shape ``(len(amds), len(amds_))``. ``dm[ij]``
369
        is the distance (given by ``metric``) between ``amds[i]`` and
370
        ``amds[j]``.
371
    """
372
    if low_memory:
373
        if metric != 'chebyshev':
374
            raise ValueError(
375
                "'low_memory' parameter of amd.AMD_cdist() only implemented "
376
                "with metric='chebyshev'."
377
            )
378
        dm = np.empty((len(amds), len(amds_)))
379
        for i, amd_vec in enumerate(amds):
380
            dm[i] = np.amax(np.abs(amds_ - amd_vec), axis=-1)
381
    else:
382
        dm = cdist(amds, amds_, metric=metric, **kwargs)
383
    return dm
384
385
386
def AMD_pdist(
387
        amds,
388
        metric: str = 'chebyshev',
389
        low_memory: bool = False,
390
        **kwargs
391
) -> FloatArray:
392
    """Compare a set of AMDs pairwise, returning a condensed distance
393
    matrix. This function is essentially
394
    :func:`scipy.spatial.distance.pdist` with the default metric
395
    ``chebyshev`` and a low memory parameter.
396
397
    Parameters
398
    ----------
399
    amds : ArrayLike
400
        An list/array of AMDs.
401
    metric : str or callable, default 'chebyshev'
402
        Usually AMDs are compared with the Chebyshev (L-infinity)
403
        distance. Accepts any metric accepted by
404
        :func:`scipy.spatial.distance.pdist`.
405
    low_memory : bool, default False
406
        Use a slower but more memory efficient method for large
407
        collections of AMDs (metric 'chebyshev' only).
408
    **kwargs :
409
        Extra arguments for ``metric``, passed to
410
        :func:`scipy.spatial.distance.pdist`.
411
412
    Returns
413
    -------
414
    cdm : :class:`numpy.ndarray`
415
        Returns a condensed distance matrix. Collapses a square distance
416
        matrix into a vector, just keeping the upper half. See the
417
        function :func:`squareform <scipy.spatial.distance.squareform>`
418
        from SciPy to convert to a symmetric square distance matrix.
419
    """
420
    if low_memory:
421
        m = len(amds)
422
        if metric != 'chebyshev':
423
            raise ValueError(
424
                "'low_memory' parameter of amd.AMD_pdist() only implemented "
425
                "with metric='chebyshev'."
426
            )
427
        cdm = np.empty((m * (m - 1)) // 2, dtype=np.float64)
428
        ind = 0
429
        for i in range(m):
430
            ind_ = ind + m - i - 1
431
            cdm[ind:ind_] = np.amax(np.abs(amds[i+1:] - amds[i]), axis=-1)
432
            ind = ind_
433
    else:
434
        cdm = pdist(amds, metric=metric, **kwargs)
435
    return cdm
436
437
438
def PDD_cdist(
439
        pdds: List[FloatArray],
440
        pdds_: List[FloatArray],
441
        metric: str = 'chebyshev',
442
        backend: str = 'multiprocessing',
443
        n_jobs: Optional[int] = None,
444
        verbose: bool = False,
445
        **kwargs
446
) -> FloatArray:
447
    r"""Compare two sets of PDDs with each other, returning a distance
448
    matrix. Supports parallel processing via joblib. If using
449
    parallelisation, make sure to include an if __name__ == '__main__'
450
    guard around this function.
451
452
    Parameters
453
    ----------
454
    pdds : List[:class:`numpy.ndarray`]
455
        A list of PDDs.
456
    pdds\_ : List[:class:`numpy.ndarray`]
457
        A list of PDDs.
458
    metric : str or callable, default 'chebyshev'
459
        Usually PDD rows are compared with the Chebyshev/l-infinity
460
        distance. Accepts any metric accepted by
461
        :func:`scipy.spatial.distance.cdist`.
462
    backend : str, default 'multiprocessing'
463
        The parallelization backend implementation. For a list of
464
        supported backends, see the backend argument of
465
        :class:`joblib.Parallel`.
466
    n_jobs : int, default None
467
        Maximum number of concurrent jobs for parallel processing with
468
        ``joblib``. Set to -1 to use the maximum. Using parallel
469
        processing may be slower for small inputs.
470
    verbose : bool, default False
471
        Prints a progress bar. If using parallel processing
472
        (n_jobs > 1), the verbose argument of :class:`joblib.Parallel`
473
        is used, otherwise uses tqdm.
474
    **kwargs :
475
        Extra arguments for ``metric``, passed to
476
        :func:`scipy.spatial.distance.cdist`.
477
478
    Returns
479
    -------
480
    dm : :class:`numpy.ndarray`
481
        Returns a distance matrix shape ``(len(pdds), len(pdds_))``. The
482
        :math:`ij` th entry is the distance between ``pdds[i]`` and
483
        ``pdds_[j]`` given by Earth mover's distance.
484
    """
485
486
    kwargs.pop('return_transport', None)
487
    k = pdds[0].shape[-1] - 1
488
    _verbose = 3 if verbose else 0
489
490
    if n_jobs is not None and n_jobs not in (0, 1):
491
        # TODO: put results into preallocated empty array in place
492
        dm = Parallel(backend=backend, n_jobs=n_jobs, verbose=_verbose)(
493
            delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds_[j])
494
            for i in range(len(pdds)) for j in range(len(pdds_))
495
        )
496
        dm = np.array(dm).reshape((len(pdds), len(pdds_)))
497
498
    else:
499
        n, m = len(pdds), len(pdds_)
500
        dm = np.empty((n, m))
501
        if verbose:
502
            desc = f'Comparing {len(pdds)}x{len(pdds_)} PDDs (k={k})'
503
            progress_bar = tqdm.tqdm(desc=desc, total=n*m)
504
            for i in range(n):
505
                for j in range(m):
506
                    dm[i, j] = EMD(pdds[i], pdds_[j], metric=metric, **kwargs)
507
                    progress_bar.update(1)
508
            progress_bar.close()
509
        else:
510
            for i in range(n):
511
                for j in range(m):
512
                    dm[i, j] = EMD(pdds[i], pdds_[j], metric=metric, **kwargs)
513
514
    return dm
515
516
517
def PDD_pdist(
518
        pdds: List[FloatArray],
519
        metric: str = 'chebyshev',
520
        backend: str = 'multiprocessing',
521
        n_jobs: Optional[int] = None,
522
        verbose: bool = False,
523
        **kwargs
524
) -> FloatArray:
525
    """Compare a set of PDDs pairwise, returning a condensed distance
526
    matrix. Supports parallelisation via joblib. If using
527
    parallelisation, make sure to include a if __name__ == '__main__'
528
    guard around this function.
529
530
    Parameters
531
    ----------
532
    pdds : List[:class:`numpy.ndarray`]
533
        A list of PDDs.
534
    metric : str or callable, default 'chebyshev'
535
        Usually PDD rows are compared with the Chebyshev/l-infinity
536
        distance. Accepts any metric accepted by
537
        :func:`scipy.spatial.distance.cdist`.
538
    backend : str, default 'multiprocessing'
539
        The parallelization backend implementation. For a list of
540
        supported backends, see the backend argument of
541
        :class:`joblib.Parallel`.
542
    n_jobs : int, default None
543
        Maximum number of concurrent jobs for parallel processing with
544
        ``joblib``. Set to -1 to use the maximum. Using parallel
545
        processing may be slower for small inputs.
546
    verbose : bool, default False
547
        Prints a progress bar. If using parallel processing
548
        (n_jobs > 1), the verbose argument of :class:`joblib.Parallel`
549
        is used, otherwise uses tqdm.
550
    **kwargs :
551
        Extra arguments for ``metric``, passed to
552
        :func:`scipy.spatial.distance.cdist`.
553
554
    Returns
555
    -------
556
    cdm : :class:`numpy.ndarray`
557
        Returns a condensed distance matrix. Collapses a square distance
558
        matrix into a vector, just keeping the upper half. See the
559
        function :func:`squareform <scipy.spatial.distance.squareform>`
560
        from SciPy to convert to a symmetric square distance matrix.
561
    """
562
563
    kwargs.pop('return_transport', None)
564
    k = pdds[0].shape[-1] - 1
565
    _verbose = 3 if verbose else 0
566
567
    if n_jobs is not None and n_jobs > 1:
568
        # TODO: put results into preallocated empty array in place
569
        cdm = Parallel(backend=backend, n_jobs=n_jobs, verbose=_verbose)(
570
            delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds[j])
571
            for i, j in combinations(range(len(pdds)), 2)
572
        )
573
        cdm = np.array(cdm)
574
575
    else:
576
        m = len(pdds)
577
        cdm_len = (m * (m - 1)) // 2
578
        cdm = np.empty(cdm_len, dtype=np.float64)
579
        inds = ((i, j) for i in range(0, m - 1) for j in range(i + 1, m))
580
        if verbose:
581
            desc = f'Comparing {len(pdds)} PDDs pairwise (k={k})'
582
            progress_bar = tqdm.tqdm(desc=desc, total=cdm_len)
583
            for r, (i, j) in enumerate(inds):
584
                cdm[r] = EMD(pdds[i], pdds[j], metric=metric, **kwargs)
585
                progress_bar.update(1)
586
            progress_bar.close()
587
        else:
588
            for r, (i, j) in enumerate(inds):
589
                cdm[r] = EMD(pdds[i], pdds[j], metric=metric, **kwargs)
590
591
    return cdm
592
593
594
def emd(
595
        pdd: FloatArray, pdd_: FloatArray, **kwargs
596
) -> Union[float, Tuple[float, FloatArray]]:
597
    """Alias for :func:`EMD() <.compare.EMD>`."""
598
    return EMD(pdd, pdd_, **kwargs)
599
600
601
def _neighbors_from_distance_matrix(
602
        n: int, dm: FloatArray
603
) -> Tuple[FloatArray, IntArray]:
604
    """Given a distance matrix, find the n nearest neighbors of each
605
    item.
606
607
    Parameters
608
    ----------
609
    n : int
610
        Number of nearest neighbors to find for each item.
611
    dm : :class:`numpy.ndarray`
612
        2D distance matrix or 1D condensed distance matrix.
613
614
    Returns
615
    -------
616
    (nn_dm, inds) : tuple of :class:`numpy.ndarray` s
617
        ``nn_dm[i][j]`` is the distance from item :math:`i` to its
618
        :math:`j+1` st nearest neighbor, and ``inds[i][j]`` is the
619
        index of this neighbor (:math:`j+1` since index 0 is the first
620
        nearest neighbor).
621
    """
622
623
    inds = None
624
    if len(dm.shape) == 2:
625
        inds = np.array(
626
            [np.argpartition(row, n)[:n] for row in dm], dtype=np.int64
627
        )
628
    elif len(dm.shape) == 1:
629
        dm = squareform(dm)
630
        inds = []
631
        for i, row in enumerate(dm):
632
            inds_row = np.argpartition(row, n+1)[:n+1]
633
            inds_row = inds_row[inds_row != i][:n]
634
            inds.append(inds_row)
635
        inds = np.array(inds, dtype=np.int64)
636
    else:
637
        ValueError(
638
            'amd.neighbors_from_distance_matrix() accepts a distance matrix, '
639
            'either a 2D distance matrix or a condensed distance matrix as '
640
            'returned by scipy.spatial.distance.pdist().'
641
        )
642
643
    nn_dm = np.take_along_axis(dm, inds, axis=-1)
644
    sorted_inds = np.argsort(nn_dm, axis=-1)
645
    inds = np.take_along_axis(inds, sorted_inds, axis=-1)
646
    nn_dm = np.take_along_axis(nn_dm, sorted_inds, axis=-1)
647
    return nn_dm, inds
648
649
650
def _nearest_items(
651
        n_neighbors: int,
652
        XA: FloatArray,
653
        XB: Optional[FloatArray] = None,
654
        algorithm: str = 'kd_tree',
655
        leaf_size: int = 5,
656
        metric: str = 'chebyshev',
657
        n_jobs=None,
658
        **kwargs
659
) -> Tuple[FloatArray, IntArray]:
660
    """Find nearest neighbor distances and indices between all
661
    items/observations/rows in ``XA`` and ``XB``. If ``XB`` is None,
662
    find neighbors in ``XA`` for all items in ``XA``.
663
    """
664
665
    if XB is None:
666
        XB_ = XA
667
        _n_neighbors = n_neighbors + 1
668
    else:
669
        XB_ = XB
670
        _n_neighbors = n_neighbors
671
672
    dists, inds = NearestNeighbors(
673
        n_neighbors=_n_neighbors,
674
        algorithm=algorithm,
675
        leaf_size=leaf_size,
676
        metric=metric,
677
        n_jobs=n_jobs,
678
        **kwargs
679
    ).fit(XB_).kneighbors(XA)
680
681
    if XB is not None:
682
        return dists, inds
683
684
    final_shape = (dists.shape[0], n_neighbors)
685
    dists_ = np.empty(final_shape, dtype=np.float64)
686
    inds_ = np.empty(final_shape, dtype=np.int64)
687
688
    for i, (d_row, ind_row) in enumerate(zip(dists, inds)):
689
        i_ = 0
690
        for d, j in zip(d_row, ind_row):
691
            if i == j:
692
                continue
693
            dists_[i, i_] = d
694
            inds_[i, i_] = j
695
            i_ += 1
696
            if i_ == n_neighbors:
697
                break
698
    return dists_, inds_
699
700
701
def _nearest_neighbors_dataframe(nn_dm, inds, names, names_=None):
702
    """Make ``pandas.DataFrame`` from distances to and indices of
703
    nearest neighbors from one set to another (as returned by
704
    neighbors_from_distance_matrix() or _nearest_items()).
705
    DataFrame has columns ID 1, DIST1, ID 2, DIST 2..., and names as
706
    indices.
707
    """
708
709
    if names_ is None:
710
        names_ = names
711
    data = {}
712
    for i in range(nn_dm.shape[-1]):
713
        data['ID ' + str(i+1)] = [names_[j] for j in inds[:, i]]
714
        data['DIST ' + str(i+1)] = nn_dm[:, i]
715
    return pd.DataFrame(data, index=names)
716