jdit.optimizer

Optimizer

class jdit.Optimizer(params: parameters of model, optimizer: [Adam,RMSprop,SGD...], lr_decay: float = 1.0, decay_position: Union[int, tuple, list] = -1, lr_reset: Dict[int, float] = None, position_type: ('epoch','step') = 'epoch', **kwargs)[source]

This is a wrapper of optimizer class in pytorch.

We add something new features in order to feather control the optimizer.

  • params is the parameters of model which need to be updated. It will use a filter to get all the parameters that required grad automatically. Like this

    filter(lambda p: p.requires_grad, params)

    So, you can passing model.all_params() without any filters.

  • learning rate decay When calling do_lr_decay(), it will do a learning rate decay. like:

    \[lr = lr * decay\]
  • learning rate reset . Reset learning rate, it can change learning rate and decay directly.

Parameters:
  • params – parameters of model, which need to be updated.
  • optimizer – An optimizer classin pytorch, such as torch.optim.Adam.
  • lr_decay – learning rate decay. Default: 0.92.
  • decay_at_epoch – The position of applying lr decay. Default: None.
  • decay_at_step – learning rate decay. Default: None
  • kwargs – pass hyper-parameters to optimizer, such as lr , betas , weight_decay .
Returns:

Args:

params (dict): parameters of model, which need to be updated.

optimizer (torch.optim.Optimizer): An optimizer classin pytorch, such as torch.optim.Adam

lr_decay (float, optional): learning rate decay. Default: 0.92

decay_position (int, list, optional): The decaly position of lr. Default: None

lr_reset (Dict[position(int), lr(float)] ): Reset learning at a certain position. Default: None

position_type (‘epoch’,’step’): Position type. Default: None

**kwargs : pass hyper-parameters to optimizer, such as lr , betas , weight_decay .

Example:

>>> from torch.nn import Sequential, Conv3d
>>> from torch.optim import Adam
>>> module = Sequential(Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)))
>>> opt = Optimizer(module.parameters() ,"Adam", 0.5, 10, {4:0.99},"epoch", lr=1.0, betas=(0.9, 0.999),
weight_decay=1e-5)
>>> print(opt)
(Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 1.0
    weight_decay: 1e-05
)
    lr_decay:0.5
    decay_position:10
    lr_reset:{4: 0.99}
    position_type:epoch
))
>>> opt.lr
1.0
>>> opt.lr_decay
0.5
>>> opt.do_lr_decay()
>>> opt.lr
0.5
>>> opt.do_lr_decay(reset_lr=1)
>>> opt.lr
1
>>> opt.opt
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 1
    weight_decay: 1e-05
)
>>> opt.is_decay_lr(1)
False
>>> opt.is_decay_lr(10)
True
>>> opt.is_decay_lr(20)
True
>>> opt.is_reset_lr(4)
0.99
>>> opt.is_reset_lr(5)
False
do_lr_decay(reset_lr_decay: float = None, reset_lr: float = None)[source]

Do learning rate decay, or reset them.

Passing parameters both None:
Do a learning rate decay by self.lr = self.lr * self.lr_decay .
Passing parameters reset_lr_decay or reset_lr:
Do a learning rate or decay reset. by self.lr = reset_lr self.lr_decay = reset_lr_decay
Parameters:
  • reset_lr_decay – if not None, use this value to reset self.lr_decay. Default: None.
  • reset_lr – if not None, use this value to reset self.lr. Default: None.
Returns:

is_decay_lr(position: Optional[int]) → bool[source]

Judge if use learning decay on this position.

Parameters:position – (int) A position of step or epoch.
Returns:bool
is_reset_lr(position: Optional[int]) → bool[source]

Judge if use learning decay on this position.

Parameters:position – (int) A position of step or epoch.
Returns:bool