Passed
Push — dev ( 3b058b...d73523 )
by Konstantinos
01:35
created

ClusteringFactory.from_som()   A

Complexity

Conditions 2

Size

Total Lines 17
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 14
nop 5
dl 0
loc 17
rs 9.7
c 0
b 0
f 0
1
from .clustering import ReportingClustering
2
from .cluster import SOMCluster
3
4
import attr
5
6
@attr.s
7
class ClusteringFactory(object):
8
    # algorithms = attr.ib(init=True)
9
10
    @classmethod
11
    def inferred(cls, x):
12
        pass
13
14
    def from_som(self, dataset, som, nb_clusters, **kwargs):  #  algorithm, nb_clusters=8, ngrams=1, random_state=None, vars=None):
15
        id2members = dict.fromkeys(range(nb_clusters), set())  # cluster id => members set mapping
16
        som.cluster(nb_clusters, random_state=kwargs.get('random_state', None))
17
        # som.cluster(algorithm=self.algorithms[algorithm](nb_clusters, kwargs.get('random_state', None)))
18
        for i, arr in enumerate(som.bmus):  # iterate through the array of shape [nb_datapoints, 2]. Each row is the coordinates
19
            # of the neuron the datapoint gets attributed to (closest distance)
20
            attributed_cluster = som.clusters[arr[0], arr[1]]  # >= 0
21
            id2members[attributed_cluster].add(dataset.datapoints[i])
22
        def ex1(a_cluster):
23
            return [_ for _ in a_cluster]
24
        def ex2(datapoints, attribute):
25
            return datapoints[str(attribute)]
26
27
        return ReportingClustering([SOMCluster(cluster_members) for cluster_members in id2members.values()],
28
                                   str(dataset)+'-'+str(som),
29
                                   ex1,
30
                                   ex2,
31
                                   )
32