abaco.ABaCo module#
- class abaco.ABaCo.BatchDiscriminator(net)[source]#
Bases:
ModuleDefine the Batch Discriminator for adversarial training
- forward(x)[source]#
Computes the forward pass through the discriminator.
- Parameters:
x (torch.Tensor) – Input to be passed through the model
- Returns:
batch_class – Prediction of the batch of origin of the observation
- Return type:
- class abaco.ABaCo.ConditionalEnsembleVAE(prior, decoders: ModuleList, encoder, beta=1.0)[source]#
Bases:
ModuleDefine a conditional Variational Autoencoder (VAE) model.
- elbo(x)[source]#
Compute the ELBO for the given batch of data.
Parameters: x: [torch.Tensor]
A tensor of dimension (batch_size, feature_dim1, feature_dim2, …) n_samples: [int] Number of samples to use for the Monte Carlo estimate of the ELBO.
- forward(x)[source]#
Compute the negative ELBO for the given batch of data.
Parameters: x: [torch.Tensor]
A tensor of dimension (batch_size, feature_dim1, feature_dim2)
- get_posterior(x)[source]#
Given a set of points, compute the posterior distribution.
Parameters: x: [torch.Tensor]
Samples to pass to the encoder
- pca_posterior(x)[source]#
Given a set of points, compute the PCA of the posterior distribution.
Parameters: x: [torch.Tensor]
Samples to pass to the encoder
- class abaco.ABaCo.ConditionalVAE(prior, decoder, encoder, beta=1.0)[source]#
Bases:
Module- elbo(x)[source]#
Compute the ELBO for the given batch of data.
- Parameters:
x (torch.Tensor) – A tensor of dimension (batch_size, feature_dim1, feature_dim2, …)
n_samples (int) – Number of samples to use for the Monte Carlo estimate of the ELBO
- Returns:
elbo – Evidence Lower Bound
- Return type:
- forward(x)[source]#
Compute the negative ELBO for the given batch of data.
- Parameters:
x (torch.Tensor) – A tensor of dimension (batch_size, feature_dim1, feature_dim2)
- Returns:
loss – Negative ELBO
- Return type:
- get_posterior(x)[source]#
Given a set of points, compute the samples of the posterior distribution.
- Parameters:
x (torch.Tensor) – Samples to pass to the encoder to obtain the posterior distribution from
- Returns:
z – Encoded points sampled from the posterior distribution
- Return type:
- kl_div_loss(x)[source]#
KL-divergence between the prior and the posterior distribution.
- Parameters:
x (torch.Tensor) – Observation to be encoded in order to get the approximate posterior distribution
- Returns:
kl_loss – KL-divergence loss
- Return type:
- pca_posterior(x)[source]#
Given a set of points, compute the PCA of the posterior distribution.
- Parameters:
x (torch.Tensor) – Samples to pass to the encoder
- pca_prior(n_samples)[source]#
Given a number of samples, get the PCA from the sampling of the prior distribution.
- Parameters:
n_samples (int) – Number of samples from the prior distribution
- class abaco.ABaCo.DMDecoder(decoder_net, total_count, eps=1e-08)[source]#
Bases:
Module- forward(z)[source]#
Computes the Dirichlet-Multinomial distribution over the data space. We are obtaining the concentration parameter of the distribution, hence the decoder output would have shape (batch, features).
- Parameters:
z (torch.Tensor) – Latent space embeddings
- Returns:
Dirichlet-Multinomial distribution use to calculate the likelihood using log_prob() method
- Return type:
torch.distribution
- class abaco.ABaCo.DirichletMultinomial(total_count: int, concentration: Tensor, validate_args=None)[source]#
Bases:
Distribution- Dirichlet-Multinomial distribution, defined by:
p ~ Dirichlet(alpha) x | p ~ Multinomial(total_count, p)
- arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'total_count': IntegerGreaterThan(lower_bound=0)}#
- has_rsample = False#
- log_prob(x: Tensor)[source]#
Defines the log_prob() function inherent from torch.distributions.Distribution.
- Parameters:
x (torch.Tensor) – Observation used to calculate the log probability of the distribution fitting it
- Returns:
log_p – Log probability of the distribution fitting the observation.
- Return type:
- sample(sample_shape=())[source]#
Defines the sample() function, which is used to sample data points using the distribution parameters.
- Parameters:
sample_shape – Amount of samples. If not given sample() method would sampled an observation from the distribution.
- Returns:
Sample(s) from the Dirichlet Multinomial distribution.
- Return type:
dm_sample
- support = IntegerGreaterThan(lower_bound=0)#
- class abaco.ABaCo.MixtureOfGaussians(mixture_logits, means, stds, temperature=1e-05, validate_args=None)[source]#
Bases:
DistributionA Mixture of Gaussians distribution with reparameterized sampling. Computation of gradients is possible.
- arg_constraints = {}#
- entropy()[source]#
Returns entropy of distribution, batched over batch_shape.
- Returns:
Tensor of shape batch_shape.
- has_rsample = True#
- log_prob(value)[source]#
Compute the log probability of a given value. The log prob of a MoG is defined as:
log_prob(x) = log [sum_k (pi_k * N(x; mu_k, sigma_k^2)]
Where pi_k are the mixture probabilities.
- Parameters:
value (torch.Tensor) – Value to be used to calculate the MoG probability of fitting it
- Returns:
log_prob – Log probability of the distribution fitting the value
- Return type:
- property mean#
weighted sum of component means
- Type:
Mixture mean
- rsample(sample_shape=())[source]#
Reparameterized sampling using the Gubel-softmax trick.
- Parameters:
sample_shape – Amount of samples. If not given it would return a sample
- Returns:
sample – Sample from the MoG distribution
- Return type:
- support = Real()#
- class abaco.ABaCo.MoCPPrior(d_z, n_comp, multiplier=1.0)[source]#
Bases:
Module
- class abaco.ABaCo.MoGEncoder(encoder_net, n_comp)[source]#
Bases:
Module- det_encode(x)[source]#
Computes the encoded point without stochastic component.
- Parameters:
x (torch.Tensor) – Tensor to be encoded
- Returns:
z – Sum of MoG components means with proportion of the mixing probabilities of the encoded point
- Return type:
- forward(x)[source]#
Computes the MoG distribution over the latent space.
- Parameters:
x (torch.Tensor)
- Returns:
mog_dist – MoG distribution use to calculate the approximate posterior with rsample()
- Return type:
- monte_carlo_encode(x, K=100)[source]#
Computes a Monte Carlo simulation of the same point to approximate the true posterior distribution.
- Parameters:
x (torch.Tensor) – Tensor to be encoded
K (int) – Number of Monte Carlo iterations
- Returns:
sample – Mean from all samples (i.e. expectation of the encoded point estimated through Monte Carlo simulations)
- Return type:
- class abaco.ABaCo.NBDecoder(decoder_net)[source]#
Bases:
Module- forward(z)[source]#
Computes the Negative Binomial distribution over the data space. What we are getting is the mean and the dispersion parameters, so it is needed a parameterization in order to get the NB distribution parameters: total_count (dispersion) and probs (dispersion/(dispersion + mean)).
- Parameters:
z (torch.Tensor) – Latent space embeddings
- Returns:
Negative Binomial distribution use to calculate the likelihood using log_prob() method
- Return type:
torch.distribution
- class abaco.ABaCo.NormalEncoder(encoder_net)[source]#
Bases:
Module- det_encode(x)[source]#
Computes the encoded point without stochastic component.
- Parameters:
x (torch.Tensor) – Tensor to be encoded
- Returns:
mu – Gaussian Noraml mean of the encoded point
- Return type:
- forward(x)[source]#
Computes the Gaussian Normal distribution over the latent space.
- Parameters:
x (torch.Tensor) – Tensor to be encoded
- Returns:
Gaussian Normal distribution use to calculate the posterior distribution with .rsample()
- Return type:
torch.distribution
- monte_carlo_encode(x, K=100)[source]#
Computes a Monte Carlo simulation of the same point to approximate the true posterior distribution.
- Parameters:
x (torch.Tensor) – Tensor to be encoded
K (int) – Number of Monte Carlo iterations
- Returns:
sample – Mean from all samples (i.e. expectation of the encoded point estimated through Monte Carlo simulations)
- Return type:
- class abaco.ABaCo.SupervisedContrastiveLoss(temp=0.1)[source]#
Bases:
ModuleContrastive loss definition
- class abaco.ABaCo.VAE(encoder, decoder, prior, input_size, device)[source]#
Bases:
Module- decode(z)[source]#
Decode the latent space representation to the data space.
- Parameters:
z – [torch.Tensor] Latent space tensor of shape (batch, d_z)
- encode(x)[source]#
Forward pass through the VAE.
- Parameters:
x – [torch.Tensor] Input data tensor of shape (batch, features)
- get_posterior(x)[source]#
Given a set of points, compute the posterior distribution.
Parameters: x: [torch.Tensor]
Samples to pass to the encoder
- class abaco.ABaCo.VMMPrior(d_z, n_features, n_comp, n_batch, encoder, multiplier=1.0, dataloader=None)[source]#
Bases:
Module- cluster_loss()[source]#
Compute the clustering loss for the MoG prior. This loss encourages the components to be well separated by maximizing the pairwise KL divergence between the Gaussian components.
- forward(k_ohe)[source]#
Return prior distribution, allowing for the computation of the KL-divergence by calling self.prior().
- Parameters:
k_ohe – [torch.tensor] One-hot encoded tensor with the corresponding component the point belongs to.
b_ohe – [torch.tensor] One-hot encoded tensor with the corresponding batch label, necessary to append at the Encoder input.
encoder – [nn.Module] Encoder used for getting the centroid of the cluster by encoding the pseudo-input.
- Returns:
[torch.distributions.Distribution]
- Return type:
prior
- class abaco.ABaCo.VampPriorMixtureConditionalEnsembleVAE(encoder, decoders: ModuleList, input_dim, batch_dim, n_comps, d_z, beta=1.0, data_loader=None)[source]#
Bases:
ModuleDefine a VampPrior Variational Autoencoder model.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- get_posterior(x)[source]#
Given a set of points, compute the posterior distribution.
Parameters: x: [torch.Tensor]
Samples to pass to the encoder
- pca_posterior(x)[source]#
Given a set of points, compute the PCA of the posterior distribution.
Parameters: x: [torch.Tensor]
Samples to pass to the encoder
- class abaco.ABaCo.VampPriorMixtureConditionalVAE(encoder, decoder, input_dim, batch_dim, n_comps, d_z, beta=1.0, data_loader=None)[source]#
Bases:
ModuleDefine a VampPrior Variational Autoencoder model. Class is used for the baseline application of ABaCo model.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- get_posterior(x)[source]#
Given a set of points, compute the posterior distribution.
Parameters: x: [torch.Tensor]
Samples to pass to the encoder
- pca_posterior(x)[source]#
Given a set of points, compute the PCA of the posterior distribution.
Parameters: x: [torch.Tensor]
Samples to pass to the encoder
- class abaco.ABaCo.ZIDM(dm: DirichletMultinomial, pi: Tensor, eps: float = 1e-08, validate_args=None)[source]#
Bases:
DistributionZero-inflated Dirichlet-Multinomial (ZIDM) distribution, mixture of structural zeros per-category with a Dirichlet-Multinomial core.
- arg_constraints = {}#
- has_rsample = False#
- log_prob(x: Tensor)[source]#
Defines the log_prob() function inherent from torch.distributions.Distribution.
- Parameters:
x (torch.Tensor) – Observation used to calculate the log probability of the distribution fitting it
- Returns:
log_zidm – Log probability of the distribution fitting the observation
- Return type:
- sample(sample_shape=())[source]#
Defines the sample() function, which is used to sample data points using the distribution parameters.
- Parameters:
sample_shape – Amount of samples. If not given sample() method would sampled an observation from the distribution.
- Returns:
Sample(s) from the Zero-inflated Dirichlet Multinomial distribution.
- Return type:
zinb_sample
- support = IntegerGreaterThan(lower_bound=0)#
- class abaco.ABaCo.ZIDMDecoder(decoder_net, total_count, eps=1e-08)[source]#
Bases:
Module- forward(z)[source]#
Computes the Zero-inflated Dirichlet-Multinomial distribution over the data space. We are obtaining the zero probability and concentration parameter of the distribution, hence the decoder output would have shape (batch, 2*features).
- Parameters:
z (torch.Tensor) – Latent space embeddings
- Returns:
Zero-inflated Dirichlet-Multinomial distribution use to calculate the likelihood using log_prob() method
- Return type:
torch.distribution
- class abaco.ABaCo.ZINB(nb, pi, validate_args=None)[source]#
Bases:
Distribution- Zero-inflated Negative Binomial (ZINB) distribution definition. The ZINB distribution is defined by the following:
- For x = 0:
ZINB.log_prob(0) = log(pi + (1 - pi) * NegativeBinomial(0 | r, p))
- For x > 0:
ZINB.log_prob(x) = log(1 - pi) + NegativeBinomial(x |r, p).log_prob(x)
- arg_constraints = {}#
- has_rsample = False#
- log_prob(x)[source]#
Defines the log_prob() function inherent from torch.distributions.Distribution.
- Parameters:
x (torch.Tensor) – Observation used to calculate the log probability of the distribution fitting it.
- Returns:
log_p – Log probability of the model fitting the observation.
- Return type:
- sample(sample_shape=())[source]#
- Defines the sample() function, which is used to sample data points using the distribution parameters.
- For x = 0:
Binary mask for zero-inflated values
- For x > 0:
Sample from Negative Binomial distribution
- Parameters:
sample_shape – Amount of samples. If not given sample() method would sampled an observation from the distribution.
- Returns:
Sample(s) from the Zero-inflated Negative Binomial distribution.
- Return type:
zinb_sample
- support = IntegerGreaterThan(lower_bound=0)#
- class abaco.ABaCo.ZINBDecoder(decoder_net)[source]#
Bases:
Module- forward(z)[source]#
Computes the Zero-inflated Negative Binomial distribution over the data space. What we are getting is the zero probability, mean and the dispersion parameters, so it is needed a parameterization in order to get the NB distribution parameters: total_count (dispersion) and probs (dispersion/(dispersion + mean)).
- Parameters:
z (torch.Tensor) – Latent space embeddings
- Returns:
Zero-inflated Negative Binomial distribution use to calculate the likelihood using log_prob() method
- Return type:
torch.distribution
- abaco.ABaCo.abaco_recon(model, device: device, data: DataFrame, dataloader: DataLoader, sample_label: str, batch_label: str, bio_label, seed=42, det_encode=False, monte_carlo=100)[source]#
Function used to reconstruct data using trained ABaCo model.
- Parameters:
model – Trained ABaCo model.
device (torch.device) – Device to run the model on, e.g., “cuda” or “cpu”.
data (pd.DataFrame) – DataFrame containing the data to be reconstructed.
dataloader (torch.utils.data.DataLoader) – Pytorch DataLoader for the data to be reconstructed.
sample_label (str) – Column name in the DataFrame that contains unique ids for the observations/samples.
batch_label (str) – Column name in the DataFrame that contains ids for the batch/factor groupings to be corrected by abaco. e.g. dates of sample analysis
bio_label (str) – Column name in the DataFrame that contains biological groupings where there is the biological/experimental factor variation for abaco to retain when correcting batch effect e.g., experimental condition
seed (int, optional) – Random seed for reproducibility. Default is 42.
det_encode (bool, optional) – If True, use deterministic encoding. Default is False.
monte_carlo (int, optional) – Number of Monte Carlo samples to use for reconstruction. Default is 100. Setting at 1 is the same as just sampling from the final ZINB distribution obtained from the trained model
- Returns:
otu_corrected_pd – DataFrame containing the reconstructed data with batch and biological labels.
- Return type:
pd.DataFrame
- abaco.ABaCo.abaco_recon_ensemble(model, device, data, dataloader, sample_label, batch_label, bio_label, seed=42, det_encode=False, monte_carlo=100)[source]#
Function used to reconstruct data using trained ABaCo model.
- abaco.ABaCo.abaco_run(dataloader: DataLoader, n_batches: int, n_bios: int, device: device, input_size: int, new_pre_train: bool = False, seed: int = 42, d_z: int = 16, prior: str = 'VMM', count: bool = True, pre_epochs: int = 2000, post_epochs: int = 2000, kl_cycle: bool = True, smooth_annealing: bool = False, encoder_net: list = [1024, 512, 256], decoder_net: list = [256, 512, 1024], vae_act_func=ReLU(), disc_net: list = [256, 128, 64], disc_act_func=ReLU(), disc_loss_type: str = 'CrossEntropy', w_elbo: float = 1.0, beta: float = 20.0, w_disc: float = 1.0, w_adv: float = 1.0, w_contra: float = 10.0, temp: float = 0.1, w_cycle: float = 0.1, vae_pre_lr: float = 0.001, vae_post_lr: float = 0.0001, disc_lr: float = 1e-05, adv_lr: float = 1e-05)[source]#
Function to run the ABaCo model training.
- Parameters:
dataloader (torch.utils.data.DataLoader) – Pytorch dataLoader for the training data.
n_batches (int) – Number of batches in the dataset. For example, if samples were sequenced in 5 batches (e.g., 5 different dates) then batches = 5.
n_bios (int) – Number of labels or (potential) clusters based on biological variance. For example, if 2 experimental conditions (e.g., control and treatment) then n_bios = 2.
device (torch.device) – Device to run the model on, e.g., “cuda” or “cpu”.
input_size (int) – Number of features in the input data, columns. For example, if the input is a gene expression matrix with 1000 genes, then input_size = 1000.
new_pre_train (bool) – If True, use the new pre-training method with adversarial training and contrastive loss.
seed (int) – Random seed for reproducibility.
d_z (int) – Dimensionality of the latent space. For example, if d_z = 16, then the latent space will have 16 dimensions.
prior (str) – Prior distribution used. Baseline is “VMM” (VampPrior Mixture Model). Options are “VMM” “MoG” (Mixture of Gaussians), or “Normal”.
count (bool) – If True, the model will use a zero-inflated negative binomial (ZINB) decoder. If False, it will use a zero-inflated Dirichlet (ZIDirichlet) decoder.
pre_epochs (int) – Number of epochs for first phase of ABaCo: data reconstruction. Default is 2000.
post_epochs (int) – Number of epochs for second phase of ABaCo: batch correction. Default is 2000.
kl_cycle (bool) – If True, the model will use a KL divergence cycle loss during second phase of ABaCo (batch correction). If False, cross loss 0
smooth_annealing (bool) – Slow batch masking during ABaCo second phase (batch correction) to avoid exploding gradients. Default is False.
encoder_net (list) – List of integers defining the architecture of the encoder. Each integer is a layer size. For example, [1024, 512, 256] means the encoder will have three layers with 1024, 512, and 256 neurons respectively.
decoder_net (list) – List of integers defining the architecture of the decoder. Each integer is a layer size. For example, [256, 512, 1024] means the decoder will have three layers with 256, 512, and 1024 neurons respectively.
vae_act_func (nn.Module) – Activation function for the VAE encoder and decoder. Default is nn.ReLU().
disc_net (list) – List of integers defining the architecture of the discriminator. Each integer is a layer size. For example, [256, 128, 64] means the discriminator will have three layers with 256, 128, and 64 neurons respectively.
disc_act_func (nn.Module) – Activation function for the discriminator. Default is nn.ReLU().
disc_loss_type (str) – Type of loss function for the discriminator. Options are “CrossEntropy” or “Uniform”. Default is “CrossEntropy”.
w_elbo (float) – Weight of the ELBO loss in the pre-training phase. Default is 1.0
beta (float) – KL-divergence coefficient: higher value yields bigger penalization from the prior distribution during training. Default is 20.0.
w_disc (float) – Weight of the discriminator loss in the pre-training phase. Default is 1.0
w_adv (float) – Weight of the adversarial loss in the pre-training phase. Default is 1.0
w_contra (float) – Contrastive learning power. Higher value yields a higher separation of biological groups at the latent space. Sometimes higher is better.
w_cycle (float) – Higher value leads to more unstability during ABaCo second phase. Default is 0.1.
temp (float) – Temperature for the contrastive loss. Default is 0.1.
vae_pre_lr (float) – Learning rate for first phase of ABaCo. Default is 1e-3.
vae_post_lr (float) – Learning rate for second phase (batch correction). Default is 1e-4.
disc_lr (float) – Learning rate for the batch discriminator. Default is 1e-5.
adv_lr (float) – Adversarial learning rate: to the encoder, if batch effect is on latent space. Default is 1e-5.
- Returns:
Trained ABaCo model
- Return type:
vae
- abaco.ABaCo.abaco_run_ensemble(dataloader, n_batches, n_bios, device, input_size, seed=42, d_z=16, prior='VMM', count=True, pre_epochs=2000, post_epochs=2000, kl_cycle=True, smooth_annealing=False, encoder_net=[1024, 512, 256], n_dec=5, decoder_net=[256, 512, 1024], vae_act_func=ReLU(), disc_net=[256, 128, 64], disc_act_func=ReLU(), disc_loss_type='CrossEntropy', w_elbo=1.0, beta=20.0, w_disc=1.0, w_adv=1.0, w_contra=10.0, temp=0.1, vae_pre_lr=0.001, vae_post_lr=0.0001, disc_lr=1e-05, adv_lr=1e-05)[source]#
Full ABaCo run with default setting
- abaco.ABaCo.adversarial_loss(pred_logits: Tensor, true_labels: Tensor, loss_type: str = 'CrossEntropy')[source]#
Compute adversarial loss for generator (to fool discriminator).
- Parameters:
pred_logits – [batch_size, n_classes] raw discriminator outputs.
true_labels – [batch_size] long tensor of class indices.
loss_type – “CrossEntropy” or “Uniform”.
- Returns:
A scalar loss (negated for generator on CrossEntropy).
- abaco.ABaCo.contour_plot(samples, n_levels=10, x_offset=5, y_offset=5, alpha=0.8)[source]#
Given an array computes the contour plot
- Parameters:
samples – [np.array] An array with (,2) dimensions
- abaco.ABaCo.kl_mog_mog(p, q)[source]#
Approximation of the KL-divergence among two different Mixture of Gaussians distributions.
- Parameters:
p (MixtureOfGaussians) – MoG distribution 1
q (MixtureOfGaussians) – MoG distribution 2
- Returns:
kl – KL-divergence between p and q
- Return type:
- abaco.ABaCo.kl_zinb_zinb(p, q)[source]#
Approximation of the KL-divergence among two different Zero-inflated Negative Binomial distributions.
- Parameters:
- Returns:
kl – KL-divergence between p and q
- Return type:
- class abaco.ABaCo.metaABaCo(data, n_bios, bio_label, n_batches, batch_label, n_features, device, prior='MoG', pdist='ZINB', d_z=16, epochs=[1000, 2000, 2000], encoder_net=[512, 256, 128], decoder_net=[128, 256, 512], vae_act_fun=ReLU(), disc_net=[128, 64], disc_act_fun=ReLU())[source]#
Bases:
Module- batch_correct(train_loader, vae_optimizer, disc_optimizer, adv_optimizer, w_disc=1.0, w_adv=1.0, w_elbo_nll=1.0, w_elbo_kl=1.0, w_bio_penalty=1.0, w_cluster_penalty=1.0)[source]#
Train the conditional VAE model for batch correction. This is trained after VAE prior parameters are defined,
- batch_mask(train_loader, decoder_optimizer, smooth_annealing=True, cycle_reg=None, w_elbo_nll=1.0, w_cycle=0.001)[source]#
Pre-trained VAE will now have frozen encoder and batch labels masked at the encoder.
- fit(smooth_annealing=True, cycle_reg=None, seed=None, w_elbo_nll=1.0, w_elbo_kl=1.0, w_bio_penalty=1.0, w_cluster_penalty=1.0, w_cycle=0.001, w_disc=1.0, w_adv=1.0, phase_1_vae_lr=0.001, phase_2_vae_lr=0.001, phase_3_vae_lr=1e-06, disc_lr=0.001, adv_lr=0.001)[source]#
- plot_pca_posterior(figsize=(14, 6), palette='tab10')[source]#
Get the plot of the first 2 principal components of the posterior distribution.
- train_vae(train_loader, optimizer, w_elbo_nll=1.0, w_elbo_kl=1.0, w_bio_penalty=1.0, w_cluster_penalty=1.0)[source]#
Train the conditional VAE model. If clustering prior is used, penalization term is applied to increase sparsity of the clusters.
- Parameters:
vae – [VAE] Variational Autoencoder model
train_loader – [torch.utils.data.DataLoader] DataLoader for the training data
optimizer – [torch.optim.Optimizer] Optimizer for training
epochs – [int] Number of training epochs
device – [str] Device to use for computations
- abaco.ABaCo.new_pre_train_abaco(vae, vae_optim_pre, discriminator, disc_optim, adv_optim, data_loader, normal_epochs, device, mog_epochs=0, w_contra=1.0, temp=0.1, w_elbo=1.0, w_disc=1.0, w_adv=1.0, disc_loss_type='CrossEntropy', n_disc_updates=1, label_smooth=0.1, normal=False, count=True)[source]#
PART OF THE NEW ABACO RUN WITH THREE PHASES. TWO OF THEM ARE COMPUTED HERE. Pre-training of conditional VAE with contrastive loss and adversarial mixing in latent space.
- abaco.ABaCo.pre_train_abaco(vae, vae_optim_pre, discriminator, disc_optim, adv_optim, data_loader, epochs, device, w_contra=1.0, temp=0.1, w_elbo=1.0, w_disc=1.0, w_adv=1.0, disc_loss_type='CrossEntropy', n_disc_updates=1, label_smooth=0.1, normal=False, count=True)[source]#
Pre-training of conditional VAE with contrastive loss and adversarial mixing in latent space.
- abaco.ABaCo.pre_train_abaco_ensemble(vae, vae_optim_pre, discriminator, disc_optim, adv_optim, data_loader, epochs, device, w_contra=1.0, temp=0.1, w_elbo=1.0, w_disc=1.0, w_adv=1.0, disc_loss_type='CrossEntropy', n_disc_updates=1, label_smooth=0.1, normal=False, count=True)[source]#
Pre-training of conditional VAE with contrastive loss and adversarial mixing in latent space.
- abaco.ABaCo.train_abaco(vae, vae_optim_post, data_loader, epochs, device, w_elbo=1.0, w_cycle=1.0, cycle='KL', smooth_annealing=False)[source]#
This function trains a pre-trained ABaCo cVAE decoder but applies masking to batch labels so information passed solely depends on the latent space which had batch mixing
- abaco.ABaCo.train_abaco_ensemble(vae, vae_optim_post, data_loader, epochs, device, w_elbo=1.0, w_cycle=1.0, cycle='KL', smooth_annealing=False)[source]#
This function trains a pre-trained ABaCo cVAE decoder but applies masking to batch labels so information passed solely depends on the latent space which had batch mixing