class jdit.Model(proto_model: <Mock name='mock.Module' id='139642669721584'>, gpu_ids_abs: Union[list, tuple] = (), init_method: Union[str, function, None] = 'kaiming', show_structure=False, check_point_pos=None, verbose=True)[source]

A warapper of pytorch module .

In the simplest case, we use a raw pytorch module to assemble a Model of this class. It can be more convenient to use some feather method, such _check_point , load_weights and so on.

  • proto_model is the core model in this class. It is no necessary to passing a module when you init a Model . You can build a model later by using Model.define(module) or load a model from a file.

  • gpu_ids_abs controls the gpus which you want to use. you should use a absolute id of gpus.

  • init_method controls the weights init method.

    • At init_method=”xavier”, it will use init.xavier_normal_ , in pytorch.nn.init , to init the Conv layers of model.
    • At init_method=”kaiming”, it will use init.kaiming_normal_ , in pytorch.nn.init , to init the Conv layers of model.
    • At init_method=your_own_method, it will be used on weights, just like what pytorch.nn.init method does.
  • show_structure controls whether to show your network structure.


Don’t try to pass a DataParallel model. Only module is accessible. It will change to DataParallel class automatically by passing a muti-gpus ids, like [0, 1] .


gpu_ids_abs must be a tuple or list. If you want to use cpu, just passing an ampty list like [] .


proto_model (module): A pytroch module. Default: None

gpu_ids_abs (tuple or list): The absolute id of gpus. if [] using cpu. Default: ()

init_method (str or def): Weights init method. Default: "Kaiming"

show_structure (bool): Is the structure shown. Default: False


num_params (int): The totals amount of weights in this model.

gpu_ids_abs (list or tuple): Which device is this model on.


>>> from torch.nn import Sequential, Conv3d
>>> # using a square kernels and equal stride
>>> module = Sequential(Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)))
>>> # using cpu to init a Model by module.
>>> net = Model(module, [], show_structure=False)
Sequential Total number of parameters: 15873
Sequential model use CPU!
apply kaiming weight init!
>>> input_tensor = torch.randn(20, 16, 10, 50, 100)
>>> output = net(input_tensor)
convert_to_distributed(device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, find_unused_parameters=False, check_reduction=False)[source]

Args: module (Module): module to be parallelized device_ids (list of int or torch.device): CUDA devices. This should

only be provided when the input module resides on a single CUDA device. For single-device modules, the i``th :attr:`module` replica is placed on ``device_ids[i]. For multi-device modules and CPU modules, device_ids must be None or an empty list, and input data for the forward pass must be placed on the correct device. (default: all devices for single-device modules)
output_device (int or torch.device): device location of output for
single-device CUDA modules. For multi-device modules and CPU modules, it must be None, and the module itself dictates the output location. (default: device_ids[0] for single-device modules)
broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of
the module at beginning of the forward function. (default: True)
process_group: the process group to be used for distributed data
all-reduction. If None, the default process group, which is created by `torch.distributed.init_process_group`, will be used. (default: None)
bucket_cap_mb: DistributedDataParallel will bucket parameters into
multiple buckets so that gradient reduction of each bucket can potentially overlap with backward computation. bucket_cap_mb controls the bucket size in MegaBytes (MB) (default: 25)
find_unused_parameters (bool): Traverse the autograd graph of all tensors
contained in the return value of the wrapped module’s forward function. Parameters that don’t receive gradients as part of this graph are preemptively marked as being ready to be reduced. (default: False)
check_reduction: when setting to True, it enables DistributedDataParallel
to automatically check if the previous iteration’s backward reductions were successfully issued at the beginning of every iteration’s forward function. You normally don’t need this option enabled unless you are observing weird behaviors such as different ranks are getting different gradients, which should not happen if DistributedDataParallel is correctly used. (default: False)

Attributes: module (Module): the module to be parallelized


>>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
>>> net.convert_to_distributed(pg)
>>> # same thing
>>> net.model = torch.nn.DistributedDataParallel(net.model, pg)
static count_params(proto_model: <Mock name='mock.Module' id='139642669721584'>)[source]

count the total parameters of model.

Parameters:proto_model – pytorch module
Returns:number of parameters
define(proto_model: <Mock name='mock.Module' id='139642669721584'>, gpu_ids_abs: Union[list, tuple], init_method: Union[str, function, None], show_structure: bool)[source]

Define and wrap a pytorch module, according to CPU, GPU and multi-GPUs.

  • Print the module’s info.
  • Move this module to specify device.
  • Apply weight init method.
  • proto_model – Network, type of module.
  • gpu_ids_abs – Be used GPUs’ id, type of tuple or list. If not use GPU, pass ().
  • init_method – init weights method(“kaiming”) or False don’t use any init.
  • show_structure – If print structure of model.
load_point(model_name: str, epoch: int, logdir='log')[source]

load model and weights from a certain checkpoint.

this method is cooperate with method self.chechPoint()

load_weights(weights: Union[dict, str], strict=True)[source]

Assemble a model and weights from paths or passing parameters.

You can load a model from a file, passing parameters or both.

  • weights – Pytorch weights or weights file path.
  • strict – The same function in pytorch model.load_state_dict(weights,strict = strict) . default:True



>>> from torchvision.models.resnet import resnet18
>>> model = Model(resnet18())
ResNet Total number of parameters: 11689512
ResNet model use CPU!
apply kaiming weight init!
>>> model.save_weights("model.pth",)
try to remove 'module.' in keys of weights dict...
>>> model.load_weights("model.pth", True)
Try to remove `moudle.` to keys of weights dict
print_network(proto_model: <Mock name='mock.Module' id='139642669721584'>, show_structure=False)[source]

Print total number of parameters and structure of network

  • proto_model – Pytorch module
  • show_structure – If show network’s structure. default: False

Total number of parameters

save_weights(weights_path: str, fix_weights=True)[source]

Save a model and weights to files.

You can save a model, weights or both to file.


This method deal well with different devices on model saving. You don’ need to care about which devices your model have saved.

  • weights_path – Pytorch weights or weights file path.
  • fix_weights – If this is true, it will remove the ‘.module’ in keys, when you save a DataParallel. without any moving operation. Otherwise, it will move to cpu, especially in DataParallel. default:False


>>> from torch.nn import Linear
>>> model = Model(Linear(10,1))
Linear Total number of parameters: 11
Linear model use CPU!
apply kaiming weight init!
>>> model.save_weights("weights.pth")
try to remove 'module.' in keys of weights dict...
>>> model.load_weights("weights.pth")
Try to remove `moudle.` to keys of weights dict