#  Copyright (c) 2021. TsumiNa. All rights reserved.
#  Use of this source code is governed by a BSD-style
#  license that can be found in the LICENSE file.

from collections import OrderedDict, namedtuple
from copy import deepcopy
from pathlib import Path
from typing import Union, Tuple, List, Any, Dict, Callable

import numpy as np
import pandas as pd
import torch
from torch.nn import Module
from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau
from import DataLoader

from deprecated import deprecated

from import ClipValue, ClipNorm, Checker
from import BaseOptimizer, BaseLRScheduler, BaseRunner
from xenonpy.utils import camel_to_snake

__all__ = ['Trainer']

[docs]class Trainer(BaseRunner): checkpoint_tuple = namedtuple('checkpoint', 'id iterations model_state') results_tuple = namedtuple('results', 'total_epochs device training_info checkpoints model') def __init__( self, *, loss_func: torch.nn.Module = None, optimizer: BaseOptimizer = None, model: Module = None, lr_scheduler: BaseLRScheduler = None, clip_grad: Union[ClipNorm, ClipValue] = None, epochs: int = 200, cuda: Union[bool, str, torch.device] = False, non_blocking: bool = False, ): """ NN model trainer. Parameters ---------- loss_func Loss function. optimizer Optimizer for model parameters tuning. model Pytorch NN model. lr_scheduler Learning rate scheduler. clip_grad Clip grad before each optimize. epochs Number of iterations. cuda Set training device(s). non_blocking When non_blocking is ``True``, it tries to convert/move asynchronously with respect to the host if possible. """ super().__init__(cuda=cuda) # None able self._clip_grad = clip_grad self._epochs = epochs self._non_blocking = non_blocking # loss function placeholder self._loss_func = None self._loss_type = None # model placeholder self._model = None self._init_states = None # optimizer placeholder self._optim = None self._optimizer = None self._optimizer_state = None # lr_scheduler placeholder self._scheduler = None self._lr_scheduler = None # init private vars self._early_stopping: Tuple[bool, str] = (False, '') self._checkpoints: Dict[Union[int, str], Trainer.checkpoint_tuple] = OrderedDict() self._training_info: List[OrderedDict] = [] self._total_its: int = 0 # of total iterations self._total_epochs: int = 0 # of total epochs self._x_val = None self._y_val = None self._validate_dataset = None # init self.model = model self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.loss_func = loss_func @property def epochs(self): return self._epochs @property def non_blocking(self): return self._non_blocking @property def loss_type(self): return self._loss_type @property def total_epochs(self): return self._total_epochs @property def total_iterations(self): return self._total_its @property def x_val(self): return self._x_val @property def y_val(self): return self._y_val @property def validate_dataset(self): return self._validate_dataset @property def loss_func(self): return self._loss_func @loss_func.setter def loss_func(self, loss_func): if loss_func is not None: self._loss_func = loss_func self._loss_type = 'train_' + camel_to_snake(loss_func.__class__.__name__) @property def training_info(self): if len(self._training_info) > 0: return pd.DataFrame(data=self._training_info) return None @property def device(self): return self._device @device.setter def device(self, v): self._device = self.check_device(v) self.model = None @property def model(self): return self._model @model.setter def model(self, model): if model is not None: if isinstance(model, torch.nn.Module): self.reset(to=model) else: raise TypeError('parameter `m` must be a instance of <torch.nn.modules> but got %s' % type(model)) @property def optimizer(self): return self._optimizer @optimizer.setter def optimizer(self, optimizer): if optimizer is not None: self._optim = optimizer if self._optim is not None and self._model is not None: self._optimizer = self._optim(self._model.parameters()) self._optimizer_state = deepcopy(self._optimizer.state_dict()) self.lr_scheduler = None @property def lr_scheduler(self): return self._lr_scheduler @lr_scheduler.setter def lr_scheduler(self, scheduler): if scheduler is not None: self._scheduler = scheduler if self._scheduler is not None and self._optimizer is not None: self._lr_scheduler: Union[_LRScheduler, None] = self._scheduler(self._optimizer) @property def clip_grad(self): return self._clip_grad @clip_grad.setter def clip_grad(self, fn): self._clip_grad = fn @property def checkpoints(self): return self._checkpoints
[docs] def get_checkpoint(self, checkpoint: Union[int, str] = None): if checkpoint is None: return list(self._checkpoints.keys()) if isinstance(checkpoint, int): id_ = f'cp_{checkpoint}' return self._checkpoints[id_] if isinstance(checkpoint, str): return self._checkpoints[checkpoint] raise TypeError(f'parameter <cp> must be str or int but got {checkpoint.__class__}')
[docs] def set_checkpoint(self, id_: str = None): if id_ is None: id_ = f'cp_{self._total_its}' cp = self.checkpoint_tuple( id=id_, iterations=self._total_its, model_state=deepcopy(self._model.state_dict()), ) self._checkpoints[id_] = cp self._on_checkpoint(checkpoint=cp, trainer=self, is_training=True)
[docs] def early_stop(self, msg: str): self._early_stopping = (True, msg)
[docs] def reset(self, *, to: Union[Module, int, str] = None, remove_checkpoints: bool = True): """ Reset trainer. This will reset all trainer states and drop all training step information. Parameters ---------- to: Union[bool, Module] Bind trainer to the given model or reset current model to it's initialization states. remove_checkpoints Set to ``True`` to remove all checkpoints when resetting. Default ``true``. """ self._training_info = [] self._total_its = 0 self._total_epochs = 0 self._early_stopping = (False, '') if isinstance(to, Module): self._model =, non_blocking=self._non_blocking) self._init_states = deepcopy(to.state_dict()) self.optimizer = None self.lr_scheduler = None elif isinstance(to, (int, str)): cp = self.get_checkpoint(to) self._model.load_state_dict(cp.model_state) elif to is None: self._model.load_state_dict(self._init_states) self._optimizer.load_state_dict(self._optimizer_state) else: raise TypeError(f'parameter <to> must be torch.nnModule, int, or str but got {type(to)}') if remove_checkpoints: self._checkpoints = OrderedDict() self._on_reset(trainer=self, is_training=True)
[docs] def fit(self, x_train: Union[Any, Tuple[Any]] = None, y_train: Any = None, x_val: Union[Any, Tuple[Any]] = None, y_val: Any = None, *, training_dataset: DataLoader = None, validation_dataset: DataLoader = None, epochs: int = None, checkpoint: Union[bool, int, Callable[[int], Tuple[bool, str]]] = None, progress_bar: Union[str, None] = 'auto', **model_params): """ Train the Neural Network model Parameters ---------- x_train: Union[torch.Tensor, Tuple[torch.Tensor]] Training data. Will be ignored will``training_dataset`` is given. y_train: torch.Tensor Test data. Will be ignored will``training_dataset`` is given. training_dataset: DataLoader Torch DataLoader. If given, will only use this as training dataset. When loop over this dataset, it should yield a tuple contains ``x_train`` and ``y_train`` in order. x_val : Union[Any, Tuple[Any]] Data for validation. y_val : Any Data for validation. validation_dataset: DataLoader epochs : int Epochs. If not ``None``, it will overwrite ``self.epochs`` temporarily. checkpoint: Union[bool, int, Callable[[int], bool]] If ``True``, will save model states at each step. If ``int``, will save model states every `checkpoint` steps. If ``Callable``, the function should take current ``total_epochs`` as input return ``bool``. progress_bar Show progress bar when for training steps. Can be 'auto', 'console', or ``None``. See for more details. model_params: dict Other model parameters. """ if epochs is None: epochs = self._epochs prob = self._total_epochs if progress_bar is not None: if progress_bar == 'auto': from import tqdm else: from tqdm import tqdm with tqdm(total=epochs, desc='Training') as pbar: for _ in self(x_train=x_train, y_train=y_train, x_val=x_val, y_val=y_val, training_dataset=training_dataset, validation_dataset=validation_dataset, epochs=epochs, checkpoint=checkpoint, **model_params): delta = self._total_epochs - prob if delta: prob = self._total_epochs pbar.update(delta) else: for _ in self(x_train=x_train, y_train=y_train, x_val=x_val, y_val=y_val, training_dataset=training_dataset, validation_dataset=validation_dataset, epochs=epochs, checkpoint=checkpoint, **model_params): pass
[docs] def __call__(self, x_train: Union[torch.Tensor, Tuple[torch.Tensor]] = None, y_train: Any = None, x_val: Union[Any, Tuple[Any]] = None, y_val: Any = None, *, training_dataset: DataLoader = None, validation_dataset: DataLoader = None, epochs: int = None, checkpoint: Union[bool, int, Callable[[int], bool]] = None, **model_params): """ Train the Neural Network model Parameters ---------- x_train Training data. Will be ignored will ``training_dataset`` is given. y_train Test data. Will be ignored will ``training_dataset`` is given. training_dataset: DataLoader Torch DataLoader. If given, will only use this as training dataset. When loop over this dataset, it should yield a tuple contains ``x_train`` and ``y_train`` in order. x_val : Union[Any, Tuple[Any]] Data for validation. y_val : Any Data for validation. validation_dataset : DataLoader epochs : int Epochs. If not ``None``, it will overwrite ``self.epochs`` temporarily. checkpoint: Union[bool, int, Callable[[int], bool]] If ``True``, will save model states at each step. If ``int``, will save model states every `checkpoint` steps. If ``Callable``, the function should take current ``total_epochs`` as input return ``bool``. model_params: dict Other model parameters. Yields ------ namedtuple """ if self._model is None: raise RuntimeError( 'no model for training, use `trainer.model = <model>` or `trainer.reset(to=<model>)` to set one') if self._loss_func is None: raise RuntimeError('no loss function for training, use `trainer.loss_func = <loss_func>` to set one') if self._optimizer is None: raise RuntimeError('no optimizer for training, use `trainer.optimizer = <optimizer>` to set one') if epochs is None: epochs = self._epochs if training_dataset is not None: if y_train is not None or x_train is not None: raise RuntimeError('parameter <training_dataset> is exclusive of <x_train> and <y_train>') else: if y_train is None or x_train is None: raise RuntimeError('missing parameter <x_train> or <y_train>') # training step def _step(x_, y_, i_b=0): def closure(): self._optimizer.zero_grad() y_p_ = self._model(*x_, **model_params) y_p_, y_t_ = self.output_proc(y_p_, y_, trainer=self, is_training=True) loss_ = self._loss_func(y_p_, y_t_) loss_.backward() if self._clip_grad is not None: self._clip_grad(self._model.parameters()) return loss_ # make sure running in training mode if not self._model.train() train_loss = self._optimizer.step(closure).item() step_info = OrderedDict( total_iters=self._total_its, i_epoch=self._total_epochs, i_batch=i_b + 1, ) step_info[self._loss_type] = train_loss self._total_its += 1 self._step_forward(step_info=step_info, trainer=self, is_training=True) self._training_info.append(step_info) if self._lr_scheduler is not None: if isinstance(self._lr_scheduler, ReduceLROnPlateau): self._lr_scheduler.step(train_loss) else: self._lr_scheduler.step() return step_info def _snapshot(): if checkpoint is not None: if isinstance(checkpoint, bool) and checkpoint: self.set_checkpoint() if isinstance(checkpoint, int): if self._total_epochs % checkpoint == 0: self.set_checkpoint() if callable(checkpoint): flag, msg = checkpoint(self._total_epochs) if flag: self.set_checkpoint(msg) if validation_dataset is not None: if y_val is not None or x_val is not None: raise RuntimeError('parameter <validation_dataset> is exclusive of <x_val> and <y_val>') else: self._validate_dataset = validation_dataset else: if y_val is not None and x_val is not None: self._x_val, self._y_val = self.input_proc(x_val, y_val, trainer=self, is_training=False) # before processing self._before_proc(trainer=self, is_training=True) if training_dataset: for i_epoch in range(self._total_epochs, epochs + self._total_epochs): self._total_epochs += 1 for i_batch, (x_train, y_train) in enumerate(training_dataset): x_train, y_train = self.input_proc(x_train, y_train, trainer=self, is_training=True) if not isinstance(x_train, tuple): x_train = (x_train,) yield _step(x_train, y_train, i_batch) if self._early_stopping[0]: print(f'Early stopping is applied: {self._early_stopping[1]}') self._after_proc(trainer=self, is_training=True) self._model.eval() return _snapshot() else: x_train, y_train = self.input_proc(x_train, y_train, trainer=self, is_training=True) if not isinstance(x_train, tuple): x_train = (x_train,) for i_epoch in range(self._total_epochs, epochs + self._total_epochs): self._total_epochs += 1 yield _step(x_train, y_train) if self._early_stopping[0]: print(f'Early stopping is applied: {self._early_stopping[1]}.') self._after_proc(trainer=self, is_training=True) self._model.eval() return _snapshot() # after processing self._after_proc(trainer=self, is_training=True) self._model.eval()
[docs] @classmethod @deprecated(reason="will be removed in v1.0.0, please use `Trainer.from_checker` instead") def load( cls, from_: Union[str, Path, Checker], *, loss_func: torch.nn.Module = None, optimizer: BaseOptimizer = None, lr_scheduler: BaseLRScheduler = None, clip_grad: Union[ClipNorm, ClipValue] = None, epochs: int = 200, cuda: Union[bool, str, torch.device] = False, non_blocking: bool = False, ) -> 'Trainer': return cls.from_checker(from_, loss_func=loss_func, optimizer=optimizer, lr_scheduler=lr_scheduler, clip_grad=clip_grad, epochs=epochs, cuda=cuda, non_blocking=non_blocking)
[docs] @classmethod def from_checker( cls, checker: Union[str, Path, Checker], *, loss_func: torch.nn.Module = None, optimizer: BaseOptimizer = None, lr_scheduler: BaseLRScheduler = None, clip_grad: Union[ClipNorm, ClipValue] = None, epochs: int = 200, cuda: Union[bool, str, torch.device] = False, non_blocking: bool = False, ) -> 'Trainer': """ Load model for local path or :class:``. Parameters ---------- checker Path to the model dir or :class:`` object. loss_func Loss function. optimizer Optimizer for model parameters tuning. lr_scheduler Learning rate scheduler. clip_grad Clip grad before each optimize. epochs Number of iterations. cuda Set training device(s). non_blocking When non_blocking is ``True``, it tries to convert/move asynchronously with respect to the host if possible. Returns ------- """ if isinstance(checker, (str, Path)): checker = Checker(checker) else: checker = checker if len(checker.files) == 0: raise RuntimeError(f'{checker.path} is not a model dir') tmp = cls(model=checker.model, cuda=cuda, loss_func=loss_func, optimizer=optimizer, lr_scheduler=lr_scheduler, clip_grad=clip_grad, epochs=epochs, non_blocking=non_blocking) tmp._training_info = checker.training_info.to_dict(orient='records') if Path(checker.path + '/checkpoints').is_dir(): for k in checker.checkpoints.files: tmp._checkpoints[k] = cls.checkpoint_tuple(**checker.checkpoints[k]) return tmp
[docs] def predict(self, x_in: Union[Any, Tuple[Any]] = None, y_true: Union[Any, Tuple[Any]] = None, *, dataset: DataLoader = None, checkpoint: Union[int, str] = None, **model_params) -> Union[Tuple[np.ndarray], np.ndarray]: """ Predict from x input. This is just a simple wrapper for :meth:`~model.nn.utils.Predictor.predict`. Parameters ---------- x_in Input data for prediction. y_true True values. dataset Pytorch :class:``. checkpoint Reset to checkpoint if given. model_params: dict Model parameters for prediction. Returns ------- ret Predict results. If ``y_true`` is not ``None``, ``y_true`` will be returned in second. """ def _predict(x_, y_=None): x_, y_ = self.input_proc(x_, y_, trainer=self, is_training=False) if not isinstance(x_, tuple): x_ = (x_,) if checkpoint: cp = self.get_checkpoint(checkpoint) model = deepcopy(self._model).to(self.device) model.load_state_dict(cp.model_state) y_p_ = model(*x_, **model_params) else: y_p_ = self._model(*x_, **model_params) return self.output_proc(y_p_, y_, trainer=self, is_training=False) def _vstack(ls): if isinstance(ls[0], np.ndarray): return np.concatenate(ls) if isinstance(ls[0], torch.Tensor): return, dim=0) return ls # maker sure eval mode self._model.eval() if x_in is None and y_true is None and dataset is not None: y_preds = [] y_trues = [] for x_in, y_true in dataset: y_pred, y_true = _predict(x_in, y_true) y_preds.append(y_pred) y_trues.append(y_true) return _vstack(y_preds), _vstack(y_trues) elif x_in is not None and dataset is None: y_preds, y_trues = _predict(x_in, y_true) if y_trues is None: return y_preds return y_preds, y_trues else: raise RuntimeError('parameters <x_in> and <dataset> are mutually exclusive')
[docs] def to_namedtuple(self): return self.results_tuple(total_epochs=self.total_epochs, device=self.device, training_info=self.training_info, checkpoints={k: deepcopy(v.model_state) for k, v in self._checkpoints.items()}, model=deepcopy(self._model.cpu()))