NeuPI: Inference and Test-Time Refinement¶
This notebook demonstrates the final step in the NeuPI pipeline: using a trained neural solver to perform Most Probable Explanation (MPE) (or MMAP) inference. We will explore two methods:
``SinglePassInferenceEngine``: A fast method that performs a single forward pass of the neural network to get the MPE assignments.
``ITSELF_Engine``: An advanced method that performs test-time refinement. It uses the PGM’s feedback to fine-tune the model on each specific inference instance, often leading to significantly better results.
We will cover:
Setting up a pre-trained model (recapping the updated training process from Notebook 2).
Creating a new dataset for inference.
Running the
SinglePassInferenceEngineand evaluating its results.Running the
ITSELF_Engineto refine the predictions.Comparing the log-likelihood scores to demonstrate the improvement from ITSELF.
Setup¶
We import all necessary components.
[1]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from pathlib import Path
import os
# Import neupi components
from neupi import (
MLP,
MarkovNetwork,
SelfSupervisedTrainer,
mpe_log_likelihood_loss,
DiscreteEmbedder,
ThresholdDiscretizer,
SinglePassInferenceEngine,
ITSELF_Engine,
)
# Define the device for computation
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
# --- Path Setup ---
UAI_PATH = Path("networks") / "mn" / "Grids_17.uai"
print(f"Markov Network path: {UAI_PATH}")
assert UAI_PATH.exists(), f"File not found: {UAI_PATH}"
Using device: cpu
Markov Network path: networks/mn/Grids_17.uai
Step 1: Recap - Get a Pre-Trained Model¶
For this notebook to be self-contained, we’ll quickly train a model, incorporating the library’s recent updates (DiscreteEmbedder and the new data format). This will provide the trained_model we need for inference.
[2]:
# Load the PGM evaluator
mn_evaluator = MarkovNetwork(uai_file=str(UAI_PATH), device=DEVICE)
num_vars = mn_evaluator.num_variables
# Create a dummy training dataloader
num_samples_train = 64
evidence_data_train = torch.randint(
0, 2, (num_samples_train, num_vars), device=DEVICE, dtype=torch.float32
)
evidence_mask_train = torch.rand(num_samples_train, num_vars, device=DEVICE) > 0.5
query_mask_train = ~evidence_mask_train
unobs_mask_train = torch.zeros_like(evidence_mask_train, dtype=torch.bool)
train_dataset = TensorDataset(
evidence_data_train, evidence_mask_train, query_mask_train, unobs_mask_train
)
train_dataloader = DataLoader(train_dataset, batch_size=16)
# Setup model with the new DiscreteEmbedder
embedding = DiscreteEmbedder(num_vars)
model = MLP(hidden_sizes=[32, 16], output_size=num_vars, embedding=embedding).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Setup trainer
trainer = SelfSupervisedTrainer(
model=model,
pgm_evaluator=mn_evaluator,
loss_fn=mpe_log_likelihood_loss,
optimizer=optimizer,
device=DEVICE,
)
# Train for a few epochs
print("Training a model for inference demonstration...")
trained_model = trainer.fit(train_dataloader, num_epochs=3)
print("Training complete. We now have a trained model.")
Using 1d factors: False
PGM is pairwise.
Training a model for inference demonstration...
Epoch 1/3, Average Loss: -1.7110
Epoch 2/3, Average Loss: -14.5424
Epoch 3/3, Average Loss: -22.0633
Training complete. We now have a trained model.
Step 2: Create an Inference DataLoader¶
Now we create a new, unseen set of inference queries. The model has not seen this data during training.
[3]:
num_samples_inf = 32
# The model takes evidence data and masks as input
evidence_data_inf = torch.randint(
0, 2, (num_samples_inf, num_vars), device=DEVICE, dtype=torch.float32
)
evidence_mask_inf = torch.rand(num_samples_inf, num_vars, device=DEVICE) > 0.5
query_mask_inf = ~evidence_mask_inf
unobs_mask_inf = torch.zeros_like(evidence_mask_inf, dtype=torch.bool)
inf_dataset = TensorDataset(evidence_data_inf, evidence_mask_inf, query_mask_inf, unobs_mask_inf)
inf_dataloader = DataLoader(inf_dataset, batch_size=8)
print(f"Created an inference DataLoader with {len(inf_dataset)} samples.")
Created an inference DataLoader with 32 samples.
Step 3: Single-Pass Inference¶
We first use the SinglePassInferenceEngine. It runs the model once, applies a discretizer to get binary assignments, and returns the result. This is the fastest method.
[4]:
# A discretizer is needed to convert the model's continuous outputs (probabilities) into binary assignments.
discretizer = ThresholdDiscretizer(threshold=0.5)
simple_inference_engine = SinglePassInferenceEngine(
model=trained_model, discretizer=discretizer, device=DEVICE
)
print("Running single-pass inference...")
initial_results = simple_inference_engine.run(inf_dataloader)
initial_assignments = initial_results["final_assignments"].to(DEVICE)
# Evaluate the quality of these assignments using the PGM
with torch.no_grad():
initial_ll = mn_evaluator(initial_assignments).mean()
print(f"Single-Pass Avg Log-Likelihood: {initial_ll.item():.4f}")
Running single-pass inference...
Single-Pass Avg Log-Likelihood: 7.2980
Step 4: ITSELF Inference¶
Now, we use the ITSELF_Engine. For each batch of data, it performs several optimization steps, fine-tuning the model’s prediction specifically for that data. This test-time adaptation leverages the PGM evaluator to find better solutions.
[5]:
itself_engine = ITSELF_Engine(
model=trained_model,
pgm_evaluator=mn_evaluator,
loss_fn=mpe_log_likelihood_loss,
optimizer_cls=torch.optim.Adam, # The optimizer to use for refinement
discretizer=discretizer,
refinement_lr=1e-3, # Learning rate for the refinement steps
refinement_steps=5, # Number of refinement steps per instance
device=DEVICE,
)
print("Running ITSELF inference with test-time refinement...")
refined_results = itself_engine.run(inf_dataloader)
refined_assignments = refined_results["final_assignments"].to(DEVICE)
# Evaluate the quality of the refined assignments
with torch.no_grad():
refined_ll = mn_evaluator(refined_assignments).mean()
print(f"ITSELF Refined Avg Log-Likelihood: {refined_ll.item():.4f}")
Running ITSELF inference with test-time refinement...
ITSELF Refined Avg Log-Likelihood: 207.5446
Step 5: Comparison and Conclusion¶
Finally, we compare the average log-likelihoods. A higher (less negative) log-likelihood indicates a better solution to the MPE problem.
[6]:
print(f"Initial Avg Log-Likelihood (Single Pass): {initial_ll.item():.4f}")
print(f"Refined Avg Log-Likelihood (ITSELF): {refined_ll.item():.4f}")
improvement = refined_ll - initial_ll
print(f"\nImprovement from ITSELF: {improvement.item():.4f}")
assert refined_ll > initial_ll, "ITSELF failed to improve the log-likelihood!"
print("\nSuccessfully demonstrated that ITSELF improves inference quality.")
Initial Avg Log-Likelihood (Single Pass): 7.2980
Refined Avg Log-Likelihood (ITSELF): 207.5446
Improvement from ITSELF: 200.2466
Successfully demonstrated that ITSELF improves inference quality.