Source code for xenonpy.model.nn.wrap

#  Copyright (c) 2021. yoshida-lab. All rights reserved.
#  Use of this source code is governed by a BSD-style
#  license that can be found in the LICENSE file.
from functools import partial

import torch as tc
from torch import nn

__all__ = ['Optim', 'LrScheduler', 'Init', 'L1']


[docs]class Optim(object):
[docs] @staticmethod def sgd(*args, **kwargs): """ Wrapper class for :class:`torch.optim.SGD`. http://pytorch.org/docs/0.3.0/optim.html#torch.optim.SGD """ return partial(tc.optim.SGD, *args, **kwargs)
[docs] @staticmethod def ada_delta(*args, **kwargs): """ Wrapper class for :class:`torch.optim.Adadelta`. http://pytorch.org/docs/0.3.0/optim.html#torch.optim.Adadelta """ return partial(tc.optim.Adadelta, *args, **kwargs)
[docs] @staticmethod def ada_grad(*args, **kwargs): """ Wrapper class for :class:`torch.optim.Adagrad`. http://pytorch.org/docs/0.3.0/optim.html#torch.optim.Adagrad """ return partial(tc.optim.Adagrad, *args, **kwargs)
[docs] @staticmethod def adam(*args, **kwargs): """ Wrapper class for :class:`torch.optim.Adam`. http://pytorch.org/docs/0.3.0/optim.html#torch.optim.Adam """ return partial(tc.optim.Adam, *args, **kwargs)
[docs] @staticmethod def sparse_adam(*args, **kwargs): """ Wrapper class for :class:`torch.optim.SparseAdam`. http://pytorch.org/docs/0.3.0/optim.html#torch.optim.SparseAdam """ return partial(tc.optim.SparseAdam, *args, **kwargs)
[docs] @staticmethod def ada_max(*args, **kwargs): """ Wrapper class for :class:`torch.optim.Adamax`. http://pytorch.org/docs/0.3.0/optim.html#torch.optim.Adamax """ return partial(tc.optim.Adamax, *args, **kwargs)
[docs] @staticmethod def asgd(*args, **kwargs): """ Wrapper class for :class:`torch.optim.ASGD`. http://pytorch.org/docs/0.3.0/optim.html#torch.optim.ASGD """ return partial(tc.optim.ASGD, *args, **kwargs)
[docs] @staticmethod def lbfgs(*args, **kwargs): """ Wrapper class for :class:`torch.optim.LBFGS`. http://pytorch.org/docs/0.3.0/optim.html#torch.optim.LBFGS """ return partial(tc.optim.LBFGS, *args, **kwargs)
[docs] @staticmethod def rms_prop(*args, **kwargs): """ Wrapper class for :class:`torch.optim.RMSprop`. http://pytorch.org/docs/0.3.0/optim.html#torch.optim.RMSprop """ return partial(tc.optim.RMSprop, *args, **kwargs)
[docs] @staticmethod def r_prop(*args, **kwargs): """ Wrapper class for :class:`torch.optim.Rprop`. http://pytorch.org/docs/0.3.0/optim.html#torch.optim.Rprop """ return partial(tc.optim.Rprop, *args, **kwargs)
[docs]class LrScheduler(object):
[docs] @staticmethod def lambda_lr(*args, **kwargs): """ Wrapper class for :class:`torch.optim.lr_scheduler.LambdaLR`. http://pytorch.org/docs/0.3.0/optim.html#torch.optim.lr_scheduler.LambdaLR """ return partial(tc.optim.lr_scheduler.LambdaLR, *args, **kwargs)
[docs] @staticmethod def step_lr(*args, **kwargs): """ Wrapper class for :class:`torch.optim.lr_scheduler.StepLR`. http://pytorch.org/docs/0.3.0/optim.html#torch.optim.lr_scheduler.StepLR """ return partial(tc.optim.lr_scheduler.StepLR, *args, **kwargs)
[docs] @staticmethod def multi_step_lr(*args, **kwargs): """ Wrapper class for :class:`torch.optim.lr_scheduler.MultiStepLR`. http://pytorch.org/docs/0.3.0/optim.html#torch.optim.lr_scheduler.MultiStepLR """ return partial(tc.optim.lr_scheduler.MultiStepLR, *args, **kwargs)
[docs] @staticmethod def exponential_lr(*args, **kwargs): """ Wrapper class for :class:`torch.optim.lr_scheduler.ExponentialLR`. http://pytorch.org/docs/0.3.0/optim.html#torch.optim.lr_scheduler.ExponentialLR """ return partial(tc.optim.lr_scheduler.ExponentialLR, *args, **kwargs)
[docs] @staticmethod def reduce_lr_on_plateau(*args, **kwargs): """ Wrapper class for :class:`torch.optim.lr_scheduler.ReduceLROnPlateau`. http://pytorch.org/docs/0.3.0/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau """ return partial(tc.optim.lr_scheduler.ReduceLROnPlateau, *args, **kwargs)
[docs]class Init(object):
[docs] @staticmethod def uniform(*, scale=0.1): b = 1 * scale a = -b return partial(nn.init.uniform, a=a, b=b)
[docs]class L1(object):
[docs] @staticmethod def conv(*args, **kwargs): """ Wrapper class for :class:`torch.nn.Conv1d`. http://pytorch.org/docs/0.3.0/optim.html#torch.nn.Conv1d """ return partial(nn.Conv1d, *args, **kwargs)
[docs] @staticmethod def linear(*args, **kwargs): """ Wrapper class for :class:`torch.nn.Linear`. http://pytorch.org/docs/0.3.0/optim.html#torch.nn.Linear """ return partial(nn.Linear, *args, **kwargs)
[docs] @staticmethod def batch_norm(*args, **kwargs): """ Wrapper class for :class:`torch.nn.BatchNorm1d`. http://pytorch.org/docs/0.3.0/optim.html#torch.nn.BatchNorm1d """ return partial(nn.BatchNorm1d, *args, **kwargs)
[docs] @staticmethod def instance_norm(*args, **kwargs): """ Wrapper class for :class:`torch.nn.InstanceNorm1d`. http://pytorch.org/docs/0.3.0/optim.html#torch.nn.InstanceNorm1d """ return partial(nn.InstanceNorm1d, *args, **kwargs)