Source code for jdit.trainer.gan.sup_gan

from ..super import SupTrainer
from tqdm import tqdm
import torch
from jdit.optimizer import Optimizer
from jdit.model import Model
from jdit.dataset import DataLoadersFactory

[docs]class SupGanTrainer(SupTrainer): d_turn = 1 """The training times of Discriminator every ones Generator training. """ def __init__(self, logdir, nepochs, gpu_ids_abs, netG: Model, netD: Model, optG: Optimizer, optD: Optimizer, datasets: DataLoadersFactory): super(SupGanTrainer, self).__init__(nepochs, logdir, gpu_ids_abs=gpu_ids_abs) self.netG = netG self.netD = netD self.optG = optG self.optD = optD self.datasets = datasets self.fake = None self.fixed_input = None
[docs] def train_epoch(self, subbar_disable=False): for iteration, batch in tqdm(enumerate(self.datasets.loader_train, 1), unit="step", disable=subbar_disable): self.step += 1 self.input, self.ground_truth = self.get_data_from_batch(batch, self.device) self.fake = self.netG(self.input) self._train_iteration(self.optD, self.compute_d_loss, csv_filename="Train_D") if (self.step % self.d_turn) == 0: self._train_iteration(self.optG, self.compute_g_loss, csv_filename="Train_G") if iteration == 1: self._watch_images("Train")
[docs] def get_data_from_batch(self, batch_data: list, device: torch.device): """ Load and wrap data from the data lodaer. Split your one batch data to specify variable. Example:: # batch_data like this [input_Data, ground_truth_Data] input_cpu, ground_truth_cpu = batch_data[0], batch_data[1] # then move them to device and return them return, :param batch_data: one batch data load from ``DataLoader`` :param device: A device variable. ``torch.device`` :return: input Tensor, ground_truth Tensor """ input_tensor, ground_truth_tensor = batch_data[0], batch_data[1] return input_tensor, ground_truth_tensor
def _watch_images(self, tag: str, grid_size: tuple = (3, 3), shuffle=False, save_file=True): """ Show images in tensorboard To show images in tensorboad. If want to show fixed input and it's output, please use ``shuffle=False`` to fix the visualized data. Otherwise, it will sample and visualize the data randomly. Example:: # show fake data self.watcher.image(self.fake, self.current_epoch, tag="%s/fake" % tag, grid_size=grid_size, shuffle=shuffle, save_file=save_file) # show ground_truth self.watcher.image(self.ground_truth, self.current_epoch, tag="%s/real" % tag, grid_size=grid_size, shuffle=shuffle, save_file=save_file) # show input self.watcher.image(self.input, self.current_epoch, tag="%s/input" % tag, grid_size=grid_size, shuffle=shuffle, save_file=save_file) :param tag: tensorboard tag :param grid_size: A tuple for grad size which data you want to visualize :param shuffle: If shuffle the data. :param save_file: If save this images. :return: """ self.watcher.image(self.fake, self.current_epoch, tag="%s/fake" % tag, grid_size=grid_size, shuffle=shuffle, save_file=save_file) self.watcher.image(self.ground_truth, self.current_epoch, tag="%s/real" % tag, grid_size=grid_size, shuffle=shuffle, save_file=save_file)
[docs] def compute_d_loss(self) -> (torch.Tensor, dict): """ Rewrite this method to compute your own loss Discriminator. You should return a **loss** for the first position. You can return a ``dict`` of loss that you want to visualize on the second Example:: d_fake = self.netD(self.fake.detach()) d_real = self.netD(self.ground_truth) var_dic = {} var_dic["GP"] = gp = gradPenalty(self.netD, self.ground_truth, self.fake, input=self.input, use_gpu=self.use_gpu) var_dic["WD"] = w_distance = (d_real.mean() - d_fake.mean()).detach() var_dic["LOSS_D"] = loss_d = d_fake.mean() - d_real.mean() + gp + sgp return: loss_d, var_dic """ loss_d: torch.Tensor var_dic = {} return loss_d, var_dic
[docs] def compute_g_loss(self) -> (torch.Tensor, dict): """Rewrite this method to compute your own loss of Generator. You should return a **loss** for the first position. You can return a ``dict`` of loss that you want to visualize on the second Example:: d_fake = self.netD(self.fake) var_dic = {} var_dic["JC"] = jc = jcbClamp(self.netG, self.input, use_gpu=self.use_gpu) var_dic["LOSS_D"] = loss_g = -d_fake.mean() + jc return: loss_g, var_dic """ loss_g: torch.Tensor var_dic = {} return loss_g, var_dic
[docs] def compute_valid(self) -> dict: """ Rewrite this method to compute your validation values. You can return a ``dict`` of validation values that you want to visualize. Example:: # It will do the same thing as ``compute_g_loss()`` and ``self.compute_d_loss()`` g_loss, _ = self.compute_g_loss() d_loss, _ = self.compute_d_loss() var_dic = {"LOSS_D": d_loss, "LOSS_G": g_loss} return var_dic """ g_loss, _ = self.compute_g_loss() d_loss, _ = self.compute_d_loss() var_dic = {"LOSS_D": d_loss, "LOSS_G": g_loss} return var_dic
[docs] def valid_epoch(self): """Validate model each epoch. It will be called each epoch, when training finish. So, do same verification here. Example:: avg_dic: dict = {} self.netG.eval() self.netD.eval() # Load data from loader_valid. for iteration, batch in enumerate(self.datasets.loader_valid, 1): self.input, self.ground_truth = self.get_data_from_batch(batch) with torch.no_grad(): self.fake = self.netG(self.input) # You can write this function to apply your computation. dic: dict = self.compute_valid() if avg_dic == {}: avg_dic: dict = dic else: for key in dic.keys(): avg_dic[key] += dic[key] for key in avg_dic.keys(): avg_dic[key] = avg_dic[key] / self.datasets.nsteps_valid self.watcher.scalars(avg_dic, self.step, tag="Valid") self._watch_images(tag="Valid") self.netG.train() self.netD.train() """ avg_dic: dict = {} self.netG.eval() self.netD.eval() for iteration, batch in enumerate(self.datasets.loader_valid, 1): self.input, self.ground_truth = self.get_data_from_batch(batch, self.device) with torch.no_grad(): self.fake = self.netG(self.input) dic: dict = self.compute_valid() if avg_dic == {}: avg_dic: dict = dic else: # 求和 for key in dic.keys(): avg_dic[key] += dic[key] for key in avg_dic.keys(): avg_dic[key] = avg_dic[key] / self.datasets.nsteps_valid self.watcher.scalars(avg_dic, self.step, tag="Valid") self.loger.write(self.step, self.current_epoch, avg_dic, "Valid", header=self.current_epoch <= 1) self._watch_images(tag="Valid") self.netG.train() self.netD.train()
def test(self): pass @property def configure(self): config_dic = super(SupGanTrainer, self).configure config_dic["d_turn"] = str(self.d_turn) return config_dic