php-ai /
php-ml
| 1 | <?php |
||
| 2 | |||
| 3 | declare(strict_types=1); |
||
| 4 | |||
| 5 | namespace Phpml\Classification\Linear; |
||
| 6 | |||
| 7 | use Closure; |
||
| 8 | use Exception; |
||
| 9 | use Phpml\Exception\InvalidArgumentException; |
||
| 10 | use Phpml\Helper\Optimizer\ConjugateGradient; |
||
| 11 | |||
| 12 | class LogisticRegression extends Adaline |
||
| 13 | { |
||
| 14 | /** |
||
| 15 | * Batch training: Gradient descent algorithm (default) |
||
| 16 | */ |
||
| 17 | public const BATCH_TRAINING = 1; |
||
| 18 | |||
| 19 | /** |
||
| 20 | * Online training: Stochastic gradient descent learning |
||
| 21 | */ |
||
| 22 | public const ONLINE_TRAINING = 2; |
||
| 23 | |||
| 24 | /** |
||
| 25 | * Conjugate Batch: Conjugate Gradient algorithm |
||
| 26 | */ |
||
| 27 | public const CONJUGATE_GRAD_TRAINING = 3; |
||
| 28 | |||
| 29 | /** |
||
| 30 | * Cost function to optimize: 'log' and 'sse' are supported <br> |
||
| 31 | * - 'log' : log likelihood <br> |
||
| 32 | * - 'sse' : sum of squared errors <br> |
||
| 33 | * |
||
| 34 | * @var string |
||
| 35 | */ |
||
| 36 | protected $costFunction = 'log'; |
||
| 37 | |||
| 38 | /** |
||
| 39 | * Regularization term: only 'L2' is supported |
||
| 40 | * |
||
| 41 | * @var string |
||
| 42 | */ |
||
| 43 | protected $penalty = 'L2'; |
||
| 44 | |||
| 45 | /** |
||
| 46 | * Lambda (Ī») parameter of regularization term. If Ī» is set to 0, then |
||
| 47 | * regularization term is cancelled. |
||
| 48 | * |
||
| 49 | * @var float |
||
| 50 | */ |
||
| 51 | protected $lambda = 0.5; |
||
| 52 | |||
| 53 | /** |
||
| 54 | * Initalize a Logistic Regression classifier with maximum number of iterations |
||
| 55 | * and learning rule to be applied <br> |
||
| 56 | * |
||
| 57 | * Maximum number of iterations can be an integer value greater than 0 <br> |
||
| 58 | * If normalizeInputs is set to true, then every input given to the algorithm will be standardized |
||
| 59 | * by use of standard deviation and mean calculation <br> |
||
| 60 | * |
||
| 61 | * Cost function can be 'log' for log-likelihood and 'sse' for sum of squared errors <br> |
||
| 62 | * |
||
| 63 | * Penalty (Regularization term) can be 'L2' or empty string to cancel penalty term |
||
| 64 | * |
||
| 65 | * @throws InvalidArgumentException |
||
| 66 | */ |
||
| 67 | public function __construct( |
||
| 68 | int $maxIterations = 500, |
||
| 69 | bool $normalizeInputs = true, |
||
| 70 | int $trainingType = self::CONJUGATE_GRAD_TRAINING, |
||
| 71 | string $cost = 'log', |
||
| 72 | string $penalty = 'L2' |
||
| 73 | ) { |
||
| 74 | $trainingTypes = range(self::BATCH_TRAINING, self::CONJUGATE_GRAD_TRAINING); |
||
| 75 | if (!in_array($trainingType, $trainingTypes, true)) { |
||
| 76 | throw new InvalidArgumentException( |
||
| 77 | 'Logistic regression can only be trained with '. |
||
| 78 | 'batch (gradient descent), online (stochastic gradient descent) '. |
||
| 79 | 'or conjugate batch (conjugate gradients) algorithms' |
||
| 80 | ); |
||
| 81 | } |
||
| 82 | |||
| 83 | if (!in_array($cost, ['log', 'sse'], true)) { |
||
| 84 | throw new InvalidArgumentException( |
||
| 85 | "Logistic regression cost function can be one of the following: \n". |
||
| 86 | "'log' for log-likelihood and 'sse' for sum of squared errors" |
||
| 87 | ); |
||
| 88 | } |
||
| 89 | |||
| 90 | if ($penalty !== '' && strtoupper($penalty) !== 'L2') { |
||
| 91 | throw new InvalidArgumentException('Logistic regression supports only \'L2\' regularization'); |
||
| 92 | } |
||
| 93 | |||
| 94 | $this->learningRate = 0.001; |
||
| 95 | |||
| 96 | parent::__construct($this->learningRate, $maxIterations, $normalizeInputs); |
||
| 97 | |||
| 98 | $this->trainingType = $trainingType; |
||
| 99 | $this->costFunction = $cost; |
||
| 100 | $this->penalty = $penalty; |
||
| 101 | } |
||
| 102 | |||
| 103 | /** |
||
| 104 | * Sets the learning rate if gradient descent algorithm is |
||
| 105 | * selected for training |
||
| 106 | */ |
||
| 107 | public function setLearningRate(float $learningRate): void |
||
| 108 | { |
||
| 109 | $this->learningRate = $learningRate; |
||
| 110 | } |
||
| 111 | |||
| 112 | /** |
||
| 113 | * Lambda (Ī») parameter of regularization term. If 0 is given, |
||
| 114 | * then the regularization term is cancelled |
||
| 115 | */ |
||
| 116 | public function setLambda(float $lambda): void |
||
| 117 | { |
||
| 118 | $this->lambda = $lambda; |
||
| 119 | } |
||
| 120 | |||
| 121 | /** |
||
| 122 | * Adapts the weights with respect to given samples and targets |
||
| 123 | * by use of selected solver |
||
| 124 | * |
||
| 125 | * @throws \Exception |
||
| 126 | */ |
||
| 127 | protected function runTraining(array $samples, array $targets): void |
||
| 128 | { |
||
| 129 | $callback = $this->getCostFunction(); |
||
| 130 | |||
| 131 | switch ($this->trainingType) { |
||
| 132 | case self::BATCH_TRAINING: |
||
| 133 | $this->runGradientDescent($samples, $targets, $callback, true); |
||
| 134 | |||
| 135 | return; |
||
| 136 | |||
| 137 | case self::ONLINE_TRAINING: |
||
| 138 | $this->runGradientDescent($samples, $targets, $callback, false); |
||
| 139 | |||
| 140 | return; |
||
| 141 | |||
| 142 | case self::CONJUGATE_GRAD_TRAINING: |
||
| 143 | $this->runConjugateGradient($samples, $targets, $callback); |
||
| 144 | |||
| 145 | return; |
||
| 146 | |||
| 147 | default: |
||
| 148 | // Not reached |
||
| 149 | throw new Exception(sprintf('Logistic regression has invalid training type: %d.', $this->trainingType)); |
||
| 150 | } |
||
| 151 | } |
||
| 152 | |||
| 153 | /** |
||
| 154 | * Executes Conjugate Gradient method to optimize the weights of the LogReg model |
||
| 155 | */ |
||
| 156 | protected function runConjugateGradient(array $samples, array $targets, Closure $gradientFunc): void |
||
| 157 | { |
||
| 158 | if ($this->optimizer === null) { |
||
| 159 | $this->optimizer = (new ConjugateGradient($this->featureCount)) |
||
| 160 | ->setMaxIterations($this->maxIterations); |
||
| 161 | } |
||
| 162 | |||
| 163 | $this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc); |
||
| 164 | $this->costValues = $this->optimizer->getCostValues(); |
||
|
0 ignored issues
–
show
Bug
introduced
by
Loading history...
|
|||
| 165 | } |
||
| 166 | |||
| 167 | /** |
||
| 168 | * Returns the appropriate callback function for the selected cost function |
||
| 169 | * |
||
| 170 | * @throws \Exception |
||
| 171 | */ |
||
| 172 | protected function getCostFunction(): Closure |
||
| 173 | { |
||
| 174 | $penalty = 0; |
||
| 175 | if ($this->penalty === 'L2') { |
||
| 176 | $penalty = $this->lambda; |
||
| 177 | } |
||
| 178 | |||
| 179 | switch ($this->costFunction) { |
||
| 180 | case 'log': |
||
| 181 | /* |
||
| 182 | * Negative of Log-likelihood cost function to be minimized: |
||
| 183 | * J(x) = ā( - y . log(h(x)) - (1 - y) . log(1 - h(x))) |
||
| 184 | * |
||
| 185 | * If regularization term is given, then it will be added to the cost: |
||
| 186 | * for L2 : J(x) = J(x) + Ī»/m . w |
||
| 187 | * |
||
| 188 | * The gradient of the cost function to be used with gradient descent: |
||
| 189 | * āJ(x) = -(y - h(x)) = (h(x) - y) |
||
| 190 | */ |
||
| 191 | return function ($weights, $sample, $y) use ($penalty): array { |
||
| 192 | $this->weights = $weights; |
||
| 193 | $hX = $this->output($sample); |
||
| 194 | |||
| 195 | // In cases where $hX = 1 or $hX = 0, the log-likelihood |
||
| 196 | // value will give a NaN, so we fix these values |
||
| 197 | if ($hX == 1) { |
||
| 198 | $hX = 1 - 1e-10; |
||
| 199 | } |
||
| 200 | |||
| 201 | if ($hX == 0) { |
||
| 202 | $hX = 1e-10; |
||
| 203 | } |
||
| 204 | |||
| 205 | $y = $y < 0 ? 0 : 1; |
||
| 206 | |||
| 207 | $error = -$y * log($hX) - (1 - $y) * log(1 - $hX); |
||
| 208 | $gradient = $hX - $y; |
||
| 209 | |||
| 210 | return [$error, $gradient, $penalty]; |
||
| 211 | }; |
||
| 212 | case 'sse': |
||
| 213 | /* |
||
| 214 | * Sum of squared errors or least squared errors cost function: |
||
| 215 | * J(x) = ā (y - h(x))^2 |
||
| 216 | * |
||
| 217 | * If regularization term is given, then it will be added to the cost: |
||
| 218 | * for L2 : J(x) = J(x) + Ī»/m . w |
||
| 219 | * |
||
| 220 | * The gradient of the cost function: |
||
| 221 | * āJ(x) = -(h(x) - y) . h(x) . (1 - h(x)) |
||
| 222 | */ |
||
| 223 | return function ($weights, $sample, $y) use ($penalty): array { |
||
| 224 | $this->weights = $weights; |
||
| 225 | $hX = $this->output($sample); |
||
| 226 | |||
| 227 | $y = $y < 0 ? 0 : 1; |
||
| 228 | |||
| 229 | $error = (($y - $hX) ** 2); |
||
| 230 | $gradient = -($y - $hX) * $hX * (1 - $hX); |
||
| 231 | |||
| 232 | return [$error, $gradient, $penalty]; |
||
| 233 | }; |
||
| 234 | default: |
||
| 235 | // Not reached |
||
| 236 | throw new Exception(sprintf('Logistic regression has invalid cost function: %s.', $this->costFunction)); |
||
| 237 | } |
||
| 238 | } |
||
| 239 | |||
| 240 | /** |
||
| 241 | * Returns the output of the network, a float value between 0.0 and 1.0 |
||
| 242 | */ |
||
| 243 | protected function output(array $sample): float |
||
| 244 | { |
||
| 245 | $sum = parent::output($sample); |
||
| 246 | |||
| 247 | return 1.0 / (1.0 + exp(-$sum)); |
||
| 248 | } |
||
| 249 | |||
| 250 | /** |
||
| 251 | * Returns the class value (either -1 or 1) for the given input |
||
| 252 | */ |
||
| 253 | protected function outputClass(array $sample): int |
||
| 254 | { |
||
| 255 | $output = $this->output($sample); |
||
| 256 | |||
| 257 | if ($output > 0.5) { |
||
| 258 | return 1; |
||
| 259 | } |
||
| 260 | |||
| 261 | return -1; |
||
| 262 | } |
||
| 263 | |||
| 264 | /** |
||
| 265 | * Returns the probability of the sample of belonging to the given label. |
||
| 266 | * |
||
| 267 | * The probability is simply taken as the distance of the sample |
||
| 268 | * to the decision plane. |
||
| 269 | * |
||
| 270 | * @param mixed $label |
||
| 271 | */ |
||
| 272 | protected function predictProbability(array $sample, $label): float |
||
| 273 | { |
||
| 274 | $sample = $this->checkNormalizedSample($sample); |
||
| 275 | $probability = $this->output($sample); |
||
| 276 | |||
| 277 | if (array_search($label, $this->labels, true) > 0) { |
||
| 278 | return $probability; |
||
| 279 | } |
||
| 280 | |||
| 281 | return 1 - $probability; |
||
| 282 | } |
||
| 283 | } |
||
| 284 |