Source code for neupi.inference.itself

import copy
from typing import Dict, List, Type

import torch
from torch.utils.data import DataLoader

from neupi.core.inference_module import BaseInferenceModule
from neupi.registry import register
from neupi.utils.pgm_utils import apply_evidence


[docs] @register("inference_module") class ITSELF_Engine(BaseInferenceModule): """ Handles Inference Time Self-Supervised Training (ITSELF). For each batch of data, this engine performs a few steps of fine-tuning on a copy of the model to refine the predictions for that specific batch. Args: model (torch.nn.Module): The base, pre-trained neural network model. pgm_evaluator (torch.nn.Module): The PGM evaluator for loss calculation. loss_fn (callable): The loss function used during refinement. optimizer_cls (Type[torch.optim.Optimizer]): The class of the optimizer to use for refinement (e.g., torch.optim.Adam). refinement_lr (float): The learning rate for the test-time optimizer. refinement_steps (int): The number of optimization steps to perform per instance. device (str): The device to run inference on ('cpu' or 'cuda'). References ---------- Arya, S., Rahman, T., & Gogate, V. G. (2024). A neural network approach for efficiently answering most probable explanation queries in probabilistic models. NeurIPS 2024. https://openreview.net/forum?id=ufPPf9ghzP """ def __init__( self, model: torch.nn.Module, pgm_evaluator: torch.nn.Module, loss_fn: callable, optimizer_cls: Type[torch.optim.Optimizer], discretizer: torch.nn.Module, refinement_lr: float, refinement_steps: int, device: str = "cpu", ): self.base_model = model.to(device) self.base_model.eval() self.pgm_evaluator = pgm_evaluator.to(device) self.loss_fn = loss_fn self.optimizer_cls = optimizer_cls self.discretizer = discretizer self.refinement_lr = refinement_lr self.refinement_steps = refinement_steps self.device = device # The @torch.no_grad() decorator has been removed from the run method.
[docs] def run(self, data) -> Dict[str, torch.Tensor]: """ Performs test-time refinement for each batch in the dataloader or a single batch. Args: data: Either a DataLoader or a tuple of (evidence_data, evidence_mask, query_mask, unobs_mask). Returns: Dict[str, torch.Tensor]: Final assignments. """ all_final_assignments: List[torch.Tensor] = [] # Determine if input is DataLoader or a single batch if isinstance(data, DataLoader): batch_iter = data elif isinstance(data, (tuple, list)) and len(data) == 4: batch_iter = [data] else: raise ValueError("Input must be a DataLoader or a tuple of four tensors.") for batch_data in batch_iter: final_assignment, _ = self._process_batch(batch_data) all_final_assignments.append(final_assignment.cpu()) results = {"final_assignments": torch.cat(all_final_assignments, dim=0).int()} return results
def _process_batch(self, batch_data, trained_model=None): """ Process a single batch of data. Args: batch_data: A tuple of (evidence_data, evidence_mask, query_mask, unobs_mask). trained_model: A model that has been trained on a subset of the data. Returns: final_assignment: The final assignment for the batch. temp_model: The model that has been trained on the batch. """ evidence_data, evidence_mask, query_mask, unobs_mask = batch_data if evidence_data.device != self.device: evidence_data = evidence_data.to(self.device) evidence_mask = evidence_mask.to(self.device) query_mask = query_mask.to(self.device) unobs_mask = unobs_mask.to(self.device) if trained_model is None: temp_model = copy.deepcopy(self.base_model) else: temp_model = trained_model temp_model.train() optimizer = self.optimizer_cls(temp_model.parameters(), lr=self.refinement_lr) for _ in range(self.refinement_steps): optimizer.zero_grad() raw_preds = temp_model(evidence_data, evidence_mask, query_mask, unobs_mask) prob_preds = torch.sigmoid(raw_preds) final_assigns = apply_evidence(prob_preds, evidence_data, evidence_mask) loss = self.loss_fn(final_assigns, self.pgm_evaluator) loss.backward() optimizer.step() with torch.no_grad(): temp_model.eval() final_raw = temp_model(evidence_data, evidence_mask, query_mask, unobs_mask) final_prob = torch.sigmoid(final_raw) final_assignment = apply_evidence(final_prob, evidence_data, evidence_mask) final_assignment = self.discretizer(final_assignment) return final_assignment, temp_model