# Copyright (c) 2021. yoshida-lab. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
from collections import OrderedDict
from collections import defaultdict
from pathlib import Path
from typing import Any, Union, Dict, Callable, Tuple
import joblib
import pandas as pd
import torch
from deprecated import deprecated
from torch.nn import Module
from xenonpy.model.training.base import BaseRunner
__all__ = ['Checker']
[docs]class Checker(object):
"""
Check point.
"""
class __SL:
dump = torch.save
load = torch.load
def __init__(
self,
path: Union[Path, str] = None,
*,
increment: bool = False,
device: Union[bool, str, torch.device] = 'cpu',
default_handle: Tuple[Callable, str] = (joblib, '.pkl.z'),
):
"""
Parameters
----------
path: Union[Path, str]
Dir path for data access. Can be ``Path`` or ``str``.
Given a relative path will be resolved to abstract path automatically.
increment : bool
Set to ``True`` to prevent the potential risk of overwriting.
Default ``False``.
"""
if path is None:
path = Path().cwd().name
self._path = Path.cwd() / path
else:
self._path = Path.cwd() / path
if increment:
i = 1
while Path(f'{path}@{i}').exists():
i += 1
self._path = Path.cwd() / f'{path}@{i}'
self._path.mkdir(parents=True, exist_ok=True)
# self._path = self._path.resolve()
self._device = BaseRunner.check_device(device)
self._handle = default_handle
self._files: Dict[str, str] = defaultdict(str)
self._make_file_index()
[docs] @classmethod
@deprecated('This method is rotten and will be removed in v1.0.0, use `Checker(<model path>)` instead')
def load(cls, model_path):
return cls(model_path)
@property
def path(self):
return str(self._path)
@property
def files(self):
return list(self._files.keys())
@property
def model_name(self):
"""
Returns
-------
str
Model name.
"""
return self._path.name
@property
def model_structure(self):
structure = self['model_structure']
return structure
@property
def training_info(self):
return self['training_info']
@property
def describe(self):
"""
Description for this model.
Returns
-------
dict
Description.
"""
return self['describe']
@property
def model(self):
"""
Returns
-------
model: :class:`torch.nn.Module`
A pytorch model.
"""
if (self._path / 'model.pth.m').exists():
model = torch.load(str(self._path / 'model.pth.m'), map_location=self._device)
state = self.final_state
if state is not None:
try:
model.load_state_dict(state)
except torch.nn.modules.module.ModuleAttributeError:
# pytorch 1.6.0 compatability
for _, m in model.named_modules():
m._non_persistent_buffers_set = set()
model.load_state_dict(state)
return model
else:
return model
return None
@model.setter
def model(self, model: Module):
"""
Set a model instance.
Parameters
----------
model: :class:`torch.nn.Module`
Pytorch model instance.
"""
if isinstance(model, Module):
self(model=model)
self.init_state = model.state_dict()
self(model_structure=str(model))
else:
raise TypeError(f'except `torch.nn.Module` object but got {type(model)}')
@property
@deprecated('This property is rotten and will be removed in v1.0.0, use `checker.model` instead')
def trained_model(self):
if (self._path / 'trained_model.@1.pkl.z').exists():
return torch.load(str(self._path / 'trained_model.@1.pkl.z'), map_location=self._device)
else:
tmp = self.final_state
if tmp is not None:
model: torch.nn.Module = self.model
model.load_state_dict(tmp)
return model
return None
@property
def model_class(self):
if (self._path / 'model_class.pkl.z').exists():
return self['model_class']
return None
@property
def model_params(self):
if (self._path / 'model_params.pkl.z').exists():
return self['model_params']
return None
@property
def init_state(self):
if (self._path / 'init_state.pth.s').exists():
return torch.load(str(self._path / 'init_state.pth.s'), map_location=self._device)
return None
@init_state.setter
def init_state(self, state: OrderedDict):
if not isinstance(state, OrderedDict) or not state:
raise TypeError
for v in state.values():
if not isinstance(v, torch.Tensor):
raise TypeError()
self((Checker.__SL, '.pth.s'), init_state=state)
@property
def final_state(self):
if (self._path / 'final_state.pth.s').exists():
return torch.load(str(self._path / 'final_state.pth.s'), map_location=self._device)
return None
@final_state.setter
def final_state(self, state: OrderedDict):
if not isinstance(state, OrderedDict) or not state:
raise TypeError
for v in state.values():
if not isinstance(v, torch.Tensor):
raise TypeError()
self((Checker.__SL, '.pth.s'), final_state=state)
def _make_file_index(self):
for f in [f for f in self._path.iterdir() if f.match('*.pkl.*') or f.match('*.pd.*') or f.match('*.pth.*')]:
# select data
fn = '.'.join(f.name.split('.')[:-2])
self._files[fn] = str(f)
def _save_data(self, data: Any, filename: str, handle) -> str:
if isinstance(data, pd.DataFrame):
file = str(self._path / (filename + '.pd.xz'))
self._files[filename] = file
pd.to_pickle(data, file)
elif isinstance(data, (torch.Tensor, torch.nn.Module)):
file = str(self._path / (filename + '.pth.m'))
self._files[filename] = file
torch.save(data, file)
else:
file = str(self._path / (filename + handle[1]))
self._files[filename] = file
handle[0].dump(data, file)
return file
def _load_data(self, file: str, handle):
fp = self._files[file]
if fp == '':
return None
fp_ = Path(fp)
if not fp_.exists():
return None
patten = fp_.name.split('.')[-2]
if patten == 'pd':
return pd.read_pickle(fp)
if patten == 'pth':
return torch.load(fp, map_location=self._device)
if patten == 'pkl':
return joblib.load(fp)
else:
return handle.load(fp)
def __getattr__(self, name: str):
"""
Return sub-dataset.
Parameters
----------
name: str
Dataset name.
Returns
-------
self
"""
if name == 'checkpoints':
sub_set = self.__class__(self._path / name, increment=False, device=self._device)
setattr(self, f'{name}', sub_set)
return sub_set
raise AttributeError(f'no such attribute named {name}')
def __getitem__(self, item):
if isinstance(item, str):
return self._load_data(item, self._handle[0])
else:
raise KeyError(f'{item}')
[docs] def __call__(self, handle=None, **named_data: Any):
"""
Save data with or without name.
Data with same name will not be overwritten.
Parameters
----------
handle: Tuple[Callable, str]
named_data: dict
Named data as k,v pair.
"""
if handle is None:
handle = self._handle
for k, v in named_data.items():
self._save_data(v, k, handle)
[docs] def set_checkpoint(self, **kwargs):
self.checkpoints((Checker.__SL, '.pth.s'), **kwargs)
def __repr__(self):
cont_ls = ['<{}> includes:'.format(self.__class__.__name__)]
for k, v in self._files.items():
cont_ls.append('"{}": {}'.format(k, v))
return '\n'.join(cont_ls)