cellcharter.tl.TRVAE

cellcharter.tl.TRVAE#

class cellcharter.tl.TRVAE(adata, condition_key=None, conditions=None, hidden_layer_sizes=(256, 64), latent_dim=10, dr_rate=0.05, use_mmd=True, mmd_on='z', mmd_boundary=None, recon_loss='nb', beta=1, use_bn=False, use_ln=True)#

scArches’s trVAE model adapted to image-based proteomics data.

The last ReLU layer of the neural network is removed to allow for continuous and real output values

Parameters:
  • adata (~anndata.AnnData) – Annotated data matrix. Has to be count data for ‘nb’ and ‘zinb’ loss and normalized log transformed data for ‘mse’ loss.

  • condition_key (String) – column name of conditions in adata.obs data frame.

  • conditions (List) – List of Condition names that the used data will contain to get the right encoding when used after reloading.

  • hidden_layer_sizes (List) – A list of hidden layer sizes for encoder network. Decoder network will be the reversed order.

  • latent_dim (Integer) – Bottleneck layer (z) size.

  • dr_rate (Float) – Dropout rate applied to all layers, if dr_rate==0 no dropout will be applied.

  • use_mmd (Boolean) – If ‘True’ an additional MMD loss will be calculated on the latent dim. ‘z’ or the first decoder layer ‘y’.

  • mmd_on (String) – Choose on which layer MMD loss will be calculated on if ‘use_mmd=True’: ‘z’ for latent dim or ‘y’ for first decoder layer.

  • mmd_boundary (Integer or None) – Choose on how many conditions the MMD loss should be calculated on. If ‘None’ MMD will be calculated on all conditions.

  • recon_loss (String) – Definition of Reconstruction-Loss-Method, ‘mse’, ‘nb’ or ‘zinb’.

  • beta (Float) – Scaling Factor for MMD loss

  • use_bn (Boolean) – If True batch normalization will be applied to layers.

  • use_ln (Boolean) – If True layer normalization will be applied to layers.

Methods table#

load(dir_path[, adata, map_location])

Instantiate a model from the saved output.

Methods#

classmethod TRVAE.load(dir_path, adata=None, map_location=None)#

Instantiate a model from the saved output.

Parameters:
  • dir_path (str) – Path to saved outputs.

  • adata (Optional[AnnData] (default: None)) – AnnData object. If None, will check for and load anndata saved with the model.

  • map_location (Optional[str] (default: None)) – Location where all tensors should be loaded (e.g., torch.device('cpu'))

Returns:

Model with loaded state dictionaries.