Issues (14)

src/FeatureExtraction/TokenCountVectorizer.php (1 issue)

1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\FeatureExtraction;
6
7
use Phpml\Tokenization\Tokenizer;
8
use Phpml\Transformer;
9
10
class TokenCountVectorizer implements Transformer
11
{
12
    /**
13
     * @var Tokenizer
14
     */
15
    private $tokenizer;
16
17
    /**
18
     * @var StopWords|null
19
     */
20
    private $stopWords;
21
22
    /**
23
     * @var float
24
     */
25
    private $minDF;
26
27
    /**
28
     * @var array
29
     */
30
    private $vocabulary = [];
31
32
    /**
33
     * @var array
34
     */
35
    private $frequencies = [];
36
37
    public function __construct(Tokenizer $tokenizer, ?StopWords $stopWords = null, float $minDF = 0.0)
38
    {
39
        $this->tokenizer = $tokenizer;
40
        $this->stopWords = $stopWords;
41
        $this->minDF = $minDF;
42
    }
43
44
    public function fit(array $samples, ?array $targets = null): void
45
    {
46
        $this->buildVocabulary($samples);
47
    }
48
49
    public function transform(array &$samples, ?array &$targets = null): void
50
    {
51
        array_walk($samples, function (string &$sample): void {
52
            $this->transformSample($sample);
53
        });
54
55
        $this->checkDocumentFrequency($samples);
56
    }
57
58
    public function getVocabulary(): array
59
    {
60
        return array_flip($this->vocabulary);
0 ignored issues
show
Bug Best Practice introduced by
The expression return array_flip($this->vocabulary) could return the type null which is incompatible with the type-hinted return array. Consider adding an additional type-check to rule them out.
Loading history...
61
    }
62
63
    private function buildVocabulary(array &$samples): void
64
    {
65
        foreach ($samples as $sample) {
66
            $tokens = $this->tokenizer->tokenize($sample);
67
            foreach ($tokens as $token) {
68
                $this->addTokenToVocabulary($token);
69
            }
70
        }
71
    }
72
73
    private function transformSample(string &$sample): void
74
    {
75
        $counts = [];
76
        $tokens = $this->tokenizer->tokenize($sample);
77
78
        foreach ($tokens as $token) {
79
            $index = $this->getTokenIndex($token);
80
            if ($index !== false) {
81
                $this->updateFrequency($token);
82
                if (!isset($counts[$index])) {
83
                    $counts[$index] = 0;
84
                }
85
86
                ++$counts[$index];
87
            }
88
        }
89
90
        foreach ($this->vocabulary as $index) {
91
            if (!isset($counts[$index])) {
92
                $counts[$index] = 0;
93
            }
94
        }
95
96
        ksort($counts);
97
98
        $sample = $counts;
99
    }
100
101
    /**
102
     * @return int|bool
103
     */
104
    private function getTokenIndex(string $token)
105
    {
106
        if ($this->isStopWord($token)) {
107
            return false;
108
        }
109
110
        return $this->vocabulary[$token] ?? false;
111
    }
112
113
    private function addTokenToVocabulary(string $token): void
114
    {
115
        if ($this->isStopWord($token)) {
116
            return;
117
        }
118
119
        if (!isset($this->vocabulary[$token])) {
120
            $this->vocabulary[$token] = count($this->vocabulary);
121
        }
122
    }
123
124
    private function isStopWord(string $token): bool
125
    {
126
        return $this->stopWords !== null && $this->stopWords->isStopWord($token);
127
    }
128
129
    private function updateFrequency(string $token): void
130
    {
131
        if (!isset($this->frequencies[$token])) {
132
            $this->frequencies[$token] = 0;
133
        }
134
135
        ++$this->frequencies[$token];
136
    }
137
138
    private function checkDocumentFrequency(array &$samples): void
139
    {
140
        if ($this->minDF > 0) {
141
            $beyondMinimum = $this->getBeyondMinimumIndexes(count($samples));
142
            foreach ($samples as &$sample) {
143
                $this->resetBeyondMinimum($sample, $beyondMinimum);
144
            }
145
        }
146
    }
147
148
    private function resetBeyondMinimum(array &$sample, array $beyondMinimum): void
149
    {
150
        foreach ($beyondMinimum as $index) {
151
            $sample[$index] = 0;
152
        }
153
    }
154
155
    private function getBeyondMinimumIndexes(int $samplesCount): array
156
    {
157
        $indexes = [];
158
        foreach ($this->frequencies as $token => $frequency) {
159
            if (($frequency / $samplesCount) < $this->minDF) {
160
                $indexes[] = $this->getTokenIndex((string) $token);
161
            }
162
        }
163
164
        return $indexes;
165
    }
166
}
167