|
1
|
|
|
<?php |
|
2
|
|
|
|
|
3
|
|
|
declare(strict_types=1); |
|
4
|
|
|
|
|
5
|
|
|
namespace Phpml\Preprocessing; |
|
6
|
|
|
|
|
7
|
|
|
use Phpml\Exception\InvalidArgumentException; |
|
8
|
|
|
|
|
9
|
|
|
final class OneHotEncoder implements Preprocessor |
|
10
|
|
|
{ |
|
11
|
|
|
/** |
|
12
|
|
|
* @var bool |
|
13
|
|
|
*/ |
|
14
|
|
|
private $ignoreUnknown; |
|
15
|
|
|
|
|
16
|
|
|
/** |
|
17
|
|
|
* @var array |
|
18
|
|
|
*/ |
|
19
|
|
|
private $categories = []; |
|
20
|
|
|
|
|
21
|
|
|
public function __construct(bool $ignoreUnknown = false) |
|
22
|
|
|
{ |
|
23
|
|
|
$this->ignoreUnknown = $ignoreUnknown; |
|
24
|
|
|
} |
|
25
|
|
|
|
|
26
|
|
|
public function fit(array $samples, ?array $targets = null): void |
|
27
|
|
|
{ |
|
28
|
|
|
foreach (array_keys(array_values(current($samples))) as $column) { |
|
29
|
|
|
$this->fitColumn($column, array_values(array_unique(array_column($samples, $column)))); |
|
30
|
|
|
} |
|
31
|
|
|
} |
|
32
|
|
|
|
|
33
|
|
|
public function transform(array &$samples, ?array &$targets = null): void |
|
34
|
|
|
{ |
|
35
|
|
|
foreach ($samples as &$sample) { |
|
36
|
|
|
$sample = $this->transformSample(array_values($sample)); |
|
37
|
|
|
} |
|
38
|
|
|
} |
|
39
|
|
|
|
|
40
|
|
|
private function fitColumn(int $column, array $values): void |
|
41
|
|
|
{ |
|
42
|
|
|
$count = count($values); |
|
43
|
|
|
foreach ($values as $index => $value) { |
|
44
|
|
|
$map = array_fill(0, $count, 0); |
|
45
|
|
|
$map[$index] = 1; |
|
46
|
|
|
$this->categories[$column][$value] = $map; |
|
47
|
|
|
} |
|
48
|
|
|
} |
|
49
|
|
|
|
|
50
|
|
|
private function transformSample(array $sample): array |
|
51
|
|
|
{ |
|
52
|
|
|
$encoded = []; |
|
53
|
|
|
foreach ($sample as $column => $feature) { |
|
54
|
|
|
if (!isset($this->categories[$column][$feature]) && !$this->ignoreUnknown) { |
|
55
|
|
|
throw new InvalidArgumentException(sprintf('Missing category "%s" for column %s in trained encoder', $feature, $column)); |
|
56
|
|
|
} |
|
57
|
|
|
|
|
58
|
|
|
$encoded = array_merge( |
|
59
|
|
|
$encoded, |
|
60
|
|
|
$this->categories[$column][$feature] ?? array_fill(0, count($this->categories[$column]), 0) |
|
61
|
|
|
); |
|
62
|
|
|
} |
|
63
|
|
|
|
|
64
|
|
|
return $encoded; |
|
65
|
|
|
} |
|
66
|
|
|
} |
|
67
|
|
|
|