Implement custom optimization algorithms in TensorFlow/Keras

4

At first

There is hardly a time when you want to make your own optimization algorithm with TensorFlow / Keras, but I thought that there would be some people who are interested, so I created an article.

environment

  • TensorFlow(2.3.0)
  • Tested by Google Colab (GPU/TPU)

foundation

tensorflow.python.keras.optimizer_v2.optimizer_v2. It is made by inheriting OptimizerV2.

Implementing VanillaSGD is as follows:

VanillaSGD.py
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
class VanillaSGD(OptimizerV2):
    def __init__(self, learning_rate=0.01, name='VanillaSGD', **kwargs):
        super().__init__(name, **kwargs)
        self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))

    def get_config(self):
        config = super(VanillaSGD, self).get_config()
        config.update({
            "learning_rate": self._serialize_hyperparameter("learning_rate"),
        })
        return config

    def _resource_apply_dense(self, grad, var, apply_state=None):
        var_device, var_dtype = var.device, var.dtype.base_dtype
        lr_t = self._get_hyper("learning_rate", var_dtype)
        return var.assign(var - lr_t * grad)

    def _resource_apply_sparse(self, grad, var, indices):
        raise NotImplementedError("not implemented")
  • __init__() implements various initialization processes.
    • Primarily using hyperparameter registration_set_hyper()
    • In this example, learning_rate is registered,kwargs.get("lr", learning_rate) but this is so that it can be handled even if "lr=" is specified.
  • The get_config() contains the processing for serialization.
    • Model is called Save at the time.
    • Add all hyperparameters hereconfig.
    • Use to get the value_serialize_hyperparameter() .
  • Perform variable update processing in _resource_apply_dense().
    • grad is a tensor of the gradient
    • var is a tensor of a variable (i.e. weight). Same Shape as grad.
    • apply_state is a dictionary containing hyperparameters, etc.
    • Use to_get_hyper() get hyperparameters
    • Since the return value is requested to be "Operater to update variable", it is necessary to return the updated version of var with an assign-type function. In this example, the gradient multiplied by the learning rate from the weight is subtracted to obtain a new weight.
  • _resource_apply_sparse() is used to update Sparse networks. Normally, there is no problem even if it is not implemented.

When executing Load from the saved model, add it as a custom_objects as described below.

tf.keras.models.load_model('model.h5', custom_objects={'VanillaSGD': VanillaSGD})

decay support

Optimizer, which basically inherits from OptimizerV2, also supports decay parameters. The VanillaSGD created earlier is also supported, as follows.

VanillaSGD.py
class VanillaSGD2(OptimizerV2):
    def __init__(self, learning_rate=0.01, name='CustomOptimizer', **kwargs):
        super().__init__(name, **kwargs)
        self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
        self._set_hyper('decay', self._initial_decay)

    def get_config(self):
        config = super(VanillaSGD, self).get_config()
        config.update({
            "learning_rate": self._serialize_hyperparameter("learning_rate"),
            "decay": self._serialize_hyperparameter("decay"),
        })
        return config

    def _resource_apply_dense(self, grad, var, apply_state=None):
        var_device, var_dtype = var.device, var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype)
        return var.assign(var - lr_t * grad)

    def _resource_apply_sparse(self, grad, var, indices):
        raise NotImplementedError("not implemented")
  • Register 'decay' as a hyperparameter in the __init__(), but use it because the self._initial_decay is already defined in the super().__init__.
  • Within the _resource_apply_dense(), use a predefined _decayed_lr() to obtain the learning rate. You will automatically get a decayed learning rate.

Adding variables

In a practical optimization algorithm, it is necessary to hold the variables associated with each weight and use them in calculations. As an example of this, MomentumSGD is implemented as follows.

MomentumSGD.py
class MomentumSGD(OptimizerV2):
    def __init__(self, learning_rate=0.01, momentum=0.0, name='MomentumSGD', **kwargs):
        super().__init__(name, **kwargs)
        self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
        self._set_hyper('decay', self._initial_decay)
        self._set_hyper('momentum', momentum)

    def get_config(self):
        config = super(MomentumSGD, self).get_config()
        config.update({
            "learning_rate": self._serialize_hyperparameter("learning_rate"),
            "decay": self._serialize_hyperparameter("decay"),
            "momentum": self._serialize_hyperparameter("momentum"),
        })
        return config

    def _create_slots(self, var_list):
        for var in var_list:
            self.add_slot(var, 'm')

    def _resource_apply_dense(self, grad, var, apply_state=None):
        var_device, var_dtype = var.device, var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype)
        momentum = self._get_hyper("momentum", var_dtype)
        m = self.get_slot(var, 'm')
        m_t = m.assign( momentum*m + (1.0-momentum)*grad)
        var_update = var.assign(var - lr_t*m_t)
        updates = [var_update, m_t]
        return tf.group(*updates)

    def _resource_apply_sparse(self, grad, var, indices):
        raise NotImplementedError("not implemented")
  • In the _create_slots(), register the variable you want to add.
    • For the inertial term,add_slot() use and add the name 'm' for each var.
    • Theget_slot() added variables are retrieved in .
  • As a return of the _resource_apply_dense, the Operation, which is the variable to be updated, is returned together using tf.group()を. In this case, var_update and m_t are eligible.

By the way, the Momentum SGD created this time has a different interpretation of the learning rate from the one using tf.keras SGD, so the result is different even if the learning rate is the same. If you multiply this learning rate by (1-momentum), you get the same result as SGD in tf.keras. For example, MomenutumSGD(0.01, momentum=0.9) is the same as SGD(0.001, momentum=0.9).

Process differently depending on the number of steps

It may be necessary to adjust the coefficient according to the number of executions. In such cases, already defined self.iterations are available. The inertial term of MomentumSGD has a bias that is pulled to the initial value of 0.0, but the following example includes a process to correct it. (Adam makes a similar correction.) For the time being, I will name it "Centered Momentum SGD".

CMomentumSGD.py
class CMomentumSGD(OptimizerV2):
    def __init__(self, learning_rate=0.01, momentum=0.0, centered=True, name='CMomentumSGD', **kwargs):
        super().__init__(name, **kwargs)
        self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
        self._set_hyper('decay', self._initial_decay)
        self._set_hyper('momentum', momentum)
        self.centered = centered if momentum!=0.0 else False

    def get_config(self):
        config = super(CMomentumSGD, self).get_config()
        config.update({
            "learning_rate": self._serialize_hyperparameter("learning_rate"),
            "decay": self._serialize_hyperparameter("decay"),
            "momentum": self._serialize_hyperparameter("momentum"),
            'centered': self.centered,
        })
        return config
    def _create_slots(self, var_list):
        for var in var_list:
            self.add_slot(var, 'm')

    def _resource_apply_dense(self, grad, var, apply_state=None):
        var_device, var_dtype = var.device, var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype)
        momentum = self._get_hyper("momentum", var_dtype)
        m = self.get_slot(var, 'm')
        m_t = m.assign( momentum*m + (1.0-momentum)*grad)
        if self.centered:
            local_step = tf.cast(self.iterations+1, var_dtype)
            m_t_hat = m_t * 1.0 / (1.0-tf.pow(momentum, local_step))
            var_update = var.assign(var - lr_t*m_t_hat)
        else:
            var_update = var.assign(var - lr_t*m_t)
        updates = [var_update, m_t]
        return tf.group(*updates)

    def _resource_apply_sparse(self, grad, var, indices):
        raise NotImplementedError("not implemented")
  • If Centered is true, apply correction. Here you are using self.iterations
  • Although centered is a hyperparameter, it does not change in the middle, so _set_hyper() etc. are not used.

Now that we have made it, let's compare each optimization algorithm. The comparison method is according to this article

customeOpt.png

  • Comparison of MomentumSGD and VanillaSGD
    • MomentumSGD has a slower rise (it takes longer to get an angle). Since there is inertia, it can be interpreted that it takes time to move from the initial value of 0.
    • After it is at the same angle as VanillaSGD, it stays exactly the same.
    • Passing the optimal value (0.0) slightly by a large margin is the effect of the inertial term.
    • VanillaSGD vibrates finely near the optimal value, but MomentumSGD oscillates loosely. This is also the effect of the inertial term.
  • CenteredMomentumSGD vs. MomentumSGD
    • As a result of the correction of the initial value bias, the rise slower has been improved and the VanillaSGD has taken exactly the same trajectory.
    • After passing the optimal value, it becomes a movement peculiar to MomentumSGD.

Again, note that the MomentumSGD implemented here is slightly different from the one normally implemented in Keras etc. (This is easier to understand the true effect of the inertial term, although I personally think it is better.)

reference

Implementing TensorFlow https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/keras/optimizer_v2

How to implement Optimizer Part 6: Introduction to TensorFlow 2.0 Modern Writing Techniques for Customization

See also

Comparing Optimization Algorithms in a Single Run (SGD)

Share:
4
Author by

プログラマー

Updated on December 01, 2020