Source code for xenonpy.model.nn.layer

#  Copyright (c) 2021. yoshida-lab. All rights reserved.
#  Use of this source code is governed by a BSD-style
#  license that can be found in the LICENSE file.

from torch import nn

from .wrap import L1

__all__ = ['Layer1d']


[docs]class Layer1d(nn.Module): """ Base NN layer. This is a wrap around PyTorch. See here for details: http://pytorch.org/docs/master/nn.html# """ def __init__(self, n_in, n_out, *, drop_out=0., layer_func=L1.linear(bias=True), act_func=nn.ReLU(), batch_nor=L1.batch_norm(eps=1e-05, momentum=0.1, affine=True) ): """ Parameters ---------- n_in: int Size of each input sample. n_out: int Size of each output sample drop_out: float Probability of an element to be zeroed. Default: 0.5 layer_func: func Layers come with PyTorch. act_func: func Activation function. batch_nor: func Normalization layers """ super().__init__() self.layer = layer_func(n_in, n_out) self.batch_nor = None if not batch_nor else batch_nor(n_out) self.act_func = None if not act_func else act_func self.dropout = None if drop_out == 0. else nn.Dropout(drop_out)
[docs] def forward(self, *x): _out = self.layer(*x) if self.dropout: _out = self.dropout(_out) if self.batch_nor: _out = self.batch_nor(_out) if self.act_func: _out = self.act_func(_out) return _out