jdit.optimizer¶
Optimizer¶
-
class
jdit.Optimizer(params: parameters of model, optimizer: [Adam,RMSprop,SGD...], lr_decay: float = 1.0, decay_position: Union[int, tuple, list] = -1, lr_reset: Dict[int, float] = None, position_type: ('epoch','step') = 'epoch', **kwargs)[source]¶ This is a wrapper of
optimizerclass in pytorch.We add something new features in order to feather control the optimizer.
paramsis the parameters of model which need to be updated. It will use a filter to get all the parameters that required grad automatically. Like thisfilter(lambda p: p.requires_grad, params)So, you can passing
model.all_params()without any filters.learning rate decayWhen callingdo_lr_decay(), it will do a learning rate decay. like:\[lr = lr * decay\]learning rate reset. Reset learning rate, it can change learning rate and decay directly.
Parameters: - params – parameters of model, which need to be updated.
- optimizer – An optimizer classin pytorch, such as
torch.optim.Adam. - lr_decay – learning rate decay. Default: 0.92.
- decay_at_epoch – The position of applying lr decay. Default: None.
- decay_at_step – learning rate decay. Default: None
- kwargs – pass hyper-parameters to optimizer, such as
lr,betas,weight_decay.
Returns: Args:
params (dict): parameters of model, which need to be updated.
optimizer (torch.optim.Optimizer): An optimizer classin pytorch, such as
torch.optim.Adamlr_decay (float, optional): learning rate decay. Default: 0.92
decay_position (int, list, optional): The decaly position of lr. Default: None
lr_reset (Dict[position(int), lr(float)] ): Reset learning at a certain position. Default: None
position_type (‘epoch’,’step’): Position type. Default: None
**kwargs : pass hyper-parameters to optimizer, such as
lr,betas,weight_decay.Example:
>>> from torch.nn import Sequential, Conv3d >>> from torch.optim import Adam >>> module = Sequential(Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))) >>> opt = Optimizer(module.parameters() ,"Adam", 0.5, 10, {4:0.99},"epoch", lr=1.0, betas=(0.9, 0.999), weight_decay=1e-5) >>> print(opt) (Adam ( Parameter Group 0 amsgrad: False betas: (0.9, 0.999) eps: 1e-08 lr: 1.0 weight_decay: 1e-05 ) lr_decay:0.5 decay_position:10 lr_reset:{4: 0.99} position_type:epoch )) >>> opt.lr 1.0 >>> opt.lr_decay 0.5 >>> opt.do_lr_decay() >>> opt.lr 0.5 >>> opt.do_lr_decay(reset_lr=1) >>> opt.lr 1 >>> opt.opt Adam ( Parameter Group 0 amsgrad: False betas: (0.9, 0.999) eps: 1e-08 lr: 1 weight_decay: 1e-05 ) >>> opt.is_decay_lr(1) False >>> opt.is_decay_lr(10) True >>> opt.is_decay_lr(20) True >>> opt.is_reset_lr(4) 0.99 >>> opt.is_reset_lr(5) False
-
do_lr_decay(reset_lr_decay: float = None, reset_lr: float = None)[source]¶ Do learning rate decay, or reset them.
- Passing parameters both None:
- Do a learning rate decay by
self.lr = self.lr * self.lr_decay. - Passing parameters reset_lr_decay or reset_lr:
- Do a learning rate or decay reset. by
self.lr = reset_lrself.lr_decay = reset_lr_decay
Parameters: - reset_lr_decay – if not None, use this value to reset self.lr_decay. Default: None.
- reset_lr – if not None, use this value to reset self.lr. Default: None.
Returns: