jdit.optimizer¶
Optimizer¶
-
class
jdit.
Optimizer
(params: parameters of model, optimizer: [Adam,RMSprop,SGD...], lr_decay: float = 1.0, decay_position: Union[int, tuple, list] = -1, lr_reset: Dict[int, float] = None, position_type: ('epoch','step') = 'epoch', **kwargs)[source]¶ This is a wrapper of
optimizer
class in pytorch.We add something new features in order to feather control the optimizer.
params
is the parameters of model which need to be updated. It will use a filter to get all the parameters that required grad automatically. Like thisfilter(lambda p: p.requires_grad, params)
So, you can passing
model.all_params()
without any filters.learning rate decay
When callingdo_lr_decay()
, it will do a learning rate decay. like:\[lr = lr * decay\]learning rate reset
. Reset learning rate, it can change learning rate and decay directly.
Parameters: - params – parameters of model, which need to be updated.
- optimizer – An optimizer classin pytorch, such as
torch.optim.Adam
. - lr_decay – learning rate decay. Default: 0.92.
- decay_at_epoch – The position of applying lr decay. Default: None.
- decay_at_step – learning rate decay. Default: None
- kwargs – pass hyper-parameters to optimizer, such as
lr
,betas
,weight_decay
.
Returns: Args:
params (dict): parameters of model, which need to be updated.
optimizer (torch.optim.Optimizer): An optimizer classin pytorch, such as
torch.optim.Adam
lr_decay (float, optional): learning rate decay. Default: 0.92
decay_position (int, list, optional): The decaly position of lr. Default: None
lr_reset (Dict[position(int), lr(float)] ): Reset learning at a certain position. Default: None
position_type (‘epoch’,’step’): Position type. Default: None
**kwargs : pass hyper-parameters to optimizer, such as
lr
,betas
,weight_decay
.Example:
>>> from torch.nn import Sequential, Conv3d >>> from torch.optim import Adam >>> module = Sequential(Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))) >>> opt = Optimizer(module.parameters() ,"Adam", 0.5, 10, {4:0.99},"epoch", lr=1.0, betas=(0.9, 0.999), weight_decay=1e-5) >>> print(opt) (Adam ( Parameter Group 0 amsgrad: False betas: (0.9, 0.999) eps: 1e-08 lr: 1.0 weight_decay: 1e-05 ) lr_decay:0.5 decay_position:10 lr_reset:{4: 0.99} position_type:epoch )) >>> opt.lr 1.0 >>> opt.lr_decay 0.5 >>> opt.do_lr_decay() >>> opt.lr 0.5 >>> opt.do_lr_decay(reset_lr=1) >>> opt.lr 1 >>> opt.opt Adam ( Parameter Group 0 amsgrad: False betas: (0.9, 0.999) eps: 1e-08 lr: 1 weight_decay: 1e-05 ) >>> opt.is_decay_lr(1) False >>> opt.is_decay_lr(10) True >>> opt.is_decay_lr(20) True >>> opt.is_reset_lr(4) 0.99 >>> opt.is_reset_lr(5) False
-
do_lr_decay
(reset_lr_decay: float = None, reset_lr: float = None)[source]¶ Do learning rate decay, or reset them.
- Passing parameters both None:
- Do a learning rate decay by
self.lr = self.lr * self.lr_decay
. - Passing parameters reset_lr_decay or reset_lr:
- Do a learning rate or decay reset. by
self.lr = reset_lr
self.lr_decay = reset_lr_decay
Parameters: - reset_lr_decay – if not None, use this value to reset self.lr_decay. Default: None.
- reset_lr – if not None, use this value to reset self.lr. Default: None.
Returns: