mlreco.models.layers.common.gnn_full_chain module

class mlreco.models.layers.common.gnn_full_chain.FullChainGNN(cfg)[source]

Bases: torch.nn.modules.module.Module

GNN section of the full chain.

MODULES = ['grappa_shower', 'grappa_track', 'grappa_inter', 'grappa_shower_loss', 'grappa_track_loss', 'grappa_inter_loss', 'full_chain_loss', 'spice', 'spice_loss', 'fragment_clustering', 'chain', 'dbscan', ('uresnet_ppn', ['uresnet_lonely', 'ppn'])]
__init__(cfg)[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

run_gnn(grappa, input, result, clusts, labels, kwargs={})[source]

Generic function to group in one place the common code to run a GNN model.

Parameters
  • grappa (-) –

  • input (-) –

  • result (-) –

  • clusts (-) –

  • labels (-) –

  • kwargs (-) –

Returns

Return type

None (modifies the result dict in place)

select_particle_in_group(result, counts, b, particles, part_primary_ids, node_pred, group_pred, fragments)[source]

Merge fragments into particle instances, retain primary fragment id of each group

get_all_fragments(result, input)[source]

Given geometric or CNN clustering results and (optional) true fragment labels, return true or predicted fragments

run_fragment_gnns(result, input)[source]

Run all fragment-level GNN models.

  1. Shower GNN

  2. Track GNN

  3. Particle GNN (optional?)

get_all_particles(frag_result, result, input)[source]
run_particle_gnns(result, input, frag_result)[source]
full_chain_gnn(result, input)[source]
forward(input)[source]

Input can be either of the following: - input data only - input data, label clustering in this order - input data, label segmentation, label clustering in this order (when deghosting is enabled, label segmentation is needed to adapt label clustering properly)

Parameters

input (list of np.ndarray) –

__module__ = 'mlreco.models.layers.common.gnn_full_chain'
training: bool
class mlreco.models.layers.common.gnn_full_chain.FullChainLoss(cfg)[source]

Bases: torch.nn.modules.loss._Loss

Loss for UResNet + PPN chain

__init__(cfg)[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

__module__ = 'mlreco.models.layers.common.gnn_full_chain'
forward(out, seg_label, ppn_label=None, cluster_label=None, kinematics_label=None, particle_graph=None, iteration=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reduction: str
mlreco.models.layers.common.gnn_full_chain.setup_chain_cfg(self, cfg, print_info=True)[source]

Prepare both FullChain and FullChainLoss

Make sure config is logically sound with some basic checks