mlreco.models.layers.cluster_cnn.graph_spice_embedder module¶
-
class
mlreco.models.layers.cluster_cnn.graph_spice_embedder.GraphSPICEEmbedder(cfg, name='graph_spice_embedder')[source]¶ Bases:
mlreco.models.layers.common.uresnet_layers.UResNet-
MODULES= ['network_base', 'uresnet', 'graph_spice_embedder']¶
-
__init__(cfg, name='graph_spice_embedder')[source]¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
__module__= 'mlreco.models.layers.cluster_cnn.graph_spice_embedder'¶
-
training: bool¶
-
get_embeddings(input)[source]¶ point_cloud is a list of length minibatch size (assumes mbs = 1) point_cloud[0] has 3 spatial coordinates + 1 batch coordinate + 1 feature label has shape (point_cloud.shape[0] + 5*num_labels, 1) label contains segmentation labels for each point + coords of gt points
- Returns
encoder features at each spatial resolution. - feature_dec: decoder features at each spatial resolution.
- Return type
feature_enc
-