Issues (58)

patty/segmentation/dbscan.py (2 issues)

1
"""
2
Point cloud segmentation using the DBSCAN clustering algorithm.
3
4
DBSCAN - Density-Based Spatial Clustering of Applications with Noise.
5
Finds core samples of high density and expands clusters from them.
6
Good for data which contains clusters of similar density.
7
8
See the scikit-learn documentation for reference:
9
http://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html.
10
"""
11 1
import numpy as np
12 1
from sklearn.cluster import dbscan
13 1
from patty.utils import extract_mask
14
15
16 1
def dbscan_labels(pointcloud, epsilon, minpoints, rgb_weight=0,
17
                  algorithm='ball_tree'):
18
    '''
19
    Find an array of point-labels of clusters found by the DBSCAN algorithm.
20
21
    Parameters
22
    ----------
23
    pointcloud : pcl.PointCloud
24
        Input pointcloud.
25
    epsilon : float
26
        Neighborhood radius for DBSCAN.
27
    minpoints : integer
28
        Minimum neighborhood density for DBSCAN.
29
    rgb_weight : float, optional
30
        If non-zero, cluster on color information as well as location;
31
        specifies the relative weight of the RGB components to spatial
32
        coordinates in distance computations.
33
        (RGB values have wildly different scales than spatial coordinates.)
34
35
    Returns
36
    -------
37
    labels : Sequence
38
        A sequence of labels per point. Label -1 indicates a point does not
39
        belong to any cluster, other labels indicate the cluster number a
40
        point belongs to.
41
    '''
42
43 1
    if rgb_weight > 0:
44 1
        X = pointcloud.to_array()
0 ignored issues
show
Coding Style Naming introduced by
The name X does not conform to the variable naming conventions ([a-z_][a-z0-9_]{1,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
45 1
        X[:, 3:] *= rgb_weight
46
    else:
47 1
        X = pointcloud
0 ignored issues
show
Coding Style Naming introduced by
The name X does not conform to the variable naming conventions ([a-z_][a-z0-9_]{1,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
48
49 1
    _, labels = dbscan(X, eps=epsilon, min_samples=minpoints,
50
                       algorithm=algorithm)
51 1
    return np.asarray(labels)
52
53
54 1
def segment_dbscan(pointcloud, epsilon, minpoints, **kwargs):
55
    """Run the DBSCAN clustering+outlier detection algorithm on pointcloud.
56
57
    Parameters
58
    ----------
59
    pointcloud : pcl.PointCloud
60
        Input pointcloud.
61
    epsilon : float
62
        Neighborhood radius for DBSCAN.
63
    minpoints : integer
64
        Minimum neighborhood density for DBSCAN.
65
    **kwargs : keyword arguments, optional
66
        arguments passed to _dbscan_labels
67
68
    Returns
69
    -------
70
    clusters : iterable over registered PointCloud
71
    """
72 1
    labels = dbscan_labels(pointcloud, epsilon, minpoints, **kwargs)
73
74 1
    return (extract_mask(pointcloud, labels == label)
75
            for label in np.unique(labels[labels != -1]))
76
77
78 1
def get_largest_dbscan_clusters(pointcloud, min_return_fragment=0.7,
79
                                epsilon=0.1, minpoints=250, rgb_weight=0):
80
    '''
81
    Finds the largest clusters containing together at least min_return_fragment
82
    of the complete point cloud. In case less points belong to clusters, all
83
    clustered points are returned.
84
85
    Parameters
86
    ----------
87
    pointcloud : pcl.PointCloud
88
        Input pointcloud.
89
    min_return_fragment : float
90
        Minimum desired fragment of pointcloud to be returned
91
    epsilon : float
92
        Neighborhood radius for DBSCAN.
93
    minpoints : integer
94
        Minimum neighborhood density for DBSCAN.
95
    rgb_weight : float, optional
96
        If non-zero, cluster on color information as well as location;
97
        specifies the relative weight of the RGB components to spatial
98
        coordinates in distance computations.
99
        (RGB values have wildly different scales than spatial coordinates.)
100
101
    Returns
102
    -------
103
    cluster : pcl.PointCloud
104
        Registered pointcloud of the largest cluster found by dbscan.
105
    '''
106 1
    labels = dbscan_labels(pointcloud, epsilon, minpoints,
107
                           rgb_weight=rgb_weight).astype(np.int64)
108 1
    selection, selected_count = _get_top_labels(labels, min_return_fragment)
109
110
    # No clusters were found
111 1
    if selected_count < min_return_fragment * len(labels):
112
        return extract_mask(pointcloud, np.ones(len(pointcloud), dtype=bool))
113
    else:
114 1
        mask = [label in selection for label in labels]
115 1
        return extract_mask(pointcloud, mask)
116
117
118 1
def _get_top_labels(labels, min_return_fragment):
119
    """Return labels of the smallest set of clusters that contain at least
120
    min_return_fragment of the points (or everything)."""
121
122
    # +1 to make bincount happy, [1:] to get rid of outliers.
123 1
    bins = np.bincount(labels + 1)[1:]
124 1
    labelbinpairs = sorted(enumerate(bins), key=lambda x: x[1])
125
126 1
    total = len(labels)
127 1
    minimum = min_return_fragment * total
128 1
    selected = []
129 1
    selected_count = 0
130 1
    while selected_count < minimum and len(labelbinpairs) > 0:
131 1
        label, count = labelbinpairs.pop()
132 1
        selected.append(label)
133 1
        selected_count += count
134
    return selected, selected_count
135