Pre-trained Model Library

XenonPy.MDL is a library of pre-trained models that were obtained by feeding diverse materials data on structure-property relationships into neural networks and some other supervised learning models.

XenonPy offers a simple-to-use toolchain to perform transfer learning with the given pre-trained models seamlessly. In this tutorial, we will focus on model querying and retrieving.

useful functions

Running the following cell will load some commonly used packages, such as NumPy, pandas, and so on. It will also import some in-house functions used in this tutorial. See ‘tools.ipynb’ file to check what will be imported.

[1]:
%run tools.ipynb

access pre-trained models with MDL class

We prepared a wide range of APIs to let you query and download our models. These APIs can be accessed via any HTTP requests. For convenience, we implemented some of the most popular APIs and wrapped them into XenonPy. All these functions can be accessed using xenonpy.mdl.MDL.

[2]:
# --- import necessary libraries

from xenonpy.mdl import MDL
[3]:
# --- init and check

mdl = MDL()
mdl

mdl.version
[3]:
MDL(api_key='anonymous.user.key', endpoint='http://xenon.ism.ac.jp/api')
[3]:
'0.1.1'

Noticed that mdl contains optional parameters api_key and endpoint. endpoint point to where data is fetched. api_key is an access token used to validate the authorization and action permissions. At this moment, the defaulat key, anonymous.user.key, is the only valid option. We will open the public registration system when the system is ready.

querying

There are many ways to query models. The most straightforward method is to use mdl. It accepts variable keywords as input and any hit keyword will be returned. For example, to query models that predict property refractive index:

[4]:
# --- query data

query = mdl('refractive')
query
[4]:
QueryModelDetails(api_key='anonymous.user.key', endpoint='http://xenon.ism.ac.jp/api', variables={'query': ('refractive',)})
Queryable:
 id
 transferred
 succeed
 isRegression
 deprecated
 modelset
 method
 property
 descriptor
 lang
 accuracy
 precision
 recall
 f1
 sensitivity
 prevalence
 specificity
 ppv
 npv
 meanAbsError
 maxAbsError
 meanSquareError
 rootMeanSquareError
 r2
 pValue
 spearmanCorr
 pearsonCorr

You can see that run a querying method does not execute the querying immediately, but simply return a queryable object. If you print out the object, the queryable list will be shown. Only the variables in the list can be fetched from the server.

Another way to get the queryable list is call query.queryable as below:

[5]:
query.queryable
[5]:
['id',
 'transferred',
 'succeed',
 'isRegression',
 'deprecated',
 'modelset',
 'method',
 'property',
 'descriptor',
 'lang',
 'accuracy',
 'precision',
 'recall',
 'f1',
 'sensitivity',
 'prevalence',
 'specificity',
 'ppv',
 'npv',
 'meanAbsError',
 'maxAbsError',
 'meanSquareError',
 'rootMeanSquareError',
 'r2',
 'pValue',
 'spearmanCorr',
 'pearsonCorr']

Let’s say we only want to know the variables of modelset, method, property, descriptor, meanAbsError, meanSquareError, pValue, and pearsonCorr. Execute the query as follow:

[6]:
query(
    'modelset',
    'method',
    'property',
    'descriptor',
    'meanAbsError',
    'meanSquareError',
    'pValue',
    'pearsonCorr'
)
[6]:
modelset method property descriptor meanAbsError meanSquareError pValue pearsonCorr id
0 Stable inorganic compounds in materials project pytorch.nn.neural_network inorganic.crystal.refractive_index xenonpy.compositions 0.422960 0.418109 None 0.631021 2335
1 Stable inorganic compounds in materials project pytorch.nn.neural_network inorganic.crystal.refractive_index xenonpy.compositions 0.545381 1.641444 None 0.578646 2338
2 Stable inorganic compounds in materials project pytorch.nn.neural_network inorganic.crystal.refractive_index xenonpy.compositions 0.708027 3.087893 None 0.439089 2339
3 Stable inorganic compounds in materials project pytorch.nn.neural_network inorganic.crystal.refractive_index xenonpy.compositions 0.585778 2.238217 None 0.531137 2341
4 Stable inorganic compounds in materials project pytorch.nn.neural_network inorganic.crystal.refractive_index xenonpy.compositions 0.542811 2.174997 None 0.558526 2342
... ... ... ... ... ... ... ... ... ...
3595 Polymer Genome Dataset pytorch.nn.neural_network organic.polymer.refractive_index xenonpy.compositions 0.082331 0.019165 None 0.824684 31191
3596 Polymer Genome Dataset pytorch.nn.neural_network organic.polymer.refractive_index xenonpy.compositions 0.094522 0.023909 None 0.667404 31192
3597 Polymer Genome Dataset pytorch.nn.neural_network organic.polymer.refractive_index xenonpy.compositions 0.128509 0.039972 None 0.680136 31193
3598 Polymer Genome Dataset pytorch.nn.neural_network organic.polymer.refractive_index xenonpy.compositions 0.096644 0.025381 None 0.704920 31194
3599 Polymer Genome Dataset pytorch.nn.neural_network organic.polymer.refractive_index xenonpy.compositions 0.122035 0.040172 None 0.679644 31195

3600 rows × 9 columns

If everything goes right, you will get a pandas DataFrame in return. You can see that 3600 models matched the keyword and these models are contained in two sets of models. Note that all variables in column pValue are None. This is not a querying error, all variables are None indeed, because they were not recorded during training.

You can also retrieve the last querying result from the query object via the results property.

[7]:
query.results.head(3)
[7]:
modelset method property descriptor meanAbsError meanSquareError pValue pearsonCorr id
0 Stable inorganic compounds in materials project pytorch.nn.neural_network inorganic.crystal.refractive_index xenonpy.compositions 0.422960 0.418109 None 0.631021 2335
1 Stable inorganic compounds in materials project pytorch.nn.neural_network inorganic.crystal.refractive_index xenonpy.compositions 0.545381 1.641444 None 0.578646 2338
2 Stable inorganic compounds in materials project pytorch.nn.neural_network inorganic.crystal.refractive_index xenonpy.compositions 0.708027 3.087893 None 0.439089 2339

Querying with mdl is simple but not efficient enough. In most cases, we may know exactly what we want.

Let’s say we want to retrieve some models that were trained in the inorganic modelset and can predict the property of refractive index. In this case, we need to feed the parameter modelset_has with inorganic and the property_has with refractive, respectively.

[8]:
# --- query data

mdl(modelset_has='inorganic', property_has='refractive')()
[8]:
id transferred succeed isRegression deprecated modelset method property descriptor lang meanAbsError maxAbsError meanSquareError rootMeanSquareError r2 pValue spearmanCorr pearsonCorr
0 2335 False True True False Stable inorganic compounds in materials project pytorch.nn.neural_network inorganic.crystal.refractive_index xenonpy.compositions python 0.422960 1.782320 0.418109 0.646613 0.226642 None 0.805147 0.631021
1 2338 False True True False Stable inorganic compounds in materials project pytorch.nn.neural_network inorganic.crystal.refractive_index xenonpy.compositions python 0.545381 8.948927 1.641444 1.281188 0.330978 None 0.808188 0.578646
2 2339 False True True False Stable inorganic compounds in materials project pytorch.nn.neural_network inorganic.crystal.refractive_index xenonpy.compositions python 0.708027 10.142218 3.087893 1.757240 0.173911 None 0.839800 0.439089
3 2341 False True True False Stable inorganic compounds in materials project pytorch.nn.neural_network inorganic.crystal.refractive_index xenonpy.compositions python 0.585778 11.835847 2.238217 1.496067 0.241867 None 0.690957 0.531137
4 2342 False True True False Stable inorganic compounds in materials project pytorch.nn.neural_network inorganic.crystal.refractive_index xenonpy.compositions python 0.542811 11.045835 2.174997 1.474787 0.275737 None 0.886405 0.558526
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
2395 4863 False True True False All inorganic compounds in materials project pytorch.nn.neural_network inorganic.crystal.refractive_index xenonpy.compositions python 0.429642 2.204849 0.434043 0.658819 0.441885 None 0.808280 0.712245
2396 4865 False True True False All inorganic compounds in materials project pytorch.nn.neural_network inorganic.crystal.refractive_index xenonpy.compositions python 0.569579 7.622173 1.656191 1.286931 0.215002 None 0.829544 0.475225
2397 4867 False True True False All inorganic compounds in materials project pytorch.nn.neural_network inorganic.crystal.refractive_index xenonpy.compositions python 0.436152 2.635126 0.559099 0.747729 0.426260 None 0.834226 0.655400
2398 4868 False True True False All inorganic compounds in materials project pytorch.nn.neural_network inorganic.crystal.refractive_index xenonpy.compositions python 0.383506 2.473203 0.339653 0.582797 0.599966 None 0.840016 0.798442
2399 4871 False True True False All inorganic compounds in materials project pytorch.nn.neural_network inorganic.crystal.refractive_index xenonpy.compositions python 0.394862 2.490165 0.397413 0.630407 0.674495 None 0.800693 0.839631

2400 rows × 18 columns

You can see that only the models that belong to All inorganic compounds in materials project modelset were returned. If you call a query object without parameters, all queryable variables will be returned.

list/get_detail variables

You can check some meta info of the database. To do so, we use mdl.list_* and mdl.get_*_detail methods. For example, mdl.list_properties will return:

[9]:
mdl.list_properties()()
[9]:
name fullName symbol unit describe
0 inorganic.crystal.efermi
1 inorganic.crystal.refractive_index
2 inorganic.crystal.band_gap
3 inorganic.crystal.density
4 inorganic.crystal.total_magnetization
5 inorganic.crystal.dielectric_const_elec
6 inorganic.crystal.dielectric_const_total
7 inorganic.crystal.final_energy_per_atom
8 inorganic.crystal.formation_energy_per_atom
9 inorganic.crystal.volume
10 organic.polymer.band_gap_pbe
11 organic.polymer.ionization_energy
12 organic.polymer.electron_affinity
13 organic.polymer.atomization_energy
14 organic.polymer.cohesive_energy
15 organic.polymer.refractive_index
16 organic.polymer.density
17 organic.polymer.dielectric_constant
18 organic.polymer.volume_of_cell
19 organic.polymer.dielectric_constant_electronic
20 organic.polymer.dielectric_constant_ionic
21 organic.polymer.band_gap_hse06
22 organic.small_molecule

and mdl.get_property_detail will return the property information including how many models are relevant to this property:

[10]:
mdl.get_property_detail('inorganic.crystal.efermi')()
[10]:
{'name': 'inorganic.crystal.efermi',
 'fullName': '',
 'symbol': '',
 'unit': '',
 'describe': '',
 'count': 2481}

These querying statements are very useful when you want to know what is in the database.

get training info/env and download url by model ID

Note that all models have their unique IDs. We can use model ID to get more information about a particular model. The following shows how to get training info/env by model ID.

[11]:
info = mdl.get_training_info(model_id=1234)()
_, ax = plt.subplots(figsize=(10, 5), dpi=100)
info.tail(3)
info.plot(y=['train_mse_loss', 'val_mse'], ax=ax)
[11]:
total_iters i_epoch i_batch train_mse_loss val_mae val_mse val_rmse val_r2 val_pearsonr val_spearmanr val_p_value val_max_ae
694 694 22 23 2.060905 1.098812 2.106717 1.451453 0.701376 0.850455 0.847006 0.0 12.112278
695 695 22 24 2.049812 1.162004 2.331702 1.526991 0.669485 0.844444 0.839916 0.0 12.192457
696 696 22 25 2.156136 1.092442 2.166551 1.471921 0.692895 0.837554 0.828305 0.0 12.244514
[11]:
<matplotlib.axes._subplots.AxesSubplot at 0x1a1f5ff050>
../_images/tutorials_5-pre-trained_model_library_24_2.png
[12]:
mdl.get_training_env(model_id=1234)()
[12]:
{'python': '3.7.4 (default, Aug 13 2019, 20:35:49) \n[GCC 7.3.0]',
 'system': '#60-Ubuntu SMP Tue Jul 2 18:22:20 UTC 2019',
 'numpy': '1.16.4',
 'torch': '1.1.0',
 'xenonpy': '0.4.0.beta4',
 'device': 'cuda:2',
 'start': '2019/09/17 20:55:56',
 'finish': '2019/09/17 21:03:04',
 'time_elapsed': '5 days, 15:33:23.552799',
 'author': 'Chang Liu',
 'email': 'liu.chang@ism.ac.jp',
 'dataset': 'materials project'}
[13]:
mdl.get_model_urls(1234, 5678)()
[13]:
id etag url
0 1234 9274418b5ee2026ea8714b7edc7d012e-1 http://xenon.ism.ac.jp/mdl/inorganic.crystal.e...
1 5678 c5a962fce76a773b208cc59631999a25-1 http://xenon.ism.ac.jp/mdl/inorganic.crystal.b...

The output dataframe contains the column named url. If you only want to get the string of url, just use queryable specification.

[14]:
mdl.get_model_urls(1234, 5678)('url')
[14]:
url
0 http://xenon.ism.ac.jp/mdl/inorganic.crystal.e...
1 http://xenon.ism.ac.jp/mdl/inorganic.crystal.b...

Also, if you don’t want to get a dataframe or you want to control the output type yourself, you can set return_json to True.

[15]:
mdl.get_model_urls(1234, 5678)('url', return_json=True)
[15]:
[{'url': 'http://xenon.ism.ac.jp/mdl/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3.tar.gz'},
 {'url': 'http://xenon.ism.ac.jp/mdl/inorganic.crystal.band_gap/xenonpy.compositions/pytorch.nn.neural_network/290-168-132-111-1-$Nbe3TKMYM.tar.gz'}]

You can use these urls to download models yourself, but we suggest you to use mdl.pull.

[16]:
mdl.pull?
Signature:
mdl.pull(
    *model_ids: Union[int, pandas.core.series.Series, pandas.core.frame.DataFrame],
    save_to: str = '.',
) -> pandas.core.frame.DataFrame
Docstring:
Download model(s) from XenonPy.MDL server.

Parameters
----------
model_ids
    Model ids.
    It can be given by a dataframe.
    In this case, the column with name ``id`` will be used.
save_to
    Path to save models.

Returns
-------
File:      ~/projects/XenonPy/xenonpy/mdl/mdl.py
Type:      method

The column named model contains the local path of the downloaded models.

upload model

Uploading models is not yet availiable until we open the public registration. If you try to upload models, you will get ‘operation needs a logged in user and the corresponded api_key’ error.

[18]:
mdl.upload_model(
    modelset_id=2,
    describe=dict(
        property='test2',
        descriptor='test2',
        method='test',
        lang='test',
        deprecated=True,
        isRegression=True,
        meanAbsError=1.11,
        pearsonCorr=0.85,
    ),
    training_info=dict(a=1, b=2),
    supplementary=dict(true=[1,2,3,4], pred=[1,2,2,4])
)(file='data')
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-18-f862c734ea61> in <module>
     13     training_info=dict(a=1, b=2),
     14     supplementary=dict(true=[1,2,3,4], pred=[1,2,2,4])
---> 15 )(file='data')

~/projects/XenonPy/xenonpy/mdl/base.py in __call__(self, file, return_json, *querying_vars)
    104         ret = ret.json()
    105         if 'errors' in ret:
--> 106             raise ValueError(ret['errors'][0]['message'])
    107         query_name = self.__class__.__name__
    108         ret = ret['data'][query_name[0].lower() + query_name[1:]]

ValueError: operation needs a logged in user and the corresponded api_key

retrieve model

You can use xenonpy.model.training.Checker to load the downloaded models. For example, to load the model with id 1234:

[19]:
from xenonpy.model.training import Checker
[20]:
checker = Checker(ret[ret.id == 1234].model.item())
checker
/usr/local/miniconda3/envs/xepy37/lib/python3.7/site-packages/ipykernel_launcher.py:1: FutureWarning: `item` has been deprecated and will be removed in a future version
  """Entry point for launching an IPython kernel.
[20]:
<Checker> includes:
"final_state": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/final_state.pth.s
"training_info": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/training_info.pd.xz
"model_class": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/model_class.pkl.z
"init_state": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/init_state.pth.s
"model": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/model.pth.m
"splitter": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/splitter.pkl.z
"model_structure": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/model_structure.pkl.z
"describe": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/describe.pkl.z
"data_indices": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/data_indices.pkl.z
"model_params": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/model_params.pkl.z

Note that the random string $WEnkZ6e3 in the file name is a magic number to guarantee that each model has a unique name.

To load a model into python, call checker.model property.

[21]:
checker.model
[21]:
SequentialLinear(
  (layer_0): LinearLayer(
    (linear): Linear(in_features=290, out_features=180, bias=True)
    (dropout): Dropout(p=0.1)
    (normalizer): BatchNorm1d(180, eps=0.1, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU()
  )
  (layer_1): LinearLayer(
    (linear): Linear(in_features=180, out_features=177, bias=True)
    (dropout): Dropout(p=0.1)
    (normalizer): BatchNorm1d(177, eps=0.1, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU()
  )
  (layer_2): LinearLayer(
    (linear): Linear(in_features=177, out_features=162, bias=True)
    (dropout): Dropout(p=0.1)
    (normalizer): BatchNorm1d(162, eps=0.1, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU()
  )
  (layer_3): LinearLayer(
    (linear): Linear(in_features=162, out_features=46, bias=True)
    (dropout): Dropout(p=0.1)
    (normalizer): BatchNorm1d(46, eps=0.1, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU()
  )
  (layer_4): LinearLayer(
    (linear): Linear(in_features=46, out_features=32, bias=True)
    (dropout): Dropout(p=0.1)
    (normalizer): BatchNorm1d(32, eps=0.1, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU()
  )
  (output): Linear(in_features=32, out_features=1, bias=True)
)

Use checker.checkpoints to list checkpoints.

[22]:
checker.checkpoints
[22]:
<Checker> includes:
"mse_3": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/checkpoints/mse_3.pth.s
"mse_1": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/checkpoints/mse_1.pth.s
"mae_2": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/checkpoints/mae_2.pth.s
"r2_5": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/checkpoints/r2_5.pth.s
"mse_5": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/checkpoints/mse_5.pth.s
"r2_1": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/checkpoints/r2_1.pth.s
"mae_4": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/checkpoints/mae_4.pth.s
"r2_3": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/checkpoints/r2_3.pth.s
"mae_3": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/checkpoints/mae_3.pth.s
"r2_4": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/checkpoints/r2_4.pth.s
"mse_2": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/checkpoints/mse_2.pth.s
"mae_1": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/checkpoints/mae_1.pth.s
"mae_5": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/checkpoints/mae_5.pth.s
"r2_2": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/checkpoints/r2_2.pth.s
"mse_4": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/checkpoints/mse_4.pth.s
"pearsonr_5": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/checkpoints/pearsonr_5.pth.s
"pearsonr_1": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/checkpoints/pearsonr_1.pth.s
"pearsonr_3": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/checkpoints/pearsonr_3.pth.s
"pearsonr_4": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/checkpoints/pearsonr_4.pth.s
"pearsonr_2": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.efermi/xenonpy.compositions/pytorch.nn.neural_network/290-180-177-162-46-32-1-$WEnkZ6e3/checkpoints/pearsonr_2.pth.s

and checker.checkpoints[<checkpoint name>] to load information from a specific checkpoint.

[23]:
checker.checkpoints['mse_3']
[23]:
OrderedDict([('id', 'mse_3'),
             ('iterations', 670),
             ('model_state',
              OrderedDict([('layer_0.linear.weight',
                            tensor([[-2.7413,  2.2444, -0.6347,  ..., -0.4093,  1.7569,  0.3331],
                                    [-3.4710,  3.1230, -1.5549,  ...,  0.7105,  0.1097, -0.5686],
                                    [-1.0452, -0.0773, -1.8657,  ..., -2.4842, -2.4716,  0.0133],
                                    ...,
                                    [-2.8635,  2.0810, -1.8306,  ..., -0.7263,  0.1966,  0.6690],
                                    [-1.4911,  0.3969, -2.4763,  ..., -2.1228, -2.3481, -0.2280],
                                    [-4.5753,  3.2695, -3.0195,  ..., -1.5596, -0.2794,  0.9001]])),
                           ('layer_0.linear.bias',
                            tensor([ 1.8674e-01,  1.2594e+00,  6.3830e-03,  8.8717e-03, -2.2134e-02,
                                     1.0698e+00,  1.3682e-01, -6.5420e-01, -2.2904e-02,  2.5760e-01,
                                     1.3303e-02, -7.1474e-01,  5.6013e-02,  2.4740e-01,  3.6663e-02,
                                    -4.6552e-02,  3.5276e-01,  4.7275e-01,  7.8601e-02,  1.7847e-01,
                                     4.1219e-01,  1.1378e-01, -1.5997e-01, -4.0251e-02, -5.4254e-02,
                                     4.1694e-01,  1.6805e-01,  6.6942e-01,  4.1481e-02,  1.0744e-01,
                                     5.8623e-01,  9.0315e-01,  1.2878e+00,  5.1398e-01,  2.8303e-01,
                                     9.7579e-02,  1.7690e-01,  1.6764e-01, -2.3814e-01,  2.6120e-01,
                                     5.1376e-01,  2.8208e-01,  3.4689e-02,  4.8316e-02,  2.2929e-01,
                                     4.6762e-02, -4.5354e-02,  1.7217e-03,  4.2179e-03,  5.0165e-01,
                                     2.7844e-01, -1.4434e-01, -8.3940e-01,  9.0725e-02, -1.2393e-01,
                                    -7.6476e-02, -5.0103e-02,  2.0573e-02, -4.6605e-01,  4.2584e-04,
                                    -1.2431e-03,  1.7416e-02,  1.9315e-01, -8.2507e-02,  5.6182e-02,
                                    -1.5245e-01, -1.6458e-02,  1.2927e-01, -1.0788e-01, -6.2298e-02,
                                     3.5064e-01,  3.9070e-01,  1.4682e-02,  7.9495e-02,  3.9572e-01,
                                     3.2543e-02,  3.1317e-02,  2.2445e-01, -5.1132e-01,  9.8786e-01,
                                     8.7456e-02, -2.9127e-02,  2.8352e-01, -5.4528e-02, -3.7345e-02,
                                     1.5468e-01,  8.5600e-01, -4.6745e-03,  4.3161e-01,  2.3493e-02,
                                     2.1969e-01, -4.6591e-02,  4.1278e-02, -3.0714e-01,  3.9996e-02,
                                     3.8986e-02,  5.1559e-02,  6.2200e-02,  5.3427e-02,  1.5805e-01,
                                     4.4626e-01,  1.8942e-01,  3.0517e-02,  8.5060e-03,  2.3743e-01,
                                     4.2354e-02,  1.4166e-02,  1.6934e-01,  4.8535e-01, -4.7722e-02,
                                     4.4760e-01,  2.1024e-01, -3.7049e-02, -3.6345e-01,  1.1755e-01,
                                    -8.5783e-02, -1.7177e-01, -2.4121e-02,  1.7525e-01,  5.6658e-01,
                                    -1.1588e-01, -6.6737e-02,  1.9659e-01, -5.5139e-02,  2.2182e-01,
                                     4.2854e-01,  9.8938e-01,  5.8801e-01,  3.2118e-01,  3.4174e-01,
                                    -6.3782e-02,  2.4754e-01, -6.7626e-02,  2.2835e-01,  9.6599e-03,
                                     1.3314e+00,  6.5396e-03,  2.8145e-02, -1.1071e+00,  5.0319e-01,
                                     1.7117e-01,  5.1287e-01, -8.9040e-02, -1.4837e-01, -1.2847e-02,
                                     3.7948e-01, -9.2569e-04, -3.3302e-03,  2.0668e-01,  3.3091e-04,
                                     1.2669e-01, -6.7318e-02,  4.1707e-01, -5.6539e-02,  1.6370e-02,
                                    -6.2774e-02,  2.8587e-02, -3.5977e-02,  3.8779e-01, -4.2425e-02,
                                     7.1804e-03,  3.5864e-01,  1.6941e+00,  9.8432e-01,  2.3810e-02,
                                     9.4629e-02,  1.3361e-01, -9.8752e-01, -2.1931e-02,  1.8669e+00,
                                     4.0989e-02, -3.4537e-02,  3.9911e-01,  2.6089e-02,  2.7814e-02,
                                     4.0855e-01,  4.8781e-02,  1.2521e+00,  1.6534e-03,  1.0430e+00])),
                           ('layer_0.normalizer.weight',
                            tensor([ 0.5104,  0.5791,  0.3863,  0.5147,  0.4768,  0.6566,  0.8734,  0.6645,
                                     0.2071,  0.5819,  0.5087,  0.8138, -0.0092,  0.3608,  0.7038,  0.4413,
                                     0.4833,  0.4180,  0.3368,  0.8463,  0.9392,  0.6924,  0.9296,  0.3491,
                                     0.8069,  0.6619,  0.4793,  0.8459,  0.3441,  0.4590,  0.2950,  0.7371,
                                     1.0435,  0.5632,  0.8104,  0.3961,  0.8829,  0.4403,  0.6482,  0.1659,
                                     0.8168,  0.6445,  0.2457,  0.5658,  0.6175,  0.5902,  0.4836,  0.4961,
                                     0.8358,  0.6547,  1.0532,  0.2442,  0.7796,  0.6544,  0.4556,  0.7625,
                                     0.5656,  0.6227,  0.4305,  0.6411,  0.0154, -0.0189,  0.8002,  0.7829,
                                    -0.3495,  0.5086,  0.8894,  0.7311,  0.2554,  0.6265,  0.2209,  0.7341,
                                     0.3376,  0.6506,  0.5737,  0.6515,  0.3971,  0.6820,  0.2186,  0.9487,
                                     0.6806,  0.0019,  0.4236, -0.0177,  0.5589,  0.8026,  1.0764,  0.4386,
                                     0.9103,  0.5176,  0.5547,  0.7266,  0.5542,  0.8855,  0.5962,  0.2110,
                                     0.5829,  0.4500,  0.4146,  0.7857,  0.3986,  0.2688,  0.6396,  0.1095,
                                     0.3394, -0.0042,  0.8618,  0.8215,  0.6750,  0.5059,  0.6141,  0.3358,
                                     0.3855,  0.3334,  0.8454,  0.6317,  1.0210,  0.8109,  0.2147,  0.5198,
                                     0.0800,  0.7424,  0.2460,  0.8399,  0.8723,  0.5189,  0.7770,  1.0010,
                                     0.2620,  0.8994,  0.6743,  0.4495,  0.7078,  0.3760,  0.3484,  0.6523,
                                     0.6504, -0.3513,  0.3846,  0.3045,  0.6785,  0.9537,  0.3647,  0.3666,
                                     0.5900,  0.8485,  0.4425,  0.5642,  0.6941,  0.4128,  1.0695,  0.4578,
                                     0.3176,  0.6530,  0.7904,  0.7789,  0.7717,  0.3916,  0.4966,  0.3998,
                                     0.4341,  0.5248,  0.5466,  0.6723,  0.2765,  0.4336,  0.2317,  0.4454,
                                     0.1769,  1.0672,  0.7877,  0.2391,  0.3636,  0.2797,  0.3982,  0.2906,
                                     0.4604,  0.6001,  0.2475,  0.8258])),
                           ('layer_0.normalizer.bias',
                            tensor([-0.1962, -0.0443, -0.2787, -0.0800, -0.4596, -0.1044, -0.1859,  0.1330,
                                     0.0798, -0.1919, -0.1971, -0.2456, -0.0873, -0.3314, -0.1081, -0.3820,
                                    -0.2372, -0.1259,  0.2093, -0.1731, -0.0683, -0.2343, -0.2495, -0.1465,
                                    -0.4826,  0.0515, -0.1702, -0.4426,  0.0551, -0.0520, -0.1534, -0.0523,
                                    -0.0462, -0.1940,  0.1447, -0.0863, -0.1390, -0.1070,  0.0972,  0.0248,
                                    -0.4015, -0.2151, -0.1441,  0.0269, -0.0848, -0.1441, -0.3927, -0.4433,
                                    -0.4028, -0.0485, -0.1162,  0.1694, -0.1743, -0.2001, -0.3018, -0.5739,
                                     0.0599, -0.2187,  0.1332, -0.1887, -0.1445, -0.0679, -0.1672, -0.1422,
                                    -0.3938,  0.2836,  0.0275,  0.1534,  0.1156, -0.2450, -0.1039, -0.1504,
                                    -0.3893, -0.1549, -0.2389, -0.4126, -0.4205, -0.0470,  0.2166, -0.0405,
                                    -0.1702, -0.0595,  0.0128, -0.1007,  0.0836,  0.0919,  0.1459, -0.2995,
                                    -0.1740, -0.2381, -0.2750, -0.1821, -0.2118, -0.2194, -0.1279, -0.3674,
                                     0.0930, -0.0964,  0.0278, -0.3024,  0.0738, -0.0719, -0.4944, -0.2141,
                                    -0.1037, -0.0635, -0.3994, -0.1838, -0.1033, -0.1383, -0.0299,  0.0610,
                                    -0.1442,  0.0914, -0.1958, -0.2726, -0.2943, -0.1576,  0.1738,  0.0190,
                                     0.1938, -0.0365,  0.1417, -0.1462, -0.3113,  0.0542, -0.1261, -0.2680,
                                    -0.0870, -0.1613, -0.4117, -0.0631, -0.3835, -0.0387, -0.0150,  0.1125,
                                    -0.2066, -0.4758,  0.3709, -0.0092, -0.2423, -0.5018,  0.0690,  0.0409,
                                    -0.1773, -0.2771,  0.0214, -0.3464, -0.0111, -0.1203, -0.2852, -0.3696,
                                    -0.0033, -0.0295, -0.1827, -0.1715, -0.3728,  0.0648,  0.0032, -0.2143,
                                    -0.0813, -0.1441,  0.3781, -0.2067, -0.0122,  0.0530, -0.0088,  0.1383,
                                     0.0448,  0.1684, -0.1943, -0.0861,  0.1116, -0.2284, -0.0506, -0.0290,
                                    -0.1651,  0.2154, -0.3402, -0.0656])),
                           ('layer_0.normalizer.running_mean',
                            tensor([-1.2129e+06, -2.8267e+05, -1.8897e+05,  3.0676e+05, -3.7214e+05,
                                     2.1488e+05, -1.3920e+06,  3.9283e+05, -2.0499e+05, -7.2328e+05,
                                     1.0294e+06, -9.7885e+04, -8.7283e+05,  1.5692e+06,  2.8840e+05,
                                    -3.6137e+05, -3.1562e+05, -2.7778e+05, -3.0051e+05, -1.0735e+06,
                                    -1.8500e+05, -2.7350e+05, -2.9411e+05,  1.1861e+06, -3.7914e+05,
                                    -1.7805e+05, -1.1815e+06, -3.5938e+05,  6.2004e+05, -7.6983e+05,
                                     3.5743e+05, -2.6413e+05, -1.4222e+05, -1.5700e+05, -5.4438e+03,
                                    -1.1430e+06,  1.6714e+05, -1.1032e+06, -7.9594e+03,  2.0885e+05,
                                    -5.4440e+05, -6.4795e+05,  5.7594e+05, -2.3447e+05, -7.4664e+05,
                                     7.4651e+05, -7.0639e+05, -6.9201e+05, -2.2437e+05,  1.0384e+05,
                                    -8.4054e+05,  6.0643e+05,  3.3801e+03, -3.2249e+05, -4.2711e+05,
                                    -1.1043e+05, -2.9549e+05,  1.0297e+06, -1.1715e+05,  1.2755e+05,
                                     7.3041e+05,  8.6747e+05, -1.7673e+05, -7.2339e+04, -2.5830e+06,
                                    -6.3911e+04,  4.1610e+05,  1.6950e+05, -3.6926e+03, -4.5985e+05,
                                     2.6051e+05, -7.6430e+05,  1.2047e+06, -4.1445e+05, -6.9484e+05,
                                    -9.0072e+05,  1.4235e+06,  3.1018e+05, -2.3230e+04, -2.7941e+05,
                                    -3.4013e+04,  3.8489e+04, -6.6028e+05,  3.6265e+05,  1.4223e+06,
                                    -1.5331e+05, -1.1047e+05, -8.2704e+05, -1.0100e+05, -5.6225e+04,
                                    -1.0116e+06,  5.0040e+05, -1.1833e+06, -7.6824e+04,  1.3367e+06,
                                     2.0518e+06,  5.9261e+05, -3.5650e+05,  1.5055e+05, -2.9754e+05,
                                    -2.4155e+05, -1.4951e+06,  6.6657e+05,  6.3000e+04,  6.2948e+05,
                                    -4.8433e+05, -4.4355e+05, -1.0859e+05, -2.5631e+05,  4.4660e+05,
                                    -2.3497e+05,  2.0366e+04,  1.9218e+05, -2.9244e+05, -3.2546e+05,
                                    -5.2584e+05, -7.7124e+04,  3.7811e+05, -2.3890e+05, -1.7612e+05,
                                    -5.7066e+05,  4.8926e+04, -2.7278e+05,  9.3206e+05, -2.6738e+05,
                                    -2.6777e+05, -1.4219e+05, -4.1422e+05,  7.3759e+05, -6.6235e+05,
                                     1.0376e+06, -2.9605e+05,  4.3901e+04, -1.4490e+04, -1.0147e+05,
                                    -2.4937e+04, -5.1361e+04, -1.7725e+06,  2.0963e+04,  4.7144e+05,
                                    -9.8638e+05, -4.2225e+05,  7.8836e+04,  7.1857e+04, -4.3700e+02,
                                    -5.0268e+05,  3.0922e+05, -1.3685e+05, -7.1152e+05, -3.0388e+05,
                                    -1.3728e+05,  8.0650e+05, -2.1282e+05,  6.7584e+05, -2.0512e+05,
                                    -1.5178e+05, -6.8367e+05,  4.0433e+05, -4.0490e+05,  8.2663e+05,
                                    -2.6122e+05, -7.7120e+05, -7.1985e+03, -3.4465e+05,  2.2040e+05,
                                    -4.8741e+05, -7.2770e+05, -2.9618e+04,  4.6617e+05, -7.7764e+04,
                                    -3.3846e+05,  1.5320e+05, -6.9355e+04, -3.7454e+05, -1.2436e+05,
                                     8.0627e+05, -1.0685e+06, -1.6126e+05,  1.0174e+06, -2.2256e+05])),
                           ('layer_0.normalizer.running_var',
                            tensor([2.8734e+12, 1.0846e+11, 4.1629e+12, 5.6613e+11, 3.7155e+12, 1.1434e+11,
                                    1.0065e+13, 3.2466e+11, 8.6504e+11, 7.8276e+11, 1.9834e+12, 4.1663e+10,
                                    2.1929e+12, 1.0113e+13, 1.1670e+12, 1.8113e+12, 2.1190e+11, 2.4035e+11,
                                    6.7875e+11, 5.2790e+12, 1.9742e+11, 6.1139e+11, 4.2168e+11, 1.4509e+12,
                                    1.2841e+12, 1.9323e+11, 1.2050e+12, 1.4804e+11, 1.1170e+12, 3.4495e+12,
                                    5.7916e+11, 2.9058e+11, 1.1115e+11, 5.0044e+10, 2.3590e+10, 1.5554e+12,
                                    7.8063e+10, 3.3613e+12, 3.5218e+11, 1.0136e+11, 3.4714e+11, 1.1977e+12,
                                    3.1281e+11, 2.4396e+11, 2.0106e+12, 7.3961e+11, 2.3273e+12, 2.2810e+12,
                                    6.7913e+11, 3.6403e+10, 4.0202e+12, 1.7135e+12, 1.9181e+10, 4.7717e+11,
                                    4.6237e+11, 6.4884e+11, 7.9020e+11, 6.5175e+11, 1.7558e+11, 1.2140e+12,
                                    9.9990e+11, 1.6053e+12, 5.6199e+11, 8.0261e+10, 4.2270e+12, 4.0093e+10,
                                    1.7091e+11, 7.3718e+10, 3.5427e+10, 1.1907e+12, 1.8352e+11, 2.5166e+12,
                                    1.9807e+12, 1.3332e+12, 3.2925e+11, 1.2385e+12, 1.7686e+12, 3.5039e+11,
                                    8.4248e+10, 4.6650e+11, 1.1013e+11, 3.3575e+12, 5.2994e+11, 6.9099e+11,
                                    1.7189e+12, 1.5802e+11, 8.0225e+10, 2.2410e+12, 6.5210e+10, 9.8536e+10,
                                    1.2242e+12, 1.6420e+12, 2.2263e+12, 9.3532e+10, 3.2451e+12, 4.3126e+12,
                                    1.7026e+12, 1.4901e+12, 1.2225e+11, 2.8621e+11, 2.1624e+11, 2.8024e+12,
                                    3.3252e+11, 3.8007e+11, 2.4959e+12, 4.3159e+11, 1.9044e+12, 9.7064e+10,
                                    1.3783e+11, 2.6057e+11, 5.2831e+11, 2.9040e+11, 1.8034e+12, 4.2261e+11,
                                    5.9581e+11, 2.3897e+12, 2.9622e+10, 2.5139e+11, 3.8814e+11, 2.0273e+11,
                                    1.0169e+12, 6.6312e+11, 3.9934e+11, 1.5518e+12, 1.6059e+11, 2.0429e+11,
                                    7.3254e+10, 3.9329e+11, 5.0400e+12, 2.5599e+12, 1.3495e+12, 3.0244e+11,
                                    4.8005e+10, 5.1000e+10, 3.0274e+11, 2.7907e+10, 5.9713e+11, 2.4381e+12,
                                    9.9436e+09, 8.3648e+11, 3.2557e+12, 4.1370e+11, 5.8659e+10, 3.1412e+10,
                                    8.6003e+10, 9.1349e+11, 2.6566e+11, 6.0918e+11, 2.4435e+12, 4.5672e+11,
                                    8.8486e+10, 6.6567e+11, 1.4290e+11, 4.8815e+12, 4.3067e+11, 1.8827e+11,
                                    4.8827e+12, 2.1159e+11, 9.6153e+11, 1.4606e+12, 4.5325e+11, 4.5757e+11,
                                    1.0029e+10, 1.4708e+11, 6.1296e+11, 6.3223e+11, 4.4463e+11, 1.1444e+10,
                                    1.1462e+12, 5.3529e+10, 6.8962e+11, 5.9348e+11, 2.1475e+11, 4.1514e+11,
                                    1.0480e+12, 1.2999e+12, 3.4963e+12, 8.6104e+10, 9.5250e+11, 2.2460e+11])),
                           ('layer_0.normalizer.num_batches_tracked',
                            tensor(670)),
                           ('layer_1.linear.weight',
                            tensor([[-0.3181, -0.0467,  0.1220,  ..., -0.3037,  0.0751, -0.5011],
                                    [ 0.1769,  0.0907,  0.6257,  ...,  0.0173,  0.1333,  0.4019],
                                    [ 0.0093, -0.2277, -0.0919,  ..., -0.2902, -0.1235, -0.2536],
                                    ...,
                                    [ 0.4216,  0.1898, -0.1141,  ...,  0.1129, -0.0119,  0.1866],
                                    [-0.0128,  0.0372, -0.3504,  ..., -0.1693, -0.2387, -0.3789],
                                    [-0.1159,  0.1799,  0.2841,  ...,  0.1712,  0.0442, -0.0150]])),
                           ('layer_1.linear.bias',
                            tensor([ 0.4272, -0.9039,  0.1904, -0.2160,  0.4101, -0.1103,  0.4987,  0.2544,
                                     0.1019,  0.0742,  0.6390,  0.0388,  0.0458, -0.6708, -0.0971, -0.1611,
                                     0.0943, -1.1995, -0.0920,  0.3232,  0.2385, -0.1232,  0.1995, -0.4173,
                                    -0.0610,  0.0653,  0.1749,  0.0328,  0.4999, -0.2223,  0.0241,  0.7735,
                                     0.1619,  0.0434, -0.1896, -0.3395,  0.0403,  0.3407, -0.3957,  0.1370,
                                     0.5093,  0.2799,  0.0323, -0.0116,  0.2430,  0.2210, -0.0503,  0.0454,
                                    -0.0668,  0.4199,  0.1628, -1.0166,  0.4744, -0.3817,  0.1295, -0.2087,
                                     0.5544,  0.1096,  0.0749, -1.4452,  0.0350, -0.5916,  0.0621,  0.0480,
                                     0.1958,  0.4391,  0.5505,  0.2575,  0.0325,  0.3898,  0.3660,  0.2256,
                                     0.4397, -0.0505,  0.3201,  0.3299,  0.1729,  0.2120,  0.3927,  0.1806,
                                    -0.0665,  0.2679,  0.0802,  0.2549,  0.3061, -0.5109,  0.2203, -0.1280,
                                     0.2171, -0.0454,  0.3390,  0.1481,  0.8642,  0.1444, -0.2233, -0.2728,
                                    -0.1692,  0.2542,  0.1593, -0.4494, -0.0971,  0.2951, -0.1180,  0.2975,
                                     0.1763, -0.0883,  0.1806,  0.2068, -0.1619,  0.0125, -0.0279,  0.3348,
                                     0.5742, -0.1454,  0.1260,  0.6690,  0.3612, -0.0294,  0.1213,  0.4081,
                                    -0.2089,  0.2120,  0.2244,  0.4372, -0.0115,  0.7784,  0.1602, -0.0813,
                                    -0.0493, -0.1658,  0.3238,  0.2760, -0.0432, -0.0420, -0.0659,  0.3089,
                                     0.0190,  0.5050,  0.6044, -0.2358,  0.2355, -0.2841,  0.0530,  0.4726,
                                     0.5186,  0.2180,  0.2962, -0.1304,  0.1614,  0.3440,  0.2312,  0.6344,
                                    -0.3127, -0.2918,  0.2700, -0.1242,  0.6109, -0.1531,  0.0123,  0.2946,
                                     0.5476,  0.4071,  0.2653, -0.4429,  0.1034, -0.0617,  0.2970, -0.0116,
                                    -0.3955, -0.0198,  0.6731,  0.4815,  0.4867,  0.0948,  0.2610,  0.1566,
                                     0.0384])),
                           ('layer_1.normalizer.weight',
                            tensor([ 0.7724,  0.5528,  0.7754,  0.7696,  0.2109,  0.4952,  0.7166,  0.2034,
                                     0.3974,  0.4473,  0.4182,  0.4169,  0.8112,  0.9732,  0.6718,  0.6471,
                                     0.6182,  0.2758,  0.2656,  0.7420,  0.8008,  0.0970,  0.3573,  0.7887,
                                     0.4323,  0.7148,  0.8219,  0.4425,  0.2586,  0.8524,  0.7631, -0.1480,
                                     0.6211,  0.7283,  0.7019,  0.6664,  0.7825,  0.2799,  0.7122,  0.8373,
                                     0.3143,  0.5367,  0.8127,  0.5353,  0.6630,  0.7752,  0.2445,  0.5905,
                                     0.2394,  0.8094,  0.6661,  0.2757,  0.3573,  0.5570,  0.7235,  0.6305,
                                     0.8710,  0.4753,  0.3409,  0.3190,  0.6961,  0.5270,  0.6013,  0.0756,
                                     0.7081,  0.5906,  0.1810,  0.7286, -0.0057,  0.8344,  0.4059,  0.1987,
                                     0.7828,  0.7904,  0.8655,  0.4139,  0.7423,  0.3984,  0.7705,  0.7546,
                                     0.5489,  0.4179,  0.3176,  0.3345,  0.7589,  0.3580,  0.1110,  0.5608,
                                     0.5362,  0.0094,  0.2373,  0.3635,  0.7349,  0.5421,  0.7940,  0.8640,
                                     0.0260,  0.2689,  0.6665,  0.7943,  0.5509,  0.4087,  0.7532,  0.4168,
                                     0.9124,  0.0878,  0.5972,  0.6600,  0.7846,  0.3822,  0.3208,  0.6693,
                                     0.4407,  0.4913,  0.7079,  0.6869,  0.3143,  0.7635,  0.7462,  0.1772,
                                     0.7289,  0.0499,  0.5132,  0.6099,  0.8893,  0.6448,  0.4385,  0.7160,
                                     0.6739,  0.3674,  0.1015,  0.7265, -0.0049,  0.0171,  0.0766,  0.4091,
                                     0.7618,  0.4541,  0.7765,  0.9321,  0.0817,  0.4848,  0.0063,  0.3314,
                                     0.5573,  0.7276,  0.2552,  0.6275,  0.5366, -0.2922,  0.0431,  0.6978,
                                     0.8908,  0.6146,  0.4422,  0.6265,  0.3998,  0.7618,  0.8459,  0.6532,
                                     0.6608,  0.6642,  0.3993,  0.7534,  0.8476,  0.7271,  0.7883,  0.5109,
                                     0.6671,  0.3548, -0.3233,  0.6174,  0.6103,  0.8735,  0.1984,  0.8545,
                                     0.5555])),
                           ('layer_1.normalizer.bias',
                            tensor([-4.8687e-01, -6.5793e-01, -1.7362e-01, -3.0067e-01,  1.2581e-01,
                                    -2.2906e-01,  3.5819e-02,  5.0886e-02, -2.7870e-01, -1.7128e-01,
                                     3.4052e-02, -4.9039e-01, -5.6463e-01, -5.9361e-01, -4.0186e-01,
                                    -9.2585e-02, -2.2541e-01, -7.4718e-01, -5.0870e-03, -2.2716e-01,
                                    -2.0561e-02, -2.0329e-01, -8.4164e-02, -5.8757e-01, -1.9374e-01,
                                    -3.3367e-01, -2.2995e-01, -1.2856e-01,  9.6702e-03, -3.8536e-01,
                                    -3.0355e-01, -4.4498e-01, -1.8743e-01, -1.9710e-01, -5.3411e-01,
                                    -6.9149e-01, -4.0172e-01,  2.4301e-02, -3.2459e-01, -9.3060e-02,
                                    -3.6030e-02, -1.1579e-01, -2.8658e-01, -2.4621e-01, -6.2366e-03,
                                    -1.2196e-01, -1.3555e-01,  7.4362e-03, -2.5930e-01, -4.3468e-01,
                                    -1.5946e-01, -6.9949e-01, -4.7001e-02, -4.9838e-01, -2.6839e-01,
                                    -5.2396e-01, -1.5207e-01, -1.1730e-01, -2.5288e-01, -6.5695e-01,
                                    -1.9395e-01, -4.9804e-01, -1.9134e-01, -1.1425e-01, -2.8989e-01,
                                    -2.3164e-01,  1.1617e-01, -2.5385e-01, -9.5173e-02, -3.7006e-01,
                                     5.5945e-02,  2.0635e-01, -3.8304e-01, -4.2781e-01, -1.5396e-01,
                                    -1.1947e-01, -2.7199e-02, -2.8786e-01, -2.0645e-01, -1.5319e-01,
                                    -3.2556e-01, -1.1684e-01,  5.5866e-03, -3.4911e-01, -2.7693e-01,
                                    -3.4123e-01,  2.0492e-01, -3.2356e-01, -1.6581e-01, -9.6886e-02,
                                     1.1676e-01, -1.9239e-01, -1.6082e-01, -7.9961e-02, -5.2100e-01,
                                    -5.1340e-01, -1.8974e-01,  2.5859e-02, -1.3545e-01, -3.9954e-01,
                                    -7.8536e-02, -2.9645e-01, -4.8769e-01,  2.5232e-02,  1.1020e-01,
                                    -1.9782e-01, -1.7286e-01, -2.3323e-01, -4.2869e-01,  3.8918e-02,
                                    -1.7333e-01, -3.7193e-01,  1.0578e-01, -9.5304e-02, -4.1619e-01,
                                    -2.1274e-02,  2.9453e-01, -2.1269e-01, -1.0715e-01,  1.2427e-01,
                                    -1.1718e-01,  1.1451e-01, -4.9374e-04, -1.6703e-01, -1.1600e-01,
                                     7.1238e-02,  2.0437e-02, -1.5701e-01, -8.9630e-02, -2.1474e-01,
                                     1.4690e-01,  1.7623e-02, -5.2768e-02, -8.0910e-02,  1.1171e-01,
                                     9.7640e-02, -4.6893e-01,  2.8726e-02, -2.6173e-01, -6.1300e-01,
                                     1.5880e-01, -2.9128e-01, -4.5310e-02,  1.8769e-01, -3.5182e-01,
                                    -3.7850e-01,  1.3322e-03, -1.7291e-01, -1.0457e-01, -4.1100e-01,
                                    -9.8451e-02, -4.5245e-01, -7.2406e-01, -4.2713e-01,  3.8897e-02,
                                    -3.4561e-01, -4.0211e-02, -3.8704e-01, -3.5258e-01, -2.0678e-01,
                                    -3.6973e-01, -3.4208e-01, -3.0089e-02, -4.2235e-01, -1.4618e-01,
                                     1.4217e-01, -1.5353e-01, -1.9246e-01, -5.5187e-01, -9.1403e-02,
                                    -4.0453e-01, -1.8983e-01, -1.1410e-01, -3.3431e-01,  1.1839e-01,
                                    -1.3737e-01, -2.5578e-01])),
                           ('layer_1.normalizer.running_mean',
                            tensor([-1.0012,  0.9059, -0.5933,  1.2517, -0.1599, -0.2621, -0.5754, -0.7055,
                                    -0.1043,  0.3594, -0.0773, -0.7850, -0.7410, -0.4356, -0.0951, -0.8019,
                                     0.1038,  0.1979, -1.2412, -0.1483, -0.2528, -1.2501, -0.0598,  0.3535,
                                    -0.0999, -0.4959, -0.0107, -1.2390, -0.0301, -0.1888,  0.5956, -0.5109,
                                    -0.9230,  0.5715,  0.1734,  0.0887, -0.2923, -0.4678,  0.1751, -0.5447,
                                    -1.0108,  0.0278, -0.5504,  0.2809, -0.8226, -0.3325, -0.2127, -0.6235,
                                     0.2548,  0.5462,  0.0092,  0.2038, -0.6879, -0.1080, -1.1428,  0.1823,
                                    -0.1586, -0.4211,  0.4230,  1.0469, -0.1016,  0.4202, -0.1107, -1.6935,
                                    -0.0985, -0.3614, -0.2401, -0.4718,  0.2483, -0.3143,  0.4494,  0.2874,
                                    -0.5771, -0.1627,  0.1163, -0.6621, -0.3129, -0.2038, -0.2718, -0.7657,
                                     0.0480,  0.9996, -0.6919,  1.5673,  0.0111, -0.5343,  0.6665,  1.6854,
                                    -0.9618,  0.1406, -1.0776,  0.3344, -0.9839, -0.1319,  0.0170,  1.1883,
                                    -1.3508, -0.4064, -0.9850,  0.1428,  0.9735, -0.3728, -0.0558, -0.4150,
                                     0.0863, -1.1260, -0.2841, -0.0181, -0.5096, -0.4574,  0.2621, -0.2732,
                                    -0.3459,  0.3559, -0.1139, -0.2430,  0.4331, -0.7160,  0.1169,  0.2474,
                                    -0.6933,  0.7129, -0.8441, -0.0880, -0.1472, -0.2216, -0.6946, -0.9052,
                                    -1.0473,  0.4883, -0.7345, -0.5874, -0.1417, -0.5711,  0.1673, -1.0069,
                                    -0.2246,  0.5413, -0.9195, -0.5592, -0.5670, -0.2906, -1.0233,  0.0110,
                                    -0.4851,  0.3339, -1.1661, -1.1126,  0.0384,  0.0168, -1.2003, -0.5942,
                                     0.7896,  0.4963, -0.3686,  0.1805,  0.1795, -0.1698,  0.3953, -0.3368,
                                    -0.4614, -0.0705, -0.2032,  0.3827, -0.1626, -0.1555, -1.2359,  0.2446,
                                    -0.8220,  0.3409,  0.4112, -0.0505, -0.7286,  0.9226, -0.3724, -0.6035,
                                    -0.2093])),
                           ('layer_1.normalizer.running_var',
                            tensor([ 1.5784,  4.3603,  3.0406,  3.9843,  2.5842,  1.6847,  4.2242,  2.3038,
                                     1.4713,  2.3048,  3.4151,  1.8586,  2.6211,  1.4248,  1.6324,  3.5756,
                                     1.8121, 11.3280,  3.3998,  2.4089,  1.8012,  1.3024,  2.6088,  1.7034,
                                     1.7349,  2.2189,  2.0389,  2.4620,  1.6277,  1.7833,  1.6252,  2.9811,
                                     3.4582,  1.0297,  1.1774,  1.5030,  1.6943,  2.2126,  0.9261,  3.2761,
                                     3.3161,  2.1361,  1.7398,  1.4311,  3.6328,  2.6587,  1.8304,  2.9714,
                                     1.9231,  1.7223,  1.9149,  9.2152,  4.8232,  1.8995,  2.0808,  1.6926,
                                     1.0224,  2.7655,  1.2581, 10.1292,  1.2400,  2.4070,  1.8323,  3.7722,
                                     2.4576,  2.8575,  3.3780,  1.9544,  0.2371,  2.2880,  3.0949,  2.1179,
                                     1.9815,  1.7560,  1.2390,  3.1567,  2.6115,  1.3628,  2.7039,  1.5545,
                                     1.0431,  2.3704,  4.2586,  4.2653,  1.0075,  2.4698,  2.4783,  3.4799,
                                     4.7091,  0.1635,  3.5400,  1.0513,  3.8962,  2.7156,  1.4790,  2.0708,
                                     1.2599,  3.2595,  3.0823,  1.0620,  3.8206,  3.5075,  1.2568,  3.8073,
                                     1.5245,  0.6714,  2.0583,  1.1045,  1.7741,  1.8926,  2.9702,  1.6884,
                                     4.6135,  1.3068,  1.2621,  2.7422,  1.9513,  2.1740,  1.9496,  2.4949,
                                     4.0215,  1.8097,  2.8780,  1.5046,  1.2563,  3.1307,  4.7151,  4.2565,
                                     2.1220,  1.8459,  4.7927,  3.1600,  0.0777,  0.2535,  1.7539,  2.6288,
                                     1.2955,  2.9272,  1.5976,  1.9883,  2.9310,  1.9140,  0.8124,  1.6121,
                                     1.0627,  2.2470,  3.8162,  3.4301,  5.9179, 10.3987,  1.6794,  1.4560,
                                     1.0918,  1.5822,  2.9202,  1.4964,  3.1723,  0.9711,  1.6405,  3.9672,
                                     1.6582,  1.3467,  2.9507,  1.8645,  2.6762,  3.0827,  3.7867,  1.4285,
                                     2.0862,  2.0662,  7.5074,  2.5092,  2.6214,  2.2816,  2.0605,  2.0044,
                                     2.0017])),
                           ('layer_1.normalizer.num_batches_tracked',
                            tensor(670)),
                           ('layer_2.linear.weight',
                            tensor([[ 0.1494, -0.1234, -0.2821,  ..., -0.1532, -0.0732, -0.1128],
                                    [-0.0551,  0.0989, -0.3374,  ..., -0.0806, -0.0468,  0.0616],
                                    [-0.0130,  0.0661, -0.1006,  ...,  0.2441,  0.1385,  0.0855],
                                    ...,
                                    [ 0.1844, -0.2233, -0.3211,  ...,  0.0852, -0.2506,  0.1146],
                                    [ 0.2989, -0.0142,  0.0532,  ...,  0.2498,  0.1741,  0.1107],
                                    [ 0.0098,  0.3458,  0.2180,  ..., -0.1585,  0.2146,  0.1526]])),
                           ('layer_2.linear.bias',
                            tensor([ 3.0239e-01, -5.3302e-02,  2.2736e-04,  2.0087e-01, -3.8188e-02,
                                    -1.3411e-01,  1.2893e-01, -2.1880e-02, -5.9000e-02, -5.4540e-02,
                                     5.2248e-02,  1.0601e-01,  2.3720e-01,  1.8199e-01, -8.2406e-02,
                                     2.6802e-01,  2.7499e-01,  1.5136e-01,  6.0337e-02,  5.2905e-02,
                                     2.6242e-01,  5.0922e-01,  1.0646e-01,  6.9072e-02,  5.6843e-01,
                                     7.4985e-02,  3.5433e-01,  2.1146e-01,  1.9029e-02, -2.7482e-02,
                                     2.5046e-02, -1.1538e-01,  9.9619e-02,  2.5464e-01,  5.0537e-01,
                                     4.9756e-01, -6.2948e-02,  1.4187e-01, -3.8272e-01,  1.1998e-01,
                                     2.3396e-01,  8.7638e-01, -6.1409e-02, -1.7969e-01, -4.5910e-02,
                                     2.2924e-01,  2.2706e-01,  3.2552e-01,  1.0975e-01,  6.2968e-02,
                                    -6.4498e-02,  2.7169e-01,  2.5185e-01,  1.8035e-01,  3.2196e-01,
                                     2.3540e-01, -2.2477e-01, -1.8710e-01,  1.4068e-01,  4.5910e-01,
                                     1.6799e-01,  6.5608e-03,  1.2377e-01,  5.6258e-02, -7.8144e-01,
                                     4.3098e-01,  5.6467e-01, -3.8760e-02, -1.3453e-01,  1.7881e-02,
                                     4.1133e-01,  4.2046e-02,  5.0666e-01,  6.8481e-01,  4.9632e-01,
                                     7.2266e-02,  8.6528e-02,  1.9726e-01,  2.0609e-01,  2.1501e-01,
                                     1.3276e-01,  1.6760e-01,  2.7045e-01,  4.3999e-02,  3.6845e-01,
                                     5.2264e-02,  7.5264e-01,  1.1433e-01,  8.4747e-02,  3.0166e-01,
                                     1.9079e-01,  1.0732e-01,  1.6198e-01, -4.9254e-01,  3.7556e-01,
                                     7.9555e-02,  4.4684e-02, -1.6131e-01, -8.8023e-03, -1.7224e-01,
                                     8.8625e-02,  5.8451e-02, -6.2197e-03,  2.8570e-01,  3.5402e-01,
                                     6.2331e-01,  1.1865e-01,  3.0338e-01,  1.5998e-01,  4.2997e-01,
                                     1.6426e-01,  2.3096e-01,  3.6988e-01,  1.8134e-01,  2.1643e-01,
                                     1.7414e-01,  1.7112e-01,  5.2216e-01,  2.0863e-01, -7.4228e-02,
                                    -2.9927e-02,  3.0342e-01, -7.7274e-02,  9.7756e-02,  1.1052e-01,
                                    -4.7219e-01,  6.3637e-01, -9.5980e-01,  9.8897e-02,  2.1375e-01,
                                     4.4751e-02, -8.4467e-03,  1.0384e-01,  2.0667e-01,  1.5470e-01,
                                     6.8581e-02,  1.7625e-01,  4.8077e-02, -1.1737e-01,  2.7715e-01,
                                     3.4703e-01, -1.4836e-01,  6.2579e-02,  3.4763e-02,  1.2605e-01,
                                    -6.6361e-02,  1.4257e-01,  5.6900e-01,  4.5195e-02,  3.9331e-01,
                                    -5.3968e-03,  6.3631e-01,  1.8731e-02,  1.5331e-03,  3.2232e-01,
                                     2.8391e-01,  6.9658e-02,  2.5956e-01,  7.0743e-04,  8.3131e-02,
                                    -4.0698e-02, -2.9941e-02])),
                           ('layer_2.normalizer.weight',
                            tensor([ 0.5702,  0.2644,  0.0021,  0.3180,  0.0068,  0.7779,  0.9335,  0.8019,
                                    -0.0026,  0.0130,  0.2658,  0.7768,  0.2027,  0.8133,  0.4755,  0.4897,
                                     0.8120,  0.8660,  0.7208,  0.6162,  1.0882,  0.4964,  0.3053,  0.7868,
                                     0.7445,  0.4685,  0.4096,  0.6145,  1.0769,  0.5526,  0.0037,  0.4450,
                                     0.8094,  0.4504,  0.1680,  0.2447,  0.7307,  0.8844,  0.7824,  0.9964,
                                     0.6331,  0.3162,  0.7100,  1.1262,  0.7011,  0.6610,  0.3499,  0.3103,
                                     0.3882,  0.4211,  0.8804,  0.8797,  0.8348,  0.3552,  0.8234,  0.6807,
                                     0.8677,  1.0028,  0.5138,  0.4255,  0.2642,  0.3518,  0.8169,  1.1435,
                                     0.4310,  0.6880,  0.4420,  0.4483,  0.3779,  0.8295,  0.7879,  0.4734,
                                     0.4358, -0.3511,  0.8456,  0.9694,  1.0054,  0.7382,  0.5879,  0.7284,
                                     0.2063,  0.6878,  0.2600,  0.2843,  0.3376,  0.7433,  0.2563, -0.2748,
                                     0.2734,  0.8698,  0.8010,  0.5559,  0.4545,  0.4025,  0.3367,  0.3711,
                                     0.6905,  0.4669,  0.2259,  0.7472,  0.5634,  0.5947,  0.5681,  0.5274,
                                     0.3540,  0.2592,  0.7315,  0.1985,  0.6367,  0.6228,  0.5678,  0.7469,
                                     0.4683,  0.7877,  0.3997,  0.7646,  0.4645,  0.5035,  0.9133,  0.0689,
                                     0.5252,  0.2042,  0.3749,  0.8870,  0.2443,  0.5081,  0.5015,  0.2709,
                                     0.8070,  0.6796,  0.3015,  0.7249, -0.0299,  0.2952,  0.8090,  0.8051,
                                     0.5114,  0.7190,  0.7958,  0.8325,  0.2930,  0.5036,  0.4367,  0.4449,
                                     0.3949,  0.2757,  0.4136,  0.8340,  0.1668,  0.3975,  0.6852,  0.3907,
                                     0.9086,  1.0399,  0.4983,  0.4873,  0.9199,  0.7441,  0.3017,  0.9651,
                                     0.7199,  1.2637])),
                           ('layer_2.normalizer.bias',
                            tensor([-0.2877, -0.1392, -0.0760, -0.1539,  0.1394, -0.3909, -0.2643,  0.0480,
                                    -0.0762, -0.0654, -0.1486, -0.4975,  0.0746, -0.2274, -0.5470,  0.1071,
                                    -0.7426, -0.4433, -0.0887, -0.1903, -0.6628, -0.0657, -0.1192, -0.3334,
                                     0.3805, -0.3382,  0.0676, -0.1222, -0.1984, -0.2463, -0.0655, -0.3659,
                                    -0.2818, -0.2096,  0.1345,  0.2461, -0.1120, -0.1843, -0.4706, -0.0271,
                                    -0.1814,  0.2776, -0.1216, -0.6338, -0.4935, -0.2531, -0.2297, -0.1417,
                                    -0.0701, -0.2186, -0.5407, -0.4192, -0.1523, -0.1759,  0.0016, -0.0079,
                                    -0.4961, -0.0490, -0.2012, -0.3597, -0.3253, -0.0746, -0.1436, -0.1339,
                                    -0.4577, -0.0792,  0.1882, -0.2153,  0.0198, -0.4788, -0.0718, -0.7235,
                                    -0.2115, -0.5022,  0.0753,  0.0306, -0.4394, -0.1115, -0.0326, -0.1410,
                                    -0.2049, -0.2567,  0.3239, -0.2708,  0.1325,  0.3172,  0.2874, -0.2593,
                                     0.0281, -0.3286,  0.0809, -0.1017, -0.1189, -0.1250, -0.0465, -0.2186,
                                    -0.2037, -0.5524, -0.1744, -0.3866, -0.2166,  0.1573, -0.3897,  0.0766,
                                    -0.2397,  0.2393, -0.3184, -0.0487, -0.3327, -0.1560, -0.4015,  0.0770,
                                    -0.0272,  0.1276, -0.1546, -0.5013,  0.0713, -0.0201, -0.0893, -0.1252,
                                    -0.2020,  0.3025, -0.2530,  0.2395, -0.3860, -0.4408,  0.0632, -0.5001,
                                    -0.4768, -0.0856,  0.1180, -0.4356, -0.0909, -0.0226, -0.2583, -0.1047,
                                    -0.2219,  0.0344, -0.3541,  0.2453,  0.2751, -0.0668, -0.2798,  0.1193,
                                     0.2693, -0.2330, -0.2963,  0.0207,  0.0744,  0.3772, -0.0594,  0.2716,
                                    -0.5286, -0.0988,  0.0071,  0.0785, -0.4082, -0.0996, -0.2286, -0.1072,
                                    -0.0199, -0.0759])),
                           ('layer_2.normalizer.running_mean',
                            tensor([-0.4771,  0.3641,  0.8236, -0.5117, -0.0686, -0.1060, -0.2271,  1.1366,
                                    -0.2673, -0.6526, -0.0598, -0.3061, -0.8754, -0.5231, -0.1757,  0.4500,
                                    -0.8126, -0.2425, -0.5052, -0.7121, -0.4026, -0.9407, -0.2459, -0.9544,
                                     0.4757, -0.5780, -0.4919,  0.7358, -0.7519, -1.2301, -0.5611, -0.1652,
                                    -0.2985, -1.5551, -0.4821,  0.0284,  0.7578, -0.6171,  1.1162, -0.2456,
                                     0.0706,  0.7641, -0.7250, -0.0534,  0.3452, -1.3787, -0.1682, -0.4565,
                                     0.0713, -0.7465, -0.1761, -0.3235, -0.1148, -1.7422, -0.4193, -0.4965,
                                     1.6195, -0.7968, -0.8410, -0.0312, -0.3054, -1.1875, -0.0566, -0.3492,
                                     0.0022, -0.2289,  0.0654, -0.3188,  1.0440, -0.6859,  0.0535, -0.2057,
                                    -0.9382, -1.3899,  0.4321,  0.1345, -0.8927,  0.0757, -0.8157, -0.7032,
                                    -0.0305, -0.7336,  0.8059, -0.1560,  0.6057,  0.1573,  1.3014, -0.1614,
                                     0.6065, -0.0084, -1.0559, -0.9398, -1.2521,  0.4449, -0.5810, -1.4539,
                                    -0.2939,  0.2501, -0.4125,  0.1147, -0.6748,  0.5822, -1.2044, -0.1167,
                                     0.0862,  0.5174, -0.8685, -1.8570,  0.0200, -0.6124, -0.3150,  0.5020,
                                    -0.7448, -0.5465, -1.0633, -0.2995,  0.4695, -0.5782, -0.0652, -0.5671,
                                    -1.0913, -0.6757,  0.6314, -0.3749, -0.8141, -0.3002,  0.3342, -0.5968,
                                    -0.5119,  0.0035,  0.1476,  0.0465,  0.4401,  0.0067,  0.2266, -0.7001,
                                    -0.3087, -0.8820, -0.2069,  0.8888, -0.2508,  1.0863, -0.4786,  0.1945,
                                     0.5957, -0.1170, -0.3074,  0.1786, -0.8405,  0.0903,  0.1077,  0.4824,
                                    -0.3188,  0.6018,  0.2859, -1.1954, -0.3282, -0.7685, -0.7977, -0.6385,
                                     0.9811,  0.7143])),
                           ('layer_2.normalizer.running_var',
                            tensor([1.2081, 0.7104, 0.2791, 2.3922, 1.3909, 2.2632, 2.7923, 3.4170, 0.2371,
                                    0.8223, 1.3578, 1.3718, 3.6416, 1.1920, 1.7282, 2.5161, 1.9402, 0.8729,
                                    1.1787, 1.5293, 0.9767, 3.6726, 1.8250, 1.4192, 3.7494, 1.0485, 2.5048,
                                    2.8385, 2.6879, 1.5259, 0.3320, 5.0883, 2.3942, 2.9072, 2.1876, 4.2835,
                                    2.2496, 1.4994, 1.1825, 3.0025, 1.3385, 4.1513, 1.3618, 1.3812, 3.3829,
                                    2.7717, 2.4212, 2.5637, 0.7795, 2.8931, 1.4690, 1.9977, 3.8444, 2.4258,
                                    0.9577, 3.2737, 2.3147, 2.9155, 1.6753, 1.4580, 1.7667, 2.2222, 1.0000,
                                    3.4903, 4.8764, 1.9582, 2.8803, 1.6848, 3.7632, 2.2867, 2.3507, 2.0570,
                                    1.8109, 2.8613, 3.5178, 2.9047, 1.7047, 0.8481, 2.2359, 1.6086, 1.2468,
                                    1.9544, 4.4804, 1.5151, 2.4366, 5.1584, 4.9039, 7.4413, 3.6987, 1.0244,
                                    2.5796, 1.8977, 1.5950, 2.9663, 2.5948, 1.5380, 4.0055, 1.2963, 1.2724,
                                    1.0664, 1.9278, 1.7581, 1.5861, 3.4465, 1.5621, 5.3650, 1.7151, 5.2662,
                                    1.2776, 1.0102, 0.9393, 3.4590, 2.4199, 1.5162, 2.2081, 1.1188, 3.0619,
                                    2.5459, 1.2963, 0.4084, 2.9916, 1.2197, 1.8776, 2.3511, 2.1200, 1.6047,
                                    2.1336, 2.0964, 1.9538, 3.3240, 2.2276, 1.0014, 0.3187, 2.3750, 2.5523,
                                    1.7299, 1.6644, 0.9426, 1.8765, 4.0248, 2.9807, 2.4890, 1.6497, 4.3925,
                                    3.5190, 3.1066, 2.3513, 1.8139, 2.6495, 4.8435, 3.8340, 3.7585, 1.8870,
                                    2.6590, 2.8758, 2.4845, 1.1878, 1.7869, 1.4710, 2.2284, 2.6282, 3.5244])),
                           ('layer_2.normalizer.num_batches_tracked',
                            tensor(670)),
                           ('layer_3.linear.weight',
                            tensor([[ 0.1380, -0.1854,  0.0619,  ..., -0.2357,  0.0291, -0.0855],
                                    [ 0.1408,  0.0109,  0.0273,  ..., -0.1300, -0.2806,  0.0283],
                                    [-0.0817,  0.0103, -0.1799,  ...,  0.0643, -0.0872, -0.2488],
                                    ...,
                                    [-0.1365,  0.2386,  0.0441,  ...,  0.1613,  0.0506,  0.0980],
                                    [ 0.0804,  0.0271, -0.0292,  ...,  0.1756, -0.0669,  0.1731],
                                    [ 0.0381,  0.2557, -0.1221,  ...,  0.2586,  0.0182, -0.0256]])),
                           ('layer_3.linear.bias',
                            tensor([ 1.7908e-02, -1.6474e-02, -1.1066e-01,  2.7437e-01,  7.9325e-01,
                                     7.8987e-02,  3.3952e-01, -8.7965e-02,  1.5753e-01, -1.2900e-01,
                                     7.6526e-02,  7.7629e-03,  5.0898e-01, -3.8382e-01, -6.2505e-02,
                                     7.2312e-01,  4.5212e-02,  9.1312e-01,  4.5915e-02,  1.3952e-02,
                                     2.2207e-01, -5.3636e-04,  1.4154e-01,  1.2470e-01,  1.3919e-01,
                                     1.1014e-01,  1.1680e-01, -3.1510e-02,  4.0542e-02, -1.7219e-04,
                                    -1.0441e-01,  5.2226e-01, -5.1104e-02,  3.9572e-01, -2.4648e-01,
                                    -3.9327e-01, -1.4421e-01, -6.7385e-02,  2.2774e-02, -1.5080e-01,
                                    -5.9285e-03,  1.3464e-01,  1.1147e-01, -2.1358e-01, -3.9889e-01,
                                     1.2666e-01])),
                           ('layer_3.normalizer.weight',
                            tensor([ 0.4035,  0.9273,  0.3436,  0.8313, -0.5483,  0.5186,  0.3613,  0.3362,
                                     0.7600,  0.4670,  1.1825,  0.7006,  0.4275,  1.0126,  0.9773,  0.3357,
                                     0.2955,  0.3329,  0.3705, -0.0120,  0.8153,  0.9983,  0.7633,  0.8837,
                                     0.3240,  0.8691,  0.9471,  0.8034,  0.6000,  0.8582,  0.3787,  0.6470,
                                     1.0252,  0.7739,  0.3396,  0.8057,  0.8804,  0.7036,  0.8939,  0.5345,
                                     0.9509,  0.7852,  0.8444,  0.6482,  0.7647,  0.9864])),
                           ('layer_3.normalizer.bias',
                            tensor([ 0.3733,  0.2051, -0.0269,  0.2724, -0.4801, -0.1726,  0.4889,  0.4670,
                                     0.0054,  0.3607, -0.0175,  0.0667,  0.3025, -0.1931,  0.0393,  0.4796,
                                    -0.4264,  0.4658,  0.4557, -0.0588,  0.1771,  0.1360, -0.1883,  0.3348,
                                    -0.0160,  0.1065,  0.2562, -0.0518,  0.2925,  0.1273,  0.6136,  0.2823,
                                    -0.1507,  0.1460,  0.0915, -0.2639,  0.0656,  0.1196, -0.1569,  0.1028,
                                     0.0275,  0.4489,  0.1869, -0.4018, -0.2103, -0.0286])),
                           ('layer_3.normalizer.running_mean',
                            tensor([-0.9715, -2.2032, -1.8011,  0.9685,  0.6200, -0.9597,  0.7768, -2.1247,
                                    -2.0955, -2.6104,  1.6368, -1.4368,  1.0332,  1.4593,  0.8677,  1.6211,
                                    -1.0680,  1.0735, -0.7232,  1.2941, -0.8071, -1.6285, -2.1765, -1.8546,
                                    -1.0937,  0.5600,  0.9602,  0.9767, -0.7245, -1.1330, -2.3668,  0.4982,
                                     0.3226, -1.4879, -3.2258, -0.0994, -2.7357, -1.0780, -2.0709, -1.0271,
                                    -1.8341,  1.0092, -0.7226, -0.3674,  3.3634, -0.3263])),
                           ('layer_3.normalizer.running_var',
                            tensor([ 4.4649,  9.8838,  5.2923,  4.5300,  5.1298,  1.6735,  5.8431,  8.8731,
                                     4.8244,  8.9389,  6.5502,  4.1085,  7.3151,  6.8593,  4.8760,  9.5597,
                                     1.4886,  6.5716,  3.2722,  0.5959,  4.1359,  7.5561,  4.8947,  5.4064,
                                     2.4573,  3.7298,  4.4916,  4.9595,  2.8120,  7.7168, 10.7011,  4.4031,
                                     4.6533,  5.6067,  8.5437,  3.1499,  7.7018,  4.7582,  4.0134,  2.9158,
                                     5.3908,  7.7204,  3.1456,  2.2102, 10.8318,  2.5861])),
                           ('layer_3.normalizer.num_batches_tracked',
                            tensor(670)),
                           ('layer_4.linear.weight',
                            tensor([[ 0.0698, -0.0195, -0.0756,  ..., -0.0462,  0.0105, -0.0389],
                                    [ 0.1114,  0.1341,  0.3744,  ..., -0.1303, -0.5447,  0.0424],
                                    [ 0.4859,  0.2065,  0.2659,  ..., -0.2024, -0.3696, -0.2466],
                                    ...,
                                    [ 0.5584, -0.0504,  0.2195,  ..., -0.2966,  0.0918, -0.2569],
                                    [-0.0763,  0.0830,  0.2275,  ...,  0.0154, -0.0329, -0.0522],
                                    [-0.1180,  0.2637,  0.1172,  ...,  0.1356, -0.6853,  0.2353]])),
                           ('layer_4.linear.bias',
                            tensor([-0.0691, -0.1060, -0.0187,  0.3468,  0.1665,  0.0228, -0.0955, -0.0045,
                                    -0.1590,  0.2351,  0.1438,  0.3307, -0.0113, -0.0621, -0.1006, -0.0504,
                                     0.1788,  0.2348, -0.0194, -0.1862,  0.3357,  0.0535,  0.0694,  0.0870,
                                    -0.1025,  0.2098,  0.0873,  0.1724,  0.0155, -0.0475,  0.0931,  0.3493])),
                           ('layer_4.normalizer.weight',
                            tensor([ 0.2297,  0.2990,  0.2737,  0.8078,  0.5716,  0.4337,  0.6599,  0.4357,
                                     0.3786,  0.7697,  0.3827,  1.0374,  0.9383,  0.4150,  0.5763,  0.0082,
                                     0.5816,  0.2989,  0.7998,  1.0263,  0.8123,  0.3771,  0.6268,  0.8554,
                                     0.5588, -0.6344,  0.5484,  0.7461,  0.9628,  0.8320,  0.6235,  0.9862])),
                           ('layer_4.normalizer.bias',
                            tensor([-0.1278,  0.4389,  0.4260,  0.3198, -0.1430,  0.3581, -0.2091, -0.0930,
                                     0.3896,  0.3542, -0.1799,  0.2433,  0.2380, -0.1205,  0.3009, -0.0899,
                                     0.2727,  0.4550, -0.1858,  0.1964,  0.4114, -0.0596,  0.0278, -0.2187,
                                     0.0882, -0.0525, -0.0552,  0.2251,  0.2028,  0.1410, -0.1418,  0.2488])),
                           ('layer_4.normalizer.running_mean',
                            tensor([-0.2209, -0.0189, -0.1345, -0.8261, -0.1632, -0.5029,  0.1678,  0.3956,
                                    -0.2796, -0.5636, -0.1228, -0.5516, -0.8519,  0.3655, -0.5660, -0.3087,
                                    -0.7080, -0.0173, -0.1170, -0.6635, -0.8329,  0.0423, -0.4747, -0.1199,
                                    -0.5585,  0.8095,  0.2742, -0.8614, -0.7616, -0.3526,  0.0804, -0.6352])),
                           ('layer_4.normalizer.running_var',
                            tensor([ 0.1660,  9.9155,  9.3877, 10.0118,  1.8771,  8.6701,  0.9420,  1.2411,
                                     7.6982,  7.8437,  0.5575,  8.9085, 10.7996,  1.3539,  6.6764,  0.0753,
                                     6.6262,  9.5189,  0.6057,  5.4100,  8.8232,  1.7903,  4.0475,  1.0141,
                                     3.2860,  5.7909,  1.9748,  6.6539,  7.9738,  4.0276,  0.9936,  6.8559])),
                           ('layer_4.normalizer.num_batches_tracked',
                            tensor(670)),
                           ('output.weight',
                            tensor([[ 0.0907,  0.3841,  0.4231,  0.2701,  0.1529,  0.3734,  0.1848,  0.1908,
                                      0.3445,  0.2799,  0.2094,  0.2791,  0.2999,  0.2314,  0.3558,  0.0392,
                                      0.3215,  0.3956,  0.1413,  0.2460,  0.2496,  0.1880,  0.2432,  0.1266,
                                      0.2394, -0.7282,  0.2006,  0.2259,  0.2687,  0.2108,  0.1736,  0.2261]])),
                           ('output.bias', tensor([0.1369]))]))])

resue model using trainer

xenonpy.model.training.Trainer can load model and checkpoints from checker or from model directory directly.

[24]:
from xenonpy.model.training import Trainer
[25]:
trainer = Trainer.load(from_=checker)
trainer
[25]:
Trainer(clip_grad=None, cuda=None, epochs=200, loss_func=None,
        lr_scheduler=None,
        model=SequentialLinear(
  (layer_0): LinearLayer(
    (linear): Linear(in_features=290, out_features=180, bias=True)
    (dropout): Dropout(p=0.1)
    (normalizer): BatchNorm1d(180, eps=0.1, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU()
  )
  (layer_1): LinearLayer(
    (linear): Linear(...
    (normalizer): BatchNorm1d(46, eps=0.1, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU()
  )
  (layer_4): LinearLayer(
    (linear): Linear(in_features=46, out_features=32, bias=True)
    (dropout): Dropout(p=0.1)
    (normalizer): BatchNorm1d(32, eps=0.1, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU()
  )
  (output): Linear(in_features=32, out_features=1, bias=True)
),
        non_blocking=False, optimizer=None)

The following codes show how to reuse model for prediction.

[26]:
# if you have not had the samples data
# preset.build('mp_samples', api_key=<your materials project api key>)
from xenonpy.datatools import preset

data = preset.mp_samples
data.head(3)
[26]:
band_gap composition density e_above_hull efermi elements final_energy_per_atom formation_energy_per_atom pretty_formula structure volume
mp-1008807 0.0000 {'Rb': 1.0, 'Cu': 1.0, 'O': 1.0} 4.784634 0.996372 1.100617 [Rb, Cu, O] -3.302762 -0.186408 RbCuO [[-3.05935361 -3.05935361 -3.05935361] Rb, [0.... 57.268924
mp-1009640 0.0000 {'Pr': 1.0, 'N': 1.0} 8.145777 0.759393 5.213442 [Pr, N] -7.082624 -0.714336 PrN [[0. 0. 0.] Pr, [1.57925232 1.57925232 1.58276... 31.579717
mp-1016825 0.7745 {'Hf': 1.0, 'Mg': 1.0, 'O': 3.0} 6.165888 0.589550 2.424570 [Hf, Mg, O] -7.911723 -3.060060 HfMgO3 [[2.03622802 2.03622802 2.03622802] Hf, [0. 0.... 67.541269
[27]:
from xenonpy.descriptor import Compositions

prop = data['efermi'].dropna().to_frame()  # reshape to 2-D
desc = Compositions(featurizers='classic').transform(data.loc[prop.index]['composition'])

desc.head(3)
prop.head(3)
[27]:
ave:atomic_number ave:atomic_radius ave:atomic_radius_rahm ave:atomic_volume ave:atomic_weight ave:boiling_point ave:bulk_modulus ave:c6_gb ave:covalent_radius_cordero ave:covalent_radius_pyykko ... min:num_s_valence min:period min:specific_heat min:thermal_conductivity min:vdw_radius min:vdw_radius_alvarez min:vdw_radius_mm3 min:vdw_radius_uff min:sound_velocity min:Polarizability
mp-1008807 24.666667 174.067140 209.333333 25.666667 55.004267 1297.063333 72.868680 1646.90 139.333333 128.333333 ... 1.0 2.0 0.360 0.02658 152.0 150.0 182.0 349.5 317.5 0.802
mp-1009640 33.000000 137.000000 232.500000 19.050000 77.457330 1931.200000 43.182441 1892.85 137.000000 123.500000 ... 2.0 2.0 0.192 0.02583 155.0 166.0 193.0 360.6 333.6 1.100
mp-1016825 21.600000 153.120852 203.400000 13.920000 50.158400 1420.714000 76.663625 343.82 102.800000 96.000000 ... 2.0 2.0 0.146 0.02658 152.0 150.0 182.0 302.1 317.5 0.802

3 rows × 290 columns

[27]:
efermi
mp-1008807 1.100617
mp-1009640 5.213442
mp-1016825 2.424570
[28]:
y_pred = trainer.predict(x_in=torch.tensor(desc.values, dtype=torch.float)).detach().numpy().flatten()
y_true = prop.values.flatten()

draw(y_true, y_pred, prop_name='Efermi ($eV$)')
Missing directory and/or file name information!
../_images/tutorials_5-pre-trained_model_library_54_1.png