import torch
import numpy as np
from mlreco.models.grappa import GNN, GNNLoss
from mlreco.utils.deghosting import adapt_labels_knn as adapt_labels
from mlreco.utils.gnn.evaluation import (node_assignment_score,
primary_assignment)
from mlreco.utils.gnn.cluster import (form_clusters,
get_cluster_batch,
get_cluster_label)
[docs]class FullChainGNN(torch.nn.Module):
"""
GNN section of the full chain.
See Also
--------
mlreco.models.full_chain.FullChain, FullChainLoss
"""
MODULES = ['grappa_shower', 'grappa_track', 'grappa_inter',
'grappa_shower_loss', 'grappa_track_loss', 'grappa_inter_loss',
'full_chain_loss', 'spice', 'spice_loss',
'fragment_clustering', 'chain', 'dbscan',
('uresnet_ppn', ['uresnet_lonely', 'ppn'])]
[docs] def __init__(self, cfg):
super(FullChainGNN, self).__init__()
# Configure the chain first
setup_chain_cfg(self, cfg)
# Initialize the particle aggregator modules
if self.enable_gnn_shower:
self.grappa_shower = GNN(cfg, name='grappa_shower', batch_col=self.batch_col, coords_col=self.coords_col)
grappa_shower_cfg = cfg.get('grappa_shower', {})
self._shower_ids = grappa_shower_cfg.get('base', {}).get('node_type', 0)
self._shower_use_true_particles = grappa_shower_cfg.get('use_true_particles', False)
if not isinstance(self._shower_ids, list): self._shower_ids = [self._shower_ids]
if self.enable_gnn_track:
self.grappa_track = GNN(cfg, name='grappa_track', batch_col=self.batch_col, coords_col=self.coords_col)
grappa_track_cfg = cfg.get('grappa_track', {})
self._track_ids = grappa_track_cfg.get('base', {}).get('node_type', 1)
self._track_use_true_particles = grappa_track_cfg.get('use_true_particles', False)
if not isinstance(self._track_ids, list): self._track_ids = [self._track_ids]
if self.enable_gnn_particle:
self.grappa_particle = GNN(cfg, name='grappa_particle', batch_col=self.batch_col, coords_col=self.coords_col)
grappa_particle_cfg = cfg.get('grappa_particle', {})
self._particle_ids = grappa_particle_cfg.get('base', {}).get('node_type', [0,1,2,3])
self._particle_use_true_particles = grappa_particle_cfg.get('use_true_particles', False)
if self.enable_gnn_inter:
self.grappa_inter = GNN(cfg, name='grappa_inter', batch_col=self.batch_col, coords_col=self.coords_col)
grappa_inter_cfg = cfg.get('grappa_inter', {})
self._inter_ids = grappa_inter_cfg.get('base', {}).get('node_type', [0,1,2,3])
self._inter_use_true_particles = grappa_inter_cfg.get('use_true_particles', False)
self.inter_source_col = cfg.get('grappa_inter_loss', {}).get('edge_loss', {}).get('source_col', 6)
self._inter_use_shower_primary = grappa_inter_cfg.get('use_shower_primary', True)
self._inter_enforce_semantics = grappa_inter_cfg.get('enforce_semantics', True)
self._inter_enforce_semantics_shape = grappa_inter_cfg.get('enforce_semantics_shape', (4,5))
self._inter_enforce_semantics_map = grappa_inter_cfg.get('enforce_semantics_map', [[0,0,1,1,1,2,3],[0,1,2,3,4,1,1]])
if self.enable_gnn_kinematics:
self.grappa_kinematics = GNN(cfg, name='grappa_kinematics', batch_col=self.batch_col, coords_col=self.coords_col)
self._kinematics_use_true_particles = cfg.get('grappa_kinematics', {}).get('use_true_particles', False)
[docs] def run_gnn(self, grappa, input, result, clusts, labels, kwargs={}):
"""
Generic function to group in one place the common code to run a GNN model.
Parameters
==========
- grappa: GrapPA module to run
- input: input data
- result: dictionary
- clusts: list of list of indices (indexing input data)
- labels: dictionary of strings to label the final result
- kwargs: extra arguments to pass to the gnn
Returns
=======
None (modifies the result dict in place)
"""
# Pass data through the GrapPA model
gnn_output = grappa(input, clusts, batch_size=self.batch_size, **kwargs)
# Update the result dictionary if the corresponding label exists
for l, tag in labels.items():
if l in gnn_output.keys():
result.update({tag: gnn_output[l]})
# Make group predictions based on the GNN output, if requested
if 'group_pred' in labels:
group_ids = []
for b in range(len(gnn_output['clusts'][0])):
if len(gnn_output['clusts'][0][b]) < 2:
group_ids.append(np.zeros(len(gnn_output['clusts'][0][b]),
dtype=np.int64))
else:
group_ids.append(node_assignment_score(
gnn_output['edge_index'][0][b],
gnn_output['edge_pred'][0][b].detach().cpu().numpy(),
len(gnn_output['clusts'][0][b])))
result.update({labels['group_pred']: [group_ids]})
[docs] def select_particle_in_group(self, result, counts, b, particles,
part_primary_ids,
node_pred,
group_pred,
fragments):
"""
Merge fragments into particle instances, retain
primary fragment id of each group
"""
voxel_inds = counts[:b].sum().item()+np.arange(counts[b].item())
primary_labels = None
if node_pred in result:
primary_labels = primary_assignment(
result[node_pred][0][b].detach().cpu().numpy(),
result[group_pred][0][b])
for g in np.unique(result[group_pred][0][b]):
group_mask = np.where(result[group_pred][0][b] == g)[0]
particles.append(
voxel_inds[np.concatenate(result[fragments][0][b][group_mask])])
if node_pred in result:
primary_id = group_mask[primary_labels[group_mask]][0]
part_primary_ids.append(primary_id)
else:
part_primary_ids.append(g)
[docs] def get_all_fragments(self, result, input):
"""
Given geometric or CNN clustering results and (optional) true
fragment labels, return true or predicted fragments
"""
if self.use_true_fragments:
label_clustering = result['label_clustering'][0]
fragments = form_clusters(label_clustering[0].int().cpu().numpy(),
column=5,
batch_index=self.batch_col)
fragments = np.array(fragments, dtype=object)
frag_seg = get_cluster_label(label_clustering[0].int(),
fragments,
column=-1)
semantic_labels = label_clustering[0].int()[:, -1]
frag_batch_ids = get_cluster_batch(input[0][:, :5],
fragments,
batch_index=self.batch_col)
else:
fragments = result['frags'][0]
frag_seg = result['frag_seg'][0]
frag_batch_ids = result['frag_batch_ids'][0]
semantic_labels = result['semantic_labels'][0]
frag_dict = {
'frags': fragments,
'frag_seg': frag_seg,
'frag_batch_ids': frag_batch_ids,
'semantic_labels': semantic_labels
}
# Since <vids> and <counts> depend on the batch column of the input
# tensor, they are shared between the two settings.
frag_dict['vids'] = result['vids'][0]
frag_dict['counts'] = result['counts'][0]
return frag_dict
[docs] def run_fragment_gnns(self, result, input):
"""
Run all fragment-level GNN models.
1. Shower GNN
2. Track GNN
3. Particle GNN (optional?)
"""
frag_dict = self.get_all_fragments(result, input)
fragments = frag_dict['frags']
frag_seg = frag_dict['frag_seg']
if self.enable_gnn_shower:
# Run shower GrapPA: merges shower fragments into shower instances
em_mask, kwargs = self.get_extra_gnn_features(fragments,
frag_seg,
self._shower_ids,
input,
result,
use_ppn=self.use_ppn_in_gnn,
use_supp=self.use_supp_in_gnn)
output_keys = {'clusts' : 'shower_fragments',
'node_pred' : 'shower_node_pred',
'edge_pred' : 'shower_edge_pred',
'edge_index': 'shower_edge_index',
'group_pred': 'shower_group_pred',
'input_node_features': 'shower_node_features'}
# shower_grappa_input = input
# if self.use_true_fragments and 'points' not in kwargs:
# # Add true particle coords to input
# print("adding true points to grappa shower input")
# shower_grappa_input += result['true_points']
# result['shower_gnn_points'] = [kwargs['points']]
# result['shower_gnn_extra_feats'] = [kwargs['extra_feats']]
self.run_gnn(self.grappa_shower,
input,
result,
fragments[em_mask],
output_keys,
kwargs)
if self.enable_gnn_track:
# Run track GrapPA: merges tracks fragments into track instances
track_mask, kwargs = self.get_extra_gnn_features(fragments,
frag_seg,
self._track_ids,
input,
result,
use_ppn=self.use_ppn_in_gnn,
use_supp=self.use_supp_in_gnn)
output_keys = {'clusts' : 'track_fragments',
'node_pred' : 'track_node_pred',
'edge_pred' : 'track_edge_pred',
'edge_index': 'track_edge_index',
'group_pred': 'track_group_pred',
'input_node_features': 'track_node_features'}
self.run_gnn(self.grappa_track,
input,
result,
fragments[track_mask],
output_keys,
kwargs)
if self.enable_gnn_particle:
# Run particle GrapPA: merges particle fragments or
# labels in _partile_ids together into particle instances
mask, kwargs = self.get_extra_gnn_features(fragments,
frag_seg,
self._particle_ids,
input,
result,
use_ppn=self.use_ppn_in_gnn,
use_supp=self.use_supp_in_gnn)
kwargs['groups'] = frag_seg[mask]
output_keys = {'clusts' : 'particle_fragments',
'node_pred' : 'particle_node_pred',
'edge_pred' : 'particle_edge_pred',
'edge_index': 'particle_edge_index',
'group_pred': 'particle_group_pred'}
self.run_gnn(self.grappa_particle,
input,
result,
fragments[mask],
output_keys,
kwargs)
return frag_dict
[docs] def get_all_particles(self, frag_result, result, input):
fragments = frag_result['frags']
frag_seg = frag_result['frag_seg']
frag_batch_ids = frag_result['frag_batch_ids']
semantic_labels = frag_result['semantic_labels']
# for i, c in enumerate(fragments):
# print('format' , torch.unique(input[0][c, self.batch_col], return_counts=True))
vids = frag_result['vids']
counts = frag_result['counts']
# Merge fragments into particle instances, retain primary fragment id of showers
particles, part_primary_ids = [], []
# It is possible that len(counts) > len(np.unique(frag_batch_ids))
#assert len(counts) == len(np.unique(frag_batch_ids))
# Can happen e.g. if an event has no shower fragments
for b in range(len(counts)):
mask = (frag_batch_ids == b)
# Append one particle per particle group
# To use true group predictions, change use_group_pred to True
# in each grappa config.
if self.enable_gnn_particle:
self.select_particle_in_group(result, counts, b, particles,
part_primary_ids,
'particle_node_pred',
'particle_group_pred',
'particle_fragments')
for c in self._particle_ids:
mask &= (frag_seg != c)
# Append one particle per shower group
if self.enable_gnn_shower:
self.select_particle_in_group(result, counts, b, particles,
part_primary_ids,
'shower_node_pred',
'shower_group_pred',
'shower_fragments')
for c in self._shower_ids:
mask &= (frag_seg != c)
# Append one particle per track group
if self.enable_gnn_track:
self.select_particle_in_group(result, counts, b, particles,
part_primary_ids,
'track_node_pred',
'track_group_pred',
'track_fragments')
for c in self._track_ids:
mask &= (frag_seg != c)
# Append one particle per fragment that is not already accounted for
particles.extend(fragments[mask])
part_primary_ids.extend(-np.ones(np.sum(mask)).astype(int))
same_length = np.all([len(p) == len(particles[0]) for p in particles])
particles = np.array(particles,
dtype=object if not same_length else np.int64)
part_batch_ids = get_cluster_batch(input[0],
particles,
batch_index=self.batch_col)
part_primary_ids = np.array(part_primary_ids, dtype=np.int32)
part_seg = np.empty(len(particles), dtype=np.int32)
for i, p in enumerate(particles):
vals, cnts = semantic_labels[p].unique(return_counts=True)
#assert len(vals) == 1
part_seg[i] = vals[torch.argmax(cnts)].item()
# Store in result the intermediate fragments
bcids = [np.where(part_batch_ids == b)[0] for b in range(len(counts))]
same_length = [np.all([len(c) == len(particles[b][0]) \
for c in particles[b]] ) for b in bcids]
parts = [np.array([vids[c].astype(np.int64) for c in particles[b]],
dtype=object \
if not same_length[idx] \
else np.int64) for idx, b in enumerate(bcids)]
parts_seg = [part_seg[b] for idx, b in enumerate(bcids)]
result.update({
'particles': [parts],
'particles_seg': [parts_seg]
})
part_result = {
'particles': particles,
'part_seg': part_seg,
'part_batch_ids': part_batch_ids,
'part_primary_ids': part_primary_ids,
'counts': counts
}
return part_result
[docs] def run_particle_gnns(self, result, input, frag_result):
part_result = self.get_all_particles(frag_result, result, input)
particles = part_result['particles']
part_seg = part_result['part_seg']
part_batch_ids = part_result['part_batch_ids']
part_primary_ids = part_result['part_primary_ids']
counts = part_result['counts']
label_clustering = result['label_clustering'][0] if 'label_clustering' in result else None
if label_clustering is None and (self.use_true_fragments or (self.enable_cosmic and self._cosmic_use_true_interactions)):
raise Exception('Need clustering labels to use true fragments or true interactions.')
device = input[0].device
if self.enable_gnn_inter:
if self._inter_use_true_particles:
#label_clustering = [label_clustering[0].cpu().numpy()]
particles = form_clusters(label_clustering[0].int().cpu().numpy(), min_size=-1, column=self.inter_source_col, cluster_classes=self._inter_ids)
particles = np.array(particles, dtype=object)
part_seg = get_cluster_label(label_clustering[0].int(), particles, column=-1)
part_batch_ids = get_cluster_batch(label_clustering[0], particles, batch_index=0)
_, counts = torch.unique(label_clustering[0][:, 0], return_counts=True)
# For showers, select primary for extra feature extraction
extra_feats_particles = []
for i, p in enumerate(particles):
if part_seg[i] == 0 and not self._inter_use_true_particles and self._inter_use_shower_primary:
voxel_inds = counts[:part_batch_ids[i]].sum().item() + \
np.arange(counts[part_batch_ids[i]].item())
if len(voxel_inds) and len(result['shower_fragments'][0][part_batch_ids[i]]) > 0:
try:
p = voxel_inds[result['shower_fragments'][0]\
[part_batch_ids[i]][part_primary_ids[i]]]
except IndexError as e:
print(len(result['shower_fragments'][0]))
print([part_batch_ids[i]])
print(part_primary_ids[i])
print(len(voxel_inds))
print(result['shower_fragments'][0][part_batch_ids[i]][part_primary_ids[i]])
raise e
extra_feats_particles.append(p)
# result['extra_feats_particles'] = [extra_feats_particles]
same_length = np.all([len(p) == len(extra_feats_particles[0]) \
for p in extra_feats_particles])
extra_feats_particles = np.array(extra_feats_particles,
dtype=object \
if not same_length else np.int64)
# Run interaction GrapPA: merges particle instances into interactions
inter_mask, kwargs = self.get_extra_gnn_features(extra_feats_particles,
part_seg,
self._inter_ids,
input,
result,
use_ppn=self.use_ppn_in_gnn,
use_supp=True)
output_keys = {'clusts': 'inter_particles',
'edge_pred': 'inter_edge_pred',
'edge_index': 'inter_edge_index',
'group_pred': 'inter_group_pred',
'node_pred': 'inter_node_pred',
'node_pred_type': 'node_pred_type',
'node_pred_p': 'node_pred_p',
'node_pred_vtx': 'node_pred_vtx',
'input_node_features': 'particle_node_features',
'input_edge_features': 'particle_edge_features'}
self.run_gnn(self.grappa_inter,
input,
result,
particles[inter_mask],
output_keys,
kwargs)
# If requested, enforce that particle PID predictions are compatible with semantics,
# i.e. set logits to -inf if they belong to incompatible PIDs
if self._inter_enforce_semantics and 'node_pred_type' in result:
sem_pid_logic = -float('inf')*torch.ones(self._inter_enforce_semantics_shape, dtype=input[0].dtype, device=input[0].device)
sem_pid_logic[self._inter_enforce_semantics_map] = 0.
pid_logits = result['node_pred_type']
for i in range(len(pid_logits)):
for b in range(len(pid_logits[i])):
pid_logits[i][b] += sem_pid_logic[part_seg[part_batch_ids==b]]
result['node_pred_type'] = pid_logits
# ---
# 4. GNN for particle flow & kinematics
# ---
if self.enable_gnn_kinematics:
if not self.enable_gnn_inter:
raise Exception("Need interaction clustering before kinematic GNN.")
output_keys = {'clusts': 'kinematics_particles',
'edge_index': 'kinematics_edge_index',
'node_pred_p': 'kinematics_node_pred_p',
'node_pred_type': 'kinematics_node_pred_type',
'edge_pred': 'flow_edge_pred'}
self.run_gnn(self.grappa_kinematics,
input,
result,
particles[inter_mask],
output_keys)
# ---
# 5. CNN for interaction classification
# ---
if self.enable_cosmic:
if not self.enable_gnn_inter and not self._cosmic_use_true_interactions:
raise Exception("Need interaction clustering before cosmic discrimination.")
_, counts = torch.unique(input[0][:, self.batch_col], return_counts=True)
interactions, inter_primary_ids = [], []
# Note to self: inter_primary_ids is not used as of now
if self._cosmic_use_true_interactions:
if label_clustering is None:
raise Exception("The option to use true interactions requires label segmentation and clustering in the network input.")
interactions = form_clusters(label_clustering[0], column=7, batch_index=self.batch_col)
interactions = [inter.cpu().numpy() for inter in interactions]
else:
for b in range(len(counts)):
self.select_particle_in_group(result, counts, b, interactions, inter_primary_ids,
None, 'inter_group_pred', 'particles')
same_length = np.all([len(inter) == len(interactions[0]) for inter in interactions])
interactions = [inter.astype(np.int64) for inter in interactions]
interactions = np.array(interactions,
dtype=object if not same_length else np.int64)
inter_batch_ids = get_cluster_batch(input[0], interactions, batch_index=self.batch_col)
inter_cosmic_pred = torch.empty((len(interactions), 2), dtype=torch.float)
# Replace batch id column with a global "interaction id"
# because ResidualEncoder uses the batch id column to shape its output
if 'ppn_feature_dec' in result:
feature_map = result['ppn_feature_dec'][0][-1]
else:
feature_map = result['ppn_layers'][0][-1]
if not torch.is_tensor(feature_map):
feature_map = feature_map.features
inter_input_data = input[0].float() if self._cosmic_use_input_data \
else torch.cat([input[0][:, :4].float(), feature_map], dim=1)
inter_data = torch.empty((0, inter_input_data.size(1)), dtype=torch.float, device=device)
for i, interaction in enumerate(interactions):
inter_data = torch.cat([inter_data, inter_input_data[interaction]], dim=0)
inter_data[-len(interaction):, self.batch_col] = i * torch.ones(len(interaction)).to(device)
inter_cosmic_pred = self.cosmic_discriminator(inter_data)
# Reorganize into batches before storing in result dictionary
same_length = np.all([len(f) == len(interactions[0]) for f in interactions] )
interactions = np.array(interactions, dtype=object if not same_length else np.int64)
inter_batch_ids = np.array(inter_batch_ids)
batches, counts = torch.unique(input[0][:, self.batch_col], return_counts=True)
# In case one of the events is "missing" and len(counts) < batch_size
if len(counts) < self.batch_size:
new_counts = torch.zeros(self.batch_size, dtype=torch.int64, device=counts.device)
new_counts[batches] = counts
counts = new_counts
vids = np.concatenate([np.arange(n.item()) for n in counts])
bcids = [np.where(inter_batch_ids == b)[0] for b in range(len(counts))]
same_length = [np.all([len(c) == len(interactions[b][0]) for c in interactions[b]] ) for b in bcids]
interactions_np = [np.array([vids[c].astype(np.int64) for c in interactions[b]],
dtype=object if not same_length[idx] else np.int64) \
for idx, b in enumerate(bcids)]
inter_cosmic_pred_np = [inter_cosmic_pred[b] for idx, b in enumerate(bcids)]
result.update({
'interactions': [interactions_np],
'inter_cosmic_pred': [inter_cosmic_pred_np]
})
[docs] def full_chain_gnn(self, result, input):
frag_dict = self.run_fragment_gnns(result, input)
self.run_particle_gnns(result, input, frag_dict)
return result
[docs] def forward(self, input):
"""
Input can be either of the following:
- input data only
- input data, label clustering in this order
- input data, label segmentation, label clustering in this order
(when deghosting is enabled, label segmentation is needed to
adapt label clustering properly)
Parameters
==========
input: list of np.ndarray
"""
result, input, revert_func = self.full_chain_cnn(input)
if len(input[0]) and 'frags' in result and self.process_fragments and (self.enable_gnn_track or self.enable_gnn_shower or self.enable_gnn_inter or self.enable_gnn_particle):
result = self.full_chain_gnn(result, input)
result = revert_func(result)
return result
[docs]class FullChainLoss(torch.nn.modules.loss._Loss):
"""
Loss for UResNet + PPN chain
See Also
--------
mlreco.models.full_chain.FullChainLoss, FullChainGNN
"""
# INPUT_SCHEMA = [
# ["parse_sparse3d_scn", (int,), (3, 1)],
# ["parse_particle_points", (int,), (3, 1)]
# ]
[docs] def __init__(self, cfg):
super(FullChainLoss, self).__init__()
# Configure the chain first
setup_chain_cfg(self, cfg, False)
if self.enable_gnn_shower:
self.shower_gnn_loss = GNNLoss(cfg, 'grappa_shower_loss', batch_col=self.batch_col, coords_col=self.coords_col)
if self.enable_gnn_track:
self.track_gnn_loss = GNNLoss(cfg, 'grappa_track_loss', batch_col=self.batch_col, coords_col=self.coords_col)
if self.enable_gnn_particle:
self.particle_gnn_loss = GNNLoss(cfg, 'grappa_particle_loss', batch_col=self.batch_col, coords_col=self.coords_col)
if self.enable_gnn_inter:
self.inter_gnn_loss = GNNLoss(cfg, 'grappa_inter_loss', batch_col=self.batch_col, coords_col=self.coords_col)
if self.enable_gnn_kinematics:
self.kinematics_loss = GNNLoss(cfg, 'grappa_kinematics_loss', batch_col=self.batch_col, coords_col=self.coords_col)
if self.enable_cosmic:
self.cosmic_loss = GNNLoss(cfg, 'cosmic_loss', batch_col=self.batch_col, coords_col=self.coords_col)
# Initialize the loss weights
self.loss_config = cfg.get('full_chain_loss', {})
self.deghost_weight = self.loss_config.get('deghost_weight', 1.0)
self.segmentation_weight = self.loss_config.get('segmentation_weight', 1.0)
self.ppn_weight = self.loss_config.get('ppn_weight', 1.0)
self.cnn_clust_weight = self.loss_config.get('cnn_clust_weight', 1.0)
self.shower_gnn_weight = self.loss_config.get('shower_gnn_weight', 1.0)
self.track_gnn_weight = self.loss_config.get('track_gnn_weight', 1.0)
self.particle_gnn_weight = self.loss_config.get('particle_gnn_weight', 1.0)
self.inter_gnn_weight = self.loss_config.get('inter_gnn_weight', 1.0)
self.kinematics_weight = self.loss_config.get('kinematics_weight', 1.0)
self.flow_weight = self.loss_config.get('flow_weight', 1.0)
self.kinematics_p_weight = self.loss_config.get('kinematics_p_weight', 1.0)
self.kinematics_type_weight = self.loss_config.get('kinematics_type_weight', 1.0)
self.cosmic_weight = self.loss_config.get('cosmic_weight', 1.0)
[docs] def forward(self, out, seg_label, ppn_label=None, cluster_label=None, kinematics_label=None,
particle_graph=None, iteration=None):
res = {}
accuracy, loss = 0., 0.
if self.enable_charge_rescaling:
ghost_label = torch.cat((seg_label[0][:,:4], (seg_label[0][:,-1] == 5).type(seg_label[0].dtype).reshape(-1,1)), dim=-1)
res_deghost = self.deghost_loss({'segmentation':out['ghost']}, [ghost_label])
for key in res_deghost:
res['deghost_' + key] = res_deghost[key]
accuracy += res_deghost['accuracy']
loss += self.deghost_weight*res_deghost['loss']
deghost = (out['ghost'][0][:,0] > out['ghost'][0][:,1]) & (seg_label[0][:,-1] < 5) # Only apply loss to reco/true non-ghosts
if self.enable_uresnet and 'segmentation' in out:
if not self.enable_charge_rescaling:
res_seg = self.uresnet_loss(out, seg_label)
else:
res_seg = self.uresnet_loss({'segmentation':[out['segmentation'][0][deghost]]}, [seg_label[0][deghost]])
for key in res_seg:
res['segmentation_' + key] = res_seg[key]
accuracy += res_seg['accuracy']
loss += self.segmentation_weight*res_seg['loss']
#print('uresnet ', self.segmentation_weight, res_seg['loss'], loss)
if self.enable_ppn and 'ppn_output_coordinates' in out:
# Apply the PPN loss
res_ppn = self.ppn_loss(out, seg_label, ppn_label)
for key in res_ppn:
res['ppn_' + key] = res_ppn[key]
accuracy += res_ppn['accuracy']
loss += self.ppn_weight*res_ppn['loss']
if self.enable_ghost and 'ghost_label' in out \
and (self.enable_cnn_clust or \
self.enable_gnn_track or \
self.enable_gnn_shower or \
self.enable_gnn_inter or \
self.enable_gnn_kinematics or \
self.enable_cosmic):
deghost = out['ghost_label'][0]
if self.cheat_ghost:
true_mask = deghost
else:
true_mask = None
# Adapt to ghost points
if cluster_label is not None:
cluster_label = adapt_labels(out,
seg_label,
cluster_label,
batch_column=self.batch_col,
true_mask=true_mask)
if kinematics_label is not None:
kinematics_label = adapt_labels(out,
seg_label,
kinematics_label,
batch_column=self.batch_col,
true_mask=true_mask)
segment_label = seg_label[0][deghost][:, -1]
seg_label = seg_label[0][deghost]
else:
segment_label = seg_label[0][:, -1]
seg_label = seg_label[0]
if self.enable_cnn_clust:
# If there is no track voxel, maybe GraphSpice didn't run
if self._enable_graph_spice and 'graph' in out:
graph_spice_out = {
'graph': out['graph'],
'graph_info': out['graph_info'],
'spatial_embeddings': out['spatial_embeddings'],
'feature_embeddings': out['feature_embeddings'],
'covariance': out['covariance'],
'hypergraph_features': out['hypergraph_features'],
'features': out['features'],
'occupancy': out['occupancy'],
'coordinates': out['coordinates'],
'batch_indices': out['batch_indices'],
#'segmentation': [out['segmentation'][0][deghost]] if self.enable_ghost else [out['segmentation'][0]]
}
segmentation_pred = out['segmentation'][0]
if self.enable_ghost:
segmentation_pred = segmentation_pred[deghost]
if self._gspice_use_true_labels:
gs_seg_label = torch.cat([cluster_label[0][:, :4], segment_label[:, None]], dim=1)
else:
gs_seg_label = torch.cat([cluster_label[0][:, :4], torch.argmax(segmentation_pred, dim=1)[:, None]], dim=1)
#gs_seg_label = torch.cat([cluster_label[0][:, :4], segment_label[:, None]], dim=1)
# NOTE: We need to limit loss computation to voxels that are
# in the intersection of truth and prediction.
# Setting seg label to -1 does not work (embeddings already
# have a shape based on predicted semantics). Instead we set
# the cluster label to -1 and the GraphSPICEEmbeddingLoss
# will remove voxels with true cluster label -1.
gs_cluster_label = cluster_label[0]
if not self._gspice_use_true_labels:
gs_cluster_label[(gs_cluster_label[:, -1] != torch.argmax(segmentation_pred, dim=1)), 5] = -1
#res['gs_cluster_label'] = [gs_cluster_label]
res_graph_spice = self.spatial_embeddings_loss(graph_spice_out, [gs_seg_label], [gs_cluster_label])
#print(res_graph_spice.keys())
if 'accuracy' in res_graph_spice:
accuracy += res_graph_spice['accuracy']
loss += self.cnn_clust_weight * res_graph_spice['loss']
for key in res_graph_spice:
res['graph_spice_' + key] = res_graph_spice[key]
elif 'embeddings' in out:
# Apply the CNN dense clustering loss to HE voxels only
he_mask = segment_label < 4
# sem_label = [torch.cat((cluster_label[0][he_mask,:4],cluster_label[0][he_mask,-1].view(-1,1)), dim=1)]
#clust_label = [torch.cat((cluster_label[0][he_mask,:4],cluster_label[0][he_mask,5].view(-1,1),cluster_label[0][he_mask,4].view(-1,1)), dim=1)]
clust_label = [cluster_label[0][he_mask].clone()]
cnn_clust_output = {'embeddings':[out['embeddings'][0][he_mask]], 'seediness':[out['seediness'][0][he_mask]], 'margins':[out['margins'][0][he_mask]]}
#cluster_label[0] = cluster_label[0][he_mask]
# FIXME does this suppose that clust_label has same ordering as embeddings?
res_cnn_clust = self.spatial_embeddings_loss(cnn_clust_output, clust_label)
for key in res_cnn_clust:
res['cnn_clust_' + key] = res_cnn_clust[key]
accuracy += res_cnn_clust['accuracy']
loss += self.cnn_clust_weight*res_cnn_clust['loss']
if self.enable_gnn_shower:
# Apply the GNN shower clustering loss
gnn_out = {}
if 'shower_edge_pred' in out:
gnn_out = {
'clusts':out['shower_fragments'],
'node_pred':out['shower_node_pred'],
'edge_pred':out['shower_edge_pred'],
'edge_index':out['shower_edge_index']
}
res_gnn_shower = self.shower_gnn_loss(gnn_out, cluster_label)
for key in res_gnn_shower:
res['grappa_shower_' + key] = res_gnn_shower[key]
accuracy += res_gnn_shower['accuracy']
loss += self.shower_gnn_weight*res_gnn_shower['loss']
if self.enable_gnn_track:
# Apply the GNN track clustering loss
gnn_out = {}
if 'track_edge_pred' in out:
gnn_out = {
'clusts':out['track_fragments'],
'edge_pred':out['track_edge_pred'],
'edge_index':out['track_edge_index']
}
res_gnn_track = self.track_gnn_loss(gnn_out, cluster_label)
for key in res_gnn_track:
res['grappa_track_' + key] = res_gnn_track[key]
accuracy += res_gnn_track['accuracy']
loss += self.track_gnn_weight*res_gnn_track['loss']
if self.enable_gnn_particle:
# Apply the GNN particle clustering loss
gnn_out = {}
if 'particle_edge_pred' in out:
gnn_out = {
'clusts':out['particle_fragments'],
'node_pred':out['particle_node_pred'],
'edge_pred':out['particle_edge_pred'],
'edge_index':out['particle_edge_index']
}
res_gnn_part = self.particle_gnn_loss(gnn_out, cluster_label)
for key in res_gnn_particle:
res['grappa_particle_' + key] = res_gnn_particle[key]
accuracy += res_gnn_part['accuracy']
loss += self.particle_gnn_weight*res_gnn_part['loss']
if self.enable_gnn_inter:
# Apply the GNN interaction grouping loss
gnn_out = {}
if 'inter_edge_pred' in out:
gnn_out = {
'clusts':out['inter_particles'],
'edge_pred':out['inter_edge_pred'],
'edge_index':out['inter_edge_index']
}
if 'inter_node_pred' in out: gnn_out.update({ 'node_pred': out['inter_node_pred'] })
if 'node_pred_type' in out: gnn_out.update({ 'node_pred_type': out['node_pred_type'] })
if 'node_pred_p' in out: gnn_out.update({ 'node_pred_p': out['node_pred_p'] })
if 'node_pred_vtx' in out: gnn_out.update({ 'node_pred_vtx': out['node_pred_vtx'] })
if 'particle_node_features' in out: gnn_out.update({ 'input_node_features': out['particle_node_features'] })
if 'particle_edge_features' in out: gnn_out.update({ 'input_edge_features': out['particle_edge_features'] })
res_gnn_inter = self.inter_gnn_loss(gnn_out, cluster_label, node_label=kinematics_label, graph=particle_graph, iteration=iteration)
for key in res_gnn_inter:
res['grappa_inter_' + key] = res_gnn_inter[key]
accuracy += res_gnn_inter['accuracy']
loss += self.inter_gnn_weight*res_gnn_inter['loss']
if self.enable_gnn_kinematics:
# Loss on node predictions (type & momentum)
gnn_out = {}
if 'flow_edge_pred' in out:
gnn_out = {
'clusts': out['kinematics_particles'],
'edge_pred': out['flow_edge_pred'],
'edge_index': out['kinematics_edge_index']
}
if 'node_pred_type' in out:
gnn_out.update({ 'node_pred_type': out['node_pred_type'] })
if 'node_pred_p' in out:
gnn_out.update({ 'node_pred_p': out['node_pred_p'] })
res_kinematics = self.kinematics_loss(gnn_out, kinematics_label, graph=particle_graph)
for key in res_kinematics:
res['grappa_kinematics_' + key] = res_kinematics[key]
accuracy += res_kinematics['node_accuracy']
# Do not forget to take p_weight and type_weight into account (above)
loss += self.kinematics_weight * res['grappa_kinematics_loss']
# Loss on edge predictions (particle hierarchy)
res['flow_loss'] = res_kinematics['edge_loss']
res['flow_accuracy'] = res_kinematics['edge_accuracy']
accuracy += res_kinematics['edge_accuracy']
loss += self.flow_weight * res_kinematics['edge_loss']
if self.enable_cosmic:
gnn_out = {
'clusts':out['interactions'],
'node_pred':out['inter_cosmic_pred'],
}
res_cosmic = self.cosmic_loss(gnn_out, cluster_label)
for key in res_cosmic:
res['cosmic_' + key] = res_cosmic[key]
accuracy += res_cosmic['accuracy']
loss += self.cosmic_weight * res_cosmic['loss']
# Combine the results
accuracy /= int(self.enable_charge_rescaling) + int(self.enable_uresnet) + int(self.enable_ppn) + int(self.enable_gnn_shower) \
+ int(self.enable_gnn_inter) + int(self.enable_gnn_track) + int(self.enable_cnn_clust) \
+ 2*int(self.enable_gnn_kinematics) + int(self.enable_cosmic) + int(self.enable_gnn_particle)
res['loss'] = loss
res['accuracy'] = accuracy
#print('Loss = ', res['loss'])
if self.verbose:
if self.enable_charge_rescaling:
print('Deghosting Accuracy: {:.4f}'.format(res_deghost['accuracy']))
if self.enable_uresnet and 'segmentation' in out:
print('Segmentation Accuracy: {:.4f}'.format(res_seg['accuracy']))
if self.enable_ppn and 'ppn_output_coordinates' in out:
print('PPN Accuracy: {:.4f}'.format(res_ppn['accuracy']))
if self.enable_cnn_clust and ('graph' in out or 'embeddings' in out):
if not self._enable_graph_spice:
print('Clustering Embedding Accuracy: {:.4f}'.format(res_cnn_clust['accuracy']))
else:
print('Clustering Accuracy: {:.4f}'.format(res_graph_spice['accuracy']))
if 'edge_accuracy' in res_graph_spice:
print('Clustering Edge Accuracy: {:.4f}'.format(res_graph_spice['edge_accuracy']))
if self.enable_gnn_shower:
print('Shower fragment clustering accuracy: {:.4f}'.format(res_gnn_shower['edge_accuracy']))
print('Shower primary prediction accuracy: {:.4f}'.format(res_gnn_shower['node_accuracy']))
if self.enable_gnn_track:
print('Track fragment clustering accuracy: {:.4f}'.format(res_gnn_track['edge_accuracy']))
if self.enable_gnn_particle:
print('Particle fragment clustering accuracy: {:.4f}'.format(res_gnn_part['edge_accuracy']))
print('Particle primary prediction accuracy: {:.4f}'.format(res_gnn_part['node_accuracy']))
if self.enable_gnn_inter:
#if 'node_accuracy' in res_gnn_inter: print('Particle ID accuracy: {:.4f}'.format(res_gnn_inter['node_accuracy']))
print('Interaction grouping accuracy: {:.4f}'.format(res_gnn_inter['edge_accuracy']))
if self.enable_gnn_kinematics:
print('Flow accuracy: {:.4f}'.format(res_kinematics['edge_accuracy']))
if 'node_pred_type' in out:
if 'grappa_inter_type_accuracy' in res:
print('Particle ID accuracy: {:.4f}'.format(res['grappa_inter_type_accuracy']))
elif 'grappa_kinematics_type_accuracy' in res:
print('Particle ID accuracy: {:.4f}'.format(res['grappa_kinematics_type_accuracy']))
if 'node_pred_p' in out:
if 'grappa_inter_p_accuracy' in res:
print('Momentum accuracy: {:.4f}'.format(res['grappa_inter_p_accuracy']))
elif 'grappa_kinematics_p_accuracy' in res:
print('Momentum accuracy: {:.4f}'.format(res['grappa_kinematics_p_accuracy']))
if 'node_pred_vtx' in out:
if 'grappa_inter_vtx_score_accuracy' in res:
print('Primary particle score accuracy: {:.4f}'.format(res['grappa_inter_vtx_score_accuracy']))
elif 'grappa_kinematics_vtx_score_accuracy' in res:
print('Primary particle score accuracy: {:.4f}'.format(res['grappa_kinematics_vtx_score_accuracy']))
if self.enable_cosmic:
print('Cosmic discrimination accuracy: {:.4f}'.format(res_cosmic['accuracy']))
return res
[docs]def setup_chain_cfg(self, cfg, print_info=True):
"""
Prepare both FullChain and FullChainLoss
Make sure config is logically sound with some basic checks
See Also
--------
mlreco.models.full_chain.FullChain, FullChainGNN
"""
chain_cfg = cfg.get('chain', {})
self.use_me = chain_cfg.get('use_mink', True)
self.batch_col = 0 if self.use_me else 3
self.coords_col = (1, 4) if self.use_me else (0, 3)
self.batch_size = None # To be set at forward time
self.process_fragments = chain_cfg.get('process_fragments', False)
self.use_true_fragments = chain_cfg.get('use_true_fragments', False)
self.use_true_particles = chain_cfg.get('use_true_particles', False)
self._gspice_use_true_labels = cfg.get('graph_spice', {}).get('use_true_labels', False)
self.enable_charge_rescaling = chain_cfg.get('enable_charge_rescaling', False)
self.enable_ghost = chain_cfg.get('enable_ghost', False)
self.cheat_ghost = chain_cfg.get('cheat_ghost', False)
self.verbose = chain_cfg.get('verbose', False)
self.enable_uresnet = chain_cfg.get('enable_uresnet', True)
self.enable_ppn = chain_cfg.get('enable_ppn', True)
self.enable_dbscan = chain_cfg.get('enable_dbscan', True)
self.enable_cnn_clust = chain_cfg.get('enable_cnn_clust', False)
self.enable_gnn_shower = chain_cfg.get('enable_gnn_shower', False)
self.enable_gnn_track = chain_cfg.get('enable_gnn_track', False)
self.enable_gnn_particle = chain_cfg.get('enable_gnn_particle', False)
self.enable_gnn_inter = chain_cfg.get('enable_gnn_inter', False)
self.enable_gnn_kinematics = chain_cfg.get('enable_gnn_kinematics', False)
self.enable_cosmic = chain_cfg.get('enable_cosmic', False)
if self.verbose and print_info:
print("Shower GNN: {}".format(self.enable_gnn_shower))
print("Track GNN: {}".format(self.enable_gnn_track))
print("Particle GNN: {}".format(self.enable_gnn_particle))
print("Interaction GNN: {}".format(self.enable_gnn_inter))
print("Kinematics GNN: {}".format(self.enable_gnn_kinematics))
print("Cosmic GNN: {}".format(self.enable_cosmic))
if (self.enable_gnn_shower or \
self.enable_gnn_track or \
self.enable_gnn_particle or \
self.enable_gnn_inter or \
self.enable_gnn_kinematics or self.enable_cosmic):
if self.verbose and print_info:
msg = """
Since one of the GNNs are turned on, process_fragments is turned ON.
"""
print(msg)
self.process_fragments = True
if self.process_fragments and self.verbose and print_info:
msg = """
Fragment processing is turned ON. When training CNN models from
scratch, we recommend turning fragment processing OFF as without
reliable segmentation and/or cnn clustering outputs this could take
prohibitively large training iterations.
"""
print(msg)
# If fragment processing is turned off, no inputs to GNN
if not self.process_fragments:
self.enable_gnn_shower = False
self.enable_gnn_track = False
self.enable_gnn_particle = False
self.enable_gnn_inter = False
self.enable_gnn_kinematics = False
self.enable_cosmic = False
# Whether to use PPN information (GNN shower clustering step only)
self.use_ppn_in_gnn = chain_cfg.get('use_ppn_in_gnn', False)
self.use_supp_in_gnn = chain_cfg.get('use_supp_in_gnn', True)
# Make sure the deghosting config is consistent
if self.enable_ghost and not self.enable_charge_rescaling:
assert cfg['uresnet_ppn']['uresnet_lonely']['ghost']
if self.enable_ppn:
assert cfg['uresnet_ppn']['ppn']['ghost']
# Enforce basic logical order
# 1. Need semantics for everything
assert self.enable_uresnet
# 2. If PPN is used in GNN, need PPN
if self.enable_gnn_shower or self.enable_gnn_track:
assert self.enable_ppn or (not self.use_ppn_in_gnn)
# 3. Need at least one of two dense clusterer
# assert self.enable_dbscan or self.enable_cnn_clust
# 4. Check that SPICE and DBSCAN are not redundant
if self.enable_cnn_clust and self.enable_dbscan:
if 'spice' in cfg:
assert not (np.array(cfg['spice']['spice_fragment_manager']['cluster_classes']) == \
np.array(np.array(cfg['dbscan']['dbscan_fragment_manager']['cluster_classes'])).reshape(-1)).any()
else:
assert 'graph_spice' in cfg
assert set(cfg['dbscan']['dbscan_fragment_manager']['cluster_classes']).issubset(
set(cfg['graph_spice']['skip_classes']))
if self.enable_gnn_particle: # If particle fragment GNN is used, make sure it is not redundant
if self.enable_gnn_shower:
assert cfg['grappa_shower']['base']['node_type'] \
not in cfg['grappa_particle']['base']['node_type']
if self.enable_gnn_track:
assert cfg['grappa_track']['base']['node_type'] \
not in cfg['grappa_particle']['base']['node_type']
if self.enable_cosmic: assert self.enable_gnn_inter # Cosmic classification needs interaction clustering