Source code for athena.neighborhood.base_estimators

# %%
import networkx as nx
import pandas as pd
from pandas.api.types import CategoricalDtype
import numpy as np
from networkx import Graph
import os
import time
from astropy.stats import RipleysKEstimator
import logging

logging.basicConfig(level=logging.INFO)

from ..utils.general import make_iterable
from .utils import get_node_interactions, get_interaction_score, permute_labels
from collections import Counter


# %%

def _infiltration_local_deprecated(G: Graph,
                        interaction1=('tumor', 'immune'),
                        interaction2=('immune', 'immune')):

    ids = np.unique(interaction1, interaction2)
    nodes_inter1 = [node for node in G.nodes if G.nodes[node]['attr'] in ids]
    nodes_inter1 = [node for node in G.nodes if G.nodes[node]['attr'] in interaction1]
    nodes_inter2 = [node for node in G.nodes if G.nodes[node]['attr'] in interaction2]

    for node in ids:
        neigh = G[node]
        counts = Counter([G.nodes[i]['attr'] for i in neigh])

def _infiltration_local(G: Graph,
                        interaction1=('tumor', 'immune'),
                        interaction2=('immune', 'immune')):

    ids = np.unique(interaction1, interaction2)
    nodes = [node for node in G.nodes if G.nodes[node]['attr'] in ids]
    for node in nodes:
        neigh = G[node]
        subG = G.subgraph()

    pass

def _infiltration(node_interactions: pd.DataFrame, interaction1=('tumor', 'immune'),
                  interaction2=('immune', 'immune')) -> float:
    """
    Compute infiltration score between two species.

    Args:
        node_interactions: Dataframe with columns `source_label` and `target_label` that specifies interactions.
        interaction1: labels of enumerator interaction
        interaction2: labels of denominator interaction

    Notes:
        The infiltration score is computed as #interactions1 / #interactions2.

    Returns:
        Interaction score

    """
    nint = node_interactions  # verbose

    (a1, a2), (b1, b2) = interaction1, interaction2
    num = nint[(nint.source_label == a1) & (nint.target_label == a2)].shape[0]
    denom = nint[(nint.source_label == b1) & (nint.target_label == b2)].shape[0]

    return num / denom if denom > 0 else np.nan  # TODO: np.inf or np.nan


[docs]class Interactions: """ Estimator to quantify interaction strength between different species in the sample. """ VALID_MODES = ['classic', 'histoCAT', 'proportion'] VALID_PREDICTION_TYPES = ['pvalue', 'observation', 'diff']
[docs] def __init__(self, so, spl: str, attr: str = 'meta_id', mode: str = 'classic', n_permutations: int = 500, random_seed=None, alpha: float = .01, graph_key: str = 'knn'): """Estimator to quantify interaction strength between different species in the sample. Args: so: SpatialOmics spl: Sample for which to compute the interaction strength attr: Categorical feature in SpatialOmics.obs to use for the grouping mode: One of {classic, histoCAT, proportion}, see notes n_permutations: Number of permutations to compute p-values and the interactions strength score (mode diff) random_seed: Random seed for permutations alpha: Threshold for significance graph_key: Specifies the graph representation to use in so.G[spl] if `local=True`. Notes: classic and histoCAT are python implementations of the corresponding methods pubished by the Bodenmiller lab at UZH. The proportion method is similar to the classic method but normalises the score by the number of edges and is thus bound [0,1]. """ self.so = so self.spl: str = spl self.graph_key = graph_key self.g: Graph = so.G[spl][graph_key] self.attr: str = attr self.data: pd.Series = so.obs[spl][attr] self.mode: str = mode self.n_perm: int = int(n_permutations) self.random_seed = random_seed if random_seed else so.random_seed self.rng = np.random.default_rng(random_seed) self.alpha: float = alpha self.fitted: bool = False # set dtype categories of data to attributes that are in the data self.data = self.data.astype(CategoricalDtype(categories=self.data.unique(), ordered=False)) # path where h0 models would be self.path = os.path.expanduser(f'~/.cache/spatialHeterogeneity/h0-models/') self.h0_file = f'{spl}_{attr}_{graph_key}_{mode}.pkl' self.h0 = None
[docs] def fit(self, prediction_type: str = 'observation', try_load: bool = True) -> None: """Compute the interactions scores for the sample. Args: prediction_type: One of {observation, pvalue, diff}, see Notes try_load: load pre-computed permutation results if available Returns: Notes: `observation`: computes the observed interaction strength in the sample `pvalue`: computes the P-value of a two-sided t-test for the interactions strength based on the random permutations `diff`: computes the difference between observed and average interaction strength (across permutations) """ if prediction_type not in self.VALID_PREDICTION_TYPES: raise ValueError( f'invalid `prediction_type` {prediction_type}. Available modes are {self.VALID_PREDICTION_TYPES}') self.prediction_type = prediction_type # extract observed interactions if self.mode == 'classic': relative_freq, observed = False, False elif self.mode == 'histoCAT': relative_freq, observed = False, True elif self.mode == 'proportion': relative_freq, observed = True, False else: raise ValueError(f'invalid mode {self.mode}. Available modes are {self.VALID_MODES}') node_interactions = get_node_interactions(self.g, self.data) obs_interaction = get_interaction_score(node_interactions, relative_freq=relative_freq, observed=observed) self.obs_interaction = obs_interaction.set_index(['source_label', 'target_label']) if not prediction_type == 'observation': if try_load: if os.path.isdir(self.path) and self.h0_file in os.listdir(self.path): logging.info( f'loading h0 for {self.spl}, graph type {self.graph_key} and mode {self.mode}') self.h0 = pd.read_pickle(os.path.join(self.path, self.h0_file)) # if try_load was not successful if self.h0 is None: logging.info( f'generate h0 for {self.spl}, graph type {self.graph_key} and mode {self.mode} and attribute {self.attr}') self.generate_h0(relative_freq=relative_freq, observed=observed, save=True) self.fitted = True
[docs] def predict(self) -> pd.DataFrame: """Predict interactions strengths of observations. Returns: A dataframe with the interaction results. """ if self.prediction_type == 'observation': return self.obs_interaction elif self.prediction_type == 'pvalue': # TODO: Check p-value computation data_perm = pd.concat((self.obs_interaction, self.h0), axis=1) data_perm.fillna(0, inplace=True) data_pval = pd.DataFrame(index=data_perm.index) # see h0_models_analysis.py for alterantive p-value computation data_pval['score'] = self.obs_interaction.score data_pval['perm_mean'] = data_perm.apply(lambda x: np.mean(x[1:]), axis=1, raw=True) data_pval['perm_std'] = data_perm.apply(lambda x: np.std(x[1:]), axis=1, raw=True) data_pval['perm_median'] = data_perm.apply(lambda x: np.median(x[1:]), axis=1, raw=True) data_pval['p_gt'] = data_perm.apply(lambda x: np.sum(x[1:] >= x[0]) / self.n_perm, axis=1, raw=True) data_pval['p_lt'] = data_perm.apply(lambda x: np.sum(x[1:] <= x[0]) / self.n_perm, axis=1, raw=True) data_pval['perm_n'] = data_perm.apply(lambda x: self.n_perm, axis=1, raw=True) data_pval['p'] = data_pval.apply(lambda x: x.p_gt if x.p_gt <= x.p_lt else x.p_lt, axis=1) data_pval['sig'] = data_pval.apply(lambda x: x.p < self.alpha, axis=1) data_pval['attraction'] = data_pval.apply(lambda x: x.p_gt <= x.p_lt, axis=1) data_pval['sigval'] = data_pval.apply(lambda x: np.sign((x.attraction - .5) * x.sig), axis=1) return data_pval elif self.prediction_type == 'diff': data_perm = pd.concat((self.obs_interaction, self.h0), axis=1) data_perm.fillna(0, inplace=True) data_pval = pd.DataFrame(index=data_perm.index) # see h0_models_analysis.py for alterantive p-value computation data_pval['score'] = self.obs_interaction.score data_pval['perm_mean'] = data_perm.apply(lambda x: np.mean(x[1:]), axis=1, raw=True) data_pval['perm_std'] = data_perm.apply(lambda x: np.std(x[1:]), axis=1, raw=True) data_pval['perm_median'] = data_perm.apply(lambda x: np.median(x[1:]), axis=1, raw=True) data_pval['diff'] = (data_pval['score'] - data_pval['perm_mean']) return data_pval else: raise ValueError( f'invalid `prediction_type` {self.prediction_type}. Available modes are {self.VALID_PREDICTION_TYPES}')
[docs] def generate_h0(self, relative_freq, observed, save=True): connectivity = get_node_interactions(self.g).reset_index(drop=True) res_perm, durations = [], [] for i in range(self.n_perm): tic = time.time() data = permute_labels(self.data, self.rng) source_label = data.loc[connectivity.source].values.ravel() target_label = data.loc[connectivity.target].values.ravel() # create pd.Series and node_interaction pd.DataFrame source_label = pd.Series(source_label, name='source_label', dtype=self.data.dtype) target_label = pd.Series(target_label, name='target_label', dtype=self.data.dtype) df = pd.concat((connectivity, source_label, target_label), axis=1) # get interaction count perm = get_interaction_score(df, relative_freq=relative_freq, observed=observed) perm['permutation_id'] = i # save result res_perm.append(perm) # stats toc = time.time() durations.append(toc - tic) if (i + 1) % 10 == 0: print(f'{time.asctime()}: {i + 1}/{self.n_perm}, duration: {np.mean(durations):.2f}) sec') print( f'{time.asctime()}: Finished, duration: {np.sum(durations) / 60:.2f} min ({np.mean(durations):.2f}sec/it)') h0 = pd.concat(res_perm) self.h0 = pd.pivot(h0, index=['source_label', 'target_label'], columns='permutation_id', values='score') # create folders if not os.path.isdir(self.path): os.makedirs(self.path) self.h0.to_pickle(os.path.join(self.path, self.h0_file))
[docs]class RipleysK():
[docs] def __init__(self, so, spl: str, id, attr: str): """Compute Ripley's K for a given sample and group. Args: so: SpatialOmics spl: Sample for which to compute the interaction strength id: The category in the categorical feature `attr`, for which Ripley's K should be computed attr: Categorical feature in SpatialOmics.obs to use for the grouping graph_key: Specifies the graph representation to use in so.G[spl] if `local=True`. """ self.id = id self.area = so.spl.loc[spl].area self.width = so.spl.loc[spl].width self.height = so.spl.loc[spl].height self.rkE = RipleysKEstimator(area=float(self.area), # we need to cast since the implementation checks for type int/float and does not recognise np.int64 x_max=float(self.width), x_min=0, y_max=float(self.height), y_min=0) df = so.obs[spl][['x', 'y', attr]] self.df = df[df[attr] == id]
[docs] def fit(self): pass
[docs] def predict(self, radii: list, correction: str='ripley', mode: str='K'): """Estimate Ripley's K Args: radii: List of radiis for which Ripley's K is computed correction: Correction method to use to correct for boarder effects, see [1]. mode: {K, csr-deviation}. If `K`, Ripley's K is estimated, with `csr-deviation` the deviation from a poission process is computed. Returns: Ripley's K estimates Notes: .. [1] https://docs.astropy.org/en/stable/stats/ripley.html """ if radii is None: radii = np.linspace(0, min(self.height, self.width) / 2, 10) radii = make_iterable(radii) # if we have no observations of the given id, K is zero if len(self.df) > 0: K = self.rkE(data=self.df[['x', 'y']], radii=radii, mode=correction) else: K = np.zeros_like(radii) if mode == 'K': res = K elif mode == 'csr-deviation': L = np.sqrt(K / np.pi) # transform, to stabilise variance res = L - radii res = pd.Series(res, index=radii) return res
[docs] def csr_deviation(self, radii, correction='ripley') -> np.ndarray: """ Compute deviation from random poisson process. Args: radii: List of radiis for which Ripley's K is computed correction: Correction method to use to correct for boarder effects, see [1]. Returns: """ # http://doi.wiley.com/10.1002/9781118445112.stat07751 radii = make_iterable(radii) K = self.rkE(data=self.df[['x', 'y']], radii=radii, mode=correction) L = np.sqrt(K / np.pi) # transform, to stabilise variance dev = L - radii return dev