Transfer Learning
There are various methods for transfer learning such as fine tuning and frozen feature extraction. In this tutorial, we will demonstrate how to perform a frozen feature extraction type of transfer learning in XenonPy.
This tutorial will use Refractive Index data, which are collected from Polymer Genome. We do not provide these data directly in this tutorial. If you want to rerun this notebook locally, you must collect these data yourself.
frozen feature extraction
A frozen feature extraction type of transfer learning can be split into 2 steps:
you need to have pre-trained model(s) as source model(s). This can be done by accessing XenonPy.MDL.
you need a feature extractor to generate “neural descriptors” from the source model(s). Here, we would like to introduce you to our feature extractor,
xenonpy.descriptor.FrozenFeaturizer
.
The following codes show a case study of transfer learning between Refractive Index of inorganic and organic materials. In this example, the source models will be trained on inorganic compounds and the target will be polymers.
Let’s do this transfer learning step-by-step.
useful functions
Running the following cell will load some commonly used packages, such as NumPy, pandas, and so on. It will also import some in-house functions used in this tutorial. See ‘tools.ipynb’ file to check what will be imported.
[1]:
%run tools.ipynb
access pre-trained models with MDL class
We prepared a wide range of APIs to let you query and download our models. These APIs can be accessed via any HTTP requests. For convenience, we implemented some of the most popular APIs and wrapped them into XenonPy. All these functions can be accessed using xenonpy.mdl.MDL
.
[3]:
from xenonpy.mdl import MDL
mdl = MDL()
mdl
mdl.version
[3]:
MDL(api_key='anonymous.user.key', endpoint='http://xenon.ism.ac.jp/api')
[3]:
'0.1.1'
1. query Refractive Index models
[29]:
query = mdl(modelset_has="Stable", property_has="refractive")
query
[29]:
QueryModelDetailsWith(api_key='anonymous.user.key', endpoint='http://xenon.ism.ac.jp/api', variables={'modelset_has': 'Stable', 'property_has': 'refractive'})
Queryable:
id
transferred
succeed
isRegression
deprecated
modelset
method
property
descriptor
lang
accuracy
precision
recall
f1
sensitivity
prevalence
specificity
ppv
npv
meanAbsError
maxAbsError
meanSquareError
rootMeanSquareError
r2
pValue
spearmanCorr
pearsonCorr
[30]:
summary = query('id', 'modelset', 'meanAbsError', 'pearsonCorr').sort_values('meanAbsError')
summary.head(3)
[30]:
id | modelset | meanAbsError | pearsonCorr | |
---|---|---|---|---|
311 | 2949 | Stable inorganic compounds in materials project | 0.282434 | 0.873065 |
847 | 4017 | Stable inorganic compounds in materials project | 0.290382 | 0.827995 |
1189 | 4702 | Stable inorganic compounds in materials project | 0.293135 | 0.876289 |
2. download the best performing model based on MAE
[31]:
results = mdl.pull(summary.id[0].item())
results
100%|██████████| 1/1 [00:00<00:00, 1.26it/s]
[31]:
id | model | |
---|---|---|
0 | 2335 | /Users/liuchang/Google 云端硬盘/postdoctoral/tutor... |
3. load Refractive Index data from Polymer Genome and calculate the Composition
descriptors
[32]:
pg = <load your polymer genome data>
pg.head(3)
[32]:
Smiles | Natoms | Ntypes | Volume of Cell($\AA^3$) | Band Gap, PBE(eV) | Band Gap, HSE06(eV) | Dielectric Constant | Dielectric Constant, Electronic | Dielectric Constant, Ionic | Atomization Energy(eV/atom) | Density(g/cm$^3$) | Refractive Index | Ionization Energy(eV) | Electron Affinity(eV) | Cohesive Energy(eV/atom) | composition | Formula | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
ID_name | |||||||||||||||||
MOL1 | [C@H]([CH]O)(O[C@H]1[C@H](CO)O[C@@H]([CH][C@@H... | 84 | 3 | 572.42 | 5.62 | 7.48 | 3.78 | 2.85 | 0.93 | -5.48 | 1.88 | 1.69 | 6.87 | 0.83 | -0.63 | {'O': 20.0, 'H': 40.0, 'C': 24.0} | H40C24O20 |
MOL2 | [CH][C@H](C[C@@H](C[C@H](C[CH][C]=[CH])C(=[CH]... | 128 | 2 | 1258.30 | 3.94 | 4.83 | 2.72 | 2.64 | 0.08 | -5.90 | 1.10 | 1.62 | 3.56 | 1.56 | -0.63 | {'C': 64.0, 'H': 64.0} | H64C64 |
MOL3 | C[C@H](C[CH][CH][CH]C)[CH2].C[C@@H](C[CH][CH][... | 108 | 2 | 762.10 | 6.32 | 7.70 | 2.61 | 2.59 | 0.02 | -5.14 | 1.10 | 1.61 | 6.19 | 0.43 | -0.51 | {'C': 36.0, 'H': 72.0} | H72C36 |
[33]:
from xenonpy.descriptor import Compositions
pg_desc = Compositions().transform(pg['composition'])
pg_desc.head(3)
[33]:
ave:atomic_number | ave:atomic_radius | ave:atomic_radius_rahm | ave:atomic_volume | ave:atomic_weight | ave:boiling_point | ave:bulk_modulus | ave:c6_gb | ave:covalent_radius_cordero | ave:covalent_radius_pyykko | ... | min:num_s_valence | min:period | min:specific_heat | min:thermal_conductivity | min:vdw_radius | min:vdw_radius_alvarez | min:vdw_radius_mm3 | min:vdw_radius_uff | min:sound_velocity | min:Polarizability | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
ID_name | |||||||||||||||||||||
MOL1 | 4.095238 | 98.42891 | 168.333333 | 11.561905 | 7.721000 | 1488.27381 | 54.596505 | 20.761905 | 51.333333 | 51.666667 | ... | 1.0 | 1.0 | 0.711 | 0.02658 | 110.0 | 120.0 | 162.0 | 288.6 | 317.5 | 0.666793 |
MOL2 | 3.500000 | 85.00000 | 172.000000 | 9.700000 | 6.509500 | 2560.14000 | 44.899820 | 27.205000 | 52.000000 | 53.500000 | ... | 1.0 | 1.0 | 0.711 | 0.18050 | 110.0 | 120.0 | 162.0 | 288.6 | 1270.0 | 0.666793 |
MOL3 | 2.666667 | 83.00000 | 166.000000 | 11.166667 | 4.675667 | 1713.52000 | 48.866426 | 20.306667 | 45.000000 | 46.333333 | ... | 1.0 | 1.0 | 0.711 | 0.18050 | 110.0 | 120.0 | 162.0 | 288.6 | 1270.0 | 0.666793 |
3 rows × 290 columns
4. predict Polymer Genome Refractive Index directly from a model trained on inorganic compounds
[34]:
from xenonpy.model.training import Checker
checker = Checker(results.model[0])
checker.checkpoints
[34]:
<Checker> includes:
"mse_3": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.refractive_index/xenonpy.compositions/pytorch.nn.neural_network/290-187-151-110-102-1-$wHSyTrhF/checkpoints/mse_3.pth.s
"mse_1": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.refractive_index/xenonpy.compositions/pytorch.nn.neural_network/290-187-151-110-102-1-$wHSyTrhF/checkpoints/mse_1.pth.s
"mae_2": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.refractive_index/xenonpy.compositions/pytorch.nn.neural_network/290-187-151-110-102-1-$wHSyTrhF/checkpoints/mae_2.pth.s
"r2_5": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.refractive_index/xenonpy.compositions/pytorch.nn.neural_network/290-187-151-110-102-1-$wHSyTrhF/checkpoints/r2_5.pth.s
"mse_5": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.refractive_index/xenonpy.compositions/pytorch.nn.neural_network/290-187-151-110-102-1-$wHSyTrhF/checkpoints/mse_5.pth.s
"r2_1": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.refractive_index/xenonpy.compositions/pytorch.nn.neural_network/290-187-151-110-102-1-$wHSyTrhF/checkpoints/r2_1.pth.s
"mae_4": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.refractive_index/xenonpy.compositions/pytorch.nn.neural_network/290-187-151-110-102-1-$wHSyTrhF/checkpoints/mae_4.pth.s
"r2_3": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.refractive_index/xenonpy.compositions/pytorch.nn.neural_network/290-187-151-110-102-1-$wHSyTrhF/checkpoints/r2_3.pth.s
"mae_3": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.refractive_index/xenonpy.compositions/pytorch.nn.neural_network/290-187-151-110-102-1-$wHSyTrhF/checkpoints/mae_3.pth.s
"r2_4": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.refractive_index/xenonpy.compositions/pytorch.nn.neural_network/290-187-151-110-102-1-$wHSyTrhF/checkpoints/r2_4.pth.s
"mse_2": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.refractive_index/xenonpy.compositions/pytorch.nn.neural_network/290-187-151-110-102-1-$wHSyTrhF/checkpoints/mse_2.pth.s
"mae_1": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.refractive_index/xenonpy.compositions/pytorch.nn.neural_network/290-187-151-110-102-1-$wHSyTrhF/checkpoints/mae_1.pth.s
"mae_5": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.refractive_index/xenonpy.compositions/pytorch.nn.neural_network/290-187-151-110-102-1-$wHSyTrhF/checkpoints/mae_5.pth.s
"r2_2": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.refractive_index/xenonpy.compositions/pytorch.nn.neural_network/290-187-151-110-102-1-$wHSyTrhF/checkpoints/r2_2.pth.s
"mse_4": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.refractive_index/xenonpy.compositions/pytorch.nn.neural_network/290-187-151-110-102-1-$wHSyTrhF/checkpoints/mse_4.pth.s
"pearsonr_5": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.refractive_index/xenonpy.compositions/pytorch.nn.neural_network/290-187-151-110-102-1-$wHSyTrhF/checkpoints/pearsonr_5.pth.s
"pearsonr_1": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.refractive_index/xenonpy.compositions/pytorch.nn.neural_network/290-187-151-110-102-1-$wHSyTrhF/checkpoints/pearsonr_1.pth.s
"pearsonr_3": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.refractive_index/xenonpy.compositions/pytorch.nn.neural_network/290-187-151-110-102-1-$wHSyTrhF/checkpoints/pearsonr_3.pth.s
"pearsonr_4": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.refractive_index/xenonpy.compositions/pytorch.nn.neural_network/290-187-151-110-102-1-$wHSyTrhF/checkpoints/pearsonr_4.pth.s
"pearsonr_2": /Users/liuchang/Google 云端硬盘/postdoctoral/tutorial/xenonpy_hands-on_20190925 2/inorganic.crystal.refractive_index/xenonpy.compositions/pytorch.nn.neural_network/290-187-151-110-102-1-$wHSyTrhF/checkpoints/pearsonr_2.pth.s
[35]:
# --- pre-trained model for prediction
from xenonpy.model.training import Trainer
trainer = Trainer.load(from_=checker)
trainer
[35]:
Trainer(clip_grad=None, cuda=None, epochs=200, loss_func=None,
lr_scheduler=None,
model=SequentialLinear(
(layer_0): LinearLayer(
(linear): Linear(in_features=290, out_features=187, bias=True)
(dropout): Dropout(p=0.1)
(normalizer): BatchNorm1d(187, eps=0.1, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
)
(layer_1): LinearLayer(
(linear): Linear(...
(normalizer): BatchNorm1d(110, eps=0.1, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
)
(layer_3): LinearLayer(
(linear): Linear(in_features=110, out_features=102, bias=True)
(dropout): Dropout(p=0.1)
(normalizer): BatchNorm1d(102, eps=0.1, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
)
(output): Linear(in_features=102, out_features=1, bias=True)
),
non_blocking=False, optimizer=None)
[36]:
trainer.reset(to='mae_1')
y_pred = trainer.predict(x_in=torch.tensor(pg_desc.values, dtype=torch.float)).detach().numpy().flatten()
y_true = pg['Refractive Index'].values
draw(y_true, y_pred, prop_name='Refractive Index')
Missing directory and/or file name information!
5. frozen feature extraction
FrozenFeaturizer
accepts a Pytorch model as its input.
[37]:
from xenonpy.descriptor import FrozenFeaturizer
# --- init FrozenFeaturizer with NN model
ff = FrozenFeaturizer(model=trainer.model)
ff
[37]:
FrozenFeaturizer(cuda=False, depth=None,
model=SequentialLinear(
(layer_0): LinearLayer(
(linear): Linear(in_features=290, out_features=187, bias=True)
(dropout): Dropout(p=0.1)
(normalizer): BatchNorm1d(187, eps=0.1, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
)
(layer_1): LinearLayer(
(linear): Linear(in_features=187, out_features=151, bias=...
(normalizer): BatchNorm1d(110, eps=0.1, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
)
(layer_3): LinearLayer(
(linear): Linear(in_features=110, out_features=102, bias=True)
(dropout): Dropout(p=0.1)
(normalizer): BatchNorm1d(102, eps=0.1, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
)
(output): Linear(in_features=102, out_features=1, bias=True)
),
on_errors='raise', return_type='any')
The following code will generate new “neural descriptors” from the corresponding neural network model.
[38]:
neural_descriptors = ff.transform(pg_desc, depth=2 ,return_type='df')
Here, depth=x
means that the last x hidden layer from the output neuron(s) will be concatenated and used as the neural descriptor.
[39]:
neural_descriptors.head(3)
[39]:
L(-2)_1 | L(-2)_2 | L(-2)_3 | L(-2)_4 | L(-2)_5 | L(-2)_6 | L(-2)_7 | L(-2)_8 | L(-2)_9 | L(-2)_10 | ... | L(-1)_93 | L(-1)_94 | L(-1)_95 | L(-1)_96 | L(-1)_97 | L(-1)_98 | L(-1)_99 | L(-1)_100 | L(-1)_101 | L(-1)_102 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
ID_name | |||||||||||||||||||||
MOL1 | -1.499497 | 1.254353 | -1.071701 | -2.697761 | -1.807236 | -2.01011 | -1.944547 | -1.533609 | -1.380140 | -0.416400 | ... | -0.146019 | -0.599903 | -0.168826 | 0.135178 | 0.607779 | 0.700811 | -1.405137 | -0.209901 | -0.297132 | 0.071797 |
MOL2 | -1.295168 | 1.335794 | -1.657156 | -3.456876 | -1.372355 | -2.39925 | -1.761726 | -1.736894 | -1.464569 | -1.349594 | ... | 0.026329 | -0.161609 | -0.542276 | 0.246088 | 1.047052 | 0.067110 | -1.963629 | -0.109732 | -0.328646 | 0.103320 |
MOL3 | -1.514608 | 1.205772 | -1.270560 | -2.831598 | -1.677628 | -2.02982 | -1.933016 | -1.568646 | -1.363735 | -0.656209 | ... | -0.118927 | -0.526747 | -0.172702 | 0.154083 | 0.656889 | 0.607651 | -1.418730 | -0.193771 | -0.278343 | 0.091769 |
3 rows × 212 columns
As an example, -1 in the column names denotes the last layer.
[41]:
from xenonpy.visualization import DescriptorHeatmap
sorted_prop = pg['Refractive Index'].sort_values()
sorted_desc = neural_descriptors.loc[sorted_prop.index]
heatmap = DescriptorHeatmap(
bc=True, # use box-cox transform
# save=dict(fname='heatmap_density', dpi=150, bbox_inches='tight'), # save figure to file
figsize=(70, 10))
heatmap.fit(sorted_desc)
heatmap.draw(sorted_prop)
[41]:
DescriptorHeatmap(bc=True, col_cluster=True, col_colors=None, col_linkage=None,
figsize=(70, 10), mask=None, method='average',
metric='euclidean', pivot_kws=None, row_cluster=False,
row_colors=None, row_linkage=None, save=None)
6. use neural descriptors to train new models.
In this case, Random Forest model and Bayesian Ridge Linear model will be trained.
[42]:
# split data
from xenonpy.datatools import Splitter
y = pg['Refractive Index']
splitter = Splitter(len(y), test_size=0.2)
X_train, X_test, y_train, y_test = splitter.split(neural_descriptors, y.values.reshape(-1, 1))
[43]:
# random forest
from sklearn.ensemble import RandomForestRegressor
rf = RandomForestRegressor(n_estimators=100)
rf.fit(X_train, y_train.ravel())
y_pred = rf.predict(X_test)
y_fit_pred = rf.predict(X_train)
[43]:
RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=None,
max_features='auto', max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, n_estimators=100,
n_jobs=None, oob_score=False, random_state=None,
verbose=0, warm_start=False)
[44]:
draw(y_test.ravel(), y_pred, y_train.ravel(), y_fit_pred, prop_name='refractive index')
Missing directory and/or file name information!
[45]:
# bayesian linear
from sklearn.linear_model import BayesianRidge
br = BayesianRidge()
br.fit(X_train, y_train.ravel())
y_pred = br.predict(X_test)
y_fit_pred = br.predict(X_train)
[45]:
BayesianRidge(alpha_1=1e-06, alpha_2=1e-06, compute_score=False, copy_X=True,
fit_intercept=True, lambda_1=1e-06, lambda_2=1e-06, n_iter=300,
normalize=False, tol=0.001, verbose=False)
[46]:
draw(y_test.ravel(), y_pred, y_train.ravel(), y_fit_pred, prop_name='refractive index')
Missing directory and/or file name information!