I decided to switch from Keras, Tensorflow to PyTorch. Then, we decided to implement Graph Convolutional Network (GCN) by compounds using PyTorch. First, it is necessary to convert the compound represented by SMILES into a form that can be used for learning. You may implement these processes yourself, but I thought that it would be easier if you diverted the preprocesses of DeepChem, which is a library based on Keras and Tensorflow. So I feturized SMILES with DeepChem's ConvMolFeaturizer and made it available in Pytorch's DataLoader. By doing so, we are planning to concentrate on the implementation of GCN without having to implement the troublesome compound handling process by ourselves.
--The Dataset simply holds a list of SMILES and correct answer data. --Since it is necessary to convert all the compounds in the mini-batch into a graph and generate a bond order matrix and an adjacency matrix for each mini-batch, we decided to implement collate_fn independently and give it as an argument of DataLoader. --In collate_fn, SMILES is featured and listed by DeepChem's ConvMolFeaturizer, and it is given to the agglomerate_mols method of the ConvMol class. As a result, a bond order matrix and an adjacency matrix of all compounds in the mini-batch are generated, so they are converted to PyTorch tensor format and returned together with the correct answer data.
import torch
from torch.utils import data
from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.feat.mol_graphs import ConvMol
class GCNDataset(data.Dataset):
def __init__(self, smiles_list, label_list):
self.smiles_list = smiles_list
self.label_list = label_list
def __len__(self):
return len(self.smiles_list)
def __getitem__(self, index):
return self.smiles_list[index], self.label_list[index]
def gcn_collate_fn(batch):
from rdkit import Chem
cmf = ConvMolFeaturizer()
mols = []
labels = []
for sample, label in batch:
mols.append(Chem.MolFromSmiles(sample))
labels.append(torch.tensor(label))
conv_mols = cmf.featurize(mols)
multiConvMol = ConvMol.agglomerate_mols(conv_mols)
atom_feature = torch.tensor(multiConvMol.get_atom_features(), dtype=torch.float64)
deg_slice = torch.tensor(multiConvMol.deg_slice, dtype=torch.float64)
membership = torch.tensor(multiConvMol.membership, dtype=torch.float64)
return atom_feature, deg_slice, membership, labels
def main():
dataset = GCNDataset(["CCC", "CCCC", "CCCCC"], [1, 0, 1])
dataloader = data.DataLoader(dataset, batch_size=3, shuffle=False, collate_fn =gcn_collate_fn)
for atom_feature, deg_slice, membership, labels in dataloader:
print(atom_feature)
print(deg_slice)
print(membership)
if __name__ == "__main__":
main()
The mini-batch with 3 compounds is as follows. The characteristics of 12 atoms in the ternary compound, the bond order matrix, and the adjacency matrix are generated. These will be explained at another time.
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
0., 1., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
0., 1., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
0., 1., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
0., 1., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
0., 1., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
0., 1., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
1., 0., 0.]], dtype=torch.float64)
tensor([[ 0., 0.],
[ 0., 6.],
[ 6., 6.],
[12., 0.],
[12., 0.],
[12., 0.],
[12., 0.],
[12., 0.],
[12., 0.],
[12., 0.],
[12., 0.]], dtype=torch.float64)
tensor([0., 0., 1., 1., 2., 2., 0., 1., 1., 2., 2., 2.], dtype=torch.float64)
In the future, I will write the code of the GCN model and the code to learn using this DataLoader.
Compared to Keras's cramped feeling and Tensorflow's innocence, PyTorch's just rightness is very comfortable (for now).
Recommended Posts