xenonpy.model.training.extension package
Submodules
xenonpy.model.training.extension.persist module
- class xenonpy.model.training.extension.persist.Persist(path=None, *, model_class=None, model_params=None, increment=False, sync_training_step=False, only_best_states=False, no_model_saving=False, **describe)[source]
Bases:
BaseExtensionTrainer extension for data persistence
- Parameters:
model_class (
Optional[Callable]) – A factory function for model reconstructing. In most case this is the model class inherits fromtorch.nn.Modulemodel_params (
Union[tuple,dict,any,None]) – The parameters for model reconstructing. This can be anything but in general this is a dict which can be used as kwargs parameters.increment (
bool) – IfTrue, dir name of path will be decorated with a auto increment number, e.g. usemodel_dir@1formodel_dir.sync_training_step (
bool) – IfTrue, will savetrainer.training_infoat each iteration. Default isFalse, only savetrainer.training_infoat each epoch.only_best_states (
bool) – IfTrue, will only save the models with the best states in terms of each of the criteria.no_model_saving (
bool) – Indicate whether to save the connected model and checkpoints.describe (
Any) – Any other information to describe this model. These information will be saved under model dir by namedescribe.pkl.z.
- property describe
- property model_structure
- property no_model_saving
- property path
xenonpy.model.training.extension.tensor_convert module
- class xenonpy.model.training.extension.tensor_convert.TensorConverter(*, x_dtype=None, y_dtype=None, empty_cache=False, auto_reshape=True, argmax=False, probability=False)[source]
Bases:
BaseExtensionCovert tensor like data into
torch.Tensorautomatically.- Parameters:
x_dtype (
Union[dtype,Sequence[dtype],None]) – Thetorch.dtype`s of **X** data. If ``None`, will convert all data intotorch.get_default_dtype()type. Can be a tuple oftorch.dtypewhen your X is a tuple.y_dtype (
Union[dtype,Sequence[dtype],None]) – Thetorch.dtype`s of **y** data. If ``None`, will convert all data intotorch.get_default_dtype()type. Can be a tuple oftorch.dtypewhen your y is a tuple.empty_cache (
bool) – See Also: https://pytorch.org/docs/stable/cuda.html#torch.cuda.empty_cacheauto_reshape (
bool) – Reshape tensor to (-1, 1) if tensor shape is (n,). DefaultTrue.argmax (
bool) – Applynp.argmax(out, 1)on the output. This should only be used with classification model. DefaultFalse. IfTrue, will ignoreprobabilityparameter.probability (
bool) – Applyscipy.special.softmaxon the output. This should only be used with classification model. DefaultFalse.
- input_proc(x_in, y_in, trainer)[source]
Convert data to
torch.Tensor.
- output_proc(y_pred, y_true, is_training)[source]
Convert
torch.Tensortonumpy.ndarray.- Parameters:
y_pred (Union[torch.Tensor, Tuple[torch.Tensor]]) –
y_true (Union[torch.Tensor, Tuple[torch.Tensor]]) –
is_training (bool) – Specify whether the model in the training mode.
- property argmax
- property auto_reshape
- property empty_cache
- property probability
xenonpy.model.training.extension.validator module
- class xenonpy.model.training.extension.validator.Validator(metrics_func, *, each_iteration=True, early_stopping=None, trace_order=1, warming_up=0, **trace_criteria)[source]
Bases:
BaseExtensionValidator extension
- Parameters:
metrics_func (
Union[str,Callable[[Any,Any],Dict]]) – Function for metrics calculation. Ifstr, should beregressorclassifywhich to specify the training types. Seexenonpy.model.utils.classification_metrics()andxenonpy.model.utils.regression_metrics()to know the details. you can also give the calculation function yourself.each_iteration (
bool) – IfTrue, validation will be executed every iteration. Otherwise, only validate at each epoch done. DefaultTrue.early_stopping (
Optional[int]) – Set patience condition of early stopping condition. This condition trace criteria setting bytrace_criteriaparameter.trace_order (
int) – How many ranks oftrace_criteriawill be saved as checkpoint. Checkpoint name follow the formatcriterion_rank, e.g.mae_1warming_up (
int) – Validator do not set/update checkpoint beforewarming_upepochs.trace_criteria (
Dict[str,float]) – Validation criteria. Should follow this formation:criterion=target, e.gmae=0, corr=1. The names of criteria must be consistent with the output of``metrics_func``.
- classify = 'classify'
- regress = 'regress'
- property warming_up