# Copyright (c) 2021. TsumiNa. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
from copy import deepcopy
from datetime import datetime, timedelta
from pathlib import Path
from platform import version as sys_ver
from sys import version as py_ver
from typing import Union, Callable, Any, OrderedDict
import numpy as np
import torch
from xenonpy import __version__
from xenonpy.model.training import Trainer, Checker
from xenonpy.model.training.base import BaseExtension, BaseRunner
__all__ = ['Persist']
[docs]class Persist(BaseExtension):
"""
Trainer extension for data persistence
"""
def __init__(self,
path: Union[Path, str] = None,
*,
model_class: Callable = None,
model_params: Union[tuple, dict, any] = None,
increment: bool = False,
sync_training_step: bool = False,
only_best_states: bool = False,
no_model_saving: bool = False,
**describe: Any):
"""
Parameters
----------
path
Path for model saving.
model_class
A factory function for model reconstructing.
In most case this is the model class inherits from :class:`torch.nn.Module`
model_params
The parameters for model reconstructing.
This can be anything but in general this is a dict which can be used as kwargs parameters.
increment
If ``True``, dir name of path will be decorated with a auto increment number,
e.g. use ``model_dir@1`` for ``model_dir``.
sync_training_step
If ``True``, will save ``trainer.training_info`` at each iteration.
Default is ``False``, only save ``trainer.training_info`` at each epoch.
only_best_states
If ``True``, will only save the models with the best states in terms of each of the criteria.
no_model_saving
Indicate whether to save the connected model and checkpoints.
describe:
Any other information to describe this model.
These information will be saved under model dir by name ``describe.pkl.z``.
"""
self._model_class: Callable = model_class
self._model_params: Union[list, dict] = model_params
self.sync_training_step = sync_training_step
self.only_best_states = only_best_states
self._increment = increment
self._describe = describe
self._describe_ = None
self._checker: Union[Checker, None] = None
self._tmp_args: list = []
self._tmp_kwargs: dict = {}
self._epoch_count = 0
self._no_model_saving = no_model_saving
self.path = path
@property
def describe(self):
if self._checker is None:
raise ValueError('can not access property `describe` before training')
return self._checker.describe
@property
def no_model_saving(self):
return self._no_model_saving
@property
def path(self):
if self._checker is None:
raise ValueError('can not access property `path` before training')
return str(self._checker.path)
@path.setter
def path(self, path: Union[Path, str]):
if self._checker is not None:
raise ValueError('can not reset property `path` after training')
self._path = path
@property
def model_structure(self):
return self._checker.model_structure
[docs] def get_checkpoint(self, id_: str = None):
if id_ is not None:
return self._checker.checkpoints[id_]
return self._checker.checkpoints.files
[docs] def __call__(self, handle: Any = None, **kwargs: Any):
if self._checker is None:
raise RuntimeError('calling of this method only after the model training')
self._checker(handle=handle, **kwargs)
def __getitem__(self, item):
return self._checker[item]
[docs] def on_checkpoint(self,
checkpoint: Trainer.checkpoint_tuple,
_trainer: BaseRunner = None,
_is_training: bool = True,
*_dependence: 'BaseExtension') -> None:
if self._no_model_saving:
return None
if self.only_best_states:
tmp = checkpoint.id.split('_')
if tmp[-1] == '1':
key = '_'.join(tmp[:-1])
value = deepcopy(checkpoint._asdict())
self._checker.set_checkpoint(**{key: value})
else:
key = checkpoint.id
value = deepcopy(checkpoint._asdict())
self._checker.set_checkpoint(**{key: value})
[docs] def step_forward(self,
step_info: OrderedDict[Any, int],
trainer: Trainer = None,
_is_training: bool = True,
*_dependence: BaseExtension) -> None:
if self.sync_training_step:
training_info = trainer.training_info
if training_info is not None:
self._checker(training_info=training_info)
else:
epoch = step_info['i_epoch']
if epoch > self._epoch_count:
training_info = trainer.training_info
if training_info is not None:
self._epoch_count = epoch
self._checker(training_info=training_info)
[docs] def before_proc(self, trainer: Trainer = None, _is_training: bool = True, *_dependence: BaseExtension) -> None:
self._checker = Checker(self._path, increment=self._increment)
if self._model_class is not None:
self._checker(model_class=self._model_class)
if self._model_params is not None:
self._checker(model_params=self._model_params)
if not self._no_model_saving:
self._checker.model = trainer.model
self._describe_ = dict(
python=py_ver,
system=sys_ver(),
numpy=np.__version__,
torch=torch.__version__,
xenonpy=__version__,
device=str(trainer.device),
start=datetime.now().strftime('%Y/%m/%d %H:%M:%S'),
finish='N/A',
time_elapsed='N/A',
**self._describe,
)
self._checker(describe=self._describe_)
[docs] def after_proc(self, trainer: Trainer = None, _is_training: bool = True, *_dependence: 'BaseExtension') -> None:
self._describe_.update(finish=datetime.now().strftime('%Y/%m/%d %H:%M:%S'),
time_elapsed=str(timedelta(seconds=trainer.timer.elapsed)))
self._checker.final_state = trainer.model.state_dict()
self._checker(
training_info=trainer.training_info,
describe=self._describe_,
)