from functools import reduce
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Sequential as Seq, Linear as Lin, ReLU, BatchNorm1d, LeakyReLU
import torch.nn.functional as F
from mlreco.models.layers.gnn.normalizations import BatchNorm, InstanceNorm
[docs]class EdgeLayer(nn.Module):
'''
An EdgeModel for predicting edge features.
Example: Parent-Child Edge prediction and EM primary assignment prediction.
INPUTS:
DEFINITIONS:
E: number of edges
F_x: number of node features
F_e: number of edge features
F_u: number of global features
F_o: number of output edge features
B: number of graphs (same as batch size)
If an entry i->j is an edge, then we have source node feature
F^i_x, target node feature F^j_x, and edge features F_e.
- source: [E, F_x] Tensor, where E is the number of edges
- target: [E, F_x] Tensor, where E is the number of edges
- edge_attr: [E, F_e] Tensor, indicating input edge features.
- global_features: [B, F_u] Tensor, where B is the number of graphs
(equivalent to number of batches).
- batch: [E] Tensor containing batch indices for each edge from 0 to B-1.
RETURNS:
- output: [E, F_o] Tensor with F_o output edge features.
'''
[docs] def __init__(self, node_in, edge_in, edge_out, leakiness=0.0):
super(EdgeLayer, self).__init__()
# TODO: Construct Edge MLP
self.edge_mlp = nn.Sequential(
BatchNorm1d(2 * node_in + edge_in),
nn.Linear(2 * node_in + edge_in, edge_out),
nn.LeakyReLU(negative_slope=leakiness),
BatchNorm1d(edge_out),
nn.Linear(edge_out, edge_out),
nn.LeakyReLU(negative_slope=leakiness),
BatchNorm1d(edge_out),
nn.Linear(edge_out, edge_out)
)
[docs] def forward(self, src, dest, edge_attr, u=None, batch=None):
out = torch.cat([src, dest, edge_attr], dim=1)
return self.edge_mlp(out)
[docs]class NodeLayer(nn.Module):
'''
NodeModel for node feature prediction.
Example: Particle Classification using node-level features.
INPUTS:
DEFINITIONS:
N: number of nodes
F_x: number of node features
F_e: number of edge features
F_u: number of global features
F_o: number of output node features
B: number of graphs (same as batch size)
If an entry i->j is an edge, then we have source node feature
F^i_x, target node feature F^j_x, and edge features F_e.
- source: [E, F_x] Tensor, where E is the number of edges
- target: [E, F_x] Tensor, where E is the number of edges
- edge_attr: [E, F_e] Tensor, indicating input edge features.
- global_features: [B, F_u] Tensor, where B is the number of graphs
(equivalent to number of batches).
- batch: [E] Tensor containing batch indices for each edge from 0 to B-1.
RETURNS:
- output: [C, F_o] Tensor with F_o output node feature
'''
[docs] def __init__(self, node_in, node_out, edge_in, leakiness=0.0, reduction='mean'):
super(NodeLayer, self).__init__()
self.node_mlp_1 = nn.Sequential(
BatchNorm1d(node_in + edge_in),
nn.Linear(node_in + edge_in, node_out),
nn.LeakyReLU(negative_slope=leakiness),
BatchNorm1d(node_out),
nn.Linear(node_out, node_out),
nn.LeakyReLU(negative_slope=leakiness),
BatchNorm1d(node_out),
nn.Linear(node_out, node_out)
)
self.reduction = reduction
self.node_mlp_2 = nn.Sequential(
BatchNorm1d(node_in + node_out),
nn.Linear(node_in + node_out, node_out),
nn.LeakyReLU(negative_slope=leakiness),
BatchNorm1d(node_out),
nn.Linear(node_out, node_out),
nn.LeakyReLU(negative_slope=leakiness),
BatchNorm1d(node_out),
nn.Linear(node_out, node_out)
)
[docs] def forward(self, x, edge_index, edge_attr, u, batch):
from torch_scatter import scatter
row, col = edge_index
out = torch.cat([x[row], edge_attr], dim=1)
out = self.node_mlp_1(out)
out = scatter(out, col, dim=0, dim_size=x.size(0), reduce=self.reduction)
out = torch.cat([x, out], dim=1)
return self.node_mlp_2(out)
[docs]class GlobalModel(nn.Module):
'''
Global Model for global feature prediction.
Example: event classification (graph classification) over the whole image
within a batch.
Do Hierarchical Pooling to reduce features
INPUTS:
DEFINITIONS:
N: number of nodes
F_x: number of node features
F_e: number of edge features
F_u: number of global features
F_o: number of output node features
B: number of graphs (same as batch size)
If an entry i->j is an edge, then we have source node feature
F^i_x, target node feature F^j_x, and edge features F_e.
- source: [E, F_x] Tensor, where E is the number of edges
- target: [E, F_x] Tensor, where E is the number of edges
- edge_attr: [E, F_e] Tensor, indicating input edge features.
- global_features: [B, F_u] Tensor, where B is the number of graphs
(equivalent to number of batches).
- batch: [E] Tensor containing batch indices for each edge from 0 to B-1.
RETURNS:
- output: [C, F_o] Tensor with F_o output node feature
'''
[docs] def __init__(self, node_in, batch_size, global_out, leakiness=0.0, reduction='mean'):
super(GlobalModel, self).__init__()
self.global_mlp = nn.Sequential(
BatchNorm1d(node_in + batch_size),
nn.Linear(node_in + batch_size, global_out),
nn.LeakyReLU(negative_slope=leakiness),
BatchNorm1d(global_out),
nn.Linear(global_out, global_out),
nn.LeakyReLU(negative_slope=leakiness),
BatchNorm1d(global_out),
nn.Linear(global_out, global_out)
)
self.reduction = reduction
[docs] def forward(self, x, edge_index, edge_attr, u, batch):
from torch_scatter import scatter
# x: [N, F_x], where N is the number of nodes.
# edge_index: [2, E] with max entry N - 1.
# edge_attr: [E, F_e]
# u: [B, F_u]
# batch: [N] with max entry B - 1.
out = torch.cat([u, scatter(x, batch, dim=0, reduce=self.reduction)], dim=1)
return self.global_mlp(out)