Total Complexity | 9 |
Total Lines | 61 |
Duplicated Lines | 0 % |
Changes | 0 |
1 | <?php |
||
10 | abstract class Split |
||
11 | { |
||
12 | /** |
||
13 | * @var array |
||
14 | */ |
||
15 | protected $trainSamples = []; |
||
16 | |||
17 | /** |
||
18 | * @var array |
||
19 | */ |
||
20 | protected $testSamples = []; |
||
21 | |||
22 | /** |
||
23 | * @var array |
||
24 | */ |
||
25 | protected $trainLabels = []; |
||
26 | |||
27 | /** |
||
28 | * @var array |
||
29 | */ |
||
30 | protected $testLabels = []; |
||
31 | |||
32 | public function __construct(Dataset $dataset, float $testSize = 0.3, ?int $seed = null) |
||
33 | { |
||
34 | if ($testSize <= 0 || $testSize >= 1) { |
||
35 | throw new InvalidArgumentException('testsize must be between 0.0 and 1.0'); |
||
36 | } |
||
37 | |||
38 | $this->seedGenerator($seed); |
||
39 | |||
40 | $this->splitDataset($dataset, $testSize); |
||
41 | } |
||
42 | |||
43 | public function getTrainSamples(): array |
||
46 | } |
||
47 | |||
48 | public function getTestSamples(): array |
||
49 | { |
||
50 | return $this->testSamples; |
||
51 | } |
||
52 | |||
53 | public function getTrainLabels(): array |
||
54 | { |
||
55 | return $this->trainLabels; |
||
56 | } |
||
57 | |||
58 | public function getTestLabels(): array |
||
59 | { |
||
60 | return $this->testLabels; |
||
61 | } |
||
62 | |||
63 | abstract protected function splitDataset(Dataset $dataset, float $testSize): void; |
||
64 | |||
65 | protected function seedGenerator(?int $seed = null): void |
||
71 | } |
||
72 | } |
||
73 | } |
||
74 |