Suite au GraphConvLayer d'hier, j'ai implémenté GraphPoolLayer de DeepChem dans une couche personnalisée de Pytorch.
J'ai porté GraphPoolLayer de DeepChem sur PyTorch et j'ai essayé de transmettre le résultat de sortie du GraphConvLayer précédent au GraphPoolLayer créé.
import torch
from torch.utils import data
from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.feat.mol_graphs import ConvMol
import torch.nn as nn
import numpy as np
class GraphConv(nn.Module):
def __init__(self,
in_channel,
out_channel,
min_deg=0,
max_deg=10,
activation=lambda x: x
):
super().__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.min_degree = min_deg
self.max_degree = max_deg
num_deg = 2 * self.max_degree + (1 - self.min_degree)
self.W_list = [
nn.Parameter(torch.Tensor(
np.random.normal(size=(in_channel, out_channel))).double())
for k in range(num_deg)]
self.b_list = [
nn.Parameter(torch.Tensor(np.zeros(out_channel)).double()) for k in range(num_deg)]
def forward(self, atom_features, deg_slice, deg_adj_lists):
#print("deg_adj_list")
#print(deg_adj_lists)
W = iter(self.W_list)
b = iter(self.b_list)
# Sum all neighbors using adjacency matrix
deg_summed = self.sum_neigh(atom_features, deg_adj_lists)
# Get collection of modified atom features
new_rel_atoms_collection = (self.max_degree + 1 - self.min_degree) * [None]
for deg in range(1, self.max_degree + 1):
# Obtain relevant atoms for this degree
rel_atoms = deg_summed[deg - 1]
# Get self atoms
begin = deg_slice[deg - self.min_degree, 0]
size = deg_slice[deg - self.min_degree, 1]
self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))
# Apply hidden affine to relevant atoms and append
rel_out = torch.matmul(rel_atoms, next(W)) + next(b)
self_out = torch.matmul(self_atoms, next(W)) + next(b)
out = rel_out + self_out
new_rel_atoms_collection[deg - self.min_degree] = out
# Determine the min_deg=0 case
if self.min_degree == 0:
deg = 0
begin = deg_slice[deg - self.min_degree, 0]
size = deg_slice[deg - self.min_degree, 1]
self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))
# Only use the self layer
out = torch.matmul(self_atoms, next(W)) + next(b)
new_rel_atoms_collection[deg - self.min_degree] = out
# Combine all atoms back into the list
#print(new_rel_atoms_collection)
atom_features = torch.cat(new_rel_atoms_collection, 0)
return atom_features
def sum_neigh(self, atoms, deg_adj_lists):
"""Store the summed atoms by degree"""
deg_summed = self.max_degree * [None]
for deg in range(1, self.max_degree + 1):
index = torch.tensor(deg_adj_lists[deg - 1], dtype=torch.int64)
gathered_atoms = atoms[index]
# Sum along neighbors as well as self, and store
summed_atoms = torch.sum(gathered_atoms, 1)
deg_summed[deg - 1] = summed_atoms
return deg_summed
class GraphPool(nn.Module):
def __init__(self, min_degree=0, max_degree=10):
super().__init__()
self.min_degree = min_degree
self.max_degree = max_degree
def forward(self, atom_features, deg_slice, deg_adj_lists):
# Perform the mol gather
deg_maxed = (self.max_degree + 1 - self.min_degree) * [None]
# Tensorflow correctly processes empty lists when using concat
for deg in range(1, self.max_degree + 1):
# Get self atoms
begin = deg_slice[deg - self.min_degree, 0]
size = deg_slice[deg - self.min_degree, 1]
self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))
# Expand dims
self_atoms = torch.unsqueeze(self_atoms, 1)
# always deg-1 for deg_adj_lists
index = torch.tensor(deg_adj_lists[deg - 1], dtype=torch.int64)
gathered_atoms = atom_features[index]
gathered_atoms = torch.cat([self_atoms, gathered_atoms], 1)
if gathered_atoms.shape[0] > 0:
maxed_atoms = torch.max(gathered_atoms, 1)[0]
else:
maxed_atoms = torch.Tensor([])
deg_maxed[deg - self.min_degree] = maxed_atoms
if self.min_degree == 0:
begin = deg_slice[0, 0]
size = deg_slice[0, 1]
self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))
deg_maxed[0] = self_atoms
return torch.cat(deg_maxed, 0)
class GCNDataset(data.Dataset):
def __init__(self, smiles_list, label_list):
self.smiles_list = smiles_list
self.label_list = label_list
def __len__(self):
return len(self.smiles_list)
def __getitem__(self, index):
return self.smiles_list[index], self.label_list[index]
def gcn_collate_fn(batch):
from rdkit import Chem
cmf = ConvMolFeaturizer()
mols = []
labels = []
for sample, label in batch:
mols.append(Chem.MolFromSmiles(sample))
labels.append(torch.tensor(label))
conv_mols = cmf.featurize(mols)
multiConvMol = ConvMol.agglomerate_mols(conv_mols)
atom_feature = torch.tensor(multiConvMol.get_atom_features(), dtype=torch.float64)
deg_slice = torch.tensor(multiConvMol.deg_slice, dtype=torch.float64)
membership = torch.tensor(multiConvMol.membership, dtype=torch.float64)
deg_adj_lists = []
for i in range(1, len(multiConvMol.get_deg_adjacency_lists())):
deg_adj_lists.append(multiConvMol.get_deg_adjacency_lists()[i])
return atom_feature, deg_slice, membership, deg_adj_lists, labels
def main():
dataset = GCNDataset(["CCC", "CCCC", "CCCCC"], [1, 0, 1])
dataloader = data.DataLoader(dataset, batch_size=3, shuffle=False, collate_fn =gcn_collate_fn)
gc = GraphConv(75, 20)
gp = GraphPool()
for atom_feature, deg_slice, membership, deg_adj_lists, labels in dataloader:
print("atom_feature")
print(atom_feature)
print("deg_slice")
print(deg_slice)
print("membership")
print(membership)
print("result")
gc_out = gc(atom_feature, deg_slice, deg_adj_lists)
gp_out = gp(gc_out, deg_slice, deg_adj_lists)
print(gp_out)
if __name__ == "__main__":
main()
Oui, non. Pour le moment, la forme résultante est le nombre d'atomes x 20 dimensions, ce qui semble être dû au fait qu'elle maintient les dimensions produites par GraphConvLayer. J'aime ce sentiment de boîte blanche comme d'habitude (les commentaires sont exactement les mêmes que la dernière fois, donc je saute). Cependant, le calcul est légèrement différent de TensorFlow, et il faut du temps pour le vérifier un peu.
tensor([[ 1.8113e+00, 1.1862e+00, 1.3068e+00, 1.8266e+00, 6.0706e-03,
7.2303e+00, -8.7022e-01, 1.1336e+00, -5.1411e+00, -3.3319e-02,
1.8048e+00, 4.7143e+00, 3.8385e+00, 1.7524e+00, 5.2120e+00,
2.8675e+00, 4.8746e+00, -2.5079e+00, 8.1260e+00, 7.8020e+00],
[ 1.8113e+00, 1.1862e+00, 1.3068e+00, 1.8266e+00, 6.0706e-03,
7.2303e+00, -8.7022e-01, 1.1336e+00, -5.1411e+00, -3.3319e-02,
1.8048e+00, 4.7143e+00, 3.8385e+00, 1.7524e+00, 5.2120e+00,
2.8675e+00, 4.8746e+00, -2.5079e+00, 8.1260e+00, 7.8020e+00],
[ 3.0749e+00, 2.2618e+00, 8.2658e-02, 3.1331e+00, 6.0706e-03,
4.5357e+00, -8.7022e-01, 1.1336e+00, -5.9143e+00, -3.3319e-02,
1.8048e+00, 4.7143e+00, 5.9190e+00, 1.7524e+00, 5.2120e+00,
1.5569e+00, 3.0329e+00, -2.5079e+00, 4.3327e+00, 4.7906e+00],
[ 3.0749e+00, 2.2618e+00, 8.2658e-02, 3.1331e+00, 6.0706e-03,
4.5357e+00, -8.7022e-01, 1.1336e+00, -5.9143e+00, -3.3319e-02,
1.8048e+00, 4.7143e+00, 5.9190e+00, 1.7524e+00, 5.2120e+00,
1.5569e+00, 3.0329e+00, -2.5079e+00, 4.3327e+00, 4.7906e+00],
[ 3.0749e+00, 2.2618e+00, 8.2658e-02, 3.1331e+00, 6.0706e-03,
4.5357e+00, -8.7022e-01, 1.1336e+00, -5.9143e+00, -3.3319e-02,
1.8048e+00, 4.7143e+00, 5.9190e+00, 1.7524e+00, 5.2120e+00,
1.5569e+00, 3.0329e+00, -2.5079e+00, 4.3327e+00, 4.7906e+00],
[ 3.0749e+00, 2.2618e+00, 8.2658e-02, 3.1331e+00, 6.0706e-03,
4.5357e+00, -8.7022e-01, 1.1336e+00, -5.9143e+00, -3.3319e-02,
1.8048e+00, 4.7143e+00, 5.9190e+00, 1.7524e+00, 5.2120e+00,
1.5569e+00, 3.0329e+00, -2.5079e+00, 4.3327e+00, 4.7906e+00]],
dtype=torch.float64, grad_fn=<MaxBackward0>)