Completed
Push — master ( 8ac013...18c36b )
by Arkadiusz
98:48 queued 92:06
created

MnistDataset   A

Complexity

Total Complexity 11

Size/Duplication

Total Lines 89
Duplicated Lines 0 %

Coupling/Cohesion

Components 0
Dependencies 2

Importance

Changes 0
Metric Value
wmc 11
lcom 0
cbo 2
dl 0
loc 89
rs 10
c 0
b 0
f 0

3 Methods

Rating   Name   Duplication   Size   Complexity  
A __construct() 0 9 2
B readImages() 0 41 6
A readLabels() 0 26 3
1
<?php
2
3
declare(strict_types=1);
4
5
namespace Phpml\Dataset;
6
7
use Phpml\Exception\InvalidArgumentException;
8
9
/**
10
 * MNIST dataset: http://yann.lecun.com/exdb/mnist/
11
 * original mnist dataset reader: https://github.com/AndrewCarterUK/mnist-neural-network-plain-php
12
 */
13
final class MnistDataset extends ArrayDataset
14
{
15
    private const MAGIC_IMAGE = 0x00000803;
16
17
    private const MAGIC_LABEL = 0x00000801;
18
19
    private const IMAGE_ROWS = 28;
20
21
    private const IMAGE_COLS = 28;
22
23
    public function __construct(string $imagePath, string $labelPath)
24
    {
25
        $this->samples = $this->readImages($imagePath);
26
        $this->targets = $this->readLabels($labelPath);
27
28
        if (count($this->samples) !== count($this->targets)) {
29
            throw new InvalidArgumentException('Must have the same number of images and labels');
30
        }
31
    }
32
33
    private function readImages(string $imagePath): array
34
    {
35
        $stream = fopen($imagePath, 'rb');
36
37
        if ($stream === false) {
38
            throw new InvalidArgumentException('Could not open file: '.$imagePath);
39
        }
40
41
        $images = [];
42
43
        try {
44
            $header = fread($stream, 16);
45
46
            $fields = unpack('Nmagic/Nsize/Nrows/Ncols', (string) $header);
47
48
            if ($fields['magic'] !== self::MAGIC_IMAGE) {
49
                throw new InvalidArgumentException('Invalid magic number: '.$imagePath);
50
            }
51
52
            if ($fields['rows'] != self::IMAGE_ROWS) {
53
                throw new InvalidArgumentException('Invalid number of image rows: '.$imagePath);
54
            }
55
56
            if ($fields['cols'] != self::IMAGE_COLS) {
57
                throw new InvalidArgumentException('Invalid number of image cols: '.$imagePath);
58
            }
59
60
            for ($i = 0; $i < $fields['size']; $i++) {
61
                $imageBytes = fread($stream, $fields['rows'] * $fields['cols']);
62
63
                // Convert to float between 0 and 1
64
                $images[] = array_map(function ($b) {
65
                    return $b / 255;
66
                }, array_values(unpack('C*', (string) $imageBytes)));
67
            }
68
        } finally {
69
            fclose($stream);
70
        }
71
72
        return $images;
73
    }
74
75
    private function readLabels(string $labelPath): array
76
    {
77
        $stream = fopen($labelPath, 'rb');
78
79
        if ($stream === false) {
80
            throw new InvalidArgumentException('Could not open file: '.$labelPath);
81
        }
82
83
        $labels = [];
84
85
        try {
86
            $header = fread($stream, 8);
87
88
            $fields = unpack('Nmagic/Nsize', (string) $header);
89
90
            if ($fields['magic'] !== self::MAGIC_LABEL) {
91
                throw new InvalidArgumentException('Invalid magic number: '.$labelPath);
92
            }
93
94
            $labels = fread($stream, $fields['size']);
95
        } finally {
96
            fclose($stream);
97
        }
98
99
        return array_values(unpack('C*', (string) $labels));
100
    }
101
}
102