Build your own trainer

To build your own trainer, you need prepare these sections:

  • dataset This is the datasets which you want to use.
  • Model This is a wrapper of your own pytorch module .
  • Optimizer This is a wrapper of pytorch opt .
  • trainer This is a training pipeline which assemble the sections above.


In this section, you should build your own dataset that you want to use following.

Common dataset

For some reasons, many opening dataset are common. So, you can easily build a standard common dataaset. such as :

  • Fashion mnist
  • Cifar10
  • Lsun

Only one parameters you need to set is batch_shize . For these common datasets, you only need to reset the batch size.

>>> from jdit.dataset import FashionMNIST
>>> fashion_data = FashionMNIST(batch_shize=64)  # now you get a ``dataset``

Custom dataset

If you want to build a dataset by your own data, you need to inherit the class


and rewrite it’s build_transforms() and build_datasets() (If you want to use default set, rewrite this is not necessary.)

Following these setps:

  • Rewrite your own transforms to self.train_transform_list and self.valid_transform_list. (Not necessary)
  • Register your training dataset to self.dataset_train by using self.train_transform_list
  • Register your valid_epoch dataset to self.dataset_valid by using self.valid_transform_list


class FashionMNIST(DataLoadersFactory):
    def __init__(self, root=r'.\datasets\fashion_data', batch_size=128, num_workers=-1):
        super(FashionMNIST, self).__init__(root, batch_size, num_workers)

    def build_transforms(self, resize=32):
        # This is a default set, you can rewrite it.
        self.train_transform_list = self.valid_transform_list = [
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]

    def build_datasets(self):
        self.dataset_train = datasets.CIFAR10(root, train=True, download=True,
        self.dataset_valid = datasets.CIFAR10(root, train=False, download=True,

For now, you get your own dataset.


In this section, you should build your own network.

First, you need to build a pytorch module like this:

>>> class SimpleModel(nn.Module):
...     def __init__(self):
...         super(SimpleModel, self).__init__()
...         self.layer1 = nn.Linear(32, 64)
...         self.layer2 = nn.Linear(64, 1)
...    def forward(self, input):
...        out = self.layer1(input)
...        out = self.layer2(out)
...        return out
>>> network = SimpleModel()


You don’t need to convert it to gpu or using data parallel. The jdit.Model will do this for you.

Second, wrap your model by using jdit.Model . Set which gpus you want to use and the weights init method.


For some reasons, the gpu id in pytorch still start from 0. For this model, it will handel this problem. If you have gpu [0,1,2,3] , and you only want to use 2,3. Just set gpu_ids_abs=[2, 3] .

>>> from jdit import Model
>>> network = SimpleModel()
>>> jdit_model = Model(network, gpu_ids_abs=[2,3], init_method="kaiming")
SimpleModel Total number of parameters: 2177
SimpleModel dataParallel use GPUs[2, 3]!
apply kaiming weight init!

For now, you get your own dataset.


In this section, you should build your an optimizer.

Compare with the optimizer in pytorch. This extend a easy function that can do a learning rate decay and reset.

However, do_lr_decay() will be called every epoch or on certain epoch at the end automatically. Actually, you don’ need to do anything to apply learning rate decay. If you don’t want to decay. Just set lr_decay = 1. or set a decay epoch larger than training epoch. I will show you how it works and you can implement something special strategies.

   >>> from jdit import Optimizer
   >>> from torch.nn import Linear
   >>> network = Linear(10, 1)
   >>> #set params
   >>> #`optimizer` is equal to pytorch class name (torch.optim.RMSprop).
   >>> hparams = {
   ...     "optimizer" = "RMSprop" ,
   ...     "lr" = 0.001,
   ...     "lr_decay" = 0.5,
   ...     "weight_decay" = 2e-5,
   ...     "momentum" = 0}
   >>> #define optimizer
   >>> opt = Optimizer(network.parameters(),**hparams)
   >>> opt.do_lr_decay()
   >>> opt.do_lr_decay(reset_lr = 1)

You can pass a certain name to use it,such "Adam" ,"RMSprop", "SGD".


As for spectrum normalization, the optimizer will filter out the differentiable weights. So, you don’t need write something like this filter(lambda p: p.requires_grad, params) Merely pass the model.parameters() is enough.

For now, you get an Optimizer.


For the final section it is a little complex. It supplies some templates such as SupTrainer GanTrainer ClassificationTrainer and instances .

The inherit relation shape is following:


Top level SupTrainer

SupTrainer is the top class of these templates.

It defines some tools to record the log, data visualization and so on. Besides, it contain a big loop of epoch, which can be inherited by the second level templates to fill the contents in each opch training.

Something like this:

def train():
   for epoch in range(nepochs):
       self._record_configs() # record info
       # do learning rate decay
       # save model check point

Every method will be rewrite by the second level templates. It only defines a rough framework.

Second level ClassificationTrainer

On this level, the task becomes more clear, a classification task. We get one model, one optimizer and one dataset and the data structure is images and labels. So, to init a ClassificationTrainer.

class ClassificationTrainer(SupTrainer):
    def __init__(self, logdir, nepochs, gpu_ids, net, opt, datasets, num_class):
        super(ClassificationTrainer, self).__init__(nepochs, logdir, gpu_ids_abs) = net
        self.opt = opt
        self.datasets = datasets
        self.num_class = num_class
        self.labels = None
        self.output = None

For the next, build a training loop for one epoch. You must using self.step to record the training step.

def train_epoch(self, subbar_disable=False):
    # display training images every epoch
    self._watch_images(show_imgs_num=3, tag="Train")
    for iteration, batch in tqdm(enumerate(self.datasets.loader_train, 1), unit="step", disable=subbar_disable):
        self.step += 1 # necessary!
        # unzip data from one batch and move to certain device
        self.input, self.ground_truth, self.labels = self.get_data_from_batch(batch, self.device)
        self.output =
        # this is defined in SupTrainer.
        # using `self.compute_loss` and `self.opt` to do a backward.
        self._train_iteration(self.opt, self.compute_loss, tag="Train")

def compute_loss(self):
    """Compute the main loss and observed variables.
    Rewrite by the next templates.

def compute_valid(self):
    """Compute the valid_epoch variables for visualization.
    Rewrite by the next templates.

The compute_loss() and compute_valid should be rewrite in the next template.

Third level FashionClassTrainer

Up to this level every this is clear. So, inherit the ClassificationTrainer and fill the specify methods.

class FashionClassTrainer(ClassificationTrainer):
    def __init__(self, logdir, nepochs, gpu_ids, net, opt, dataset):
        super(FashionClassTrainer, self).__init__(logdir, nepochs, gpu_ids, net, opt, dataset)
        data, label = self.datasets.samples_train
        # show dataset in tensorboard
        self.watcher.embedding(data, data, label, 1)

    def compute_loss(self):
        var_dic = {}
        var_dic["CEP"] = loss = nn.CrossEntropyLoss()(self.output, self.labels.squeeze().long())
        return loss, var_dic

    def compute_valid(self):
        var_dic = {}
        var_dic["CEP"] = cep = nn.CrossEntropyLoss()(self.output, self.labels.squeeze().long())

        _, predict = torch.max(self.output.detach(), 1)  # 0100=>1  0010=>2
        total = predict.size(0) * 1.0
        labels = self.labels.squeeze().long()
        correct = predict.eq(labels).cpu().sum().float()
        acc = correct / total
        var_dic["ACC"] = acc
        return var_dic

compute_loss() will be called every training step of backward. It returns two values.

  • The first one, loss , is main loss which will be implemented loss.backward() to update model weights.
  • The second one, var_dic , is a value dictionary which will be visualized on tensorboard and depicted as a curve.

In this example, for compute_loss() it will use loss = nn.CrossEntropyLoss() to do a backward propagation and visualize it on tensorboard named "CEP".

compute_loss() will be called every validation step. It returns one value.

  • The var_dic , is the same thing like var_dic in compute_loss() .


compute_loss() will be called under torch.no_grad() . So, grads will not be computed in this method. But if you need to get grads, please use torch.enable_grad() to make grads computation available.

Finally, you get a trainer.

You have got everything. Put them together and train it!

>>> mnist = FashionMNIST(batch_size)
>>> net = Model(SimpleModel(depth=depth), gpu_ids_abs=gpus, init_method="kaiming")
>>> opt = Optimizer(net.parameters(), **hparams)
>>> Trainer = FashionClassTrainer("log", nepochs, gpus, net, opt, mnist, 10)
>>> Trainer.train()