Completed
Push — develop ( 021320...f04cc0 )
by Arkadiusz
03:15
created

StratifiedRandomSplit   A

Complexity

Total Complexity 6

Size/Duplication

Total Lines 53
Duplicated Lines 0 %

Coupling/Cohesion

Components 0
Dependencies 3

Importance

Changes 0
Metric Value
wmc 6
c 0
b 0
f 0
lcom 0
cbo 3
dl 0
loc 53
rs 10

3 Methods

Rating   Name   Duplication   Size   Complexity  
A splitDataset() 0 8 2
A splitByTarget() 0 16 2
A createDatasets() 0 9 2
1
<?php
2
3
declare (strict_types = 1);
4
5
namespace Phpml\CrossValidation;
6
7
use Phpml\Dataset\ArrayDataset;
8
use Phpml\Dataset\Dataset;
9
10
class StratifiedRandomSplit extends RandomSplit
11
{
12
    /**
13
     * @param Dataset $dataset
14
     * @param float   $testSize
15
     */
16
    protected function splitDataset(Dataset $dataset, float $testSize)
17
    {
18
        $datasets = $this->splitByTarget($dataset);
19
20
        foreach ($datasets as $targetSet) {
21
            parent::splitDataset($targetSet, $testSize);
22
        }
23
    }
24
25
    /**
26
     * @param Dataset $dataset
27
     *
28
     * @return Dataset[]|array
29
     */
30
    private function splitByTarget(Dataset $dataset): array
31
    {
32
        $targets = $dataset->getTargets();
33
        $samples = $dataset->getSamples();
34
35
        $uniqueTargets = array_unique($targets);
36
        $split = array_combine($uniqueTargets, array_fill(0, count($uniqueTargets), []));
37
38
        foreach ($samples as $key => $sample) {
39
            $split[$targets[$key]][] = $sample;
40
        }
41
42
        $datasets = $this->createDatasets($uniqueTargets, $split);
43
44
        return $datasets;
45
    }
46
47
    /**
48
     * @param array $uniqueTargets
49
     * @param array $split
50
     *
51
     * @return array
52
     */
53
    private function createDatasets(array $uniqueTargets, array $split): array
54
    {
55
        $datasets = [];
56
        foreach ($uniqueTargets as $target) {
57
            $datasets[$target] = new ArrayDataset($split[$target], array_fill(0, count($split[$target]), $target));
58
        }
59
60
        return $datasets;
61
    }
62
}
63