Source code for athena.graph_builder.contact_graph_builder

import networkx as nx
import numpy as np

from skimage.morphology import binary_dilation

from .base_graph_builder import BaseGraphBuilder
from .constants import EDGE_WEIGHT
from .constants import DILATION_KERNELS

from tqdm import tqdm


# %%
[docs]def dilation(args) -> list: """Compute dilation of a given object in a segmentation mask Args: args: masks, obj and dilation kernel Returns: """ mask, obj, kernel = args dilated_img = binary_dilation(mask == obj, kernel) cells = np.unique(mask[dilated_img]) cells = cells[cells != obj] # remove object itself cells = cells[cells != 0] # remove background return [(obj, cell, {EDGE_WEIGHT: 1}) for cell in cells]
# %%
[docs]class ContactGraphBuilder(BaseGraphBuilder): '''Contact-Graph class. Build contact graph based on pixel expansion of cell masks. '''
[docs] def __init__(self, config: dict): """Base-Graph Builder constructor Args: config: Dictionary containing a dict called `builder_params` that provides function call arguments to the build_topology function """ super().__init__(config)
def _build_topology(self, topo_data: dict, **kwargs) -> None: """Build topology using pixel expansion of segmentation masks provided by topo_data['mask']. Masks that overlap after expansion are connected in the graph. Args: topo_data: dict providing the segmentation mask under key 'mask' Returns: """ # type hints self.graph: nx.Graph params = self.config['builder_params'] mask = topo_data['mask'] if params['dilation_kernel'] in DILATION_KERNELS: kernel = DILATION_KERNELS[params['dilation_kernel']](params['radius']) else: raise ValueError( f'Specified dilate kernel not available. Please use one of {{{", ".join(DILATION_KERNELS)}}}.') # get object ids, 0 is background. objs = np.unique(mask) objs = objs[objs != 0] # compute neighbours edges = [] for obj in tqdm(objs): dilated_img = binary_dilation(mask == obj, kernel) cells = np.unique(mask[dilated_img]) cells = cells[cells != obj] # remove object itself cells = cells[cells != 0] # remove background edges.extend([(obj, cell, {EDGE_WEIGHT: 1}) for cell in cells]) self.graph.add_edges_from(edges)