xenonpy.model.training.dataset package
Submodules
xenonpy.model.training.dataset.array module
xenonpy.model.training.dataset.cgcnn module
- class xenonpy.model.training.dataset.cgcnn.CrystalGraphDataset(crystal_features, targets=None)[source]
Bases:
Dataset
- Parameters:
root_dir (str) – The path to the root directory of the dataset
max_num_nbr (int) – The maximum number of neighbors while constructing the crystal graph
radius (float) – The cutoff radius for searching neighbors
dmin (float) – The minimum distance for constructing GaussianDistance
step (float) – The step size for constructing GaussianDistance
random_seed (int or None) – Random seed for shuffling the dataset
- Returns:
atom_fea (torch.Tensor shape (n_i, atom_fea_len))
nbr_fea (torch.Tensor shape (n_i, M, nbr_fea_len))
nbr_fea_idx (torch.LongTensor shape (n_i, M))
target (torch.Tensor shape (1, ))
cif_id (str or int)
- static collate_fn(dataset_list)[source]
Collate a list of data and return a batch for predicting crystal properties.
- Parameters:
dataset_list (list of tuples for each data point.) –
(atom_fea, nbr_fea, nbr_fea_idx, target)
atom_fea: torch.Tensor shape (n_i, atom_fea_len) nbr_fea: torch.Tensor shape (n_i, M, nbr_fea_len) nbr_fea_idx: torch.LongTensor shape (n_i, M) target: torch.Tensor shape (1, ) cif_id: str or int
- Returns:
N = sum(n_i); N0 = sum(i)
batch_atom_fea (torch.Tensor shape (N, orig_atom_fea_len)) – Atom features from atom type
batch_nbr_fea (torch.Tensor shape (N, M, nbr_fea_len)) – Bond features of each atom’s M neighbors
batch_nbr_fea_idx (torch.LongTensor shape (N, M)) – Indices of M neighbors of each atom
crystal_atom_idx (list of torch.LongTensor of length N0) – Mapping from the crystal idx to atom idx
target (torch.Tensor shape (N, 1)) – Target value for prediction
batch_cif_ids (list)