Source code for mlreco.visualization.gnn

import numpy as np
import plotly.graph_objs as go

[docs]def scatter_clusters(voxels, labels, clusters, markersize=5, colorscale='Viridis'): """ Scatter plot of cluster voxels colored by cluster order Args: voxels (np.ndarray) : (N,3) List of voxel coordinate labels (np.ndarray) : (N) List of voxels labels clusts ([np.ndarray]): (C) List of arrays of voxel IDs in each cluster markersize (int) : Size of the voxel markersize colorscale (str) : Plotly color scale name Returns: [plotly.graph_objs.Scatter3d]: Scatter plot """ # first build voxel set voxels = np.concatenate([voxels[c] for c in clusters], axis=0) vfeats = np.concatenate([labels[c] for c in clusters], axis=0) _, cs = np.unique(vfeats, return_inverse=True) trace = go.Scatter3d(x=voxels[:,0], y=voxels[:,1], z=voxels[:,2], mode='markers', marker = dict( size = markersize, color = cs, colorscale = colorscale, opacity = 0.8 ), hovertext=vfeats) return [trace]
[docs]def network_topology(voxels, clusters, edge_index=[], clust_labels=[], edge_labels=[], mode='scatter', markersize=3, linewidth=2, colorscale='Inferno', cmin=None, cmax=None, coords_col=(1, 4), **kwargs): """ Network 3D topological representation Args: voxels (np.ndarray) : (N,3) List of voxel coordinate clusts ([np.ndarray]) : (C) List of arrays of voxel IDs in each cluster edge_index (np.ndarray) : (E,2) List of connections clust_labels (np.ndarray): (C) Cluster labels edge_labels (np.ndarray) : (E) Edge labels mode (str) : Draw mode ('sphere', 'cone', 'hull', 'scatter') markersize (int) : Size of the voxel markersize in pixels linewidth (int) : Width of the edge lines in pixels colorscale (str) : Plotly color scale name Returns: [plotly.graph_objs.Scatter3d]: (2) 3D Scatter plots of [nodes, edges] """ c1, c2 = coords_col if voxels.shape[1] > 3: voxels = voxels[:, c1:c2] # Define the arrays of node positions (barycenter of voxels in the cluster) pos = np.array([voxels[c].mean(0) for c in clusters]) # Define the node features (label, color) n = len(clusters) if not len(clust_labels): clust_labels = np.ones(n) if len(clust_labels) and not float(clust_labels[0]).is_integer(): node_labels = ['Instance ID: %d<br>Group ID: %0.3f<br>Centroid: (%0.1f, %0.1f, %0.1f)' % (i, clust_labels[i], pos[i,0], pos[i,1], pos[i,2]) for i in range(n)] else: node_labels = ['Instance ID: %d<br>Group ID: %d<br>Centroid: (%0.1f, %0.1f, %0.1f)' % (i, clust_labels[i], pos[i,0], pos[i,1], pos[i,2]) for i in range(n)] # Assert if there is edges to draw draw_edges = bool(len(edge_index)) # Adjust min and max color if cmin is None: cmin = min(clust_labels) if cmax is None: cmax = max(clust_labels) # Define the nodes and their connections graph_data = [] edge_vertices = [] if mode == 'sphere': # Define the node size as a sqrt function of the amount of voxels in the cluster node_sizes = 4*np.sqrt([len(c) for c in clusters]) # Define the nodes as sphere of radius proportional to the log of the cluster voxel content graph_data.append(go.Scatter3d(x = pos[:,0], y = pos[:,1], z = pos[:,2], name = 'Graph nodes', mode = 'markers', marker = dict( symbol = 'circle', size = node_sizes, color = clust_labels, opacity = 0.5, colorscale = colorscale, line = dict(color='rgb(50,50,50)', width=0.5) ), text = node_labels, hoverinfo = 'text', **kwargs) ) # Define the edges center to center if draw_edges: edge_vertices = np.concatenate([[pos[i], pos[j], [None, None, None]] for i, j in edge_index]) elif mode == 'ellipsoid': # Compute the points on a unit 3-ball phi = np.linspace(0, 2*np.pi, num=20) theta = np.linspace(-np.pi/2, np.pi/2, num=20) phi, theta = np.meshgrid(phi, theta) x = np.cos(theta) * np.sin(phi) y = np.cos(theta) * np.cos(phi) z = np.sin(theta) unit_points = np.hstack((x.reshape(-1,1), y.reshape(-1,1), z.reshape(-1,1))) # Compute the range of node values min_label, max_label = min(clust_labels), max(clust_labels) for i, c in enumerate(clusters): # Get the centroid and the covariance matrix centroid = np.mean(voxels[c], axis=0) covmat = np.dot((voxels[c]-centroid).T, (voxels[c]-centroid))/len(c) # Diagonalize the covariance matrix, get rotation matrix w, v = np.linalg.eigh(covmat) diag = np.zeros((3,3)) np.fill_diagonal(diag, np.sqrt(w)) rotmat = np.dot(diag, v.T) # Rotate the points into the basis of the covariance matrix radius = 1.75 # radius in chi value. For a 3D Gaussian, 1.75 corresponds to a ~50% probability content points = centroid + radius*np.dot(unit_points, rotmat) # Append Mesh3d object graph_data.append(go.Mesh3d(x = points[:,0], y = points[:,1], z = points[:,2], alphahull = 0, opacity = 0.5, color = get_object_color(min_label, max_label, clust_labels[i], colorscale), hoverinfo = 'text', text = node_labels[i], **kwargs), ) # Define the edges center to center if draw_edges: edge_vertices = np.concatenate([[pos[i], pos[j], [None, None, None]] for i, j in edge_index]) elif mode == 'cone': # Evaluate the cone parameters from sklearn.decomposition import PCA import numpy.linalg as LA pca = PCA() axes, spos, epos = np.empty((0, 3)), np.empty((0, 3)), np.empty((0, 3)) curv = lambda vox, pid, norm :\ np.sum([np.abs(np.dot((v-vox[pid])/LA.norm(v-vox[pid]), norm)) for i, v in enumerate(vox) if i != pid]) for c in clusters: # Get the voxels corresponding to the cluster vox = voxels[c] # Get the mean and the principal axis from the PCA pca.fit(vox) axis = np.array([pca.components_[0][i] for i in range(3)]) # Order the point along the principal axis, get the end points pa_vals = np.dot(vox, axis) pids = np.argmax(pa_vals), np.argmin(pa_vals) # Identify the starting point as the point with the largest curvature curvs = [curv(vox, pid, axis) for pid in pids] start_id, end_id = pids[np.argmax(curvs)], pids[np.argmin(curvs)] spos = np.concatenate((spos, [vox[start_id]])) epos = np.concatenate((epos, [vox[end_id]])) # Get the full length of the principal axis pa_dist = pa_vals[start_id]-pa_vals[end_id] # Append the cone parameters axes = np.concatenate((axes, [2.*pa_dist*axis])) # Compute plotly's internal vector scale to undo it... vector_scale = np.inf for i, p in enumerate(spos): u = axes[i] if i > 0: vector_scale = min(vector_scale, 2*LA.norm(p2-p) / (LA.norm(u2) + LA.norm(u))) p2 = p u2 = u # Add a graph with a cone per cluster graph_data.append(go.Cone(x = spos[:,0], y = spos[:,1], z = spos[:,2], u = axes[:,0], v = axes[:,1], w = axes[:,2], name = 'Graph node cones', opacity = 0.5, sizeref = 0.5/vector_scale, showscale = False, anchor = 'tip')) # Add a graph with the starting points graph_data.append(go.Scatter3d(x=spos[:,0], y=spos[:,1], z=spos[:,2], name = 'Graph node starts', mode ='markers', marker = dict( symbol = 'circle', color = clust_labels, size = 5, colorscale = colorscale ), text = node_labels, hoverinfo = 'text', **kwargs) ) # Join end points of primary cones to starting points of secondary cones for e in edge_index: edge_vertices = np.concatenate([[epos[i], spos[j], [None, None, None]] for i, j in edge_index]) elif mode == 'hull': # Compute the range of node values min_label, max_label = min(clust_labels), max(clust_labels) # For each cluster, add the convex hull of all its voxels graph_data += [go.Mesh3d(alphahull =10.0, name = 'Graph nodes', x = voxels[c][:,0], y = voxels[c][:,1], z = voxels[c][:,2], color = get_object_color(min_label, max_label, clust_labels[i], colorscale), opacity = 0.3, text = node_labels[i], hoverinfo = 'text', **kwargs) for i, c in enumerate(clusters)] # Define the edges closest pixel to closest pixel import scipy as sp edge_vertices = [] for i, j in edge_index: vi, vj = voxels[clusters[i]], voxels[clusters[j]] d12 = sp.spatial.distance.cdist(vi, vj, 'euclidean') i1, i2 = np.unravel_index(np.argmin(d12), d12.shape) edge_vertices.append([vi[i1], vj[i2], [None, None, None]]) if draw_edges: edge_vertices = np.concatenate(edge_vertices) elif mode == 'scatter': # Simply draw all the voxels of each cluster, using labels as color cids = np.full(len(voxels), -1) for i, c in enumerate(clusters): cids[c] = i mask = np.where(cids != -1)[0] colors = [clust_labels[i] for i in cids[mask]] node_labels = [node_labels[i] for i in cids[mask]] graph_data = [go.Scatter3d(x = voxels[mask][:,0], y = voxels[mask][:,1], z = voxels[mask][:,2], mode = 'markers', name = 'Graph nodes', marker = dict( symbol = 'circle', color = colors, colorscale = colorscale, cmin = cmin, cmax = cmax, size = markersize ), text = node_labels, hoverinfo = 'text', **kwargs )] # Define the edges closest pixel to closest pixel if draw_edges: import scipy as sp edge_vertices = [] for i, j in edge_index: vi, vj = voxels[clusters[i]], voxels[clusters[j]] d12 = sp.spatial.distance.cdist(vi, vj, 'euclidean') i1, i2 = np.unravel_index(np.argmin(d12), d12.shape) edge_vertices.append([vi[i1], vj[i2], [None, None, None]]) edge_vertices = np.concatenate(edge_vertices) else: raise ValueError("Network topology mode not supported") # Initialize a graph that contains the edges if draw_edges: if not len(edge_labels): edge_labels = np.ones(len(edge_index)) edge_colors = np.concatenate([[edge_labels[i]]*3 for i in range(len(edge_index))]) graph_data.append(go.Scatter3d(x = edge_vertices[:,0], y = edge_vertices[:,1], z = edge_vertices[:,2], mode = 'lines', name = 'Graph edges', line = dict( color = edge_colors, width = linewidth, colorscale = 'Picnic', cmin = 0, cmax = 1 ), hoverinfo = 'none')) # Return return graph_data
[docs]def network_schematic(clusters, edge_index, clust_labels=[], edge_labels=[], linewidth=1, colorscale='Inferno'): """ Network 2D schematic representation Args: clusts ([np.ndarray]) : (C) List of arrays of voxel IDs in each cluster edge_index (np.ndarray) : (E,2) List of connections clust_labels (np.ndarray): (C) Node labels edge_labels (np.ndarray) : (E) Edge labels linewidth (int) : Width of the edge lines in pixels colorscale (str) : Plotly color scale name Returns: [plotly.graph_objs.Scatter]: (2) Scatter plots of [nodes, edges] """ # Get the cluster sizes (will determine the node size) sizes = np.array([len(c) for c in clusters]) node_sizes = sizes * 100./sizes.max() # Define the node features (label, color) n = len(clusters) if not len(clust_labels): clust_labels = np.zeros(n) node_labels = ['Cluster ID: %d<br>Cluster label: %0.3f<br>Cluster size: %d' % (i, clust_labels[i], sizes[i]) for i in range(n)] # Define the node positions (primaries on the left, secondaries on the right) pos = np.array([[l, i] for i, l in enumerate(clust_labels)]) # Define the nodes as sphere of radius proportional to the log of the cluster voxel content graph_data = [] graph_data.append(go.Scatter(x = pos[:,0], y = pos[:,1], mode = 'markers', name = 'Graph nodes', marker = dict( color = clust_labels, size = node_sizes, colorscale = colorscale, reversescale = True ), text = node_labels, hoverinfo = 'text')) # Initialize the edges (one graph per edge to allow for multiple edge colors) if len(edge_index): if not len(edge_labels): edge_vertices = np.concatenate([[pos[i], pos[j], [None, None]] for i, j in edge_index]) graph_data.append(go.Scatter(x = edge_vertices[:,0], y = edge_vertices[:,1], mode = 'lines', name = 'Graph edges', line = dict( color = 'black', width = linewidth ), hoverinfo = 'none')) else: for k, e in enumerate(edge_index): i, j = e graph_data.append(go.Scatter(x = [pos[i,0], pos[j,0]], y = [pos[i,1], pos[j,1]], mode = 'lines', name = 'Graph edges', line = dict( color = 'rgb({0:0.2f}, {0:0.2f}, {0:0.2f})'.format(255*(1-edge_labels[k])), width = linewidth ), hoverinfo = 'none', showlegend = False)) graph_data[-1]['showlegend'] = True return graph_data
[docs]def get_object_color(min, max, val, colorscale): """ Get color given the value of an object and a plotly colorscale (if multiple objects are drawn, their colors is arbitrary and does not follow the color value given to the object) Args: min (double) : Minimum value of the range of values to be drawn max (double) : Maxmimum value of the range of values to be drawn val (double) : Value of the object colorscale (string or list): Plotly colorscale (either name of user defined list) Returns: str: Plotly color """ # If the colorscale is a string, look for it in plotly express if isinstance(colorscale, str): import plotly.express as px colorscale = getattr(px.colors.sequential, colorscale) colorscale = [[i/(len(colorscale)-1), c] for i, c in enumerate(colorscale)] # Get the value adjusted to the value range if (max-min) > 0: frac_val = (val-min)/(max-min) else: frac_val = 0.5 # Find the color ID if frac_val == 0: color_id = 0 elif frac_val == 1: color_id = len(colorscale)-1 else: cs_limits = [color[0] for color in colorscale] color_id = np.where(cs_limits/frac_val > 1)[0][0]-1 return colorscale[color_id][1]