abaco.ABaCo module#
- class abaco.ABaCo.BatchDiscriminator(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleDefine the Batch Discriminator for adversarial training
- forward(x)[source]#
Computes the forward pass through the discriminator.
- Parameters:
x (torch.Tensor) – Input to be passed through the model
- Returns:
batch_class – Prediction of the batch of origin of the observation
- Return type:
- class abaco.ABaCo.ConditionalEnsembleVAE(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleDefine a conditional Variational Autoencoder (VAE) model.
- elbo(x)[source]#
Compute the ELBO for the given batch of data.
Parameters: x: [torch.Tensor]
A tensor of dimension (batch_size, feature_dim1, feature_dim2, …) n_samples: [int] Number of samples to use for the Monte Carlo estimate of the ELBO.
- forward(x)[source]#
Compute the negative ELBO for the given batch of data.
Parameters: x: [torch.Tensor]
A tensor of dimension (batch_size, feature_dim1, feature_dim2)
- get_posterior(x)[source]#
Given a set of points, compute the posterior distribution.
Parameters: x: [torch.Tensor]
Samples to pass to the encoder
- pca_posterior(x)[source]#
Given a set of points, compute the PCA of the posterior distribution.
Parameters: x: [torch.Tensor]
Samples to pass to the encoder
- class abaco.ABaCo.ConditionalVAE(*args: Any, **kwargs: Any)[source]#
Bases:
Module- elbo(x)[source]#
Compute the ELBO for the given batch of data.
- Parameters:
x (torch.Tensor) – A tensor of dimension (batch_size, feature_dim1, feature_dim2, …)
n_samples (int) – Number of samples to use for the Monte Carlo estimate of the ELBO
- Returns:
elbo – Evidence Lower Bound
- Return type:
- forward(x)[source]#
Compute the negative ELBO for the given batch of data.
- Parameters:
x (torch.Tensor) – A tensor of dimension (batch_size, feature_dim1, feature_dim2)
- Returns:
loss – Negative ELBO
- Return type:
- get_posterior(x)[source]#
Given a set of points, compute the samples of the posterior distribution.
- Parameters:
x (torch.Tensor) – Samples to pass to the encoder to obtain the posterior distribution from
- Returns:
z – Encoded points sampled from the posterior distribution
- Return type:
- kl_div_loss(x)[source]#
KL-divergence between the prior and the posterior distribution.
- Parameters:
x (torch.Tensor) – Observation to be encoded in order to get the approximate posterior distribution
- Returns:
kl_loss – KL-divergence loss
- Return type:
- pca_posterior(x)[source]#
Given a set of points, compute the PCA of the posterior distribution.
- Parameters:
x (torch.Tensor) – Samples to pass to the encoder
- pca_prior(n_samples)[source]#
Given a number of samples, get the PCA from the sampling of the prior distribution.
- Parameters:
n_samples (int) – Number of samples from the prior distribution
- class abaco.ABaCo.DMDecoder(*args: Any, **kwargs: Any)[source]#
Bases:
Module- forward(z)[source]#
Computes the Dirichlet-Multinomial distribution over the data space. We are obtaining the concentration parameter of the distribution, hence the decoder output would have shape (batch, features).
- Parameters:
z (torch.Tensor) – Latent space embeddings
- Returns:
Dirichlet-Multinomial distribution use to calculate the likelihood using log_prob() method
- Return type:
torch.distribution
- class abaco.ABaCo.DirichletMultinomial(*args: Any, **kwargs: Any)[source]#
Bases:
Distribution- Dirichlet-Multinomial distribution, defined by:
p ~ Dirichlet(alpha) x | p ~ Multinomial(total_count, p)
- arg_constraints = {'concentration': torch.distributions.constraints.positive, 'total_count': torch.distributions.constraints.nonnegative_integer}#
- has_rsample = False#
- log_prob(x: torch.Tensor)[source]#
Defines the log_prob() function inherent from torch.distributions.Distribution.
- Parameters:
x (torch.Tensor) – Observation used to calculate the log probability of the distribution fitting it
- Returns:
log_p – Log probability of the distribution fitting the observation.
- Return type:
- sample(sample_shape=torch.Size)[source]#
Defines the sample() function, which is used to sample data points using the distribution parameters.
- Parameters:
sample_shape – Amount of samples. If not given sample() method would sampled an observation from the distribution.
- Returns:
Sample(s) from the Dirichlet Multinomial distribution.
- Return type:
dm_sample
- class abaco.ABaCo.MixtureOfGaussians(*args: Any, **kwargs: Any)[source]#
Bases:
DistributionA Mixture of Gaussians distribution with reparameterized sampling. Computation of gradients is possible.
- arg_constraints = {}#
- has_rsample = True#
- log_prob(value)[source]#
Compute the log probability of a given value. The log prob of a MoG is defined as:
log_prob(x) = log [sum_k (pi_k * N(x; mu_k, sigma_k^2)]
Where pi_k are the mixture probabilities.
- Parameters:
value (torch.Tensor) – Value to be used to calculate the MoG probability of fitting it
- Returns:
log_prob – Log probability of the distribution fitting the value
- Return type:
- property mean#
weighted sum of component means
- Type:
Mixture mean
- class abaco.ABaCo.MoCPPrior(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class abaco.ABaCo.MoGEncoder(*args: Any, **kwargs: Any)[source]#
Bases:
Module- det_encode(x)[source]#
Computes the encoded point without stochastic component.
- Parameters:
x (torch.Tensor) – Tensor to be encoded
- Returns:
z – Sum of MoG components means with proportion of the mixing probabilities of the encoded point
- Return type:
- forward(x)[source]#
Computes the MoG distribution over the latent space.
- Parameters:
x (torch.Tensor)
- Returns:
mog_dist – MoG distribution use to calculate the approximate posterior with rsample()
- Return type:
- monte_carlo_encode(x, K=100)[source]#
Computes a Monte Carlo simulation of the same point to approximate the true posterior distribution.
- Parameters:
x (torch.Tensor) – Tensor to be encoded
K (int) – Number of Monte Carlo iterations
- Returns:
sample – Mean from all samples (i.e. expectation of the encoded point estimated through Monte Carlo simulations)
- Return type:
- class abaco.ABaCo.NBDecoder(*args: Any, **kwargs: Any)[source]#
Bases:
Module- forward(z)[source]#
Computes the Negative Binomial distribution over the data space. What we are getting is the mean and the dispersion parameters, so it is needed a parameterization in order to get the NB distribution parameters: total_count (dispersion) and probs (dispersion/(dispersion + mean)).
- Parameters:
z (torch.Tensor) – Latent space embeddings
- Returns:
Negative Binomial distribution use to calculate the likelihood using log_prob() method
- Return type:
torch.distribution
- class abaco.ABaCo.NormalEncoder(*args: Any, **kwargs: Any)[source]#
Bases:
Module- det_encode(x)[source]#
Computes the encoded point without stochastic component.
- Parameters:
x (torch.Tensor) – Tensor to be encoded
- Returns:
mu – Gaussian Noraml mean of the encoded point
- Return type:
- forward(x)[source]#
Computes the Gaussian Normal distribution over the latent space.
- Parameters:
x (torch.Tensor) – Tensor to be encoded
- Returns:
Gaussian Normal distribution use to calculate the posterior distribution with .rsample()
- Return type:
torch.distribution
- monte_carlo_encode(x, K=100)[source]#
Computes a Monte Carlo simulation of the same point to approximate the true posterior distribution.
- Parameters:
x (torch.Tensor) – Tensor to be encoded
K (int) – Number of Monte Carlo iterations
- Returns:
sample – Mean from all samples (i.e. expectation of the encoded point estimated through Monte Carlo simulations)
- Return type:
- class abaco.ABaCo.SupervisedContrastiveLoss(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleContrastive loss definition
- class abaco.ABaCo.VAE(*args: Any, **kwargs: Any)[source]#
Bases:
Module- decode(z)[source]#
Decode the latent space representation to the data space.
- Parameters:
z – [torch.Tensor] Latent space tensor of shape (batch, d_z)
- encode(x)[source]#
Forward pass through the VAE.
- Parameters:
x – [torch.Tensor] Input data tensor of shape (batch, features)
- get_posterior(x)[source]#
Given a set of points, compute the posterior distribution.
Parameters: x: [torch.Tensor]
Samples to pass to the encoder
- class abaco.ABaCo.VMMPrior(*args: Any, **kwargs: Any)[source]#
Bases:
Module- cluster_loss()[source]#
Compute the clustering loss for the MoG prior. This loss encourages the components to be well separated by maximizing the pairwise KL divergence between the Gaussian components.
- forward(k_ohe)[source]#
Return prior distribution, allowing for the computation of the KL-divergence by calling self.prior().
- Parameters:
k_ohe – [torch.tensor] One-hot encoded tensor with the corresponding component the point belongs to.
b_ohe – [torch.tensor] One-hot encoded tensor with the corresponding batch label, necessary to append at the Encoder input.
encoder – [nn.Module] Encoder used for getting the centroid of the cluster by encoding the pseudo-input.
- Returns:
[torch.distributions.Distribution]
- Return type:
prior
- class abaco.ABaCo.VampPriorMixtureConditionalEnsembleVAE(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleDefine a VampPrior Variational Autoencoder model.
- get_posterior(x)[source]#
Given a set of points, compute the posterior distribution.
Parameters: x: [torch.Tensor]
Samples to pass to the encoder
- pca_posterior(x)[source]#
Given a set of points, compute the PCA of the posterior distribution.
Parameters: x: [torch.Tensor]
Samples to pass to the encoder
- class abaco.ABaCo.VampPriorMixtureConditionalVAE(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleDefine a VampPrior Variational Autoencoder model. Class is used for the baseline application of ABaCo model.
- get_posterior(x)[source]#
Given a set of points, compute the posterior distribution.
Parameters: x: [torch.Tensor]
Samples to pass to the encoder
- pca_posterior(x)[source]#
Given a set of points, compute the PCA of the posterior distribution.
Parameters: x: [torch.Tensor]
Samples to pass to the encoder
- class abaco.ABaCo.ZIDM(*args: Any, **kwargs: Any)[source]#
Bases:
DistributionZero-inflated Dirichlet-Multinomial (ZIDM) distribution, mixture of structural zeros per-category with a Dirichlet-Multinomial core.
- arg_constraints = {}#
- has_rsample = False#
- log_prob(x: torch.Tensor)[source]#
Defines the log_prob() function inherent from torch.distributions.Distribution.
- Parameters:
x (torch.Tensor) – Observation used to calculate the log probability of the distribution fitting it
- Returns:
log_zidm – Log probability of the distribution fitting the observation
- Return type:
- sample(sample_shape=torch.Size)[source]#
Defines the sample() function, which is used to sample data points using the distribution parameters.
- Parameters:
sample_shape – Amount of samples. If not given sample() method would sampled an observation from the distribution.
- Returns:
Sample(s) from the Zero-inflated Dirichlet Multinomial distribution.
- Return type:
zinb_sample
- class abaco.ABaCo.ZIDMDecoder(*args: Any, **kwargs: Any)[source]#
Bases:
Module- forward(z)[source]#
Computes the Zero-inflated Dirichlet-Multinomial distribution over the data space. We are obtaining the zero probability and concentration parameter of the distribution, hence the decoder output would have shape (batch, 2*features).
- Parameters:
z (torch.Tensor) – Latent space embeddings
- Returns:
Zero-inflated Dirichlet-Multinomial distribution use to calculate the likelihood using log_prob() method
- Return type:
torch.distribution
- class abaco.ABaCo.ZINB(*args: Any, **kwargs: Any)[source]#
Bases:
Distribution- Zero-inflated Negative Binomial (ZINB) distribution definition. The ZINB distribution is defined by the following:
- For x = 0:
ZINB.log_prob(0) = log(pi + (1 - pi) * NegativeBinomial(0 | r, p))
- For x > 0:
ZINB.log_prob(x) = log(1 - pi) + NegativeBinomial(x |r, p).log_prob(x)
- arg_constraints = {}#
- has_rsample = False#
- log_prob(x)[source]#
Defines the log_prob() function inherent from torch.distributions.Distribution.
- Parameters:
x (torch.Tensor) – Observation used to calculate the log probability of the distribution fitting it.
- Returns:
log_p – Log probability of the model fitting the observation.
- Return type:
- sample(sample_shape=torch.Size)[source]#
- Defines the sample() function, which is used to sample data points using the distribution parameters.
- For x = 0:
Binary mask for zero-inflated values
- For x > 0:
Sample from Negative Binomial distribution
- Parameters:
sample_shape – Amount of samples. If not given sample() method would sampled an observation from the distribution.
- Returns:
Sample(s) from the Zero-inflated Negative Binomial distribution.
- Return type:
zinb_sample
- class abaco.ABaCo.ZINBDecoder(*args: Any, **kwargs: Any)[source]#
Bases:
Module- forward(z)[source]#
Computes the Zero-inflated Negative Binomial distribution over the data space. What we are getting is the zero probability, mean and the dispersion parameters, so it is needed a parameterization in order to get the NB distribution parameters: total_count (dispersion) and probs (dispersion/(dispersion + mean)).
- Parameters:
z (torch.Tensor) – Latent space embeddings
- Returns:
Zero-inflated Negative Binomial distribution use to calculate the likelihood using log_prob() method
- Return type:
torch.distribution
- abaco.ABaCo.abaco_recon(model, device: torch.device, data: pandas.DataFrame, dataloader: torch.utils.data.DataLoader, sample_label: str, batch_label: str, bio_label, seed=42, det_encode=False, monte_carlo=100)[source]#
Function used to reconstruct data using trained ABaCo model.
- Parameters:
model – Trained ABaCo model.
device (torch.device) – Device to run the model on, e.g., “cuda” or “cpu”.
data (pd.DataFrame) – DataFrame containing the data to be reconstructed.
dataloader (torch.utils.data.DataLoader) – Pytorch DataLoader for the data to be reconstructed.
sample_label (str) – Column name in the DataFrame that contains unique ids for the observations/samples.
batch_label (str) – Column name in the DataFrame that contains ids for the batch/factor groupings to be corrected by abaco. e.g. dates of sample analysis
bio_label (str) – Column name in the DataFrame that contains biological groupings where there is the biological/experimental factor variation for abaco to retain when correcting batch effect e.g., experimental condition
seed (int, optional) – Random seed for reproducibility. Default is 42.
det_encode (bool, optional) – If True, use deterministic encoding. Default is False.
monte_carlo (int, optional) – Number of Monte Carlo samples to use for reconstruction. Default is 100. Setting at 1 is the same as just sampling from the final ZINB distribution obtained from the trained model
- Returns:
otu_corrected_pd – DataFrame containing the reconstructed data with batch and biological labels.
- Return type:
pd.DataFrame
- abaco.ABaCo.abaco_recon_ensemble(model, device, data, dataloader, sample_label, batch_label, bio_label, seed=42, det_encode=False, monte_carlo=100)[source]#
Function used to reconstruct data using trained ABaCo model.
- abaco.ABaCo.abaco_run(dataloader: torch.utils.data.DataLoader, n_batches: int, n_bios: int, device: torch.device, input_size: int, new_pre_train: bool = False, seed: int = 42, d_z: int = 16, prior: str = 'VMM', count: bool = True, pre_epochs: int = 2000, post_epochs: int = 2000, kl_cycle: bool = True, smooth_annealing: bool = False, encoder_net: list = [1024, 512, 256], decoder_net: list = [256, 512, 1024], vae_act_func=torch.nn.ReLU, disc_net: list = [256, 128, 64], disc_act_func=torch.nn.ReLU, disc_loss_type: str = 'CrossEntropy', w_elbo: float = 1.0, beta: float = 20.0, w_disc: float = 1.0, w_adv: float = 1.0, w_contra: float = 10.0, temp: float = 0.1, w_cycle: float = 0.1, vae_pre_lr: float = 0.001, vae_post_lr: float = 0.0001, disc_lr: float = 1e-05, adv_lr: float = 1e-05)[source]#
Function to run the ABaCo model training.
- Parameters:
dataloader (torch.utils.data.DataLoader) – Pytorch dataLoader for the training data.
n_batches (int) – Number of batches in the dataset. For example, if samples were sequenced in 5 batches (e.g., 5 different dates) then batches = 5.
n_bios (int) – Number of labels or (potential) clusters based on biological variance. For example, if 2 experimental conditions (e.g., control and treatment) then n_bios = 2.
device (torch.device) – Device to run the model on, e.g., “cuda” or “cpu”.
input_size (int) – Number of features in the input data, columns. For example, if the input is a gene expression matrix with 1000 genes, then input_size = 1000.
new_pre_train (bool) – If True, use the new pre-training method with adversarial training and contrastive loss.
seed (int) – Random seed for reproducibility.
d_z (int) – Dimensionality of the latent space. For example, if d_z = 16, then the latent space will have 16 dimensions.
prior (str) – Prior distribution used. Baseline is “VMM” (VampPrior Mixture Model). Options are “VMM” “MoG” (Mixture of Gaussians), or “Normal”.
count (bool) – If True, the model will use a zero-inflated negative binomial (ZINB) decoder. If False, it will use a zero-inflated Dirichlet (ZIDirichlet) decoder.
pre_epochs (int) – Number of epochs for first phase of ABaCo: data reconstruction. Default is 2000.
post_epochs (int) – Number of epochs for second phase of ABaCo: batch correction. Default is 2000.
kl_cycle (bool) – If True, the model will use a KL divergence cycle loss during second phase of ABaCo (batch correction). If False, cross loss 0
smooth_annealing (bool) – Slow batch masking during ABaCo second phase (batch correction) to avoid exploding gradients. Default is False.
encoder_net (list) – List of integers defining the architecture of the encoder. Each integer is a layer size. For example, [1024, 512, 256] means the encoder will have three layers with 1024, 512, and 256 neurons respectively.
decoder_net (list) – List of integers defining the architecture of the decoder. Each integer is a layer size. For example, [256, 512, 1024] means the decoder will have three layers with 256, 512, and 1024 neurons respectively.
vae_act_func (nn.Module) – Activation function for the VAE encoder and decoder. Default is nn.ReLU().
disc_net (list) – List of integers defining the architecture of the discriminator. Each integer is a layer size. For example, [256, 128, 64] means the discriminator will have three layers with 256, 128, and 64 neurons respectively.
disc_act_func (nn.Module) – Activation function for the discriminator. Default is nn.ReLU().
disc_loss_type (str) – Type of loss function for the discriminator. Options are “CrossEntropy” or “Uniform”. Default is “CrossEntropy”.
w_elbo (float) – Weight of the ELBO loss in the pre-training phase. Default is 1.0
beta (float) – KL-divergence coefficient: higher value yields bigger penalization from the prior distribution during training. Default is 20.0.
w_disc (float) – Weight of the discriminator loss in the pre-training phase. Default is 1.0
w_adv (float) – Weight of the adversarial loss in the pre-training phase. Default is 1.0
w_contra (float) – Contrastive learning power. Higher value yields a higher separation of biological groups at the latent space. Sometimes higher is better.
w_cycle (float) – Higher value leads to more unstability during ABaCo second phase. Default is 0.1.
temp (float) – Temperature for the contrastive loss. Default is 0.1.
vae_pre_lr (float) – Learning rate for first phase of ABaCo. Default is 1e-3.
vae_post_lr (float) – Learning rate for second phase (batch correction). Default is 1e-4.
disc_lr (float) – Learning rate for the batch discriminator. Default is 1e-5.
adv_lr (float) – Adversarial learning rate: to the encoder, if batch effect is on latent space. Default is 1e-5.
- Returns:
Trained ABaCo model
- Return type:
vae
- abaco.ABaCo.abaco_run_ensemble(dataloader, n_batches, n_bios, device, input_size, seed=42, d_z=16, prior='VMM', count=True, pre_epochs=2000, post_epochs=2000, kl_cycle=True, smooth_annealing=False, encoder_net=[1024, 512, 256], n_dec=5, decoder_net=[256, 512, 1024], vae_act_func=torch.nn.ReLU, disc_net=[256, 128, 64], disc_act_func=torch.nn.ReLU, disc_loss_type='CrossEntropy', w_elbo=1.0, beta=20.0, w_disc=1.0, w_adv=1.0, w_contra=10.0, temp=0.1, vae_pre_lr=0.001, vae_post_lr=0.0001, disc_lr=1e-05, adv_lr=1e-05)[source]#
Full ABaCo run with default setting
- abaco.ABaCo.adversarial_loss(pred_logits: torch.Tensor, true_labels: torch.Tensor, loss_type: str = 'CrossEntropy')[source]#
Compute adversarial loss for generator (to fool discriminator).
- Parameters:
pred_logits – [batch_size, n_classes] raw discriminator outputs.
true_labels – [batch_size] long tensor of class indices.
loss_type – “CrossEntropy” or “Uniform”.
- Returns:
A scalar loss (negated for generator on CrossEntropy).
- abaco.ABaCo.contour_plot(samples, n_levels=10, x_offset=5, y_offset=5, alpha=0.8)[source]#
Given an array computes the contour plot
- Parameters:
samples – [np.array] An array with (,2) dimensions
- abaco.ABaCo.kl_mog_mog(p, q)#
Approximation of the KL-divergence among two different Mixture of Gaussians distributions.
- Parameters:
p (MixtureOfGaussians) – MoG distribution 1
q (MixtureOfGaussians) – MoG distribution 2
- Returns:
kl – KL-divergence between p and q
- Return type:
- abaco.ABaCo.kl_zinb_zinb(p, q)#
Approximation of the KL-divergence among two different Zero-inflated Negative Binomial distributions.
- Parameters:
- Returns:
kl – KL-divergence between p and q
- Return type:
- class abaco.ABaCo.metaABaCo(*args: Any, **kwargs: Any)[source]#
Bases:
Module- batch_correct(train_loader, vae_optimizer, disc_optimizer, adv_optimizer, w_disc=1.0, w_adv=1.0, w_elbo_nll=1.0, w_elbo_kl=1.0, w_bio_penalty=1.0, w_cluster_penalty=1.0)[source]#
Train the conditional VAE model for batch correction. This is trained after VAE prior parameters are defined,
- batch_mask(train_loader, decoder_optimizer, smooth_annealing=True, cycle_reg=None, w_elbo_nll=1.0, w_cycle=0.001)[source]#
Pre-trained VAE will now have frozen encoder and batch labels masked at the encoder.
- fit(smooth_annealing=True, cycle_reg=None, seed=None, w_elbo_nll=1.0, w_elbo_kl=1.0, w_bio_penalty=1.0, w_cluster_penalty=1.0, w_cycle=0.001, w_disc=1.0, w_adv=1.0, phase_1_vae_lr=0.001, phase_2_vae_lr=0.001, phase_3_vae_lr=1e-06, disc_lr=0.001, adv_lr=0.001)[source]#
- plot_pca_posterior(figsize=(14, 6), palette='tab10')[source]#
Get the plot of the first 2 principal components of the posterior distribution.
- train_vae(train_loader, optimizer, w_elbo_nll=1.0, w_elbo_kl=1.0, w_bio_penalty=1.0, w_cluster_penalty=1.0)[source]#
Train the conditional VAE model. If clustering prior is used, penalization term is applied to increase sparsity of the clusters.
- Parameters:
vae – [VAE] Variational Autoencoder model
train_loader – [torch.utils.data.DataLoader] DataLoader for the training data
optimizer – [torch.optim.Optimizer] Optimizer for training
epochs – [int] Number of training epochs
device – [str] Device to use for computations
- abaco.ABaCo.new_pre_train_abaco(vae, vae_optim_pre, discriminator, disc_optim, adv_optim, data_loader, normal_epochs, device, mog_epochs=0, w_contra=1.0, temp=0.1, w_elbo=1.0, w_disc=1.0, w_adv=1.0, disc_loss_type='CrossEntropy', n_disc_updates=1, label_smooth=0.1, normal=False, count=True)[source]#
PART OF THE NEW ABACO RUN WITH THREE PHASES. TWO OF THEM ARE COMPUTED HERE. Pre-training of conditional VAE with contrastive loss and adversarial mixing in latent space.
- abaco.ABaCo.pre_train_abaco(vae, vae_optim_pre, discriminator, disc_optim, adv_optim, data_loader, epochs, device, w_contra=1.0, temp=0.1, w_elbo=1.0, w_disc=1.0, w_adv=1.0, disc_loss_type='CrossEntropy', n_disc_updates=1, label_smooth=0.1, normal=False, count=True)[source]#
Pre-training of conditional VAE with contrastive loss and adversarial mixing in latent space.
- abaco.ABaCo.pre_train_abaco_ensemble(vae, vae_optim_pre, discriminator, disc_optim, adv_optim, data_loader, epochs, device, w_contra=1.0, temp=0.1, w_elbo=1.0, w_disc=1.0, w_adv=1.0, disc_loss_type='CrossEntropy', n_disc_updates=1, label_smooth=0.1, normal=False, count=True)[source]#
Pre-training of conditional VAE with contrastive loss and adversarial mixing in latent space.
- abaco.ABaCo.train_abaco(vae, vae_optim_post, data_loader, epochs, device, w_elbo=1.0, w_cycle=1.0, cycle='KL', smooth_annealing=False)[source]#
This function trains a pre-trained ABaCo cVAE decoder but applies masking to batch labels so information passed solely depends on the latent space which had batch mixing
- abaco.ABaCo.train_abaco_ensemble(vae, vae_optim_post, data_loader, epochs, device, w_elbo=1.0, w_cycle=1.0, cycle='KL', smooth_annealing=False)[source]#
This function trains a pre-trained ABaCo cVAE decoder but applies masking to batch labels so information passed solely depends on the latent space which had batch mixing