php-ai /
php-ml
| 1 | <?php |
||
| 2 | |||
| 3 | declare(strict_types=1); |
||
| 4 | |||
| 5 | namespace Phpml\Classification; |
||
| 6 | |||
| 7 | use Phpml\Exception\InvalidArgumentException; |
||
| 8 | use Phpml\NeuralNetwork\Network\MultilayerPerceptron; |
||
| 9 | |||
| 10 | class MLPClassifier extends MultilayerPerceptron implements Classifier |
||
| 11 | { |
||
| 12 | /** |
||
| 13 | * @param mixed $target |
||
| 14 | * |
||
| 15 | * @throws InvalidArgumentException |
||
| 16 | */ |
||
| 17 | public function getTargetClass($target): int |
||
| 18 | { |
||
| 19 | if (!in_array($target, $this->classes, true)) { |
||
| 20 | throw new InvalidArgumentException( |
||
| 21 | sprintf('Target with value "%s" is not part of the accepted classes', $target) |
||
| 22 | ); |
||
| 23 | } |
||
| 24 | |||
| 25 | return array_search($target, $this->classes, true); |
||
|
0 ignored issues
–
show
Bug
Best Practice
introduced
by
Loading history...
|
|||
| 26 | } |
||
| 27 | |||
| 28 | /** |
||
| 29 | * @return mixed |
||
| 30 | */ |
||
| 31 | protected function predictSample(array $sample) |
||
| 32 | { |
||
| 33 | $output = $this->setInput($sample)->getOutput(); |
||
| 34 | |||
| 35 | $predictedClass = null; |
||
| 36 | $max = 0; |
||
| 37 | foreach ($output as $class => $value) { |
||
| 38 | if ($value > $max) { |
||
| 39 | $predictedClass = $class; |
||
| 40 | $max = $value; |
||
| 41 | } |
||
| 42 | } |
||
| 43 | |||
| 44 | return $predictedClass; |
||
| 45 | } |
||
| 46 | |||
| 47 | /** |
||
| 48 | * @param mixed $target |
||
| 49 | */ |
||
| 50 | protected function trainSample(array $sample, $target): void |
||
| 51 | { |
||
| 52 | // Feed-forward. |
||
| 53 | $this->setInput($sample); |
||
| 54 | |||
| 55 | // Back-propagate. |
||
| 56 | $this->backpropagation->backpropagate($this->getLayers(), $this->getTargetClass($target)); |
||
| 57 | } |
||
| 58 | } |
||
| 59 |