xenonpy.model.training.extension package
Submodules
xenonpy.model.training.extension.persist module
- class xenonpy.model.training.extension.persist.Persist(path=None, *, model_class=None, model_params=None, increment=False, sync_training_step=False, only_best_states=False, no_model_saving=False, **describe)[source]
Bases:
BaseExtension
Trainer extension for data persistence
- Parameters:
model_class (
Optional
[Callable
]) – A factory function for model reconstructing. In most case this is the model class inherits fromtorch.nn.Module
model_params (
Union
[tuple
,dict
,any
,None
]) – The parameters for model reconstructing. This can be anything but in general this is a dict which can be used as kwargs parameters.increment (
bool
) – IfTrue
, dir name of path will be decorated with a auto increment number, e.g. usemodel_dir@1
formodel_dir
.sync_training_step (
bool
) – IfTrue
, will savetrainer.training_info
at each iteration. Default isFalse
, only savetrainer.training_info
at each epoch.only_best_states (
bool
) – IfTrue
, will only save the models with the best states in terms of each of the criteria.no_model_saving (
bool
) – Indicate whether to save the connected model and checkpoints.describe (
Any
) – Any other information to describe this model. These information will be saved under model dir by namedescribe.pkl.z
.
- property describe
- property model_structure
- property no_model_saving
- property path
xenonpy.model.training.extension.tensor_convert module
- class xenonpy.model.training.extension.tensor_convert.TensorConverter(*, x_dtype=None, y_dtype=None, empty_cache=False, auto_reshape=True, argmax=False, probability=False)[source]
Bases:
BaseExtension
Covert tensor like data into
torch.Tensor
automatically.- Parameters:
x_dtype (
Union
[dtype
,Sequence
[dtype
],None
]) – Thetorch.dtype`s of **X** data. If ``None`
, will convert all data intotorch.get_default_dtype()
type. Can be a tuple oftorch.dtype
when your X is a tuple.y_dtype (
Union
[dtype
,Sequence
[dtype
],None
]) – Thetorch.dtype`s of **y** data. If ``None`
, will convert all data intotorch.get_default_dtype()
type. Can be a tuple oftorch.dtype
when your y is a tuple.empty_cache (
bool
) – See Also: https://pytorch.org/docs/stable/cuda.html#torch.cuda.empty_cacheauto_reshape (
bool
) – Reshape tensor to (-1, 1) if tensor shape is (n,). DefaultTrue
.argmax (
bool
) – Applynp.argmax(out, 1)
on the output. This should only be used with classification model. DefaultFalse
. IfTrue
, will ignoreprobability
parameter.probability (
bool
) – Applyscipy.special.softmax
on the output. This should only be used with classification model. DefaultFalse
.
- input_proc(x_in, y_in, trainer)[source]
Convert data to
torch.Tensor
.
- output_proc(y_pred, y_true, is_training)[source]
Convert
torch.Tensor
tonumpy.ndarray
.- Parameters:
y_pred (Union[torch.Tensor, Tuple[torch.Tensor]]) –
y_true (Union[torch.Tensor, Tuple[torch.Tensor]]) –
is_training (bool) – Specify whether the model in the training mode.
- property argmax
- property auto_reshape
- property empty_cache
- property probability
xenonpy.model.training.extension.validator module
- class xenonpy.model.training.extension.validator.Validator(metrics_func, *, each_iteration=True, early_stopping=None, trace_order=1, warming_up=0, **trace_criteria)[source]
Bases:
BaseExtension
Validator extension
- Parameters:
metrics_func (
Union
[str
,Callable
[[Any
,Any
],Dict
]]) – Function for metrics calculation. Ifstr
, should beregress
orclassify
which to specify the training types. Seexenonpy.model.utils.classification_metrics()
andxenonpy.model.utils.regression_metrics()
to know the details. you can also give the calculation function yourself.each_iteration (
bool
) – IfTrue
, validation will be executed every iteration. Otherwise, only validate at each epoch done. DefaultTrue
.early_stopping (
Optional
[int
]) – Set patience condition of early stopping condition. This condition trace criteria setting bytrace_criteria
parameter.trace_order (
int
) – How many ranks oftrace_criteria
will be saved as checkpoint. Checkpoint name follow the formatcriterion_rank
, e.g.mae_1
warming_up (
int
) – Validator do not set/update checkpoint beforewarming_up
epochs.trace_criteria (
Dict
[str
,float
]) – Validation criteria. Should follow this formation:criterion=target
, e.gmae=0, corr=1
. The names of criteria must be consistent with the output of``metrics_func``.
- classify = 'classify'
- regress = 'regress'
- property warming_up