Source code for athena.graph_builder.base_graph_builder

import abc

import networkx as nx
import numpy as np
import pandas as pd
from ..utils.tools.graph import df2node_attr
from abc import ABC

[docs]class BaseGraphBuilder(ABC):
[docs] def __init__(self, config: dict): """Base-Graph Builder constructor Args: config: Dictionary containing a dict called `builder_params` that provides function call arguments to the build_topology function """ self.config = config self.ndata = None self.edata = None self.graph = nx.Graph()
def __call__(self, ndata: pd.DataFrame, edata: pd.DataFrame = None, topo_data: dict = None) -> nx.Graph: """Builds graph Args: ndata: dataframe with node data, index is the node edata: dataframe with edge data, index specifies edges, i.e. (node1, node2) topo_data: dict with additional data for graph construction (necessary for contact graph) Returns: nx.Graph """ self.ndata = ndata self.edata = edata self._add_nodes() self._add_nodes_attr() if edata is None: self._build_topology(topo_data=topo_data) else: self._add_edges() self._add_edges_attr() return self.graph def _add_nodes(self): """Adds nodes in ndata to graph Returns: """ self.graph.add_nodes_from(self.ndata.index) def _add_nodes_attr(self) -> None: """Adds node attributes in ndata to graph Returns: """ attr = df2node_attr(self.ndata) nx.set_node_attributes(self.graph, attr) def _add_edges(self) -> None: """Adds edges in edata to graph Returns: """ self.graph.add_edges_from(self.edata.index) def _add_edges_attr(self) -> None: """Adds edge attributes in edata to graph Returns: """ attr = df2node_attr(self.edata) nx.set_edge_attributes(self.graph, attr) @abc.abstractmethod def _build_topology(self, **kwargs) -> None: """Builds graph topology. Implemented in subclasses. Args: **kwargs: Returns: """ raise NotImplementedError('Implemented in subclasses.') # Convenient method to build graph from cellmask
[docs] @classmethod def from_mask(cls, config: dict, mask: np.ndarray) -> nx.Graph: """Construct graph topology from segmentation masks. Args: config: config: Dictionary containing a dict called `builder_params` that provides function call arguments to the build_topology function mask: image file that provides the image segmentation Returns: nx.Graph """ # load required dependencies try: import numpy as np from skimage.io import imread from skimage.measure import regionprops_table except ImportError: raise ImportError( 'Please install the skimage: `conda install -c anaconda scikit-image`.') instance = cls(config) # extract location ndata = regionprops_table(mask, properties=['label', 'centroid']) ndata = pd.DataFrame.from_dict(ndata) ndata.columns = ['cell_id', 'y', 'x'] # NOTE: axis 0 is y and axis 1 is x ndata.set_index('cell_id', inplace=True) ndata.sort_index(axis=0, ascending=True, inplace=True) return instance(ndata, topo_data={'mask': mask})