# Copyright (c) 2021. TsumiNa. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
import numpy as np
from collections import OrderedDict
from typing import Callable, Any, Dict, Union
from xenonpy.model.utils import regression_metrics, classification_metrics
from xenonpy.model.training import Trainer
from xenonpy.model.training.base import BaseExtension
__all__ = ['Validator']
[docs]class Validator(BaseExtension):
"""
Validator extension
"""
regress = 'regress'
classify = 'classify'
def __init__(self,
metrics_func: Union[str, Callable[[Any, Any], Dict]],
*,
each_iteration: bool = True,
early_stopping: int = None,
trace_order: int = 1,
warming_up: int = 0,
**trace_criteria: Dict[str, float]):
"""
Parameters
----------
metrics_func
Function for metrics calculation.
If ``str``, should be ``regress`` or ``classify`` which to specify the training types.
See :py:func:`xenonpy.model.utils.classification_metrics` and
:py:func:`xenonpy.model.utils.regression_metrics` to know the details.
you can also give the calculation function yourself.
each_iteration
If ``True``, validation will be executed every iteration.
Otherwise, only validate at each epoch done.
Default ``True``.
early_stopping
Set patience condition of early stopping condition.
This condition trace criteria setting by ``trace_criteria`` parameter.
trace_order
How many ranks of ``trace_criteria`` will be saved as checkpoint.
Checkpoint name follow the format ``criterion_rank``, e.g. ``mae_1``
warming_up
Validator do not set/update checkpoint before ``warming_up`` epochs.
trace_criteria
Validation criteria.
Should follow this formation: ``criterion=target``, e.g ``mae=0, corr=1``.
The names of criteria must be consistent with the output of``metrics_func``.
"""
if metrics_func == 'regress':
self.metrics_func = regression_metrics
elif metrics_func == 'classify':
self.metrics_func = classification_metrics
else:
self.metrics_func = metrics_func
self.warming_up = warming_up
self.each_iteration = each_iteration
self.patience = early_stopping + 1 if early_stopping is not None else None
self._count = early_stopping
self.order = trace_order
self._epoch_count = 0
self.trace = {}
self.trace_order = trace_order
self.trace_criteria = trace_criteria
self._set_trace(trace_criteria, trace_order)
self.from_dataset = False
self.train_loss = np.inf
@property
def warming_up(self):
return self._warming_up
@warming_up.setter
def warming_up(self, val: int):
if val < 0:
raise ValueError("`warming` up must equal or greater than 0")
self._warming_up = val
def _set_trace(self, trace_metrics: dict, trace_order: int):
for name, target in trace_metrics.items():
self.trace[name] = (target, [np.inf] * trace_order)
[docs] def on_reset(self) -> None:
self._set_trace(self.trace_criteria, self.trace_order)
[docs] def before_proc(self, trainer: Trainer) -> None:
x_val, y_val = trainer.x_val, trainer.y_val
val_dataset = trainer.validate_dataset
if x_val is None and y_val is None and val_dataset is not None:
self.from_dataset = True
elif x_val is None or y_val is None:
raise RuntimeError('no data for validation')
[docs] def step_forward(self, trainer: Trainer, step_info: OrderedDict) -> None:
def _validate():
if self.from_dataset:
y_preds, y_trues = trainer.predict(dataset=trainer.validate_dataset)
else:
y_preds, y_trues = trainer.predict(trainer.x_val, trainer.y_val)
train_loss = step_info[trainer.loss_type]
if train_loss < self.train_loss:
self.train_loss = train_loss
self._count = self.patience
metrics = self.metrics_func(y_trues, y_preds)
for name, (target, current) in self.trace.items():
if name in metrics:
score = np.abs(metrics[name] - target)
if score < current[-1] and step_info['i_epoch'] >= self._warming_up:
current.append(score)
current.sort()
current.pop()
self._count = self.patience
if self.order == 1:
trainer.set_checkpoint(name)
else:
index = current.index(score) + 1
for i in range(self.order, index, -1):
if f'{name}_{i - 1}' in trainer.checkpoints:
trainer.checkpoints[f'{name}_{i}'] = trainer.checkpoints[f'{name}_{i - 1}']
trainer.set_checkpoint(f'{name}_{index}')
if self.patience is not None:
self._count -= 1
if self._count == 0:
trainer.early_stop(
f'no improvement for {[k for k in self.trace]} since the last {self.patience} iterations, '
f'finish training at iteration {trainer.total_iterations}')
step_info.update({f'val_{k}': v for k, v in metrics.items()})
if not self.each_iteration:
epoch = step_info['i_epoch']
if epoch > self._epoch_count:
self._epoch_count = epoch
_validate()
else:
_validate()