StratifiedRandomSplit   A
last analyzed

Complexity

Total Complexity 6

Size/Duplication

Total Lines 38
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 6
eloc 15
dl 0
loc 38
rs 10
c 0
b 0
f 0

3 Methods

Rating   Name   Duplication   Size   Complexity  
A splitByTarget() 0 14 2
A createDatasets() 0 8 2
A splitDataset() 0 6 2
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\CrossValidation;
6
7
use Phpml\Dataset\ArrayDataset;
8
use Phpml\Dataset\Dataset;
9
10
class StratifiedRandomSplit extends RandomSplit
11
{
12
    protected function splitDataset(Dataset $dataset, float $testSize): void
13
    {
14
        $datasets = $this->splitByTarget($dataset);
15
16
        foreach ($datasets as $targetSet) {
17
            parent::splitDataset($targetSet, $testSize);
18
        }
19
    }
20
21
    /**
22
     * @return Dataset[]
23
     */
24
    private function splitByTarget(Dataset $dataset): array
25
    {
26
        $targets = $dataset->getTargets();
27
        $samples = $dataset->getSamples();
28
29
        $uniqueTargets = array_unique($targets);
30
        /** @var array $split */
31
        $split = array_combine($uniqueTargets, array_fill(0, count($uniqueTargets), []));
32
33
        foreach ($samples as $key => $sample) {
34
            $split[$targets[$key]][] = $sample;
35
        }
36
37
        return $this->createDatasets($uniqueTargets, $split);
38
    }
39
40
    private function createDatasets(array $uniqueTargets, array $split): array
41
    {
42
        $datasets = [];
43
        foreach ($uniqueTargets as $target) {
44
            $datasets[$target] = new ArrayDataset($split[$target], array_fill(0, count($split[$target]), $target));
45
        }
46
47
        return $datasets;
48
    }
49
}
50