# Copyright (c) 2021. TsumiNa. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
import torch
from torch import nn
__all__ = ['ConvLayer', 'CrystalGraphConvNet']
[docs]class ConvLayer(nn.Module):
"""
Convolutional operation on graphs
"""
def __init__(self, atom_fea_len, nbr_fea_len):
"""
Initialize ConvLayer.
Parameters
----------
atom_fea_len: int
Number of atom hidden features.
nbr_fea_len: int
Number of bond features.
"""
super(ConvLayer, self).__init__()
self.atom_fea_len = atom_fea_len
self.nbr_fea_len = nbr_fea_len
self.fc_full = nn.Linear(2 * self.atom_fea_len + self.nbr_fea_len, 2 * self.atom_fea_len)
self.sigmoid = nn.Sigmoid()
self.softplus1 = nn.Softplus()
self.bn1 = nn.BatchNorm1d(2 * self.atom_fea_len)
self.bn2 = nn.BatchNorm1d(self.atom_fea_len)
self.softplus2 = nn.Softplus()
[docs] def forward(self, atom_in_fea, nbr_fea, nbr_fea_idx):
"""
Forward pass
N: Total number of atoms in the batch
M: Max number of neighbors
Parameters
----------
atom_in_fea: Variable(torch.Tensor) shape (N, atom_fea_len)
Atom hidden features before convolution
nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len)
Bond features of each atom's M neighbors
nbr_fea_idx: torch.LongTensor shape (N, M)
Indices of M neighbors of each atom
Returns
-------
atom_out_fea: nn.Variable shape (N, atom_fea_len)
Atom hidden features after convolution
"""
# TODO will there be problems with the index zero padding?
N, M = nbr_fea_idx.shape
# convolution
atom_nbr_fea = atom_in_fea[nbr_fea_idx, :]
total_nbr_fea = torch.cat(
[atom_in_fea.unsqueeze(1).expand(N, M, self.atom_fea_len), atom_nbr_fea, nbr_fea],
dim=2)
total_gated_fea = self.fc_full(total_nbr_fea)
total_gated_fea = self.bn1(total_gated_fea.view(-1, self.atom_fea_len * 2)).view(
N, M, self.atom_fea_len * 2)
nbr_filter, nbr_core = total_gated_fea.chunk(2, dim=2)
nbr_filter = self.sigmoid(nbr_filter)
nbr_core = self.softplus1(nbr_core)
nbr_sumed = torch.sum(nbr_filter * nbr_core, dim=1)
nbr_sumed = self.bn2(nbr_sumed)
out = self.softplus2(atom_in_fea + nbr_sumed)
return out
[docs]class CrystalGraphConvNet(nn.Module):
"""
Create a crystal graph convolutional neural network for predicting total
material properties.
See Also: [CGCNN]_.
.. [CGCNN] `Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties`__
__ https://doi.org/10.1103/PhysRevLett.120.145301
"""
def __init__(self,
orig_atom_fea_len,
nbr_fea_len,
atom_fea_len=64,
n_conv=3,
h_fea_len=128,
n_h=1,
classification=False):
"""
Initialize CrystalGraphConvNet.
Parameters
----------
orig_atom_fea_len: int
Number of atom features in the input.
nbr_fea_len: int
Number of bond features.
atom_fea_len: int
Number of hidden atom features in the convolutional layers
n_conv: int
Number of convolutional layers
h_fea_len: int
Number of hidden features after pooling
n_h: int
Number of hidden layers after pooling
"""
super(CrystalGraphConvNet, self).__init__()
self.classification = classification
self.embedding = nn.Linear(orig_atom_fea_len, atom_fea_len)
self.convs = nn.ModuleList(
[ConvLayer(atom_fea_len=atom_fea_len, nbr_fea_len=nbr_fea_len) for _ in range(n_conv)])
self.conv_to_fc = nn.Linear(atom_fea_len, h_fea_len)
self.conv_to_fc_softplus = nn.Softplus()
if n_h > 1:
self.fcs = nn.ModuleList([nn.Linear(h_fea_len, h_fea_len) for _ in range(n_h - 1)])
self.softpluses = nn.ModuleList([nn.Softplus() for _ in range(n_h - 1)])
if self.classification:
self.fc_out = nn.Linear(h_fea_len, 2)
else:
self.fc_out = nn.Linear(h_fea_len, 1)
if self.classification:
self.logsoftmax = nn.LogSoftmax()
self.dropout = nn.Dropout()
[docs] def forward(self, atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx):
"""
Forward pass
N: Total number of atoms in the batch
M: Max number of neighbors
N0: Total number of crystals in the batch
Parameters
----------
atom_fea: Variable(torch.Tensor) shape (N, orig_atom_fea_len)
Atom features from atom type
nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len)
Bond features of each atom's M neighbors
nbr_fea_idx: torch.LongTensor shape (N, M)
Indices of M neighbors of each atom
crystal_atom_idx: list of torch.LongTensor of length N0
Mapping from the crystal idx to atom idx
Returns
-------
prediction: nn.Variable shape (N, )
Atom hidden features after convolution
"""
atom_fea = self.embedding(atom_fea)
for conv_func in self.convs:
atom_fea = conv_func(atom_fea, nbr_fea, nbr_fea_idx)
crys_fea = self.pooling(atom_fea, crystal_atom_idx)
crys_fea = self.conv_to_fc(self.conv_to_fc_softplus(crys_fea))
crys_fea = self.conv_to_fc_softplus(crys_fea)
if self.classification:
crys_fea = self.dropout(crys_fea)
if hasattr(self, 'fcs') and hasattr(self, 'softpluses'):
for fc, softplus in zip(self.fcs, self.softpluses):
crys_fea = softplus(fc(crys_fea))
out = self.fc_out(crys_fea)
if self.classification:
out = self.logsoftmax(out)
return out
[docs] @staticmethod
def pooling(atom_fea, crystal_atom_idx):
"""
Pooling the atom features to crystal features
N: Total number of atoms in the batch
N0: Total number of crystals in the batch
Parameters
----------
atom_fea: Variable(torch.Tensor) shape (N, atom_fea_len)
Atom feature vectors of the batch
crystal_atom_idx: list of torch.LongTensor of length N0
Mapping from the crystal idx to atom idx
"""
assert sum([len(idx_map) for idx_map in crystal_atom_idx]) == \
atom_fea.data.shape[0]
summed_fea = [
torch.mean(atom_fea[idx_map], dim=0, keepdim=True) for idx_map in crystal_atom_idx
]
return torch.cat(summed_fea, dim=0)