mlreco.models.layers.cluster_cnn.graph_spice_embedder module

class mlreco.models.layers.cluster_cnn.graph_spice_embedder.GraphSPICEEmbedder(cfg, name='graph_spice_embedder')[source]

Bases: mlreco.models.layers.common.uresnet_layers.UResNet

MODULES = ['network_base', 'uresnet', 'graph_spice_embedder']
__init__(cfg, name='graph_spice_embedder')[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

__module__ = 'mlreco.models.layers.cluster_cnn.graph_spice_embedder'
training: bool
get_embeddings(input)[source]

point_cloud is a list of length minibatch size (assumes mbs = 1) point_cloud[0] has 3 spatial coordinates + 1 batch coordinate + 1 feature label has shape (point_cloud.shape[0] + 5*num_labels, 1) label contains segmentation labels for each point + coords of gt points

Returns

encoder features at each spatial resolution. - feature_dec: decoder features at each spatial resolution.

Return type

  • feature_enc

forward(input)[source]

Train time forward