Passed
Pull Request — master (#271)
by
unknown
02:36
created

DataTransformer::testSet()   B

Complexity

Conditions 4
Paths 5

Size

Total Lines 24
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
dl 0
loc 24
rs 8.6845
c 0
b 0
f 0
cc 4
eloc 12
nc 5
nop 1
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\SupportVectorMachine;
6
7
use Phpml\Exception\InvalidArgumentException;
8
9
class DataTransformer
10
{
11
    public static function trainingSet(array $samples, array $labels, bool $targets = false): string
12
    {
13
        $set = '';
14
        $numericLabels = [];
15
16
        if (!$targets) {
17
            $numericLabels = self::numericLabels($labels);
18
        }
19
20
        // Locale needs to be set to en, because libsvm only accepts dots as decimal points.
21
        $oldlocale = setlocale(LC_NUMERIC, "0");
22
        setlocale(LC_NUMERIC, 'en');
23
24
        foreach ($labels as $index => $label) {
25
            $set .= sprintf('%s %s %s', ($targets ? $label : $numericLabels[$label]), self::sampleRow($samples[$index]), PHP_EOL);
26
        }
27
28
        setlocale(LC_NUMERIC, $oldlocale);
29
30
        return $set;
31
    }
32
33
    public static function testSet(array $samples): string
34
    {
35
        if (empty($samples)) {
36
            throw new InvalidArgumentException('The array has zero elements');
37
        }
38
39
        if (!is_array($samples[0])) {
40
            $samples = [$samples];
41
        }
42
43
44
        // Locale needs to be set to en, because libsvm only accepts dots as decimal points.
45
        $oldlocale = setlocale(LC_NUMERIC, "0");
46
        setlocale(LC_NUMERIC, 'en');
47
48
        $set = '';
49
        foreach ($samples as $sample) {
50
            $set .= sprintf('0 %s %s', self::sampleRow($sample), PHP_EOL);
51
        }
52
53
        setlocale(LC_NUMERIC, $oldlocale);
54
55
        return $set;
56
    }
57
58
    public static function predictions(string $rawPredictions, array $labels): array
59
    {
60
        $numericLabels = self::numericLabels($labels);
61
        $results = [];
62
        foreach (explode(PHP_EOL, $rawPredictions) as $result) {
63
            if (isset($result[0])) {
64
                $results[] = array_search((int) $result, $numericLabels, true);
65
            }
66
        }
67
68
        return $results;
69
    }
70
71
    public static function probabilities(string $rawPredictions, array $labels): array
72
    {
73
        $numericLabels = self::numericLabels($labels);
74
75
        $predictions = explode(PHP_EOL, trim($rawPredictions));
76
77
        $header = array_shift($predictions);
78
        $headerColumns = explode(' ', $header);
79
        array_shift($headerColumns);
80
81
        $columnLabels = [];
82
        foreach ($headerColumns as $numericLabel) {
83
            $columnLabels[] = array_search((int) $numericLabel, $numericLabels, true);
84
        }
85
86
        $results = [];
87
        foreach ($predictions as $rawResult) {
88
            $probabilities = explode(' ', $rawResult);
89
            array_shift($probabilities);
90
91
            $result = [];
92
            foreach ($probabilities as $i => $prob) {
93
                $result[$columnLabels[$i]] = (float) $prob;
94
            }
95
96
            $results[] = $result;
97
        }
98
99
        return $results;
100
    }
101
102
    public static function numericLabels(array $labels): array
103
    {
104
        $numericLabels = [];
105
        foreach ($labels as $label) {
106
            if (isset($numericLabels[$label])) {
107
                continue;
108
            }
109
110
            $numericLabels[$label] = count($numericLabels);
111
        }
112
113
        return $numericLabels;
114
    }
115
116
    private static function sampleRow(array $sample): string
117
    {
118
        $row = [];
119
        foreach ($sample as $index => $feature) {
120
            $row[] = sprintf('%s:%s', $index + 1, $feature);
121
        }
122
123
        return implode(' ', $row);
124
    }
125
}
126