Following GraphConvLayer and GraphPoolLayer, I implemented DeepChem's GraphGatherLayer with a custom layer of Pytorch.
I ported DeepChem's GraphGatherLayer to PyTorch and tried to feed the output result of the previous GraphConvLayer to the created GraphPoolLayer.
import torch
from torch.utils import data
from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.feat.mol_graphs import ConvMol
import torch.nn as nn
import numpy as np
from torch_scatter import scatter_max
def unsorted_segment_sum(data, segment_ids, num_segments):
# segment_ids is a 1-D tensor repeat it to have the same shape as data
if len(segment_ids.shape) == 1:
s = torch.prod(torch.tensor(data.shape[1:])).long()
segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:])
shape = [num_segments] + list(data.shape[1:])
tensor = torch.zeros(*shape).scatter_add(0, segment_ids, data.float())
tensor = tensor.type(data.dtype)
return tensor
class GraphConv(nn.Module):
def __init__(self,
in_channel,
out_channel,
min_deg=0,
max_deg=10,
activation=lambda x: x
):
super().__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.min_degree = min_deg
self.max_degree = max_deg
num_deg = 2 * self.max_degree + (1 - self.min_degree)
self.W_list = [
nn.Parameter(torch.Tensor(
np.random.normal(size=(in_channel, out_channel))).double())
for k in range(num_deg)]
self.b_list = [
nn.Parameter(torch.Tensor(np.zeros(out_channel)).double()) for k in range(num_deg)]
def forward(self, atom_features, deg_slice, deg_adj_lists):
#print("deg_adj_list")
#print(deg_adj_lists)
W = iter(self.W_list)
b = iter(self.b_list)
# Sum all neighbors using adjacency matrix
deg_summed = self.sum_neigh(atom_features, deg_adj_lists)
# Get collection of modified atom features
new_rel_atoms_collection = (self.max_degree + 1 - self.min_degree) * [None]
for deg in range(1, self.max_degree + 1):
# Obtain relevant atoms for this degree
rel_atoms = deg_summed[deg - 1]
# Get self atoms
begin = deg_slice[deg - self.min_degree, 0]
size = deg_slice[deg - self.min_degree, 1]
self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))
# Apply hidden affine to relevant atoms and append
rel_out = torch.matmul(rel_atoms, next(W)) + next(b)
self_out = torch.matmul(self_atoms, next(W)) + next(b)
out = rel_out + self_out
new_rel_atoms_collection[deg - self.min_degree] = out
# Determine the min_deg=0 case
if self.min_degree == 0:
deg = 0
begin = deg_slice[deg - self.min_degree, 0]
size = deg_slice[deg - self.min_degree, 1]
self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))
# Only use the self layer
out = torch.matmul(self_atoms, next(W)) + next(b)
new_rel_atoms_collection[deg - self.min_degree] = out
# Combine all atoms back into the list
#print(new_rel_atoms_collection)
atom_features = torch.cat(new_rel_atoms_collection, 0)
return atom_features
def sum_neigh(self, atoms, deg_adj_lists):
"""Store the summed atoms by degree"""
deg_summed = self.max_degree * [None]
for deg in range(1, self.max_degree + 1):
index = torch.tensor(deg_adj_lists[deg - 1], dtype=torch.int64)
gathered_atoms = atoms[index]
# Sum along neighbors as well as self, and store
summed_atoms = torch.sum(gathered_atoms, 1)
deg_summed[deg - 1] = summed_atoms
return deg_summed
class GraphPool(nn.Module):
def __init__(self, min_degree=0, max_degree=10):
super().__init__()
self.min_degree = min_degree
self.max_degree = max_degree
def forward(self, atom_features, deg_slice, deg_adj_lists):
# Perform the mol gather
deg_maxed = (self.max_degree + 1 - self.min_degree) * [None]
# Tensorflow correctly processes empty lists when using concat
for deg in range(1, self.max_degree + 1):
# Get self atoms
begin = deg_slice[deg - self.min_degree, 0]
size = deg_slice[deg - self.min_degree, 1]
self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))
# Expand dims
self_atoms = torch.unsqueeze(self_atoms, 1)
# always deg-1 for deg_adj_lists
index = torch.tensor(deg_adj_lists[deg - 1], dtype=torch.int64)
gathered_atoms = atom_features[index]
gathered_atoms = torch.cat([self_atoms, gathered_atoms], 1)
if gathered_atoms.shape[0] > 0:
maxed_atoms = torch.max(gathered_atoms, 1)[0]
else:
maxed_atoms = torch.Tensor([])
deg_maxed[deg - self.min_degree] = maxed_atoms
if self.min_degree == 0:
begin = deg_slice[0, 0]
size = deg_slice[0, 1]
self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))
deg_maxed[0] = self_atoms
return torch.cat(deg_maxed, 0)
class GraphGather(nn.Module):
def __init__(self, batch_size):
super().__init__()
self.batch_size = batch_size
def forward(self, atom_features, membership):
assert self.batch_size > 1, "graph_gather requires batches larger than 1"
sparse_reps = unsorted_segment_sum(atom_features, membership, self.batch_size)
max_reps = scatter_max(atom_features, membership, dim=0)
mol_features = torch.cat([sparse_reps, max_reps[0]], 1)
return mol_features
class GCNDataset(data.Dataset):
def __init__(self, smiles_list, label_list):
self.smiles_list = smiles_list
self.label_list = label_list
def __len__(self):
return len(self.smiles_list)
def __getitem__(self, index):
return self.smiles_list[index], self.label_list[index]
def gcn_collate_fn(batch):
from rdkit import Chem
cmf = ConvMolFeaturizer()
mols = []
labels = []
for sample, label in batch:
mols.append(Chem.MolFromSmiles(sample))
labels.append(torch.tensor(label))
conv_mols = cmf.featurize(mols)
multiConvMol = ConvMol.agglomerate_mols(conv_mols)
atom_feature = torch.tensor(multiConvMol.get_atom_features(), dtype=torch.float64)
deg_slice = torch.tensor(multiConvMol.deg_slice, dtype=torch.float64)
membership = torch.tensor(multiConvMol.membership, dtype=torch.int64)
deg_adj_lists = []
for i in range(1, len(multiConvMol.get_deg_adjacency_lists())):
deg_adj_lists.append(multiConvMol.get_deg_adjacency_lists()[i])
return atom_feature, deg_slice, membership, deg_adj_lists, labels
def main():
dataset = GCNDataset(["CCC", "CCCC", "CCCCC"], [1, 0, 1])
dataloader = data.DataLoader(dataset, batch_size=3, shuffle=False, collate_fn =gcn_collate_fn)
gc = GraphConv(75, 20)
gp = GraphPool()
gt = GraphGather(3)
for atom_feature, deg_slice, membership, deg_adj_lists, labels in dataloader:
print("atom_feature")
print(atom_feature)
print("deg_slice")
print(deg_slice)
print("membership")
print(membership)
print("result")
gc_out = gc(atom_feature, deg_slice, deg_adj_lists)
gp_out = gp(gc_out, deg_slice, deg_adj_lists)
#print(gp_out)
gt_out = gt(gp_out, membership)
print(gt_out)
if __name__ == "__main__":
main()
Yes, don't. For the time being, the resulting shape is the number of molecules x 40 dimensions, and it can be seen that the atoms are aggregated into the molecules. I like this white box feeling as usual (the comments are exactly the same every time, so I'm skipping it). This time, I had a lot of trouble porting TensorFlow's unsorted_segment_sum and unsorted_segment_max operations. Verification is about to begin.
tensor([[ 7.7457, 2.1970, 22.1151, 1.8238, 7.5860, 15.5079, -1.3865, 5.3634,
0.3872, 24.7713, 30.9865, 13.0032, 5.8331, 12.8195, 9.2520, 16.4660,
-8.8977, 10.5881, 16.8875, 3.6356, 2.5819, 0.7323, 7.3717, 0.6079,
2.5287, 5.1693, -0.4622, 1.7878, 0.1291, 8.2571, 10.3288, 4.3344,
1.9444, 4.2732, 3.0840, 5.4887, -2.9659, 3.5294, 5.6292, 1.2119],
[12.4624, 16.9705, 26.8321, 4.3047, 17.4027, 23.3370, -1.8487, 7.1511,
0.2538, 23.2520, 25.0874, 17.3375, 7.7775, 9.7369, 8.3362, 20.8373,
-4.3081, 14.1175, 17.6781, 6.4011, 3.1156, 4.2426, 6.7080, 1.0762,
4.3507, 5.8342, -0.4622, 1.7878, 0.0634, 5.8130, 6.2718, 4.3344,
1.9444, 2.4342, 2.0840, 5.2093, -1.0770, 3.5294, 4.4195, 1.6003],
[17.1790, 31.7441, 33.5401, 8.6282, 27.2195, 31.1660, -4.6301, 4.2145,
-1.0452, 29.0650, 31.3592, 15.0395, 14.6857, 12.1711, 10.4202, 26.0466,
3.5187, 10.4842, 22.0976, 9.1667, 3.6493, 7.7530, 6.7080, 2.1586,
6.1727, 6.4992, -0.4622, 1.7878, 0.0634, 5.8130, 6.2718, 4.3344,
3.5990, 2.4342, 2.0840, 5.2093, 1.8909, 3.5294, 4.4195, 1.9887]],
dtype=torch.float64, grad_fn=<CatBackward>)
Recommended Posts