Source code for neupi.training.trainers.ssl_trainer
import torch
from neupi.core.trainer import BaseTrainer
from neupi.registry import register
from neupi.utils.pgm_utils import apply_evidence
from torch.utils.data import DataLoader
[docs]
@register("trainer")
class SelfSupervisedTrainer(BaseTrainer):
"""
A trainer for self-supervised learning of neural PGM solvers.
This class handles the training and validation loops, including the
forward pass, loss calculation, backpropagation, and optimizer steps.
Args:
model (torch.nn.Module): The neural network model to be trained.
pgm_evaluator (torch.nn.Module): The PGM evaluator (e.g., MarkovNetwork)
used for calculating log-likelihoods.
loss_fn (callable): The loss function. It should accept predictions
and the PGM evaluator as arguments.
optimizer (torch.optim.Optimizer): The optimizer for updating model weights.
device (str): The device to run training on ('cpu' or 'cuda').
"""
def __init__(
self,
model: torch.nn.Module,
pgm_evaluator: torch.nn.Module,
loss_fn: callable,
optimizer: torch.optim.Optimizer,
device: str = "cpu",
):
self.model = model.to(device)
self.pgm_evaluator = pgm_evaluator.to(device)
self.loss_fn = loss_fn
self.optimizer = optimizer
self.device = device
[docs]
def step(self, batch_data):
"""Performs a single training step on a batch of data."""
# Unpack batch data (assuming a tuple of tensors)
# You might need to adjust this based on your DataLoader's output
evidence_data, evidence_mask, query_mask, unobs_mask = batch_data
evidence_data = evidence_data.to(self.device)
evidence_mask = evidence_mask.to(self.device)
query_mask = query_mask.to(self.device)
unobs_mask = unobs_mask.to(self.device)
# Forward pass
self.model.train()
raw_predictions = self.model(evidence_data, evidence_mask, query_mask, unobs_mask)
# Apply a sigmoid to get probabilities in the range [0, 1]
predictions = torch.sigmoid(raw_predictions)
# Process predictions: apply evidence
# The network's predictions for evidence variables are replaced with their true values.
final_assignments = apply_evidence(predictions, evidence_data, evidence_mask)
# Calculate loss
loss = self.loss_fn(final_assignments, self.pgm_evaluator)
# Backward pass and optimization
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
[docs]
def fit(self, dataloader: DataLoader, num_epochs: int):
"""
Runs the full training loop for a specified number of epochs.
Args:
dataloader (DataLoader): The DataLoader providing training data.
num_epochs (int): The number of epochs to train for.
"""
for epoch in range(num_epochs):
total_loss = 0
for batch_data in dataloader:
loss = self.step(batch_data)
total_loss += loss
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
return self.model