# %%
import networkx as nx
import pandas as pd
from pandas.api.types import CategoricalDtype
import numpy as np
from networkx import Graph
import os
import time
from astropy.stats import RipleysKEstimator
import logging
logging.basicConfig(level=logging.INFO)
from ..utils.general import make_iterable
from .utils import get_node_interactions, get_interaction_score, permute_labels
from collections import Counter
# %%
def _infiltration_local_deprecated(G: Graph,
                        interaction1=('tumor', 'immune'),
                        interaction2=('immune', 'immune')):
    ids = np.unique(interaction1, interaction2)
    nodes_inter1 = [node for node in G.nodes if G.nodes[node]['attr'] in ids]
    nodes_inter1 = [node for node in G.nodes if G.nodes[node]['attr'] in interaction1]
    nodes_inter2 = [node for node in G.nodes if G.nodes[node]['attr'] in interaction2]
    for node in ids:
        neigh = G[node]
        counts = Counter([G.nodes[i]['attr'] for i in neigh])
def _infiltration_local(G: Graph,
                        interaction1=('tumor', 'immune'),
                        interaction2=('immune', 'immune')):
    ids = np.unique(interaction1, interaction2)
    nodes = [node for node in G.nodes if G.nodes[node]['attr'] in ids]
    for node in nodes:
        neigh = G[node]
        subG = G.subgraph()
    pass
def _infiltration(node_interactions: pd.DataFrame, interaction1=('tumor', 'immune'),
                  interaction2=('immune', 'immune')) -> float:
    """
    Compute infiltration score between two species.
    Args:
        node_interactions: Dataframe with columns `source_label` and `target_label` that specifies interactions.
        interaction1: labels of enumerator interaction
        interaction2: labels of denominator interaction
    Notes:
        The infiltration score is computed as #interactions1 / #interactions2.
    Returns:
        Interaction score
    """
    nint = node_interactions  # verbose
    (a1, a2), (b1, b2) = interaction1, interaction2
    num = nint[(nint.source_label == a1) & (nint.target_label == a2)].shape[0]
    denom = nint[(nint.source_label == b1) & (nint.target_label == b2)].shape[0]
    return num / denom if denom > 0 else np.nan  # TODO: np.inf or np.nan
[docs]class Interactions:
    """
    Estimator to quantify interaction strength between different species in the sample.
    """
    VALID_MODES = ['classic', 'histoCAT', 'proportion']
    VALID_PREDICTION_TYPES = ['pvalue', 'observation', 'diff']
[docs]    def __init__(self, so, spl: str, attr: str = 'meta_id', mode: str = 'classic', n_permutations: int = 500,
                 random_seed=None, alpha: float = .01, graph_key: str = 'knn'):
        """Estimator to quantify interaction strength between different species in the sample.
        Args:
            so: SpatialOmics
            spl: Sample for which to compute the interaction strength
            attr: Categorical feature in SpatialOmics.obs to use for the grouping
            mode: One of {classic, histoCAT, proportion}, see notes
            n_permutations: Number of permutations to compute p-values and the interactions strength score (mode diff)
            random_seed: Random seed for permutations
            alpha: Threshold for significance
            graph_key: Specifies the graph representation to use in so.G[spl] if `local=True`.
        Notes:
            classic and histoCAT are python implementations of the corresponding methods pubished by the Bodenmiller lab at UZH.
            The proportion method is similar to the classic method but normalises the score by the number of edges and is thus bound [0,1].
        """
        self.so = so
        self.spl: str = spl
        self.graph_key = graph_key
        self.g: Graph = so.G[spl][graph_key]
        self.attr: str = attr
        self.data: pd.Series = so.obs[spl][attr]
        self.mode: str = mode
        self.n_perm: int = int(n_permutations)
        self.random_seed = random_seed if random_seed else so.random_seed
        self.rng = np.random.default_rng(random_seed)
        self.alpha: float = alpha
        self.fitted: bool = False
        # set dtype categories of data to attributes that are in the data
        self.data = self.data.astype(CategoricalDtype(categories=self.data.unique(), ordered=False))
        # path where h0 models would be
        self.path = os.path.expanduser(f'~/.cache/spatialHeterogeneity/h0-models/')
        self.h0_file = f'{spl}_{attr}_{graph_key}_{mode}.pkl'
        self.h0 = None 
[docs]    def fit(self, prediction_type: str = 'observation', try_load: bool = True) -> None:
        """Compute the interactions scores for the sample.
        Args:
            prediction_type: One of {observation, pvalue, diff}, see Notes
            try_load: load pre-computed permutation results if available
        Returns:
        Notes:
            `observation`: computes the observed interaction strength in the sample
            `pvalue`: computes the P-value of a two-sided t-test for the interactions strength based on the random permutations
            `diff`: computes the difference between observed and average interaction strength (across permutations)
        """
        if prediction_type not in self.VALID_PREDICTION_TYPES:
            raise ValueError(
                f'invalid `prediction_type` {prediction_type}. Available modes are {self.VALID_PREDICTION_TYPES}')
        self.prediction_type = prediction_type
        # extract observed interactions
        if self.mode == 'classic':
            relative_freq, observed = False, False
        elif self.mode == 'histoCAT':
            relative_freq, observed = False, True
        elif self.mode == 'proportion':
            relative_freq, observed = True, False
        else:
            raise ValueError(f'invalid mode {self.mode}. Available modes are {self.VALID_MODES}')
        node_interactions = get_node_interactions(self.g, self.data)
        obs_interaction = get_interaction_score(node_interactions, relative_freq=relative_freq, observed=observed)
        self.obs_interaction = obs_interaction.set_index(['source_label', 'target_label'])
        if not prediction_type == 'observation':
            if try_load:
                if os.path.isdir(self.path) and self.h0_file in os.listdir(self.path):
                    logging.info(
                        f'loading h0 for {self.spl}, graph type {self.graph_key} and mode {self.mode}')
                    self.h0 = pd.read_pickle(os.path.join(self.path, self.h0_file))
            # if try_load was not successful
            if self.h0 is None:
                logging.info(
                    f'generate h0 for {self.spl}, graph type {self.graph_key} and mode {self.mode} and attribute {self.attr}')
                self.generate_h0(relative_freq=relative_freq, observed=observed, save=True)
        self.fitted = True 
[docs]    def predict(self) -> pd.DataFrame:
        """Predict interactions strengths of observations.
        Returns: A dataframe with the interaction results.
        """
        if self.prediction_type == 'observation':
            return self.obs_interaction
        elif self.prediction_type == 'pvalue':
            # TODO: Check p-value computation
            data_perm = pd.concat((self.obs_interaction, self.h0), axis=1)
            data_perm.fillna(0, inplace=True)
            data_pval = pd.DataFrame(index=data_perm.index)
            # see h0_models_analysis.py for alterantive p-value computation
            data_pval['score'] = self.obs_interaction.score
            data_pval['perm_mean'] = data_perm.apply(lambda x: np.mean(x[1:]), axis=1, raw=True)
            data_pval['perm_std'] = data_perm.apply(lambda x: np.std(x[1:]), axis=1, raw=True)
            data_pval['perm_median'] = data_perm.apply(lambda x: np.median(x[1:]), axis=1, raw=True)
            data_pval['p_gt'] = data_perm.apply(lambda x: np.sum(x[1:] >= x[0]) / self.n_perm, axis=1, raw=True)
            data_pval['p_lt'] = data_perm.apply(lambda x: np.sum(x[1:] <= x[0]) / self.n_perm, axis=1, raw=True)
            data_pval['perm_n'] = data_perm.apply(lambda x: self.n_perm, axis=1, raw=True)
            data_pval['p'] = data_pval.apply(lambda x: x.p_gt if x.p_gt <= x.p_lt else x.p_lt, axis=1)
            data_pval['sig'] = data_pval.apply(lambda x: x.p < self.alpha, axis=1)
            data_pval['attraction'] = data_pval.apply(lambda x: x.p_gt <= x.p_lt, axis=1)
            data_pval['sigval'] = data_pval.apply(lambda x: np.sign((x.attraction - .5) * x.sig), axis=1)
            return data_pval
        elif self.prediction_type == 'diff':
            data_perm = pd.concat((self.obs_interaction, self.h0), axis=1)
            data_perm.fillna(0, inplace=True)
            data_pval = pd.DataFrame(index=data_perm.index)
            # see h0_models_analysis.py for alterantive p-value computation
            data_pval['score'] = self.obs_interaction.score
            data_pval['perm_mean'] = data_perm.apply(lambda x: np.mean(x[1:]), axis=1, raw=True)
            data_pval['perm_std'] = data_perm.apply(lambda x: np.std(x[1:]), axis=1, raw=True)
            data_pval['perm_median'] = data_perm.apply(lambda x: np.median(x[1:]), axis=1, raw=True)
            data_pval['diff'] = (data_pval['score'] - data_pval['perm_mean'])
            return data_pval
        else:
            raise ValueError(
                f'invalid `prediction_type` {self.prediction_type}. Available modes are {self.VALID_PREDICTION_TYPES}') 
[docs]    def generate_h0(self, relative_freq, observed, save=True):
        connectivity = get_node_interactions(self.g).reset_index(drop=True)
        res_perm, durations = [], []
        for i in range(self.n_perm):
            tic = time.time()
            data = permute_labels(self.data, self.rng)
            source_label = data.loc[connectivity.source].values.ravel()
            target_label = data.loc[connectivity.target].values.ravel()
            # create pd.Series and node_interaction pd.DataFrame
            source_label = pd.Series(source_label, name='source_label', dtype=self.data.dtype)
            target_label = pd.Series(target_label, name='target_label', dtype=self.data.dtype)
            df = pd.concat((connectivity, source_label, target_label), axis=1)
            # get interaction count
            perm = get_interaction_score(df, relative_freq=relative_freq, observed=observed)
            perm['permutation_id'] = i
            # save result
            res_perm.append(perm)
            # stats
            toc = time.time()
            durations.append(toc - tic)
            if (i + 1) % 10 == 0:
                print(f'{time.asctime()}: {i + 1}/{self.n_perm}, duration: {np.mean(durations):.2f}) sec')
        print(
            f'{time.asctime()}: Finished, duration: {np.sum(durations) / 60:.2f} min ({np.mean(durations):.2f}sec/it)')
        h0 = pd.concat(res_perm)
        self.h0 = pd.pivot(h0, index=['source_label', 'target_label'], columns='permutation_id', values='score')
        # create folders
        if not os.path.isdir(self.path):
            os.makedirs(self.path)
        self.h0.to_pickle(os.path.join(self.path, self.h0_file))  
[docs]class RipleysK():
[docs]    def __init__(self, so, spl: str, id, attr: str):
        """Compute Ripley's K for a given sample and group.
        Args:
            so: SpatialOmics
            spl: Sample for which to compute the interaction strength
            id: The category in the categorical feature `attr`, for which Ripley's K should be computed
            attr: Categorical feature in SpatialOmics.obs to use for the grouping
            graph_key: Specifies the graph representation to use in so.G[spl] if `local=True`.
        """
        self.id = id
        self.area = so.spl.loc[spl].area
        self.width = so.spl.loc[spl].width
        self.height = so.spl.loc[spl].height
        self.rkE = RipleysKEstimator(area=float(self.area),
                                     # we need to cast since the implementation checks for type int/float and does not recognise np.int64
                                     x_max=float(self.width), x_min=0,
                                     y_max=float(self.height), y_min=0)
        df = so.obs[spl][['x', 'y', attr]]
        self.df = df[df[attr] == id] 
[docs]    def predict(self, radii: list, correction: str='ripley', mode: str='K'):
        """Estimate Ripley's K
        Args:
            radii: List of radiis for which Ripley's K is computed
            correction: Correction method to use to correct for boarder effects, see [1].
            mode: {K, csr-deviation}. If `K`, Ripley's K is estimated, with `csr-deviation` the deviation from a poission process is computed.
        Returns:
            Ripley's K estimates
        Notes:
            .. [1] https://docs.astropy.org/en/stable/stats/ripley.html
        """
        if radii is None:
            radii = np.linspace(0, min(self.height, self.width) / 2, 10)
        radii = make_iterable(radii)
        # if we have no observations of the given id, K is zero
        if len(self.df) > 0:
            K = self.rkE(data=self.df[['x', 'y']], radii=radii, mode=correction)
        else:
            K = np.zeros_like(radii)
        if mode == 'K':
            res = K
        elif mode == 'csr-deviation':
            L = np.sqrt(K / np.pi)  # transform, to stabilise variance
            res = L - radii
        res = pd.Series(res, index=radii)
        return res 
[docs]    def csr_deviation(self, radii, correction='ripley') -> np.ndarray:
        """
        Compute deviation from random poisson process.
        Args:
            radii: List of radiis for which Ripley's K is computed
            correction: Correction method to use to correct for boarder effects, see [1].
        Returns:
        """
        # http://doi.wiley.com/10.1002/9781118445112.stat07751
        radii = make_iterable(radii)
        K = self.rkE(data=self.df[['x', 'y']], radii=radii, mode=correction)
        L = np.sqrt(K / np.pi)  # transform, to stabilise variance
        dev = L - radii
        return dev