abaco.ABaCo module

Contents

abaco.ABaCo module#

class abaco.ABaCo.BatchDiscriminator(*args: Any, **kwargs: Any)[source]#

Bases: Module

Define the Batch Discriminator for adversarial training

forward(x)[source]#

Computes the forward pass through the discriminator.

Parameters:

x (torch.Tensor) – Input to be passed through the model

Returns:

batch_class – Prediction of the batch of origin of the observation

Return type:

torch.Tensor

loss(pred, true)[source]#
class abaco.ABaCo.ConditionalEnsembleVAE(*args: Any, **kwargs: Any)[source]#

Bases: Module

Define a conditional Variational Autoencoder (VAE) model.

elbo(x)[source]#

Compute the ELBO for the given batch of data.

Parameters: x: [torch.Tensor]

A tensor of dimension (batch_size, feature_dim1, feature_dim2, …) n_samples: [int] Number of samples to use for the Monte Carlo estimate of the ELBO.

forward(x)[source]#

Compute the negative ELBO for the given batch of data.

Parameters: x: [torch.Tensor]

A tensor of dimension (batch_size, feature_dim1, feature_dim2)

get_posterior(x)[source]#

Given a set of points, compute the posterior distribution.

Parameters: x: [torch.Tensor]

Samples to pass to the encoder

kl_div_loss(x)[source]#
log_prob(z)[source]#
pca_posterior(x)[source]#

Given a set of points, compute the PCA of the posterior distribution.

Parameters: x: [torch.Tensor]

Samples to pass to the encoder

pca_prior(n_samples)[source]#

Given a number of samples, get the PCA from the sampling of the prior distribution.

sample(n_samples=1)[source]#

Sample from the model.

Parameters: n_samples: [int]

Number of samples to generate.

class abaco.ABaCo.ConditionalVAE(*args: Any, **kwargs: Any)[source]#

Bases: Module

elbo(x)[source]#

Compute the ELBO for the given batch of data.

Parameters:
  • x (torch.Tensor) – A tensor of dimension (batch_size, feature_dim1, feature_dim2, …)

  • n_samples (int) – Number of samples to use for the Monte Carlo estimate of the ELBO

Returns:

elbo – Evidence Lower Bound

Return type:

torch.Tensor

forward(x)[source]#

Compute the negative ELBO for the given batch of data.

Parameters:

x (torch.Tensor) – A tensor of dimension (batch_size, feature_dim1, feature_dim2)

Returns:

loss – Negative ELBO

Return type:

torch.Tensor

get_posterior(x)[source]#

Given a set of points, compute the samples of the posterior distribution.

Parameters:

x (torch.Tensor) – Samples to pass to the encoder to obtain the posterior distribution from

Returns:

z – Encoded points sampled from the posterior distribution

Return type:

torch.Tensor

kl_div_loss(x)[source]#

KL-divergence between the prior and the posterior distribution.

Parameters:

x (torch.Tensor) – Observation to be encoded in order to get the approximate posterior distribution

Returns:

kl_loss – KL-divergence loss

Return type:

torch.Tensor

log_prob(z)[source]#
pca_posterior(x)[source]#

Given a set of points, compute the PCA of the posterior distribution.

Parameters:

x (torch.Tensor) – Samples to pass to the encoder

pca_prior(n_samples)[source]#

Given a number of samples, get the PCA from the sampling of the prior distribution.

Parameters:

n_samples (int) – Number of samples from the prior distribution

sample(n_samples=1)[source]#

Sample from the model.

Parameters:

n_samples (int) – Number of samples to generate

Returns:

samples – Samples generated from the model

Return type:

torch.Tensor

class abaco.ABaCo.DMDecoder(*args: Any, **kwargs: Any)[source]#

Bases: Module

forward(z)[source]#

Computes the Dirichlet-Multinomial distribution over the data space. We are obtaining the concentration parameter of the distribution, hence the decoder output would have shape (batch, features).

Parameters:

z (torch.Tensor) – Latent space embeddings

Returns:

Dirichlet-Multinomial distribution use to calculate the likelihood using log_prob() method

Return type:

torch.distribution

class abaco.ABaCo.DirichletMultinomial(*args: Any, **kwargs: Any)[source]#

Bases: Distribution

Dirichlet-Multinomial distribution, defined by:

p ~ Dirichlet(alpha) x | p ~ Multinomial(total_count, p)

arg_constraints = {'concentration': torch.distributions.constraints.positive, 'total_count': torch.distributions.constraints.nonnegative_integer}#
has_rsample = False#
log_prob(x: torch.Tensor)[source]#

Defines the log_prob() function inherent from torch.distributions.Distribution.

Parameters:

x (torch.Tensor) – Observation used to calculate the log probability of the distribution fitting it

Returns:

log_p – Log probability of the distribution fitting the observation.

Return type:

torch.Tensor

sample(sample_shape=torch.Size)[source]#

Defines the sample() function, which is used to sample data points using the distribution parameters.

Parameters:

sample_shape – Amount of samples. If not given sample() method would sampled an observation from the distribution.

Returns:

Sample(s) from the Dirichlet Multinomial distribution.

Return type:

dm_sample

class abaco.ABaCo.MixtureOfGaussians(*args: Any, **kwargs: Any)[source]#

Bases: Distribution

A Mixture of Gaussians distribution with reparameterized sampling. Computation of gradients is possible.

arg_constraints = {}#
entropy()[source]#
has_rsample = True#
log_prob(value)[source]#

Compute the log probability of a given value. The log prob of a MoG is defined as:

log_prob(x) = log [sum_k (pi_k * N(x; mu_k, sigma_k^2)]

Where pi_k are the mixture probabilities.

Parameters:

value (torch.Tensor) – Value to be used to calculate the MoG probability of fitting it

Returns:

log_prob – Log probability of the distribution fitting the value

Return type:

torch.Tensor

property mean#

weighted sum of component means

Type:

Mixture mean

rsample(sample_shape=torch.Size)[source]#

Reparameterized sampling using the Gubel-softmax trick.

Parameters:

sample_shape – Amount of samples. If not given it would return a sample

Returns:

sample – Sample from the MoG distribution

Return type:

torch.Tensor

sample(sample_shape=torch.Size)[source]#

Sample from the MoG distribution.

variance()[source]#

Mixture variance: weighted sum of (variance + squared mean) minus squared mixture mean

class abaco.ABaCo.MoCPEncoder(*args: Any, **kwargs: Any)[source]#

Bases: Module

encode(x)[source]#
forward(x)[source]#

Computes the Categorical distribution over the latent space, sampling the component index and returning the parameters of the selected component.

Parameters:

x – [torch.Tensor]

class abaco.ABaCo.MoCPPrior(*args: Any, **kwargs: Any)[source]#

Bases: Module

cluster_loss()[source]#

Compute the clustering loss for the MoG prior. This loss encourages the components to be well separated by maximizing the pairwise KL divergence between the Gaussian components.

forward(k_ohe)[source]#

Return prior distribution, allowing for the computation of the KL-divergence by calling self.prior().

Returns:

[torch.distributions.Distribution]

Return type:

prior

class abaco.ABaCo.MoGEncoder(*args: Any, **kwargs: Any)[source]#

Bases: Module

det_encode(x)[source]#

Computes the encoded point without stochastic component.

Parameters:

x (torch.Tensor) – Tensor to be encoded

Returns:

z – Sum of MoG components means with proportion of the mixing probabilities of the encoded point

Return type:

torch.Tensor

encode(x)[source]#
forward(x)[source]#

Computes the MoG distribution over the latent space.

Parameters:

x (torch.Tensor)

Returns:

mog_dist – MoG distribution use to calculate the approximate posterior with rsample()

Return type:

MixtureOfGaussians

monte_carlo_encode(x, K=100)[source]#

Computes a Monte Carlo simulation of the same point to approximate the true posterior distribution.

Parameters:
  • x (torch.Tensor) – Tensor to be encoded

  • K (int) – Number of Monte Carlo iterations

Returns:

sample – Mean from all samples (i.e. expectation of the encoded point estimated through Monte Carlo simulations)

Return type:

torch.Tensor

class abaco.ABaCo.MoGPrior(*args: Any, **kwargs: Any)[source]#

Bases: Module

forward()[source]#

Return prior distribution, allowing for the computation of the KL-divergence by calling self.prior().

Returns:

prior

Return type:

torch.distributions.Distribution

class abaco.ABaCo.NBDecoder(*args: Any, **kwargs: Any)[source]#

Bases: Module

forward(z)[source]#

Computes the Negative Binomial distribution over the data space. What we are getting is the mean and the dispersion parameters, so it is needed a parameterization in order to get the NB distribution parameters: total_count (dispersion) and probs (dispersion/(dispersion + mean)).

Parameters:

z (torch.Tensor) – Latent space embeddings

Returns:

Negative Binomial distribution use to calculate the likelihood using log_prob() method

Return type:

torch.distribution

class abaco.ABaCo.NormalEncoder(*args: Any, **kwargs: Any)[source]#

Bases: Module

det_encode(x)[source]#

Computes the encoded point without stochastic component.

Parameters:

x (torch.Tensor) – Tensor to be encoded

Returns:

mu – Gaussian Noraml mean of the encoded point

Return type:

torch.Tensor

encode(x)[source]#
forward(x)[source]#

Computes the Gaussian Normal distribution over the latent space.

Parameters:

x (torch.Tensor) – Tensor to be encoded

Returns:

Gaussian Normal distribution use to calculate the posterior distribution with .rsample()

Return type:

torch.distribution

monte_carlo_encode(x, K=100)[source]#

Computes a Monte Carlo simulation of the same point to approximate the true posterior distribution.

Parameters:
  • x (torch.Tensor) – Tensor to be encoded

  • K (int) – Number of Monte Carlo iterations

Returns:

sample – Mean from all samples (i.e. expectation of the encoded point estimated through Monte Carlo simulations)

Return type:

torch.Tensor

class abaco.ABaCo.NormalPrior(*args: Any, **kwargs: Any)[source]#

Bases: Module

forward()[source]#

Return prior distribution. This allows the computation of KL-divergence by calling self.prior() in the VAE class.

Returns:

prior

Return type:

torch.distributions.Distribution

class abaco.ABaCo.SupervisedContrastiveLoss(*args: Any, **kwargs: Any)[source]#

Bases: Module

Contrastive loss definition

forward(latent_points, labels)[source]#

latent_points: [batch_size, d_z] labels: [batch_size]

class abaco.ABaCo.VAE(*args: Any, **kwargs: Any)[source]#

Bases: Module

decode(z)[source]#

Decode the latent space representation to the data space.

Parameters:

z – [torch.Tensor] Latent space tensor of shape (batch, d_z)

encode(x)[source]#

Forward pass through the VAE.

Parameters:

x – [torch.Tensor] Input data tensor of shape (batch, features)

get_posterior(x)[source]#

Given a set of points, compute the posterior distribution.

Parameters: x: [torch.Tensor]

Samples to pass to the encoder

kl_divergence(z, k_ohe=None)[source]#

Compute the KL-divergence between the prior and the posterior distribution.

Parameters:

z – [torch.Tensor] Latent space tensor of shape (batch, d_z)

pca_posterior(x)[source]#

Given a set of points, compute the PCA of the posterior distribution.

Parameters: x: [torch.Tensor]

Samples to pass to the encoder

class abaco.ABaCo.VMMPrior(*args: Any, **kwargs: Any)[source]#

Bases: Module

cluster_loss()[source]#

Compute the clustering loss for the MoG prior. This loss encourages the components to be well separated by maximizing the pairwise KL divergence between the Gaussian components.

forward(k_ohe)[source]#

Return prior distribution, allowing for the computation of the KL-divergence by calling self.prior().

Parameters:
  • k_ohe – [torch.tensor] One-hot encoded tensor with the corresponding component the point belongs to.

  • b_ohe – [torch.tensor] One-hot encoded tensor with the corresponding batch label, necessary to append at the Encoder input.

  • encoder – [nn.Module] Encoder used for getting the centroid of the cluster by encoding the pseudo-input.

Returns:

[torch.distributions.Distribution]

Return type:

prior

sample_from_dataloader()[source]#
class abaco.ABaCo.VampPriorMixtureConditionalEnsembleVAE(*args: Any, **kwargs: Any)[source]#

Bases: Module

Define a VampPrior Variational Autoencoder model.

forward(x)[source]#
get_posterior(x)[source]#

Given a set of points, compute the posterior distribution.

Parameters: x: [torch.Tensor]

Samples to pass to the encoder

get_prior()[source]#
log_prob(z)[source]#
pca_posterior(x)[source]#

Given a set of points, compute the PCA of the posterior distribution.

Parameters: x: [torch.Tensor]

Samples to pass to the encoder

pca_prior(n_samples)[source]#

Given a number of samples, get the PCA from the sampling of the prior distribution.

sample(n_samples)[source]#
sample_from_dataloader(data_loader)[source]#
class abaco.ABaCo.VampPriorMixtureConditionalVAE(*args: Any, **kwargs: Any)[source]#

Bases: Module

Define a VampPrior Variational Autoencoder model. Class is used for the baseline application of ABaCo model.

forward(x)[source]#
get_posterior(x)[source]#

Given a set of points, compute the posterior distribution.

Parameters: x: [torch.Tensor]

Samples to pass to the encoder

get_prior()[source]#
log_prob(z)[source]#
pca_posterior(x)[source]#

Given a set of points, compute the PCA of the posterior distribution.

Parameters: x: [torch.Tensor]

Samples to pass to the encoder

pca_prior(n_samples)[source]#

Given a number of samples, get the PCA from the sampling of the prior distribution.

sample(n_samples)[source]#
sample_from_dataloader(data_loader)[source]#
class abaco.ABaCo.ZIDM(*args: Any, **kwargs: Any)[source]#

Bases: Distribution

Zero-inflated Dirichlet-Multinomial (ZIDM) distribution, mixture of structural zeros per-category with a Dirichlet-Multinomial core.

arg_constraints = {}#
has_rsample = False#
log_prob(x: torch.Tensor)[source]#

Defines the log_prob() function inherent from torch.distributions.Distribution.

Parameters:

x (torch.Tensor) – Observation used to calculate the log probability of the distribution fitting it

Returns:

log_zidm – Log probability of the distribution fitting the observation

Return type:

torch.Tensor

sample(sample_shape=torch.Size)[source]#

Defines the sample() function, which is used to sample data points using the distribution parameters.

Parameters:

sample_shape – Amount of samples. If not given sample() method would sampled an observation from the distribution.

Returns:

Sample(s) from the Zero-inflated Dirichlet Multinomial distribution.

Return type:

zinb_sample

class abaco.ABaCo.ZIDMDecoder(*args: Any, **kwargs: Any)[source]#

Bases: Module

forward(z)[source]#

Computes the Zero-inflated Dirichlet-Multinomial distribution over the data space. We are obtaining the zero probability and concentration parameter of the distribution, hence the decoder output would have shape (batch, 2*features).

Parameters:

z (torch.Tensor) – Latent space embeddings

Returns:

Zero-inflated Dirichlet-Multinomial distribution use to calculate the likelihood using log_prob() method

Return type:

torch.distribution

class abaco.ABaCo.ZINB(*args: Any, **kwargs: Any)[source]#

Bases: Distribution

Zero-inflated Negative Binomial (ZINB) distribution definition. The ZINB distribution is defined by the following:
For x = 0:

ZINB.log_prob(0) = log(pi + (1 - pi) * NegativeBinomial(0 | r, p))

For x > 0:

ZINB.log_prob(x) = log(1 - pi) + NegativeBinomial(x |r, p).log_prob(x)

arg_constraints = {}#
has_rsample = False#
log_prob(x)[source]#

Defines the log_prob() function inherent from torch.distributions.Distribution.

Parameters:

x (torch.Tensor) – Observation used to calculate the log probability of the distribution fitting it.

Returns:

log_p – Log probability of the model fitting the observation.

Return type:

torch.Tensor

sample(sample_shape=torch.Size)[source]#
Defines the sample() function, which is used to sample data points using the distribution parameters.
For x = 0:

Binary mask for zero-inflated values

For x > 0:

Sample from Negative Binomial distribution

Parameters:

sample_shape – Amount of samples. If not given sample() method would sampled an observation from the distribution.

Returns:

Sample(s) from the Zero-inflated Negative Binomial distribution.

Return type:

zinb_sample

class abaco.ABaCo.ZINBDecoder(*args: Any, **kwargs: Any)[source]#

Bases: Module

forward(z)[source]#

Computes the Zero-inflated Negative Binomial distribution over the data space. What we are getting is the zero probability, mean and the dispersion parameters, so it is needed a parameterization in order to get the NB distribution parameters: total_count (dispersion) and probs (dispersion/(dispersion + mean)).

Parameters:

z (torch.Tensor) – Latent space embeddings

Returns:

Zero-inflated Negative Binomial distribution use to calculate the likelihood using log_prob() method

Return type:

torch.distribution

abaco.ABaCo.abaco_recon(model, device: torch.device, data: pandas.DataFrame, dataloader: torch.utils.data.DataLoader, sample_label: str, batch_label: str, bio_label, seed=42, det_encode=False, monte_carlo=100)[source]#

Function used to reconstruct data using trained ABaCo model.

Parameters:
  • model – Trained ABaCo model.

  • device (torch.device) – Device to run the model on, e.g., “cuda” or “cpu”.

  • data (pd.DataFrame) – DataFrame containing the data to be reconstructed.

  • dataloader (torch.utils.data.DataLoader) – Pytorch DataLoader for the data to be reconstructed.

  • sample_label (str) – Column name in the DataFrame that contains unique ids for the observations/samples.

  • batch_label (str) – Column name in the DataFrame that contains ids for the batch/factor groupings to be corrected by abaco. e.g. dates of sample analysis

  • bio_label (str) – Column name in the DataFrame that contains biological groupings where there is the biological/experimental factor variation for abaco to retain when correcting batch effect e.g., experimental condition

  • seed (int, optional) – Random seed for reproducibility. Default is 42.

  • det_encode (bool, optional) – If True, use deterministic encoding. Default is False.

  • monte_carlo (int, optional) – Number of Monte Carlo samples to use for reconstruction. Default is 100. Setting at 1 is the same as just sampling from the final ZINB distribution obtained from the trained model

Returns:

otu_corrected_pd – DataFrame containing the reconstructed data with batch and biological labels.

Return type:

pd.DataFrame

abaco.ABaCo.abaco_recon_ensemble(model, device, data, dataloader, sample_label, batch_label, bio_label, seed=42, det_encode=False, monte_carlo=100)[source]#

Function used to reconstruct data using trained ABaCo model.

abaco.ABaCo.abaco_run(dataloader: torch.utils.data.DataLoader, n_batches: int, n_bios: int, device: torch.device, input_size: int, new_pre_train: bool = False, seed: int = 42, d_z: int = 16, prior: str = 'VMM', count: bool = True, pre_epochs: int = 2000, post_epochs: int = 2000, kl_cycle: bool = True, smooth_annealing: bool = False, encoder_net: list = [1024, 512, 256], decoder_net: list = [256, 512, 1024], vae_act_func=torch.nn.ReLU, disc_net: list = [256, 128, 64], disc_act_func=torch.nn.ReLU, disc_loss_type: str = 'CrossEntropy', w_elbo: float = 1.0, beta: float = 20.0, w_disc: float = 1.0, w_adv: float = 1.0, w_contra: float = 10.0, temp: float = 0.1, w_cycle: float = 0.1, vae_pre_lr: float = 0.001, vae_post_lr: float = 0.0001, disc_lr: float = 1e-05, adv_lr: float = 1e-05)[source]#

Function to run the ABaCo model training.

Parameters:
  • dataloader (torch.utils.data.DataLoader) – Pytorch dataLoader for the training data.

  • n_batches (int) – Number of batches in the dataset. For example, if samples were sequenced in 5 batches (e.g., 5 different dates) then batches = 5.

  • n_bios (int) – Number of labels or (potential) clusters based on biological variance. For example, if 2 experimental conditions (e.g., control and treatment) then n_bios = 2.

  • device (torch.device) – Device to run the model on, e.g., “cuda” or “cpu”.

  • input_size (int) – Number of features in the input data, columns. For example, if the input is a gene expression matrix with 1000 genes, then input_size = 1000.

  • new_pre_train (bool) – If True, use the new pre-training method with adversarial training and contrastive loss.

  • seed (int) – Random seed for reproducibility.

  • d_z (int) – Dimensionality of the latent space. For example, if d_z = 16, then the latent space will have 16 dimensions.

  • prior (str) – Prior distribution used. Baseline is “VMM” (VampPrior Mixture Model). Options are “VMM” “MoG” (Mixture of Gaussians), or “Normal”.

  • count (bool) – If True, the model will use a zero-inflated negative binomial (ZINB) decoder. If False, it will use a zero-inflated Dirichlet (ZIDirichlet) decoder.

  • pre_epochs (int) – Number of epochs for first phase of ABaCo: data reconstruction. Default is 2000.

  • post_epochs (int) – Number of epochs for second phase of ABaCo: batch correction. Default is 2000.

  • kl_cycle (bool) – If True, the model will use a KL divergence cycle loss during second phase of ABaCo (batch correction). If False, cross loss 0

  • smooth_annealing (bool) – Slow batch masking during ABaCo second phase (batch correction) to avoid exploding gradients. Default is False.

  • encoder_net (list) – List of integers defining the architecture of the encoder. Each integer is a layer size. For example, [1024, 512, 256] means the encoder will have three layers with 1024, 512, and 256 neurons respectively.

  • decoder_net (list) – List of integers defining the architecture of the decoder. Each integer is a layer size. For example, [256, 512, 1024] means the decoder will have three layers with 256, 512, and 1024 neurons respectively.

  • vae_act_func (nn.Module) – Activation function for the VAE encoder and decoder. Default is nn.ReLU().

  • disc_net (list) – List of integers defining the architecture of the discriminator. Each integer is a layer size. For example, [256, 128, 64] means the discriminator will have three layers with 256, 128, and 64 neurons respectively.

  • disc_act_func (nn.Module) – Activation function for the discriminator. Default is nn.ReLU().

  • disc_loss_type (str) – Type of loss function for the discriminator. Options are “CrossEntropy” or “Uniform”. Default is “CrossEntropy”.

  • w_elbo (float) – Weight of the ELBO loss in the pre-training phase. Default is 1.0

  • beta (float) – KL-divergence coefficient: higher value yields bigger penalization from the prior distribution during training. Default is 20.0.

  • w_disc (float) – Weight of the discriminator loss in the pre-training phase. Default is 1.0

  • w_adv (float) – Weight of the adversarial loss in the pre-training phase. Default is 1.0

  • w_contra (float) – Contrastive learning power. Higher value yields a higher separation of biological groups at the latent space. Sometimes higher is better.

  • w_cycle (float) – Higher value leads to more unstability during ABaCo second phase. Default is 0.1.

  • temp (float) – Temperature for the contrastive loss. Default is 0.1.

  • vae_pre_lr (float) – Learning rate for first phase of ABaCo. Default is 1e-3.

  • vae_post_lr (float) – Learning rate for second phase (batch correction). Default is 1e-4.

  • disc_lr (float) – Learning rate for the batch discriminator. Default is 1e-5.

  • adv_lr (float) – Adversarial learning rate: to the encoder, if batch effect is on latent space. Default is 1e-5.

Returns:

Trained ABaCo model

Return type:

vae

abaco.ABaCo.abaco_run_ensemble(dataloader, n_batches, n_bios, device, input_size, seed=42, d_z=16, prior='VMM', count=True, pre_epochs=2000, post_epochs=2000, kl_cycle=True, smooth_annealing=False, encoder_net=[1024, 512, 256], n_dec=5, decoder_net=[256, 512, 1024], vae_act_func=torch.nn.ReLU, disc_net=[256, 128, 64], disc_act_func=torch.nn.ReLU, disc_loss_type='CrossEntropy', w_elbo=1.0, beta=20.0, w_disc=1.0, w_adv=1.0, w_contra=10.0, temp=0.1, vae_pre_lr=0.001, vae_post_lr=0.0001, disc_lr=1e-05, adv_lr=1e-05)[source]#

Full ABaCo run with default setting

abaco.ABaCo.adversarial_loss(pred_logits: torch.Tensor, true_labels: torch.Tensor, loss_type: str = 'CrossEntropy')[source]#

Compute adversarial loss for generator (to fool discriminator).

Parameters:
  • pred_logits – [batch_size, n_classes] raw discriminator outputs.

  • true_labels – [batch_size] long tensor of class indices.

  • loss_type – “CrossEntropy” or “Uniform”.

Returns:

A scalar loss (negated for generator on CrossEntropy).

abaco.ABaCo.contour_plot(samples, n_levels=10, x_offset=5, y_offset=5, alpha=0.8)[source]#

Given an array computes the contour plot

Parameters:

samples – [np.array] An array with (,2) dimensions

abaco.ABaCo.kl_mog_mog(p, q)#

Approximation of the KL-divergence among two different Mixture of Gaussians distributions.

Parameters:
Returns:

kl – KL-divergence between p and q

Return type:

torch.Tensor

abaco.ABaCo.kl_zinb_zinb(p, q)#

Approximation of the KL-divergence among two different Zero-inflated Negative Binomial distributions.

Parameters:
  • p (ZINB) – Zero-inflated negative binomial distribution 1

  • q (ZINB) – Zero-inflated negative binomial distribution 2

Returns:

kl – KL-divergence between p and q

Return type:

torch.Tensor

abaco.ABaCo.log_normal_diag(x, mu, log_var, reduction=None, dim=None)[source]#
class abaco.ABaCo.metaABaCo(*args: Any, **kwargs: Any)[source]#

Bases: Module

batch_correct(train_loader, vae_optimizer, disc_optimizer, adv_optimizer, w_disc=1.0, w_adv=1.0, w_elbo_nll=1.0, w_elbo_kl=1.0, w_bio_penalty=1.0, w_cluster_penalty=1.0)[source]#

Train the conditional VAE model for batch correction. This is trained after VAE prior parameters are defined,

batch_mask(train_loader, decoder_optimizer, smooth_annealing=True, cycle_reg=None, w_elbo_nll=1.0, w_cycle=0.001)[source]#

Pre-trained VAE will now have frozen encoder and batch labels masked at the encoder.

correct(seed=None, mask=True)[source]#
fit(smooth_annealing=True, cycle_reg=None, seed=None, w_elbo_nll=1.0, w_elbo_kl=1.0, w_bio_penalty=1.0, w_cluster_penalty=1.0, w_cycle=0.001, w_disc=1.0, w_adv=1.0, phase_1_vae_lr=0.001, phase_2_vae_lr=0.001, phase_3_vae_lr=1e-06, disc_lr=0.001, adv_lr=0.001)[source]#
plot_pca_posterior(figsize=(14, 6), palette='tab10')[source]#

Get the plot of the first 2 principal components of the posterior distribution.

train_vae(train_loader, optimizer, w_elbo_nll=1.0, w_elbo_kl=1.0, w_bio_penalty=1.0, w_cluster_penalty=1.0)[source]#

Train the conditional VAE model. If clustering prior is used, penalization term is applied to increase sparsity of the clusters.

Parameters:
  • vae – [VAE] Variational Autoencoder model

  • train_loader – [torch.utils.data.DataLoader] DataLoader for the training data

  • optimizer – [torch.optim.Optimizer] Optimizer for training

  • epochs – [int] Number of training epochs

  • device – [str] Device to use for computations

abaco.ABaCo.new_pre_train_abaco(vae, vae_optim_pre, discriminator, disc_optim, adv_optim, data_loader, normal_epochs, device, mog_epochs=0, w_contra=1.0, temp=0.1, w_elbo=1.0, w_disc=1.0, w_adv=1.0, disc_loss_type='CrossEntropy', n_disc_updates=1, label_smooth=0.1, normal=False, count=True)[source]#

PART OF THE NEW ABACO RUN WITH THREE PHASES. TWO OF THEM ARE COMPUTED HERE. Pre-training of conditional VAE with contrastive loss and adversarial mixing in latent space.

abaco.ABaCo.pre_train_abaco(vae, vae_optim_pre, discriminator, disc_optim, adv_optim, data_loader, epochs, device, w_contra=1.0, temp=0.1, w_elbo=1.0, w_disc=1.0, w_adv=1.0, disc_loss_type='CrossEntropy', n_disc_updates=1, label_smooth=0.1, normal=False, count=True)[source]#

Pre-training of conditional VAE with contrastive loss and adversarial mixing in latent space.

abaco.ABaCo.pre_train_abaco_ensemble(vae, vae_optim_pre, discriminator, disc_optim, adv_optim, data_loader, epochs, device, w_contra=1.0, temp=0.1, w_elbo=1.0, w_disc=1.0, w_adv=1.0, disc_loss_type='CrossEntropy', n_disc_updates=1, label_smooth=0.1, normal=False, count=True)[source]#

Pre-training of conditional VAE with contrastive loss and adversarial mixing in latent space.

abaco.ABaCo.train_abaco(vae, vae_optim_post, data_loader, epochs, device, w_elbo=1.0, w_cycle=1.0, cycle='KL', smooth_annealing=False)[source]#

This function trains a pre-trained ABaCo cVAE decoder but applies masking to batch labels so information passed solely depends on the latent space which had batch mixing

abaco.ABaCo.train_abaco_ensemble(vae, vae_optim_post, data_loader, epochs, device, w_elbo=1.0, w_cycle=1.0, cycle='KL', smooth_annealing=False)[source]#

This function trains a pre-trained ABaCo cVAE decoder but applies masking to batch labels so information passed solely depends on the latent space which had batch mixing