Code Duplication    Length = 41-46 lines in 2 locations

blocks/algorithms/__init__.py 2 locations

@@ 522-567 (lines=46) @@
519
        velocity = _create_algorithm_buffer_for(parameter, "velocity")
520
        step = self.momentum * velocity + previous_step
521
        updates = [(velocity, step)]
522
        return step, updates
523
524
525
class Momentum(CompositeRule):
526
    """Accumulates step with exponential discount.
527
528
    Combines :class:`BasicMomentum` and :class:`Scale` to form the
529
    usual momentum step rule.
530
531
    Parameters
532
    ----------
533
    learning_rate : float, optional
534
        The learning rate by which the previous step scaled. Defaults to 1.
535
    momentum : float, optional
536
        The momentum coefficient. Defaults to 0.
537
538
    Attributes
539
    ----------
540
    learning_rate : :class:`~tensor.SharedVariable`
541
        A variable for learning rate.
542
    momentum : :class:`~tensor.SharedVariable`
543
        A variable for momentum.
544
545
    See Also
546
    --------
547
    :class:`SharedVariableModifier`
548
549
    """
550
    def __init__(self, learning_rate=1.0, momentum=0.):
551
        scale = Scale(learning_rate=learning_rate)
552
        basic_momentum = BasicMomentum(momentum=momentum)
553
        self.learning_rate = scale.learning_rate
554
        self.momentum = basic_momentum.momentum
555
        self.components = [scale, basic_momentum]
556
557
558
class AdaDelta(StepRule):
559
    """Adapts the step size over time using only first order information.
560
561
    Parameters
562
    ----------
563
    decay_rate : float, optional
564
        Decay rate in [0, 1]. Defaults to 0.95.
565
    epsilon : float, optional
566
        Stabilizing constant for RMS. Defaults to 1e-6.
567
568
    Notes
569
    -----
570
    For more information, see [ADADELTA]_.
@@ 725-765 (lines=41) @@
722
            add_role(threshold, ALGORITHM_HYPERPARAMETER)
723
        self.threshold = threshold
724
725
    def compute_steps(self, previous_steps):
726
        if self.threshold is None:
727
            steps = previous_steps
728
        else:
729
            norm = l2_norm(previous_steps.values())
730
            multiplier = tensor.switch(norm < self.threshold,
731
                                       1, self.threshold / norm)
732
            steps = OrderedDict(
733
                (parameter, step * multiplier)
734
                for parameter, step in previous_steps.items())
735
        return steps, []
736
737
738
class VariableClipping(StepRule):
739
    """Clip the maximum norm of individual variables along certain axes.
740
741
    This :class:`StepRule` can be used to implement L2 norm constraints on
742
    e.g. the weight vectors of individual hidden units, convolutional
743
    filters or entire weight tensors. Combine with :class:`Restrict`
744
    (and possibly :class:`CompositeRule`), to apply such constraints only
745
    to certain variables and/or apply different norm constraints to
746
    different variables.
747
748
    Parameters
749
    ----------
750
    threshold : float
751
        Maximum norm for a given (portion of a) tensor.
752
    axis : int or iterable, optional
753
        An integer single axis, or an iterable collection of integer
754
        axes over which to sum in order to calculate the L2 norm. If
755
        `None` (the default), the norm is computed over all elements
756
        of the tensor.
757
758
    Notes
759
    -----
760
    Because of the way the :class:`StepRule` API works, this particular
761
    rule implements norm clipping of the value *after* update in the
762
    following way: it computes ``parameter - previous_step``, scales it
763
    to have (possibly axes-wise) norm(s) of at most `threshold`,
764
    then subtracts *that* value from `parameter` to yield an 'equivalent
765
    step' that respects the desired norm constraints. This procedure
766
    implicitly assumes one is doing simple (stochastic) gradient descent,
767
    and so steps computed by this step rule may not make sense for use
768
    in other contexts.