mlreco.models.layers.common.gnn_full_chain module¶
-
class
mlreco.models.layers.common.gnn_full_chain.FullChainGNN(cfg)[source]¶ Bases:
torch.nn.modules.module.ModuleGNN section of the full chain.
-
MODULES= ['grappa_shower', 'grappa_track', 'grappa_inter', 'grappa_shower_loss', 'grappa_track_loss', 'grappa_inter_loss', 'full_chain_loss', 'spice', 'spice_loss', 'fragment_clustering', 'chain', 'dbscan', ('uresnet_ppn', ['uresnet_lonely', 'ppn'])]¶
-
__init__(cfg)[source]¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
run_gnn(grappa, input, result, clusts, labels, kwargs={})[source]¶ Generic function to group in one place the common code to run a GNN model.
- Parameters
grappa (-) –
input (-) –
result (-) –
clusts (-) –
labels (-) –
kwargs (-) –
- Returns
- Return type
None (modifies the result dict in place)
-
select_particle_in_group(result, counts, b, particles, part_primary_ids, node_pred, group_pred, fragments)[source]¶ Merge fragments into particle instances, retain primary fragment id of each group
-
get_all_fragments(result, input)[source]¶ Given geometric or CNN clustering results and (optional) true fragment labels, return true or predicted fragments
-
run_fragment_gnns(result, input)[source]¶ Run all fragment-level GNN models.
Shower GNN
Track GNN
Particle GNN (optional?)
-
forward(input)[source]¶ Input can be either of the following: - input data only - input data, label clustering in this order - input data, label segmentation, label clustering in this order (when deghosting is enabled, label segmentation is needed to adapt label clustering properly)
- Parameters
input (list of np.ndarray) –
-
__module__= 'mlreco.models.layers.common.gnn_full_chain'¶
-
training: bool¶
-
-
class
mlreco.models.layers.common.gnn_full_chain.FullChainLoss(cfg)[source]¶ Bases:
torch.nn.modules.loss._LossLoss for UResNet + PPN chain
-
__init__(cfg)[source]¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
__module__= 'mlreco.models.layers.common.gnn_full_chain'¶
-
forward(out, seg_label, ppn_label=None, cluster_label=None, kinematics_label=None, particle_graph=None, iteration=None)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
reduction: str¶
-