abaco.ABaCo module

Contents

abaco.ABaCo module#

class abaco.ABaCo.BatchDiscriminator(net)[source]#

Bases: Module

Define the Batch Discriminator for adversarial training

forward(x)[source]#

Computes the forward pass through the discriminator.

Parameters:

x (torch.Tensor) – Input to be passed through the model

Returns:

batch_class – Prediction of the batch of origin of the observation

Return type:

torch.Tensor

loss(pred, true)[source]#
class abaco.ABaCo.ConditionalEnsembleVAE(prior, decoders: ModuleList, encoder, beta=1.0)[source]#

Bases: Module

Define a conditional Variational Autoencoder (VAE) model.

elbo(x)[source]#

Compute the ELBO for the given batch of data.

Parameters: x: [torch.Tensor]

A tensor of dimension (batch_size, feature_dim1, feature_dim2, …) n_samples: [int] Number of samples to use for the Monte Carlo estimate of the ELBO.

forward(x)[source]#

Compute the negative ELBO for the given batch of data.

Parameters: x: [torch.Tensor]

A tensor of dimension (batch_size, feature_dim1, feature_dim2)

get_posterior(x)[source]#

Given a set of points, compute the posterior distribution.

Parameters: x: [torch.Tensor]

Samples to pass to the encoder

kl_div_loss(x)[source]#
log_prob(z)[source]#
pca_posterior(x)[source]#

Given a set of points, compute the PCA of the posterior distribution.

Parameters: x: [torch.Tensor]

Samples to pass to the encoder

pca_prior(n_samples)[source]#

Given a number of samples, get the PCA from the sampling of the prior distribution.

sample(n_samples=1)[source]#

Sample from the model.

Parameters: n_samples: [int]

Number of samples to generate.

class abaco.ABaCo.ConditionalVAE(prior, decoder, encoder, beta=1.0)[source]#

Bases: Module

elbo(x)[source]#

Compute the ELBO for the given batch of data.

Parameters:
  • x (torch.Tensor) – A tensor of dimension (batch_size, feature_dim1, feature_dim2, …)

  • n_samples (int) – Number of samples to use for the Monte Carlo estimate of the ELBO

Returns:

elbo – Evidence Lower Bound

Return type:

torch.Tensor

forward(x)[source]#

Compute the negative ELBO for the given batch of data.

Parameters:

x (torch.Tensor) – A tensor of dimension (batch_size, feature_dim1, feature_dim2)

Returns:

loss – Negative ELBO

Return type:

torch.Tensor

get_posterior(x)[source]#

Given a set of points, compute the samples of the posterior distribution.

Parameters:

x (torch.Tensor) – Samples to pass to the encoder to obtain the posterior distribution from

Returns:

z – Encoded points sampled from the posterior distribution

Return type:

torch.Tensor

kl_div_loss(x)[source]#

KL-divergence between the prior and the posterior distribution.

Parameters:

x (torch.Tensor) – Observation to be encoded in order to get the approximate posterior distribution

Returns:

kl_loss – KL-divergence loss

Return type:

torch.Tensor

log_prob(z)[source]#
pca_posterior(x)[source]#

Given a set of points, compute the PCA of the posterior distribution.

Parameters:

x (torch.Tensor) – Samples to pass to the encoder

pca_prior(n_samples)[source]#

Given a number of samples, get the PCA from the sampling of the prior distribution.

Parameters:

n_samples (int) – Number of samples from the prior distribution

sample(n_samples=1)[source]#

Sample from the model.

Parameters:

n_samples (int) – Number of samples to generate

Returns:

samples – Samples generated from the model

Return type:

torch.Tensor

class abaco.ABaCo.DMDecoder(decoder_net, total_count, eps=1e-08)[source]#

Bases: Module

forward(z)[source]#

Computes the Dirichlet-Multinomial distribution over the data space. We are obtaining the concentration parameter of the distribution, hence the decoder output would have shape (batch, features).

Parameters:

z (torch.Tensor) – Latent space embeddings

Returns:

Dirichlet-Multinomial distribution use to calculate the likelihood using log_prob() method

Return type:

torch.distribution

class abaco.ABaCo.DirichletMultinomial(total_count: int, concentration: Tensor, validate_args=None)[source]#

Bases: Distribution

Dirichlet-Multinomial distribution, defined by:

p ~ Dirichlet(alpha) x | p ~ Multinomial(total_count, p)

arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'total_count': IntegerGreaterThan(lower_bound=0)}#
has_rsample = False#
log_prob(x: Tensor)[source]#

Defines the log_prob() function inherent from torch.distributions.Distribution.

Parameters:

x (torch.Tensor) – Observation used to calculate the log probability of the distribution fitting it

Returns:

log_p – Log probability of the distribution fitting the observation.

Return type:

torch.Tensor

sample(sample_shape=())[source]#

Defines the sample() function, which is used to sample data points using the distribution parameters.

Parameters:

sample_shape – Amount of samples. If not given sample() method would sampled an observation from the distribution.

Returns:

Sample(s) from the Dirichlet Multinomial distribution.

Return type:

dm_sample

support = IntegerGreaterThan(lower_bound=0)#
class abaco.ABaCo.MixtureOfGaussians(mixture_logits, means, stds, temperature=1e-05, validate_args=None)[source]#

Bases: Distribution

A Mixture of Gaussians distribution with reparameterized sampling. Computation of gradients is possible.

arg_constraints = {}#
entropy()[source]#

Returns entropy of distribution, batched over batch_shape.

Returns:

Tensor of shape batch_shape.

has_rsample = True#
log_prob(value)[source]#

Compute the log probability of a given value. The log prob of a MoG is defined as:

log_prob(x) = log [sum_k (pi_k * N(x; mu_k, sigma_k^2)]

Where pi_k are the mixture probabilities.

Parameters:

value (torch.Tensor) – Value to be used to calculate the MoG probability of fitting it

Returns:

log_prob – Log probability of the distribution fitting the value

Return type:

torch.Tensor

property mean#

weighted sum of component means

Type:

Mixture mean

rsample(sample_shape=())[source]#

Reparameterized sampling using the Gubel-softmax trick.

Parameters:

sample_shape – Amount of samples. If not given it would return a sample

Returns:

sample – Sample from the MoG distribution

Return type:

torch.Tensor

sample(sample_shape=())[source]#

Sample from the MoG distribution.

support = Real()#
variance()[source]#

Mixture variance: weighted sum of (variance + squared mean) minus squared mixture mean

class abaco.ABaCo.MoCPEncoder(encoder_net, n_comp)[source]#

Bases: Module

encode(x)[source]#
forward(x)[source]#

Computes the Categorical distribution over the latent space, sampling the component index and returning the parameters of the selected component.

Parameters:

x – [torch.Tensor]

class abaco.ABaCo.MoCPPrior(d_z, n_comp, multiplier=1.0)[source]#

Bases: Module

cluster_loss()[source]#

Compute the clustering loss for the MoG prior. This loss encourages the components to be well separated by maximizing the pairwise KL divergence between the Gaussian components.

forward(k_ohe)[source]#

Return prior distribution, allowing for the computation of the KL-divergence by calling self.prior().

Returns:

[torch.distributions.Distribution]

Return type:

prior

class abaco.ABaCo.MoGEncoder(encoder_net, n_comp)[source]#

Bases: Module

det_encode(x)[source]#

Computes the encoded point without stochastic component.

Parameters:

x (torch.Tensor) – Tensor to be encoded

Returns:

z – Sum of MoG components means with proportion of the mixing probabilities of the encoded point

Return type:

torch.Tensor

encode(x)[source]#
forward(x)[source]#

Computes the MoG distribution over the latent space.

Parameters:

x (torch.Tensor)

Returns:

mog_dist – MoG distribution use to calculate the approximate posterior with rsample()

Return type:

MixtureOfGaussians

monte_carlo_encode(x, K=100)[source]#

Computes a Monte Carlo simulation of the same point to approximate the true posterior distribution.

Parameters:
  • x (torch.Tensor) – Tensor to be encoded

  • K (int) – Number of Monte Carlo iterations

Returns:

sample – Mean from all samples (i.e. expectation of the encoded point estimated through Monte Carlo simulations)

Return type:

torch.Tensor

class abaco.ABaCo.MoGPrior(d_z, n_comp, multiplier=1.0)[source]#

Bases: Module

forward()[source]#

Return prior distribution, allowing for the computation of the KL-divergence by calling self.prior().

Returns:

prior

Return type:

torch.distributions.Distribution

class abaco.ABaCo.NBDecoder(decoder_net)[source]#

Bases: Module

forward(z)[source]#

Computes the Negative Binomial distribution over the data space. What we are getting is the mean and the dispersion parameters, so it is needed a parameterization in order to get the NB distribution parameters: total_count (dispersion) and probs (dispersion/(dispersion + mean)).

Parameters:

z (torch.Tensor) – Latent space embeddings

Returns:

Negative Binomial distribution use to calculate the likelihood using log_prob() method

Return type:

torch.distribution

class abaco.ABaCo.NormalEncoder(encoder_net)[source]#

Bases: Module

det_encode(x)[source]#

Computes the encoded point without stochastic component.

Parameters:

x (torch.Tensor) – Tensor to be encoded

Returns:

mu – Gaussian Noraml mean of the encoded point

Return type:

torch.Tensor

encode(x)[source]#
forward(x)[source]#

Computes the Gaussian Normal distribution over the latent space.

Parameters:

x (torch.Tensor) – Tensor to be encoded

Returns:

Gaussian Normal distribution use to calculate the posterior distribution with .rsample()

Return type:

torch.distribution

monte_carlo_encode(x, K=100)[source]#

Computes a Monte Carlo simulation of the same point to approximate the true posterior distribution.

Parameters:
  • x (torch.Tensor) – Tensor to be encoded

  • K (int) – Number of Monte Carlo iterations

Returns:

sample – Mean from all samples (i.e. expectation of the encoded point estimated through Monte Carlo simulations)

Return type:

torch.Tensor

class abaco.ABaCo.NormalPrior(d_z)[source]#

Bases: Module

forward()[source]#

Return prior distribution. This allows the computation of KL-divergence by calling self.prior() in the VAE class.

Returns:

prior

Return type:

torch.distributions.Distribution

class abaco.ABaCo.SupervisedContrastiveLoss(temp=0.1)[source]#

Bases: Module

Contrastive loss definition

forward(latent_points, labels)[source]#

latent_points: [batch_size, d_z] labels: [batch_size]

class abaco.ABaCo.VAE(encoder, decoder, prior, input_size, device)[source]#

Bases: Module

decode(z)[source]#

Decode the latent space representation to the data space.

Parameters:

z – [torch.Tensor] Latent space tensor of shape (batch, d_z)

encode(x)[source]#

Forward pass through the VAE.

Parameters:

x – [torch.Tensor] Input data tensor of shape (batch, features)

get_posterior(x)[source]#

Given a set of points, compute the posterior distribution.

Parameters: x: [torch.Tensor]

Samples to pass to the encoder

kl_divergence(z, k_ohe=None)[source]#

Compute the KL-divergence between the prior and the posterior distribution.

Parameters:

z – [torch.Tensor] Latent space tensor of shape (batch, d_z)

pca_posterior(x)[source]#

Given a set of points, compute the PCA of the posterior distribution.

Parameters: x: [torch.Tensor]

Samples to pass to the encoder

class abaco.ABaCo.VMMPrior(d_z, n_features, n_comp, n_batch, encoder, multiplier=1.0, dataloader=None)[source]#

Bases: Module

cluster_loss()[source]#

Compute the clustering loss for the MoG prior. This loss encourages the components to be well separated by maximizing the pairwise KL divergence between the Gaussian components.

forward(k_ohe)[source]#

Return prior distribution, allowing for the computation of the KL-divergence by calling self.prior().

Parameters:
  • k_ohe – [torch.tensor] One-hot encoded tensor with the corresponding component the point belongs to.

  • b_ohe – [torch.tensor] One-hot encoded tensor with the corresponding batch label, necessary to append at the Encoder input.

  • encoder – [nn.Module] Encoder used for getting the centroid of the cluster by encoding the pseudo-input.

Returns:

[torch.distributions.Distribution]

Return type:

prior

sample_from_dataloader()[source]#
class abaco.ABaCo.VampPriorMixtureConditionalEnsembleVAE(encoder, decoders: ModuleList, input_dim, batch_dim, n_comps, d_z, beta=1.0, data_loader=None)[source]#

Bases: Module

Define a VampPrior Variational Autoencoder model.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_posterior(x)[source]#

Given a set of points, compute the posterior distribution.

Parameters: x: [torch.Tensor]

Samples to pass to the encoder

get_prior()[source]#
log_prob(z)[source]#
pca_posterior(x)[source]#

Given a set of points, compute the PCA of the posterior distribution.

Parameters: x: [torch.Tensor]

Samples to pass to the encoder

pca_prior(n_samples)[source]#

Given a number of samples, get the PCA from the sampling of the prior distribution.

sample(n_samples)[source]#
sample_from_dataloader(data_loader)[source]#
class abaco.ABaCo.VampPriorMixtureConditionalVAE(encoder, decoder, input_dim, batch_dim, n_comps, d_z, beta=1.0, data_loader=None)[source]#

Bases: Module

Define a VampPrior Variational Autoencoder model. Class is used for the baseline application of ABaCo model.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_posterior(x)[source]#

Given a set of points, compute the posterior distribution.

Parameters: x: [torch.Tensor]

Samples to pass to the encoder

get_prior()[source]#
log_prob(z)[source]#
pca_posterior(x)[source]#

Given a set of points, compute the PCA of the posterior distribution.

Parameters: x: [torch.Tensor]

Samples to pass to the encoder

pca_prior(n_samples)[source]#

Given a number of samples, get the PCA from the sampling of the prior distribution.

sample(n_samples)[source]#
sample_from_dataloader(data_loader)[source]#
class abaco.ABaCo.ZIDM(dm: DirichletMultinomial, pi: Tensor, eps: float = 1e-08, validate_args=None)[source]#

Bases: Distribution

Zero-inflated Dirichlet-Multinomial (ZIDM) distribution, mixture of structural zeros per-category with a Dirichlet-Multinomial core.

arg_constraints = {}#
has_rsample = False#
log_prob(x: Tensor)[source]#

Defines the log_prob() function inherent from torch.distributions.Distribution.

Parameters:

x (torch.Tensor) – Observation used to calculate the log probability of the distribution fitting it

Returns:

log_zidm – Log probability of the distribution fitting the observation

Return type:

torch.Tensor

sample(sample_shape=())[source]#

Defines the sample() function, which is used to sample data points using the distribution parameters.

Parameters:

sample_shape – Amount of samples. If not given sample() method would sampled an observation from the distribution.

Returns:

Sample(s) from the Zero-inflated Dirichlet Multinomial distribution.

Return type:

zinb_sample

support = IntegerGreaterThan(lower_bound=0)#
class abaco.ABaCo.ZIDMDecoder(decoder_net, total_count, eps=1e-08)[source]#

Bases: Module

forward(z)[source]#

Computes the Zero-inflated Dirichlet-Multinomial distribution over the data space. We are obtaining the zero probability and concentration parameter of the distribution, hence the decoder output would have shape (batch, 2*features).

Parameters:

z (torch.Tensor) – Latent space embeddings

Returns:

Zero-inflated Dirichlet-Multinomial distribution use to calculate the likelihood using log_prob() method

Return type:

torch.distribution

class abaco.ABaCo.ZINB(nb, pi, validate_args=None)[source]#

Bases: Distribution

Zero-inflated Negative Binomial (ZINB) distribution definition. The ZINB distribution is defined by the following:
For x = 0:

ZINB.log_prob(0) = log(pi + (1 - pi) * NegativeBinomial(0 | r, p))

For x > 0:

ZINB.log_prob(x) = log(1 - pi) + NegativeBinomial(x |r, p).log_prob(x)

arg_constraints = {}#
has_rsample = False#
log_prob(x)[source]#

Defines the log_prob() function inherent from torch.distributions.Distribution.

Parameters:

x (torch.Tensor) – Observation used to calculate the log probability of the distribution fitting it.

Returns:

log_p – Log probability of the model fitting the observation.

Return type:

torch.Tensor

sample(sample_shape=())[source]#
Defines the sample() function, which is used to sample data points using the distribution parameters.
For x = 0:

Binary mask for zero-inflated values

For x > 0:

Sample from Negative Binomial distribution

Parameters:

sample_shape – Amount of samples. If not given sample() method would sampled an observation from the distribution.

Returns:

Sample(s) from the Zero-inflated Negative Binomial distribution.

Return type:

zinb_sample

support = IntegerGreaterThan(lower_bound=0)#
class abaco.ABaCo.ZINBDecoder(decoder_net)[source]#

Bases: Module

forward(z)[source]#

Computes the Zero-inflated Negative Binomial distribution over the data space. What we are getting is the zero probability, mean and the dispersion parameters, so it is needed a parameterization in order to get the NB distribution parameters: total_count (dispersion) and probs (dispersion/(dispersion + mean)).

Parameters:

z (torch.Tensor) – Latent space embeddings

Returns:

Zero-inflated Negative Binomial distribution use to calculate the likelihood using log_prob() method

Return type:

torch.distribution

abaco.ABaCo.abaco_recon(model, device: device, data: DataFrame, dataloader: DataLoader, sample_label: str, batch_label: str, bio_label, seed=42, det_encode=False, monte_carlo=100)[source]#

Function used to reconstruct data using trained ABaCo model.

Parameters:
  • model – Trained ABaCo model.

  • device (torch.device) – Device to run the model on, e.g., “cuda” or “cpu”.

  • data (pd.DataFrame) – DataFrame containing the data to be reconstructed.

  • dataloader (torch.utils.data.DataLoader) – Pytorch DataLoader for the data to be reconstructed.

  • sample_label (str) – Column name in the DataFrame that contains unique ids for the observations/samples.

  • batch_label (str) – Column name in the DataFrame that contains ids for the batch/factor groupings to be corrected by abaco. e.g. dates of sample analysis

  • bio_label (str) – Column name in the DataFrame that contains biological groupings where there is the biological/experimental factor variation for abaco to retain when correcting batch effect e.g., experimental condition

  • seed (int, optional) – Random seed for reproducibility. Default is 42.

  • det_encode (bool, optional) – If True, use deterministic encoding. Default is False.

  • monte_carlo (int, optional) – Number of Monte Carlo samples to use for reconstruction. Default is 100. Setting at 1 is the same as just sampling from the final ZINB distribution obtained from the trained model

Returns:

otu_corrected_pd – DataFrame containing the reconstructed data with batch and biological labels.

Return type:

pd.DataFrame

abaco.ABaCo.abaco_recon_ensemble(model, device, data, dataloader, sample_label, batch_label, bio_label, seed=42, det_encode=False, monte_carlo=100)[source]#

Function used to reconstruct data using trained ABaCo model.

abaco.ABaCo.abaco_run(dataloader: DataLoader, n_batches: int, n_bios: int, device: device, input_size: int, new_pre_train: bool = False, seed: int = 42, d_z: int = 16, prior: str = 'VMM', count: bool = True, pre_epochs: int = 2000, post_epochs: int = 2000, kl_cycle: bool = True, smooth_annealing: bool = False, encoder_net: list = [1024, 512, 256], decoder_net: list = [256, 512, 1024], vae_act_func=ReLU(), disc_net: list = [256, 128, 64], disc_act_func=ReLU(), disc_loss_type: str = 'CrossEntropy', w_elbo: float = 1.0, beta: float = 20.0, w_disc: float = 1.0, w_adv: float = 1.0, w_contra: float = 10.0, temp: float = 0.1, w_cycle: float = 0.1, vae_pre_lr: float = 0.001, vae_post_lr: float = 0.0001, disc_lr: float = 1e-05, adv_lr: float = 1e-05)[source]#

Function to run the ABaCo model training.

Parameters:
  • dataloader (torch.utils.data.DataLoader) – Pytorch dataLoader for the training data.

  • n_batches (int) – Number of batches in the dataset. For example, if samples were sequenced in 5 batches (e.g., 5 different dates) then batches = 5.

  • n_bios (int) – Number of labels or (potential) clusters based on biological variance. For example, if 2 experimental conditions (e.g., control and treatment) then n_bios = 2.

  • device (torch.device) – Device to run the model on, e.g., “cuda” or “cpu”.

  • input_size (int) – Number of features in the input data, columns. For example, if the input is a gene expression matrix with 1000 genes, then input_size = 1000.

  • new_pre_train (bool) – If True, use the new pre-training method with adversarial training and contrastive loss.

  • seed (int) – Random seed for reproducibility.

  • d_z (int) – Dimensionality of the latent space. For example, if d_z = 16, then the latent space will have 16 dimensions.

  • prior (str) – Prior distribution used. Baseline is “VMM” (VampPrior Mixture Model). Options are “VMM” “MoG” (Mixture of Gaussians), or “Normal”.

  • count (bool) – If True, the model will use a zero-inflated negative binomial (ZINB) decoder. If False, it will use a zero-inflated Dirichlet (ZIDirichlet) decoder.

  • pre_epochs (int) – Number of epochs for first phase of ABaCo: data reconstruction. Default is 2000.

  • post_epochs (int) – Number of epochs for second phase of ABaCo: batch correction. Default is 2000.

  • kl_cycle (bool) – If True, the model will use a KL divergence cycle loss during second phase of ABaCo (batch correction). If False, cross loss 0

  • smooth_annealing (bool) – Slow batch masking during ABaCo second phase (batch correction) to avoid exploding gradients. Default is False.

  • encoder_net (list) – List of integers defining the architecture of the encoder. Each integer is a layer size. For example, [1024, 512, 256] means the encoder will have three layers with 1024, 512, and 256 neurons respectively.

  • decoder_net (list) – List of integers defining the architecture of the decoder. Each integer is a layer size. For example, [256, 512, 1024] means the decoder will have three layers with 256, 512, and 1024 neurons respectively.

  • vae_act_func (nn.Module) – Activation function for the VAE encoder and decoder. Default is nn.ReLU().

  • disc_net (list) – List of integers defining the architecture of the discriminator. Each integer is a layer size. For example, [256, 128, 64] means the discriminator will have three layers with 256, 128, and 64 neurons respectively.

  • disc_act_func (nn.Module) – Activation function for the discriminator. Default is nn.ReLU().

  • disc_loss_type (str) – Type of loss function for the discriminator. Options are “CrossEntropy” or “Uniform”. Default is “CrossEntropy”.

  • w_elbo (float) – Weight of the ELBO loss in the pre-training phase. Default is 1.0

  • beta (float) – KL-divergence coefficient: higher value yields bigger penalization from the prior distribution during training. Default is 20.0.

  • w_disc (float) – Weight of the discriminator loss in the pre-training phase. Default is 1.0

  • w_adv (float) – Weight of the adversarial loss in the pre-training phase. Default is 1.0

  • w_contra (float) – Contrastive learning power. Higher value yields a higher separation of biological groups at the latent space. Sometimes higher is better.

  • w_cycle (float) – Higher value leads to more unstability during ABaCo second phase. Default is 0.1.

  • temp (float) – Temperature for the contrastive loss. Default is 0.1.

  • vae_pre_lr (float) – Learning rate for first phase of ABaCo. Default is 1e-3.

  • vae_post_lr (float) – Learning rate for second phase (batch correction). Default is 1e-4.

  • disc_lr (float) – Learning rate for the batch discriminator. Default is 1e-5.

  • adv_lr (float) – Adversarial learning rate: to the encoder, if batch effect is on latent space. Default is 1e-5.

Returns:

Trained ABaCo model

Return type:

vae

abaco.ABaCo.abaco_run_ensemble(dataloader, n_batches, n_bios, device, input_size, seed=42, d_z=16, prior='VMM', count=True, pre_epochs=2000, post_epochs=2000, kl_cycle=True, smooth_annealing=False, encoder_net=[1024, 512, 256], n_dec=5, decoder_net=[256, 512, 1024], vae_act_func=ReLU(), disc_net=[256, 128, 64], disc_act_func=ReLU(), disc_loss_type='CrossEntropy', w_elbo=1.0, beta=20.0, w_disc=1.0, w_adv=1.0, w_contra=10.0, temp=0.1, vae_pre_lr=0.001, vae_post_lr=0.0001, disc_lr=1e-05, adv_lr=1e-05)[source]#

Full ABaCo run with default setting

abaco.ABaCo.adversarial_loss(pred_logits: Tensor, true_labels: Tensor, loss_type: str = 'CrossEntropy')[source]#

Compute adversarial loss for generator (to fool discriminator).

Parameters:
  • pred_logits – [batch_size, n_classes] raw discriminator outputs.

  • true_labels – [batch_size] long tensor of class indices.

  • loss_type – “CrossEntropy” or “Uniform”.

Returns:

A scalar loss (negated for generator on CrossEntropy).

abaco.ABaCo.contour_plot(samples, n_levels=10, x_offset=5, y_offset=5, alpha=0.8)[source]#

Given an array computes the contour plot

Parameters:

samples – [np.array] An array with (,2) dimensions

abaco.ABaCo.kl_mog_mog(p, q)[source]#

Approximation of the KL-divergence among two different Mixture of Gaussians distributions.

Parameters:
Returns:

kl – KL-divergence between p and q

Return type:

torch.Tensor

abaco.ABaCo.kl_zinb_zinb(p, q)[source]#

Approximation of the KL-divergence among two different Zero-inflated Negative Binomial distributions.

Parameters:
  • p (ZINB) – Zero-inflated negative binomial distribution 1

  • q (ZINB) – Zero-inflated negative binomial distribution 2

Returns:

kl – KL-divergence between p and q

Return type:

torch.Tensor

abaco.ABaCo.log_normal_diag(x, mu, log_var, reduction=None, dim=None)[source]#
class abaco.ABaCo.metaABaCo(data, n_bios, bio_label, n_batches, batch_label, n_features, device, prior='MoG', pdist='ZINB', d_z=16, epochs=[1000, 2000, 2000], encoder_net=[512, 256, 128], decoder_net=[128, 256, 512], vae_act_fun=ReLU(), disc_net=[128, 64], disc_act_fun=ReLU())[source]#

Bases: Module

batch_correct(train_loader, vae_optimizer, disc_optimizer, adv_optimizer, w_disc=1.0, w_adv=1.0, w_elbo_nll=1.0, w_elbo_kl=1.0, w_bio_penalty=1.0, w_cluster_penalty=1.0)[source]#

Train the conditional VAE model for batch correction. This is trained after VAE prior parameters are defined,

batch_mask(train_loader, decoder_optimizer, smooth_annealing=True, cycle_reg=None, w_elbo_nll=1.0, w_cycle=0.001)[source]#

Pre-trained VAE will now have frozen encoder and batch labels masked at the encoder.

correct(seed=None, mask=True)[source]#
fit(smooth_annealing=True, cycle_reg=None, seed=None, w_elbo_nll=1.0, w_elbo_kl=1.0, w_bio_penalty=1.0, w_cluster_penalty=1.0, w_cycle=0.001, w_disc=1.0, w_adv=1.0, phase_1_vae_lr=0.001, phase_2_vae_lr=0.001, phase_3_vae_lr=1e-06, disc_lr=0.001, adv_lr=0.001)[source]#
plot_pca_posterior(figsize=(14, 6), palette='tab10')[source]#

Get the plot of the first 2 principal components of the posterior distribution.

train_vae(train_loader, optimizer, w_elbo_nll=1.0, w_elbo_kl=1.0, w_bio_penalty=1.0, w_cluster_penalty=1.0)[source]#

Train the conditional VAE model. If clustering prior is used, penalization term is applied to increase sparsity of the clusters.

Parameters:
  • vae – [VAE] Variational Autoencoder model

  • train_loader – [torch.utils.data.DataLoader] DataLoader for the training data

  • optimizer – [torch.optim.Optimizer] Optimizer for training

  • epochs – [int] Number of training epochs

  • device – [str] Device to use for computations

abaco.ABaCo.new_pre_train_abaco(vae, vae_optim_pre, discriminator, disc_optim, adv_optim, data_loader, normal_epochs, device, mog_epochs=0, w_contra=1.0, temp=0.1, w_elbo=1.0, w_disc=1.0, w_adv=1.0, disc_loss_type='CrossEntropy', n_disc_updates=1, label_smooth=0.1, normal=False, count=True)[source]#

PART OF THE NEW ABACO RUN WITH THREE PHASES. TWO OF THEM ARE COMPUTED HERE. Pre-training of conditional VAE with contrastive loss and adversarial mixing in latent space.

abaco.ABaCo.pre_train_abaco(vae, vae_optim_pre, discriminator, disc_optim, adv_optim, data_loader, epochs, device, w_contra=1.0, temp=0.1, w_elbo=1.0, w_disc=1.0, w_adv=1.0, disc_loss_type='CrossEntropy', n_disc_updates=1, label_smooth=0.1, normal=False, count=True)[source]#

Pre-training of conditional VAE with contrastive loss and adversarial mixing in latent space.

abaco.ABaCo.pre_train_abaco_ensemble(vae, vae_optim_pre, discriminator, disc_optim, adv_optim, data_loader, epochs, device, w_contra=1.0, temp=0.1, w_elbo=1.0, w_disc=1.0, w_adv=1.0, disc_loss_type='CrossEntropy', n_disc_updates=1, label_smooth=0.1, normal=False, count=True)[source]#

Pre-training of conditional VAE with contrastive loss and adversarial mixing in latent space.

abaco.ABaCo.train_abaco(vae, vae_optim_post, data_loader, epochs, device, w_elbo=1.0, w_cycle=1.0, cycle='KL', smooth_annealing=False)[source]#

This function trains a pre-trained ABaCo cVAE decoder but applies masking to batch labels so information passed solely depends on the latent space which had batch mixing

abaco.ABaCo.train_abaco_ensemble(vae, vae_optim_post, data_loader, epochs, device, w_elbo=1.0, w_cycle=1.0, cycle='KL', smooth_annealing=False)[source]#

This function trains a pre-trained ABaCo cVAE decoder but applies masking to batch labels so information passed solely depends on the latent space which had batch mixing