xenonpy.model.training.extension package

Submodules

xenonpy.model.training.extension.persist module

class xenonpy.model.training.extension.persist.Persist(path=None, *, model_class=None, model_params=None, increment=False, sync_training_step=False, only_best_states=False, no_model_saving=False, **describe)[source]

Bases: BaseExtension

Trainer extension for data persistence

Parameters:
  • path (Union[Path, str, None]) – Path for model saving.

  • model_class (Optional[Callable]) – A factory function for model reconstructing. In most case this is the model class inherits from torch.nn.Module

  • model_params (Union[tuple, dict, any, None]) – The parameters for model reconstructing. This can be anything but in general this is a dict which can be used as kwargs parameters.

  • increment (bool) – If True, dir name of path will be decorated with a auto increment number, e.g. use model_dir@1 for model_dir.

  • sync_training_step (bool) – If True, will save trainer.training_info at each iteration. Default is False, only save trainer.training_info at each epoch.

  • only_best_states (bool) – If True, will only save the models with the best states in terms of each of the criteria.

  • no_model_saving (bool) – Indicate whether to save the connected model and checkpoints.

  • describe (Any) – Any other information to describe this model. These information will be saved under model dir by name describe.pkl.z.

__call__(handle=None, **kwargs)[source]

Call self as a function.

after_proc(trainer=None, _is_training=True, *_dependence)[source]
Return type:

None

before_proc(trainer=None, _is_training=True, *_dependence)[source]
Return type:

None

get_checkpoint(id_=None)[source]
on_checkpoint(checkpoint, _trainer=None, _is_training=True, *_dependence)[source]
Return type:

None

step_forward(step_info, trainer=None, _is_training=True, *_dependence)[source]
Return type:

None

property describe
property model_structure
property no_model_saving
property path

xenonpy.model.training.extension.tensor_convert module

class xenonpy.model.training.extension.tensor_convert.TensorConverter(*, x_dtype=None, y_dtype=None, empty_cache=False, auto_reshape=True, argmax=False, probability=False)[source]

Bases: BaseExtension

Covert tensor like data into torch.Tensor automatically.

Parameters:
  • x_dtype (Union[dtype, Sequence[dtype], None]) – The torch.dtype`s of **X** data. If ``None`, will convert all data into torch.get_default_dtype() type. Can be a tuple of torch.dtype when your X is a tuple.

  • y_dtype (Union[dtype, Sequence[dtype], None]) – The torch.dtype`s of **y** data. If ``None`, will convert all data into torch.get_default_dtype() type. Can be a tuple of torch.dtype when your y is a tuple.

  • empty_cache (bool) – See Also: https://pytorch.org/docs/stable/cuda.html#torch.cuda.empty_cache

  • auto_reshape (bool) – Reshape tensor to (-1, 1) if tensor shape is (n,). Default True.

  • argmax (bool) – Apply np.argmax(out, 1) on the output. This should only be used with classification model. Default False. If True, will ignore probability parameter.

  • probability (bool) – Apply scipy.special.softmax on the output. This should only be used with classification model. Default False.

input_proc(x_in, y_in, trainer)[source]

Convert data to torch.Tensor.

Parameters:
Return type:

Union[Any, Tuple[Any, Any]]

output_proc(y_pred, y_true, is_training)[source]

Convert torch.Tensor to numpy.ndarray.

Parameters:
step_forward()[source]
property argmax
property auto_reshape
property empty_cache
property probability

xenonpy.model.training.extension.validator module

class xenonpy.model.training.extension.validator.Validator(metrics_func, *, each_iteration=True, early_stopping=None, trace_order=1, warming_up=0, **trace_criteria)[source]

Bases: BaseExtension

Validator extension

Parameters:
  • metrics_func (Union[str, Callable[[Any, Any], Dict]]) – Function for metrics calculation. If str, should be regress or classify which to specify the training types. See xenonpy.model.utils.classification_metrics() and xenonpy.model.utils.regression_metrics() to know the details. you can also give the calculation function yourself.

  • each_iteration (bool) – If True, validation will be executed every iteration. Otherwise, only validate at each epoch done. Default True.

  • early_stopping (Optional[int]) – Set patience condition of early stopping condition. This condition trace criteria setting by trace_criteria parameter.

  • trace_order (int) – How many ranks of trace_criteria will be saved as checkpoint. Checkpoint name follow the format criterion_rank, e.g. mae_1

  • warming_up (int) – Validator do not set/update checkpoint before warming_up epochs.

  • trace_criteria (Dict[str, float]) – Validation criteria. Should follow this formation: criterion=target, e.g mae=0, corr=1. The names of criteria must be consistent with the output of``metrics_func``.

before_proc(trainer)[source]
Return type:

None

on_reset()[source]
Return type:

None

step_forward(trainer, step_info)[source]
Return type:

None

classify = 'classify'
regress = 'regress'
property warming_up

Module contents