Implement custom optimization algorithms in TensorFlow/Keras
At first
There is hardly a time when you want to make your own optimization algorithm with TensorFlow / Keras, but I thought that there would be some people who are interested, so I created an article.
environment
- TensorFlow(2.3.0)
- Tested by Google Colab (GPU/TPU)
foundation
tensorflow.python.keras.optimizer_v2.optimizer_v2. It is made by inheriting OptimizerV2.
Implementing VanillaSGD is as follows:
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
class VanillaSGD(OptimizerV2):
def __init__(self, learning_rate=0.01, name='VanillaSGD', **kwargs):
super().__init__(name, **kwargs)
self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
def get_config(self):
config = super(VanillaSGD, self).get_config()
config.update({
"learning_rate": self._serialize_hyperparameter("learning_rate"),
})
return config
def _resource_apply_dense(self, grad, var, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
lr_t = self._get_hyper("learning_rate", var_dtype)
return var.assign(var - lr_t * grad)
def _resource_apply_sparse(self, grad, var, indices):
raise NotImplementedError("not implemented")
__init__()
implements various initialization processes.- Primarily using hyperparameter registration
_set_hyper()
- In this example, learning_rate is registered,
kwargs.get("lr", learning_rate)
but this is so that it can be handled even if "lr=" is specified.
- Primarily using hyperparameter registration
- The
get_config()
contains the processing for serialization.- Model is called Save at the time.
- Add all hyperparameters here
config
. - Use to get the value
_serialize_hyperparameter()
.
- Perform variable update processing in
_resource_apply_dense()
.- grad is a tensor of the gradient
- var is a tensor of a variable (i.e. weight). Same Shape as grad.
- apply_state is a dictionary containing hyperparameters, etc.
- Use to
_get_hyper()
get hyperparameters - Since the return value is requested to be "Operater to update variable", it is necessary to return the updated version of var with an assign-type function. In this example, the gradient multiplied by the learning rate from the weight is subtracted to obtain a new weight.
_resource_apply_sparse()
is used to update Sparse networks. Normally, there is no problem even if it is not implemented.
When executing Load from the saved model, add it as a custom_objects as described below.
tf.keras.models.load_model('model.h5', custom_objects={'VanillaSGD': VanillaSGD})
decay support
Optimizer, which basically inherits from OptimizerV2, also supports decay parameters. The VanillaSGD created earlier is also supported, as follows.
class VanillaSGD2(OptimizerV2):
def __init__(self, learning_rate=0.01, name='CustomOptimizer', **kwargs):
super().__init__(name, **kwargs)
self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
self._set_hyper('decay', self._initial_decay)
def get_config(self):
config = super(VanillaSGD, self).get_config()
config.update({
"learning_rate": self._serialize_hyperparameter("learning_rate"),
"decay": self._serialize_hyperparameter("decay"),
})
return config
def _resource_apply_dense(self, grad, var, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
lr_t = self._decayed_lr(var_dtype)
return var.assign(var - lr_t * grad)
def _resource_apply_sparse(self, grad, var, indices):
raise NotImplementedError("not implemented")
- Register 'decay' as a hyperparameter in the
__init__()
, but use it because the self._initial_decay is already defined in thesuper().__init__
. - Within the
_resource_apply_dense()
, use a predefined_decayed_lr()
to obtain the learning rate. You will automatically get a decayed learning rate.
Adding variables
In a practical optimization algorithm, it is necessary to hold the variables associated with each weight and use them in calculations. As an example of this, MomentumSGD is implemented as follows.
class MomentumSGD(OptimizerV2):
def __init__(self, learning_rate=0.01, momentum=0.0, name='MomentumSGD', **kwargs):
super().__init__(name, **kwargs)
self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
self._set_hyper('decay', self._initial_decay)
self._set_hyper('momentum', momentum)
def get_config(self):
config = super(MomentumSGD, self).get_config()
config.update({
"learning_rate": self._serialize_hyperparameter("learning_rate"),
"decay": self._serialize_hyperparameter("decay"),
"momentum": self._serialize_hyperparameter("momentum"),
})
return config
def _create_slots(self, var_list):
for var in var_list:
self.add_slot(var, 'm')
def _resource_apply_dense(self, grad, var, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
lr_t = self._decayed_lr(var_dtype)
momentum = self._get_hyper("momentum", var_dtype)
m = self.get_slot(var, 'm')
m_t = m.assign( momentum*m + (1.0-momentum)*grad)
var_update = var.assign(var - lr_t*m_t)
updates = [var_update, m_t]
return tf.group(*updates)
def _resource_apply_sparse(self, grad, var, indices):
raise NotImplementedError("not implemented")
- In the
_create_slots()
, register the variable you want to add.- For the inertial term,
add_slot()
use and add the name 'm' for each var. - The
get_slot()
added variables are retrieved in .
- For the inertial term,
- As a return of the
_resource_apply_dense
, the Operation, which is the variable to be updated, is returned together usingtf.group()を
. In this case, var_update and m_t are eligible.
By the way, the Momentum SGD created this time has a different interpretation of the learning rate from the one using tf.keras SGD, so the result is different even if the learning rate is the same. If you multiply this learning rate by (1-momentum), you get the same result as SGD in tf.keras. For example, MomenutumSGD(0.01, momentum=0.9) is the same as SGD(0.001, momentum=0.9).
Process differently depending on the number of steps
It may be necessary to adjust the coefficient according to the number of executions.
In such cases, already defined self.iterations
are available.
The inertial term of MomentumSGD has a bias that is pulled to the initial value of 0.0, but the following example includes a process to correct it. (Adam makes a similar correction.)
For the time being, I will name it "Centered Momentum SGD".
class CMomentumSGD(OptimizerV2):
def __init__(self, learning_rate=0.01, momentum=0.0, centered=True, name='CMomentumSGD', **kwargs):
super().__init__(name, **kwargs)
self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
self._set_hyper('decay', self._initial_decay)
self._set_hyper('momentum', momentum)
self.centered = centered if momentum!=0.0 else False
def get_config(self):
config = super(CMomentumSGD, self).get_config()
config.update({
"learning_rate": self._serialize_hyperparameter("learning_rate"),
"decay": self._serialize_hyperparameter("decay"),
"momentum": self._serialize_hyperparameter("momentum"),
'centered': self.centered,
})
return config
def _create_slots(self, var_list):
for var in var_list:
self.add_slot(var, 'm')
def _resource_apply_dense(self, grad, var, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
lr_t = self._decayed_lr(var_dtype)
momentum = self._get_hyper("momentum", var_dtype)
m = self.get_slot(var, 'm')
m_t = m.assign( momentum*m + (1.0-momentum)*grad)
if self.centered:
local_step = tf.cast(self.iterations+1, var_dtype)
m_t_hat = m_t * 1.0 / (1.0-tf.pow(momentum, local_step))
var_update = var.assign(var - lr_t*m_t_hat)
else:
var_update = var.assign(var - lr_t*m_t)
updates = [var_update, m_t]
return tf.group(*updates)
def _resource_apply_sparse(self, grad, var, indices):
raise NotImplementedError("not implemented")
- If Centered is true, apply correction. Here you are using self.iterations
- Although centered is a hyperparameter, it does not change in the middle, so
_set_hyper()
etc. are not used.
Now that we have made it, let's compare each optimization algorithm. The comparison method is according to this article
- Comparison of MomentumSGD and VanillaSGD
- MomentumSGD has a slower rise (it takes longer to get an angle). Since there is inertia, it can be interpreted that it takes time to move from the initial value of 0.
- After it is at the same angle as VanillaSGD, it stays exactly the same.
- Passing the optimal value (0.0) slightly by a large margin is the effect of the inertial term.
- VanillaSGD vibrates finely near the optimal value, but MomentumSGD oscillates loosely. This is also the effect of the inertial term.
- CenteredMomentumSGD vs. MomentumSGD
- As a result of the correction of the initial value bias, the rise slower has been improved and the VanillaSGD has taken exactly the same trajectory.
- After passing the optimal value, it becomes a movement peculiar to MomentumSGD.
Again, note that the MomentumSGD implemented here is slightly different from the one normally implemented in Keras etc. (This is easier to understand the true effect of the inertial term, although I personally think it is better.)
reference
Implementing TensorFlow https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/keras/optimizer_v2
How to implement Optimizer Part 6: Introduction to TensorFlow 2.0 Modern Writing Techniques for Customization