scQUEST.abnormality module
Summary
Classes:
Estimator to quantify the abnormality of a cell's expression profile. |
|
pytorch_lightning Abnormality module |
|
Default AE as implemented in [Wagner2019] |
Reference
- class DefaultAE(n_in, hidden=(10, 2, 10), bias=True, activation=ReLU(), activation_last=Sigmoid(), seed=None)[source]
Bases:
torch.nn.modules.module.Module
Default AE as implemented in [Wagner2019]
- __init__(n_in, hidden=(10, 2, 10), bias=True, activation=ReLU(), activation_last=Sigmoid(), seed=None)[source]
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- __doc__ = '\n Default AE as implemented in [Wagner2019]_\n '
- __module__ = 'scQUEST.abnormality'
- class AbnormalityLitModule(*args, **kwargs)[source]
Bases:
scQUEST.utils.LitModule
pytorch_lightning Abnormality module
- __doc__ = 'pytorch_lightning Abnormality module'
- __module__ = 'scQUEST.abnormality'
- class Abnormality(n_in=None, model=None, loss_fn=None, metrics=None, seed=None)[source]
Bases:
scQUEST.utils.Estimator
Estimator to quantify the abnormality of a cell’s expression profile. Abnormality is defined as the average reconstruction error of the autoencoder trained on a reference (normal) cell population.
- Parameters
n_in – number of feature for estimator
model – Model used to train estimator
torch.Module
orpytorch_lightning.Module
loss_fn – Loss function used for optimization
metrics – Metrics tracked during test time
Note
The abnormality model (Abn.model) predicts the abnormality a.k.a. reconstruction error, i.e. \(X-F(X)\) where \(F\) is the autoencoder (saved to ad.layers[‘abnormality’]). On the other hand, the base torch model (Abn.model.model) predicts the reconstruction, i.e. \(F(X)\).
- __init__(n_in=None, model=None, loss_fn=None, metrics=None, seed=None)[source]
Base estimator class
- Parameters
n_in – number of feature for estimator
model – Model used to train estimator
torch.Module
orpytorch_lightning.Module
loss_fn – Loss function used for optimization
metrics – Metrics tracked during test time
seed – Seed for model weight initialisation
- fit(ad=None, layer=None, datamodule=None, max_epochs=100, callbacks=None, seed=None, **kwargs)[source]
Fit abnormality estimator (autoencoder). Given the cell-expression profile given in ad.X or ad.layer[layer], an autoencoder is fitted. By default the given data is randomly split 90/10 in training and test set. If you wish to customize training provide a datamodule with the given train/validation/test splits.
- Parameters
ad (
Optional
[AnnData
]) – AnnData object to fitlayer (
Optional
[str
]) – layer in ad.layers to use instead of ad.Xdatamodule (
Optional
[LightningDataModule
]) – pytorch lightning data module with custom configurations of train, val and test splitspreprocessing – list of processors (
Preprocessor
) that should be applied to the datasetearly_stopping – configured
EarlyStopping
classmax_epochs (
int
) – maximum epochs for which the model is trainedcallbacks (
Optional
[list
]) – additional pytorch_lightning callbacks
- Return type
- Returns
None
- predict(ad, layer=None, inplace=True)[source]
Predict abnormality of each cell-feature as the difference between target and reconstruction (y-pred).
- static aggregate(ad, agg_fun='mse', key='abnormality', layer='abnormality')[source]
Aggregate the high-dimensional (number of features) reconstruction error of each cell.
- __doc__ = "Estimator to quantify the abnormality of a cell's expression profile. Abnormality is defined as the average\n reconstruction error of the autoencoder trained on a reference (normal) cell population.\n\n Args:\n n_in: number of feature for estimator\n model: Model used to train estimator :class:`.torch.Module` or :class:`.pytorch_lightning.Module`\n loss_fn: Loss function used for optimization\n metrics: Metrics tracked during test time\n\n Note:\n The abnormality model (`Abn.model`) predicts the abnormality a.k.a. reconstruction error, i.e. :math:`X-F(X)`\n where :math:`F` is the autoencoder (saved to `ad.layers['abnormality']`). On the other hand, the base torch\n model (`Abn.model.model`) predicts the reconstruction, i.e. :math:`F(X)`.\n\n "
- __module__ = 'scQUEST.abnormality'