# coding=utf-8
import torch
import os
from torch.nn import init, Conv2d, Linear, ConvTranspose2d, InstanceNorm2d, BatchNorm2d, DataParallel, Module
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch import save, load
from typing import Union
from collections import OrderedDict
from types import FunctionType
class _cached_property(object):
"""
Decorator that converts a method with a single self argument into a
property cached on the instance.
Optional ``name`` argument allows you to make cached properties of other
methods. (e.g. url = _cached_property(get_absolute_url, name='url') )
"""
def __init__(self, func, name=None):
self.func = func
self.__doc__ = getattr(func, '__doc__')
self.name = name or func.__name__
def __get__(self, instance, cls=None):
if instance is None:
return self
res = instance.__dict__[self.name] = self.func(instance)
return res
[docs]class Model(object):
r"""A warapper of pytorch ``module`` .
In the simplest case, we use a raw pytorch ``module`` to assemble a ``Model`` of this class.
It can be more convenient to use some feather method, such ``_check_point`` , ``load_weights`` and so on.
* :attr:`proto_model` is the core model in this class.
It is no necessary to passing a ``module`` when you init a ``Model`` .
You can build a model later by using ``Model.define(module)`` or load a model from a file.
* :attr:`gpu_ids_abs` controls the gpus which you want to use. you should use a absolute id of gpus.
* :attr:`init_method` controls the weights init method.
* At init_method="xavier", it will use ``init.xavier_normal_`` ,
in ``pytorch.nn.init`` , to init the Conv layers of model.
* At init_method="kaiming", it will use ``init.kaiming_normal_`` ,
in ``pytorch.nn.init`` , to init the Conv layers of model.
* At init_method=your_own_method, it will be used on weights,
just like what ``pytorch.nn.init`` method does.
* :attr:`show_structure` controls whether to show your network structure.
.. note::
Don't try to pass a ``DataParallel`` model. Only ``module`` is accessible.
It will change to ``DataParallel`` class automatically by passing a muti-gpus ids, like ``[0, 1]`` .
.. note::
:attr:`gpu_ids_abs` must be a tuple or list. If you want to use cpu, just passing an ampty list like ``[]`` .
Args:
proto_model (module): A pytroch module. Default: ``None``
gpu_ids_abs (tuple or list): The absolute id of gpus. if [] using cpu. Default: ``()``
init_method (str or def): Weights init method. Default: ``"Kaiming"``
show_structure (bool): Is the structure shown. Default: ``False``
Attributes:
num_params (int): The totals amount of weights in this model.
gpu_ids_abs (list or tuple): Which device is this model on.
Examples::
>>> from torch.nn import Sequential, Conv3d
>>> # using a square kernels and equal stride
>>> module = Sequential(Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)))
>>> # using cpu to init a Model by module.
>>> net = Model(module, [], show_structure=False)
Sequential Total number of parameters: 15873
Sequential model use CPU!
apply kaiming weight init!
>>> input_tensor = torch.randn(20, 16, 10, 50, 100)
>>> output = net(input_tensor)
"""
def __init__(self, proto_model: Module,
gpu_ids_abs: Union[list, tuple] = (),
init_method: Union[str, FunctionType, None] = "kaiming",
show_structure=False,
check_point_pos=None, verbose=True):
# if not isinstance(proto_model, Module):
# raise TypeError(
# "The type of `proto_model` must be `torch.nn.Module`, but got %s instead" % type(proto_model))
self.model: Union[DataParallel, Module] = None
self.model_name = proto_model.__class__.__name__
self.weights_init = None
self.init_fc = None
self.init_name: str = None
self.num_params: int = 0
self.verbose = verbose
self.check_point_pos = check_point_pos
self.define(proto_model, gpu_ids_abs, init_method, show_structure)
def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)
def __getattr__(self, item):
return getattr(self.model, item)
[docs] def define(self, proto_model: Module, gpu_ids_abs: Union[list, tuple], init_method: Union[str, FunctionType, None],
show_structure: bool):
"""Define and wrap a pytorch module, according to CPU, GPU and multi-GPUs.
* Print the module's info.
* Move this module to specify device.
* Apply weight init method.
:param proto_model: Network, type of ``module``.
:param gpu_ids_abs: Be used GPUs' id, type of ``tuple`` or ``list``. If not use GPU, pass ``()``.
:param init_method: init weights method("kaiming") or ``False`` don't use any init.
:param show_structure: If print structure of model.
"""
self.num_params = self.print_network(proto_model, show_structure)
self.model = self._set_device(proto_model, gpu_ids_abs)
self.init_name = self._apply_weight_init(init_method, self.model)
self._print("apply %s weight init!" % self.init_name)
[docs] def print_network(self, proto_model: Module, show_structure=False):
"""Print total number of parameters and structure of network
:param proto_model: Pytorch module
:param show_structure: If show network's structure. default: ``False``
:return: Total number of parameters
"""
num_params = self.count_params(proto_model)
if show_structure:
self._print(str(proto_model))
num_params_log = '%s Total number of parameters: %d' % (self.model_name, num_params)
self._print(num_params_log)
return num_params
[docs] def load_weights(self, weights: Union[OrderedDict, dict, str], strict=True):
"""Assemble a model and weights from paths or passing parameters.
You can load a model from a file, passing parameters or both.
:param weights: Pytorch weights or weights file path.
:param strict: The same function in pytorch ``model.load_state_dict(weights,strict = strict)`` .
default:``True``
:return: ``module``
Example::
>>> from torchvision.models.resnet import resnet18
>>> model = Model(resnet18())
ResNet Total number of parameters: 11689512
ResNet model use CPU!
apply kaiming weight init!
>>> model.save_weights("model.pth",)
try to remove 'module.' in keys of weights dict...
>>> model.load_weights("model.pth", True)
Try to remove `moudle.` to keys of weights dict
"""
if isinstance(weights, str):
weights = load(weights, map_location=lambda storage, loc: storage)
else:
raise TypeError("`weights` must be a `dict` or a path of weights file.")
if isinstance(self.model, DataParallel):
self._print("Try to add `moudle.` to keys of weights dict")
weights = self._fix_weights(weights, "add", False)
else:
self._print("Try to remove `moudle.` to keys of weights dict")
weights = self._fix_weights(weights, "remove", False)
self.model.load_state_dict(weights, strict=strict)
[docs] def save_weights(self, weights_path: str, fix_weights=True):
"""Save a model and weights to files.
You can save a model, weights or both to file.
.. note::
This method deal well with different devices on model saving.
You don' need to care about which devices your model have saved.
:param weights_path: Pytorch weights or weights file path.
:param fix_weights: If this is true, it will remove the '.module' in keys, when you save a ``DataParallel``.
without any moving operation. Otherwise, it will move to cpu, especially in ``DataParallel``.
default:``False``
Example::
>>> from torch.nn import Linear
>>> model = Model(Linear(10,1))
Linear Total number of parameters: 11
Linear model use CPU!
apply kaiming weight init!
>>> model.save_weights("weights.pth")
try to remove 'module.' in keys of weights dict...
>>> model.load_weights("weights.pth")
Try to remove `moudle.` to keys of weights dict
"""
if fix_weights:
import copy
weights = copy.deepcopy(self.model.state_dict())
self._print("try to remove 'module.' in keys of weights dict...")
weights = self._fix_weights(weights, "remove", False)
else:
weights = self.model.state_dict()
save(weights, weights_path)
[docs] def load_point(self, model_name: str, epoch: int, logdir="log"):
"""load model and weights from a certain checkpoint.
this method is cooperate with method `self.chechPoint()`
"""
if not logdir.endswith("checkpoint"):
logdir = os.path.join(logdir, "checkpoint")
model_weights_path = os.path.join(logdir, "Weights_%s_%d.pth" % (model_name, epoch))
self.load_weights(model_weights_path, True)
def check_point(self, model_name: str, epoch: int, logdir="log"):
if not logdir.endswith("checkpoint"):
logdir = os.path.join(logdir, "checkpoint")
if not os.path.exists(logdir):
os.makedirs(logdir)
model_weights_path = os.path.join(logdir, "Weights_%s_%d.pth" % (model_name, epoch))
weights = self._fix_weights(self.model.state_dict(), "remove", False) # try to remove '.module' in keys.
save(weights, model_weights_path)
def is_checkpoint(self, model_name: str, epoch: int, logdir="log"):
if not self.check_point_pos:
return False
if isinstance(self.check_point_pos, int):
is_check_point = epoch > 0 and (epoch % self.check_point_pos) == 0
else:
is_check_point = epoch in self.check_point_pos
if is_check_point:
self.check_point(model_name, epoch, logdir)
return is_check_point
[docs] def convert_to_distributed(self, device_ids=None,
output_device=None, dim=0, broadcast_buffers=True,
process_group=None, bucket_cap_mb=25,
find_unused_parameters=False,
check_reduction=False):
"""
Args:
module (Module): module to be parallelized
device_ids (list of int or torch.device): CUDA devices. This should
only be provided when the input module resides on a single
CUDA device. For single-device modules, the ``i``th
:attr:`module` replica is placed on ``device_ids[i]``. For
multi-device modules and CPU modules, device_ids must be None
or an empty list, and input data for the forward pass must be
placed on the correct device. (default: all devices for
single-device modules)
output_device (int or torch.device): device location of output for
single-device CUDA modules. For multi-device modules and
CPU modules, it must be None, and the module itself
dictates the output location. (default: device_ids[0] for
single-device modules)
broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of
the module at beginning of the forward function.
(default: ``True``)
process_group: the process group to be used for distributed data
all-reduction. If ``None``, the default process group, which
is created by ```torch.distributed.init_process_group```,
will be used. (default: ``None``)
bucket_cap_mb: DistributedDataParallel will bucket parameters into
multiple buckets so that gradient reduction of each
bucket can potentially overlap with backward computation.
:attr:`bucket_cap_mb` controls the bucket size in MegaBytes (MB)
(default: 25)
find_unused_parameters (bool): Traverse the autograd graph of all tensors
contained in the return value of the wrapped
module's ``forward`` function.
Parameters that don't receive gradients as
part of this graph are preemptively marked
as being ready to be reduced.
(default: ``False``)
check_reduction: when setting to ``True``, it enables DistributedDataParallel
to automatically check if the previous iteration's
backward reductions were successfully issued at the
beginning of every iteration's forward function.
You normally don't need this option enabled unless you
are observing weird behaviors such as different ranks
are getting different gradients, which should not
happen if DistributedDataParallel is correctly used.
(default: ``False``)
Attributes:
module (Module): the module to be parallelized
Example::
>>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
>>> net.convert_to_distributed(pg)
>>> # same thing
>>> net.model = torch.nn.DistributedDataParallel(net.model, pg)
"""
# assert isinstance(self.model, DataParallel), "please only use one gpu for one task"
self.model = DistributedDataParallel(self.model, device_ids,
output_device, dim, broadcast_buffers,
process_group, bucket_cap_mb,
find_unused_parameters,
check_reduction)
[docs] @staticmethod
def count_params(proto_model: Module):
"""count the total parameters of model.
:param proto_model: pytorch module
:return: number of parameters
"""
num_params = 0
for param in proto_model.parameters():
num_params += param.numel()
return num_params
def _apply_weight_init(self, init_method: Union[str, FunctionType], proto_model: Module):
init_name = "No"
if init_method:
if init_method == 'kaiming':
self.init_fc = getattr(init, "kaiming_normal_")
init_name = init_method
elif init_method == "xavier":
self.init_fc = getattr(init, "xavier_normal_")
init_name = init_method
else:
self.init_fc = getattr(init, init_method)
init_name = init_method.__name__
proto_model.apply(self._weight_init)
return init_name
def _weight_init(self, m):
if (m is None) or (not hasattr(m, "weight")):
return
if (m.bias is not None) and hasattr(m, "bias"):
m.bias.data.zero_()
if isinstance(m, Conv2d):
self.init_fc(m.weight)
# m.bias.data.zero_()
elif isinstance(m, Linear):
self.init_fc(m.weight)
# m.bias.data.zero_()
elif isinstance(m, ConvTranspose2d):
self.init_fc(m.weight)
# m.bias.data.zero_()
elif isinstance(m, InstanceNorm2d):
init.normal_(m.weight, 1.0, 0.02)
# m.bias.data.fill_(0)
elif isinstance(m, BatchNorm2d):
init.normal_(m.weight, 1.0, 0.02)
# m.bias.data.fill_(0)
else:
pass
@staticmethod
def _fix_weights(weights: Union[dict, OrderedDict], fix_type: str = "remove", is_strict=True):
# fix params' key
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in weights.items():
if fix_type == "remove":
if is_strict and not k.startswith("module."):
raise ValueError("The key of weights dict doesn't start with 'module.'. %s instead" % k)
name = k.replace("module.", "", 1) # remove `module.`
elif fix_type == "add":
if is_strict and k.startswith("module."):
raise ValueError("The key of weights dict is %s. Can not add 'module.'" % k)
if not k.startswith("module."):
name = "module." + k # add `module.`
else:
name = k
else:
raise TypeError("`fix_type` should be 'remove' or 'add'.")
new_state_dict[name] = v
return new_state_dict
def _set_device(self, proto_model: Module, gpu_ids_abs: list) -> Union[Module, DataParallel]:
if not gpu_ids_abs:
gpu_ids_abs = []
# old_enviroment = os.environ["CUDA_VISIBLE_DEVICES"]
# os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in gpu_ids_abs])
# gpu_ids = [i for i in range(len(gpu_ids_abs))]
gpu_available = torch.cuda.is_available()
model_name = proto_model.__class__.__name__
if len(gpu_ids_abs) == 1:
if not gpu_available:
raise EnvironmentError("No gpu available! torch.cuda.is_available() is False. "
"CUDA_VISIBLE_DEVICES=%s" % \
os.environ["CUDA_VISIBLE_DEVICES"])
proto_model = proto_model.cuda(gpu_ids_abs[0])
self._print("%s model use GPU %s!" % (model_name, gpu_ids_abs))
elif len(gpu_ids_abs) > 1:
if not gpu_available:
raise EnvironmentError("No gpu available! torch.cuda.is_available() is False. "
"CUDA_VISIBLE_DEVICES=%s" % \
os.environ["CUDA_VISIBLE_DEVICES"])
proto_model = DataParallel(proto_model.cuda(gpu_ids_abs[0]), gpu_ids_abs)
self._print("%s dataParallel use GPUs%s!" % (model_name, gpu_ids_abs))
else:
self._print("%s model use CPU!" % model_name)
return proto_model
def _print(self, str_msg: str):
if self.verbose:
print(str_msg)
@property
def configure(self):
config_dic = dict()
if isinstance(self.model, DataParallel):
config_dic["model_name"] = str(self.model.module.__class__.__name__)
elif isinstance(self.model, Module):
config_dic["model_name"] = str(self.model.__class__.__name__)
else:
raise TypeError("Type of `self.model` is wrong!")
config_dic["init_method"] = str(self.init_name)
config_dic["total_params"] = self.num_params
config_dic["structure"] = str(self.model)
return config_dic
if __name__ == '__main__':
from torch.nn import Sequential
mode = Sequential(Conv2d(10, 1, 3, 1, 0))
net = Model(mode, [], "kaiming", show_structure=False)
if torch.cuda.is_available():
net = Model(mode, [0], "kaiming", show_structure=False)
if torch.cuda.device_count() > 1:
net = Model(mode, [0, 1], "kaiming", show_structure=False)
if torch.cuda.device_count() > 2:
net = Model(mode, [2, 3], "kaiming", show_structure=False)
net1 = Model(mode, [], "kaiming", show_structure=False)