from abc import ABCMeta, abstractmethod
from types import FunctionType
from tqdm import tqdm
from torch.utils.data import random_split
import traceback
import shutil
from typing import Union
from jdit.dataset import DataLoadersFactory
from jdit.model import Model
from jdit.optimizer import Optimizer
import torch
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
import os
import random
import csv
import numpy as np
from functools import wraps
[docs]class SupTrainer(object):
"""this is a super class of all trainers
It defines:
* The basic tools, ``Performance()``, ``Watcher()``, ``Loger()``.
* The basic loop of epochs.
* Learning rate decay and model check point.
"""
__metaclass__ = ABCMeta
def __new__(cls, *args, **kwargs):
instance = super(SupTrainer, cls).__new__(cls)
instance._opts = dict()
instance._datasets = dict()
instance._models = dict()
return instance
def __init__(self, nepochs: int, logdir: str, gpu_ids_abs: Union[list, tuple] = ()):
# os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in gpu_ids_abs])
# self.gpu_ids = [i for i in range(len(gpu_ids_abs))]
self.gpu_ids = gpu_ids_abs
self.logdir = logdir
self.performance = Performance(gpu_ids_abs)
self.watcher = Watcher(logdir)
self.loger = Loger(logdir)
self.use_gpu = True if (len(self.gpu_ids) > 0) and torch.cuda.is_available() else False
self.device = torch.device("cuda:%d" % self.gpu_ids[0]) if self.use_gpu else torch.device("cpu")
self.input = torch.Tensor()
self.ground_truth = torch.Tensor()
self.nepochs = nepochs
self.current_epoch = 0
self.step = 0
self.start_epoch = 1
[docs] def train(self, process_bar_header: str = None, process_bar_position: int = None, subbar_disable=False,
record_configs=True, show_network=False, **kwargs):
"""The main training loop of epochs.
:param process_bar_header: The tag name of process bar header,
which is used in ``tqdm(desc=process_bar_header)``
:param process_bar_position: The process bar's position. It is useful in multitask,
which is used in ``tqdm(position=process_bar_position)``
:param subbar_disable: If show the info of every training set,
:param record_configs: If record the training processing data.
:param show_network: If show the structure of network. It will cost extra memory,
:param kwargs: Any other parameters that passing to ``tqdm()`` to control the behavior of process bar.
"""
if record_configs:
self._record_configs()
if show_network:
self.plot_graphs_lazy()
for epoch in tqdm(range(self.start_epoch, self.nepochs + 1), total=self.nepochs,
unit="epoch", desc=process_bar_header, position=process_bar_position, **kwargs):
self.current_epoch = epoch
self.train_epoch(subbar_disable)
self.valid_epoch()
self.test()
self.watcher.close()
[docs] def dist_train(self, process_bar_header: str = None, process_bar_position: int = None,
subbar_disable=False,
record_configs=True, show_network=False, **kwargs):
"""The main training loop of epochs.
:param process_bar_header: The tag name of process bar header,
which is used in ``tqdm(desc=process_bar_header)``
:param process_bar_position: The process bar's position. It is useful in multitask,
which is used in ``tqdm(position=process_bar_position)``
:param subbar_disable: If show the info of every training set,
:param record_configs: If record the training processing data.
:param show_network: If show the structure of network. It will cost extra memory,
:param kwargs: Any other parameters that passing to ``tqdm()`` to control the behavior of process bar.
"""
if record_configs:
self._record_configs()
if show_network:
self.plot_graphs_lazy()
for epoch in tqdm(range(self.start_epoch, self.nepochs + 1), total=self.nepochs,
unit="epoch", desc=process_bar_header, position=process_bar_position, **kwargs):
self._datasets["datasets"].loader_train.sampler.set_epoch(epoch)
self.current_epoch = epoch
self.train_epoch(subbar_disable)
self.valid_epoch()
self.test()
self.watcher.close()
def __setattr__(self, key, value):
super(SupTrainer, self).__setattr__(key, value)
if key == "step" and value != 0:
is_change = super(SupTrainer, self).__getattribute__("_change_lr")("step", value)
if is_change:
super(SupTrainer, self).__getattribute__("_record_configs")("optimizer")
elif key == "current_epoch" and value != 0:
is_change_lr = super(SupTrainer, self).__getattribute__("_change_lr")("epoch", value)
if is_change_lr:
super(SupTrainer, self).__getattribute__("_record_configs")("optimizer")
super(SupTrainer, self).__getattribute__("_check_point")()
super(SupTrainer, self).__getattribute__("_record_configs")("performance")
elif isinstance(value, Model):
super(SupTrainer, self).__getattribute__("_models").update({key: value})
elif isinstance(value, Optimizer):
super(SupTrainer, self).__getattribute__("_opts").update({key: value})
elif isinstance(value, DataLoadersFactory):
super(SupTrainer, self).__getattribute__("_datasets").update({key: value})
else:
pass
def __delattr__(self, item):
if isinstance(item, Model):
super(SupTrainer, self).__getattribute__("_models").pop(item)
elif isinstance(item, Optimizer):
super(SupTrainer, self).__getattribute__("_opts").pop(item)
elif isinstance(item, DataLoadersFactory):
super(SupTrainer, self).__getattribute__("_datasets").pop(item)
def __getattribute__(self, name):
v = super(SupTrainer, self).__getattribute__(name)
if name == "get_data_from_batch":
new_fc = super(SupTrainer, self).__getattribute__("_mv_device")(v)
return new_fc
return v
[docs] def debug(self):
"""Debug the trainer.
It will check the function
* ``self._record_configs()`` save all module's configures.
* ``self.train_epoch()`` train one epoch with several samples. So, it is vary fast.
* ``self.valid_epoch()`` valid one epoch using dataset_valid.
* ``self._change_lr()`` do learning rate change.
* ``self._check_point()`` do model check point.
* ``self.test()`` do test by using dataset_test.
Before debug, it will reset the ``datasets`` and only pick up several samples to do fast test.
For test, it build a ``log_debug`` directory to save the log.
:return: bool. It will return ``True``, if passes all the tests.
"""
self.watcher.close()
self.logdir = "log_debug"
# reset `log_debug`
if os.path.exists(self.logdir):
try:
shutil.rmtree("log_debug") # 递归删除文件夹
except Exception as e:
print('Can not remove logdir `log_debug`\n', e)
traceback.print_exc()
self.watcher = Watcher(self.logdir)
self.loger = Loger(self.logdir)
self.performance = Performance()
# reset datasets and dataloaders
for item in vars(self).values():
if isinstance(item, DataLoadersFactory):
item.batch_size = 2
item.shuffle = False
item.num_workers = 2
item.dataset_train, _ = random_split(item.dataset_train, [2, len(item.dataset_train) - 2])
item.dataset_valid, _ = random_split(item.dataset_valid, [2, len(item.dataset_valid) - 2])
item.dataset_test, _ = random_split(item.dataset_test, [2, len(item.dataset_test) - 2])
item.build_loaders()
item.sample_dataset_size = 1
print("datas range: (%s, %s)" % (item.samples_train[0].min().cpu().numpy(),
item.samples_train[0].max().cpu().numpy()))
if isinstance(item, Model):
item.check_point_pos = 2
if isinstance(item, Optimizer):
item.decay_position = 2
item.position_type = "step"
# the tested functions
debug_fcs = [self._record_configs, self.train_epoch, self.valid_epoch,
self._change_lr, self._check_point, self.test]
print("{:=^30}".format(">Debug<"))
success = True
for fc in debug_fcs:
print("{:_^30}".format(fc.__name__ + "()"))
try:
if fc.__name__ == "_change_lr":
self.step = 2
is_lr_change = fc()
if not is_lr_change:
raise AssertionError("doesn't change learning rate!")
elif fc.__name__ == "_check_point":
self.current_epoch = 2
fc()
else:
fc()
except Exception as e:
print('Error:', e)
traceback.print_exc()
success = False
else:
print("pass!")
self.watcher.close()
if success:
print("\033[1;32;40m" + "{:=^30}".format(">Debug Successful!<"))
else:
print("\033[1;31;40m" + "{:=^30}".format(">Debug Failed!<"))
if os.path.exists(self.logdir):
try:
shutil.rmtree("log_debug") # 递归删除文件夹
except Exception as e:
print('Can not remove logdir `log_debug`\n', e)
traceback.print_exc()
return success
[docs] @abstractmethod
def train_epoch(self, subbar_disable=False):
"""
You get train loader and do a loop to deal with data.
.. Caution::
You must record your training step on ``self.step`` in your loop by doing things like this ``self.step +=
1``.
Example::
for iteration, batch in tqdm(enumerate(self.datasets.loader_train, 1)):
self.step += 1
self.input_cpu, self.ground_truth_cpu = self.get_data_from_batch(batch, self.device)
self._train_iteration(self.opt, self.compute_loss, tag="Train")
:return:
"""
pass
def _mv_device(self, f):
@wraps(f)
def wrapper(*args, **kwargs):
variables = f(*args, **kwargs)
device = super(SupTrainer, self).__getattribute__("device")
variables = tuple(v.to(device) if hasattr(v, "to") else v for v in variables)
return variables
return wrapper
[docs] def get_data_from_batch(self, batch_data: list, device: torch.device):
""" Split your data from one batch data to specify .
If your dataset return something like
``return input_data, label``.
It means that two values need unpack.
So, you need to split the batch data into two parts, like this
``input, ground_truth = batch_data[0], batch_data[1]``
.. Caution::
Don't forget to move these data to device, by using ``input.to(device)`` .
:param batch_data: One batch data from dataloader.
:param device: the device that data will be located.
:return: The certain variable with correct device location.
Example::
# load and unzip the data from one batch tuple (input, ground_truth)
input, ground_truth = batch_data[0], batch_data[1]
# move these data to device
return input.to(device), ground_truth.to(device)
"""
input_img, ground_truth = batch_data[0], batch_data[1]
return input_img, ground_truth
def _train_iteration(self, opt: Optimizer, compute_loss_fc: FunctionType, csv_filename: str = "Train"):
opt.zero_grad()
loss, var_dic = compute_loss_fc()
loss.backward()
opt.step()
self.watcher.scalars(var_dict=var_dic, global_step=self.step, tag="Train")
opt_name = list(self._opts.keys())[list(self._opts.values()).index(opt)]
self.watcher.scalars(var_dict={"Learning rate": opt.lr}, global_step=self.step, tag=opt_name)
self.loger.write(self.step, self.current_epoch, var_dic, csv_filename, header=self.step <= 1)
def _record_configs(self, configs_names=None):
"""to register the ``Model`` , ``Optimizer`` , ``Trainer`` and ``Performance`` config info.
The default is record the info of ``trainer`` and ``performance`` config.
If you want to record more configures info, you can add more module to ``self.loger.regist_config`` .
The following is an example.
Example::
# for opt.configure
self.loger.regist_config(opt, self.current_epoch)
# for model.configure
self.loger.regist_config(model, self.current_epoch )
# for self.performance.configure
self.loger.regist_config(self.performance, self.current_epoch)
# for trainer.configure
self.loger.regist_config(self, self.current_epoch)
:return:
"""
if (configs_names is None) or "model" in configs_names:
_models = super(SupTrainer, self).__getattribute__("_models")
for name, model in _models.items():
self.loger.regist_config(model, self.current_epoch, self.step, config_filename=name)
if (configs_names is None) or "dataset" in configs_names:
_datasets = super(SupTrainer, self).__getattribute__("_datasets")
for name, dataset in _datasets.items():
self.loger.regist_config(dataset, config_filename=name)
if (configs_names is None) or "optimizer" in configs_names:
_opts = super(SupTrainer, self).__getattribute__("_opts")
for name, opt in _opts.items():
self.loger.regist_config(opt, self.current_epoch, self.step, config_filename=name)
if (configs_names is None) or "trainer" in configs_names or (configs_names is None):
self.loger.regist_config(self, config_filename=self.__class__.__name__)
if (configs_names is None) or "performance" in configs_names:
self.loger.regist_config(self.performance, self.current_epoch, self.step, config_filename="performance")
[docs] def plot_graphs_lazy(self):
"""Plot model graph on tensorboard.
To plot all models graphs in trainer, by using variable name as model name.
:return:
"""
_models = super(SupTrainer, self).__getattribute__("_models")
for name, model in _models.items():
self.watcher.graph_lazy(model, name)
def _check_point(self):
_models = super(SupTrainer, self).__getattribute__("_models")
current_epoch = super(SupTrainer, self).__getattribute__("current_epoch")
logdir = super(SupTrainer, self).__getattribute__("logdir")
for name, model in _models.items():
model.is_checkpoint(name, current_epoch, logdir)
def _change_lr(self, position_type="step", position=2):
is_change = True
_opts = super(SupTrainer, self).__getattribute__("_opts")
for opt in _opts.values():
if opt.position_type == position_type:
reset_lr = opt.is_reset_lr(position)
if reset_lr:
opt.do_lr_decay(reset_lr=reset_lr)
elif opt.is_decay_lr(position):
opt.do_lr_decay()
else:
is_change = False
return is_change
def valid_epoch(self):
pass
def test(self):
pass
@property
def configure(self):
config_dict = dict()
config_dict["nepochs"] = int(self.nepochs)
return config_dict
class Performance(object):
"""this is a performance watcher.
"""
def __init__(self, gpu_ids_abs: Union[list, tuple] = ()):
self.config_dic = dict()
self.gpu_ids = gpu_ids_abs
def mem_info(self):
from psutil import virtual_memory
mem = virtual_memory()
self.config_dic['mem_total_GB'] = round(mem.total / 1024 ** 3, 2)
self.config_dic['mem_used_GB'] = round(mem.used / 1024 ** 3, 2)
self.config_dic['mem_percent'] = mem.percent
# self.config_dic['mem_free_GB'] = round(mem.free // 1024 ** 3, 2)
# self._set_dict_smooth("mem_total_M", mem.total // 1024 ** 2, smooth=0.3)
# self._set_dict_smooth("mem_used_M", mem.used // 1024 ** 2, smooth=0.3)
# self._set_dict_smooth("mem_free_M", mem.free // 1024 ** 2, smooth=0.3)
# self._set_dict_smooth("mem_percent", mem.percent, smooth=0.3)
def gpu_info(self):
# pip install nvidia-ml-py3
if len(self.gpu_ids) >= 0 and torch.cuda.is_available():
try:
import pynvml
pynvml.nvmlInit()
self.config_dic['gpu_driver_version'] = pynvml.nvmlSystemGetDriverVersion()
for gpu_id in self.gpu_ids:
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
gpu_id_name = "gpu%s" % gpu_id
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
gpu_utilize = pynvml.nvmlDeviceGetUtilizationRates(handle)
self.config_dic['%s_device_name' % gpu_id_name] = pynvml.nvmlDeviceGetName(handle)
self.config_dic['%s_mem_total' % gpu_id_name] = gpu_mem_total = round(mem_info.total / 1024 ** 3, 2)
self.config_dic['%s_mem_used' % gpu_id_name] = gpu_mem_used = round(mem_info.used / 1024 ** 3, 2)
# self.config_dic['%s_mem_free' % gpu_id_name] = gpu_mem_free = mem_info.free // 1024 ** 2
self.config_dic['%s_mem_percent' % gpu_id_name] = round((gpu_mem_used / gpu_mem_total) * 100, 1)
self._set_dict_smooth('%s_utilize_gpu' % gpu_id_name, gpu_utilize.gpu, 0.8)
# self.config_dic['%s_utilize_gpu' % gpu_id_name] = gpu_utilize.gpu
# self.config_dic['%s_utilize_memory' % gpu_id_name] = gpu_utilize.memory
pynvml.nvmlShutdown()
except Exception as e:
print(e)
def _set_dict_smooth(self, key: str, value, smooth: float = 0.3):
now = value
if key in self.config_dic:
last = self.config_dic[key]
self.config_dic[key] = now * (1 - smooth) + last * smooth
else:
self.config_dic[key] = now
@property
def configure(self):
self.mem_info()
self.gpu_info()
self.gpu_info()
return self.config_dic
class Loger(object):
"""this is a log recorder.
"""
def __init__(self, logdir: str = "log"):
self.logdir = logdir
self.regist_dict = dict({})
self._build_dir()
def _build_dir(self):
if not os.path.exists(self.logdir):
print("%s directory is not found. Build now!" % dir)
os.makedirs(self.logdir)
def regist_config(self, opt_model_data: Union[SupTrainer, Optimizer, Model, DataLoadersFactory, Performance],
epoch=None,
step=None,
config_filename: str = None):
"""
get obj's configure. flag is time point, usually use `epoch`.
obj_name default is 'opt_model_data' class name.
If you pass two same class boj, you should give each of them a unique `obj_name`
:param opt_model_data: Optm, Model or dataset
:param epoch: time point such as `epoch`
:param flag_name: name of flag `epoch`
:param config_filename: default is 'opt_model_data' class name
:return:
"""
if config_filename is None:
config_filename = opt_model_data.__class__.__name__
obj_config_dic = opt_model_data.configure.copy()
path = os.path.join(self.logdir, config_filename + ".csv")
is_registed = config_filename in self.regist_dict.keys()
if not is_registed:
# 若没有注册过,注册该config
self.regist_dict[config_filename] = obj_config_dic.copy()
config_dic = dict()
if step is not None:
config_dic.update({"step": step})
if epoch is not None:
config_dic.update({"epoch": epoch})
config_dic.update(obj_config_dic)
# pdg = pd.DataFrame.from_dict(config_dic, orient="index").transpose()
# pdg.to_csv(path, mode="w", encoding="utf-8", index=False, header=True)
with open(path, "w", newline = "", encoding="utf-8") as csvfile:
writer = csv.writer(csvfile)
# 先写入columns_name
writer.writerow(config_dic.keys())
# 写入多行用writerows
writer.writerow(config_dic.values())
else:
# 已经注册过config
last_config = self.regist_dict[config_filename]
if last_config != obj_config_dic:
# 若已经注册过config,比对最后一次结果,如果不同,则写入,相同无操作。
self.regist_dict[config_filename] = obj_config_dic.copy()
config_dic = dict()
if step is not None:
config_dic.update({"step": step})
if epoch is not None:
config_dic.update({"epoch": epoch})
config_dic.update(obj_config_dic)
# pdg = pd.DataFrame.from_dict(config_dic, orient="index").transpose()
# pdg.to_csv(path, mode="a", encoding="utf-8", index=False, header=False)
with open(path, "a",newline = "", encoding="utf-8") as csvfile:
writer = csv.writer(csvfile)
# 先写入columns_name
# writer.writerow(config_dic.keys())
# 写入多行用writerows
writer.writerow(config_dic.values())
def write(self, step: int, current_epoch: int, msg_dic: dict, filename: str, header=True):
if msg_dic is None:
return
else:
for key, value in msg_dic.items():
if hasattr(value, "item"):
msg_dic[key] = value.detach().cpu().item()
path = os.path.join(self.logdir, filename + ".csv")
dic = dict({"step": step, "current_epoch": current_epoch})
dic.update(msg_dic)
# pdg = pd.DataFrame.from_dict(dic, orient="index").transpose()
# pdg.to_csv(path, mode="a", encoding="utf-8", index=False, header=header)
with open(path, "a", newline = "", encoding="utf-8") as csvfile:
writer = csv.writer(csvfile)
if header:
writer.writerow(dic.keys())
writer.writerow(dic.values())
def clear_regist(self):
self.regist_dict = dict({})
class Watcher(object):
"""this is a params and images watcher
"""
def __init__(self, logdir: str):
self.logdir = logdir
self.writer = SummaryWriter(logdir)
self._build_dir(logdir)
self.training_progress_images = []
self.gif_duration = 0.5
self.handel = None
def model_params(self, model: torch.nn.Module, global_step: int):
for name, param in model.named_parameters():
if "bias" in name:
continue
self.writer.add_histogram(name, param.clone().cpu().data.numpy(), global_step)
def scalars(self, var_dict: dict, global_step: int, tag="Train"):
for key, scalar in var_dict.items():
self.writer.add_scalars(key, {tag: scalar}, global_step)
@staticmethod
def _sample(tensor: torch.Tensor, num_samples: int, shuffle=True):
total = len(tensor)
if num_samples > total:
raise ValueError("sample(%d) greater than the total amount(%d)!" % (num_samples, len(tensor)))
if shuffle:
rand_index = random.sample(list(range(total)), num_samples)
sampled_tensor: torch.Tensor = tensor[rand_index]
else:
sampled_tensor: torch.Tensor = tensor[:num_samples]
return sampled_tensor
def image(self, img_tensors: torch.Tensor, global_step: int, tag: str = "Train/input",
grid_size: Union[list, tuple] = (3, 1), shuffle=True, save_file=False):
if len(img_tensors.size()) != 4:
raise TypeError("img_tensors rank should be 4, got %d instead" % len(img_tensors.size()))
self._build_dir(os.path.join(self.logdir, "plots", tag))
rows, columns = grid_size[0], grid_size[1]
batch_size = len(img_tensors) # img_tensors =>(batchsize, 3, 256, 256)
num_samples: int = min(batch_size, rows * columns)
sampled_tensor = self._sample(img_tensors, num_samples, shuffle).detach().cpu()
# (sample_num, 3, 32,32) tensors
# sampled_images = map(transforms.Normalize(mean, std), sampled_tensor) # (sample_num, 3, 32,32) images
sampled_images: torch.Tensor = make_grid(sampled_tensor, nrow=rows, normalize=True, scale_each=True)
self.writer.add_image(tag, sampled_images, global_step)
if save_file:
img = transforms.ToPILImage()(sampled_images)
filename = "%s/plots/%s/E%03d.png" % (self.logdir, tag, global_step)
img.save(filename)
def embedding(self, data: torch.Tensor, label_img: torch.Tensor = None, label=None, global_step: int = None,
tag: str = "embedding"):
""" Show PCA, t-SNE of `mat` on tensorboard
:param data: An img tensor with shape of (N, C, H, W)
:param label_img: Label img on each data point.
:param label: Label of each img. It will convert to str.
:param global_step: Img step label.
:param tag: Tag of this plot.
"""
features = data.view(len(data), -1)
self.writer.add_embedding(features, metadata=label, label_img=label_img, global_step=global_step, tag=tag)
def set_training_progress_images(self, img_tensors: torch.Tensor, grid_size: Union[list, tuple] = (3, 1)):
if len(img_tensors.size()) != 4:
raise ValueError("img_tensors rank should be 4, got %d instead" % len(img_tensors.size()))
rows, columns = grid_size[0], grid_size[1]
batch_size = len(img_tensors) # img_tensors =>(batchsize, 3, 256, 256)
num_samples = min(batch_size, rows * columns)
sampled_tensor = self._sample(img_tensors, num_samples, False).detach().cpu() # (sample_num, 3, 32,32) tensors
sampled_images = make_grid(sampled_tensor, nrow=rows, normalize=True, scale_each=True)
img_grid = np.transpose(sampled_images.numpy(), (1, 2, 0))
self.training_progress_images.append(img_grid)
def save_in_gif(self):
import imageio
import warnings
filename = "%s/plots/training.gif" % self.logdir
with warnings.catch_warnings():
warnings.simplefilter("ignore")
imageio.mimsave(filename, self.training_progress_images, duration=self.gif_duration)
self.training_progress_images = None
def graph(self, model: Union[torch.nn.Module, torch.nn.DataParallel, Model], name: str, use_gpu: bool,
*input_shape):
if isinstance(model, torch.nn.Module):
proto_model: torch.nn.Module = model
num_params: int = self._count_params(proto_model)
elif isinstance(model, torch.nn.DataParallel):
proto_model: torch.nn.Module = model.module
num_params: int = self._count_params(proto_model)
elif isinstance(model, Model):
proto_model: torch.nn.Module = model.model
num_params: int = model.num_params
else:
raise TypeError("Only `nn.Module`, `nn.DataParallel` and `Model` can be passed!")
model_logdir = os.path.join(self.logdir, name)
self._build_dir(model_logdir)
writer_for_model = SummaryWriter(log_dir=model_logdir)
input_list = tuple(torch.ones(shape).cuda() if use_gpu else torch.ones(shape) for shape in input_shape)
self.scalars({'ParamsNum': num_params}, 0, tag="ParamsNum")
self.scalars({'ParamsNum': num_params}, 1, tag="ParamsNum")
proto_model(*input_list)
writer_for_model.add_graph(proto_model, input_list)
writer_for_model.close()
def graph_lazy(self, model: Union[torch.nn.Module, torch.nn.DataParallel, Model], name: str):
if isinstance(model, torch.nn.Module):
proto_model: torch.nn.Module = model
num_params: int = self._count_params(proto_model)
elif isinstance(model, torch.nn.DataParallel):
proto_model: torch.nn.Module = model.module
num_params: int = self._count_params(proto_model)
elif isinstance(model, Model):
proto_model: torch.nn.Module = model.model
num_params: int = model.num_params
else:
raise TypeError("Only `nn.Module`, `nn.DataParallel` and `Model` can be passed!, got %s instead" % model)
model_logdir = os.path.join(self.logdir, name)
self._build_dir(model_logdir)
self.scalars({'ParamsNum': num_params}, 0, tag=name)
self.scalars({'ParamsNum': num_params}, 1, tag=name)
def hook(model, layer_input, layer_output):
writer_for_model = SummaryWriter(log_dir=model_logdir)
input_for_test = tuple(i[0].detach().clone().unsqueeze(0) for i in layer_input)
handel.remove()
if isinstance(proto_model, torch.nn.DataParallel):
writer_for_model.add_graph(proto_model.module, input_for_test)
else:
writer_for_model.add_graph(proto_model, input_for_test)
writer_for_model.close()
del writer_for_model
handel = model.register_forward_hook(hook=hook)
def close(self):
# self.writer.export_scalars_to_json("%s/scalers.json" % self.logdir)
if self.training_progress_images:
self.save_in_gif()
self.writer.close()
@staticmethod
def _count_params(proto_model: torch.nn.Module):
"""count the total parameters of model.
:param proto_model: pytorch module
:return: number of parameters
"""
num_params = 0
for param in proto_model.parameters():
num_params += param.numel()
return num_params
@staticmethod
def _build_dir(dirs: str):
if not os.path.exists(dirs):
os.makedirs(dirs)
if __name__ == '__main__':
import torch.nn as nn
test_log = Loger('log')
test_model = nn.Linear(10, 1)
test_opt = Optimizer(test_model.parameters(), "Adam", lr_decay=2, decay_position=[1, 3])
test_log.regist_config(test_opt, epoch=1)
test_opt.do_lr_decay()
test_log.regist_config(test_opt, epoch=2)
test_log.regist_config(test_opt, epoch=3)
test_log.regist_config(test_opt)