Source code for jdit.trainer.single.sup_single

from ..super import SupTrainer
from tqdm import tqdm
import torch
from jdit.optimizer import Optimizer
from jdit.model import Model
from jdit.dataset import DataLoadersFactory


[docs]class SupSingleModelTrainer(SupTrainer): """ This is a Single Model Trainer. It means you only have one model. input, gound_truth output = model(input) loss(output, gound_truth) """ def __init__(self, logdir, nepochs, gpu_ids_abs, net: Model, opt: Optimizer, datasets: DataLoadersFactory): super(SupSingleModelTrainer, self).__init__(nepochs, logdir, gpu_ids_abs=gpu_ids_abs) self.net = net self.opt = opt self.datasets = datasets self.fixed_input = None self.input = None self.output = None self.ground_truth = None
[docs] def train_epoch(self, subbar_disable=False): for iteration, batch in tqdm(enumerate(self.datasets.loader_train, 1), unit="step", disable=subbar_disable): self.step += 1 self.input, self.ground_truth = self.get_data_from_batch(batch, self.device) self.output = self.net(self.input) self._train_iteration(self.opt, self.compute_loss, csv_filename="Train") if iteration == 1: self._watch_images("Train")
[docs] def get_data_from_batch(self, batch_data: list, device: torch.device): """ Load and wrap data from the data lodaer. Split your one batch data to specify variable. Example:: # batch_data like this [input_Data, ground_truth_Data] input_cpu, ground_truth_cpu = batch_data[0], batch_data[1] # then move them to device and return them return input_cpu.to(self.device), ground_truth_cpu.to(self.device) :param batch_data: one batch data load from ``DataLoader`` :param device: A device variable. ``torch.device`` :return: input Tensor, ground_truth Tensor """ input_tensor, ground_truth_tensor = batch_data[0], batch_data[1] return input_tensor, ground_truth_tensor
def _watch_images(self, tag: str, grid_size: tuple = (3, 3), shuffle=False, save_file=True): """ Show images in tensorboard To show images in tensorboad. If want to show fixed input and it's output, please use ``shuffle=False`` to fix the visualized data. Otherwise, it will sample and visualize the data randomly. Example:: # show fake data self.watcher.image(self.output, self.current_epoch, tag="%s/output" % tag, grid_size=grid_size, shuffle=shuffle, save_file=save_file) # show ground_truth self.watcher.image(self.ground_truth, self.current_epoch, tag="%s/ground_truth" % tag, grid_size=grid_size, shuffle=shuffle, save_file=save_file) # show input self.watcher.image(self.input, self.current_epoch, tag="%s/input" % tag, grid_size=grid_size, shuffle=shuffle, save_file=save_file) :param tag: tensorboard tag :param grid_size: A tuple for grad size which data you want to visualize :param shuffle: If shuffle the data. :param save_file: If save this images. :return: """ self.watcher.image(self.output, self.current_epoch, tag="%s/output" % tag, grid_size=grid_size, shuffle=shuffle, save_file=save_file) self.watcher.image(self.ground_truth, self.current_epoch, tag="%s/ground_truth" % tag, grid_size=grid_size, shuffle=shuffle, save_file=save_file)
[docs] def compute_loss(self) -> (torch.Tensor, dict): """ Rewrite this method to compute your own loss Discriminator. Use self.input, self.output and self.ground_truth to compute loss. You should return a **loss** for the first position. You can return a ``dict`` of loss that you want to visualize on the second position.like Example:: var_dic = {} var_dic["LOSS"] = loss_d = (self.output ** 2 - self.groundtruth ** 2) ** 0.5 return: loss, var_dic """ loss: torch.Tensor var_dic = {} return loss, var_dic
[docs] def compute_valid(self) -> dict: """ Rewrite this method to compute your validation values. Use self.input, self.output and self.ground_truth to compute valid loss. You can return a ``dict`` of validation values that you want to visualize. Example:: # It will do the same thing as ``compute_loss()`` var_dic, _ = self.compute_loss() return var_dic """ # It will do the same thing as ``compute_loss()`` var_dic, _ = self.compute_loss() return var_dic
[docs] def valid_epoch(self): """Validate model each epoch. It will be called each epoch, when training finish. So, do same verification here. Example:: avg_dic: dict = {} self.net.eval() for iteration, batch in enumerate(self.datasets.loader_valid, 1): self.input, self.ground_truth = self.get_data_from_batch(batch, self.device) with torch.no_grad(): self.output = self.net(self.input) dic: dict = self.compute_valid() if avg_dic == {}: avg_dic: dict = dic else: for key in dic.keys(): avg_dic[key] += dic[key] for key in avg_dic.keys(): avg_dic[key] = avg_dic[key] / self.datasets.nsteps_valid self.watcher.scalars(avg_dic, self.step, tag="Valid") self.loger.write(self.step, self.current_epoch, avg_dic, "Valid", header=self.step <= 1) self._watch_images(tag="Valid") self.net.train() """ avg_dic: dict = {} self.net.eval() for iteration, batch in enumerate(self.datasets.loader_valid, 1): self.input, self.ground_truth = self.get_data_from_batch(batch, self.device) with torch.no_grad(): self.output = self.net(self.input) dic: dict = self.compute_valid() if avg_dic == {}: avg_dic: dict = dic else: # 求和 for key in dic.keys(): avg_dic[key] += dic[key] for key in avg_dic.keys(): avg_dic[key] = avg_dic[key] / self.datasets.nsteps_valid self.watcher.scalars(avg_dic, self.step, tag="Valid") self.loger.write(self.step, self.current_epoch, avg_dic, "Valid", header=self.current_epoch <= 1) self._watch_images(tag="Valid") self.net.train()
def test(self): pass