NeuPI: Inference and Test-Time Refinement¶

This notebook demonstrates the final step in the NeuPI pipeline: using a trained neural solver to perform Most Probable Explanation (MPE) (or MMAP) inference. We will explore two methods:

  1. ``SinglePassInferenceEngine``: A fast method that performs a single forward pass of the neural network to get the MPE assignments.

  2. ``ITSELF_Engine``: An advanced method that performs test-time refinement. It uses the PGM’s feedback to fine-tune the model on each specific inference instance, often leading to significantly better results.

We will cover:

  1. Setting up a pre-trained model (recapping the updated training process from Notebook 2).

  2. Creating a new dataset for inference.

  3. Running the SinglePassInferenceEngine and evaluating its results.

  4. Running the ITSELF_Engine to refine the predictions.

  5. Comparing the log-likelihood scores to demonstrate the improvement from ITSELF.

Setup¶

We import all necessary components.

[1]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from pathlib import Path
import os

# Import neupi components
from neupi import (
    MLP,
    MarkovNetwork,
    SelfSupervisedTrainer,
    mpe_log_likelihood_loss,
    DiscreteEmbedder,
    ThresholdDiscretizer,
    SinglePassInferenceEngine,
    ITSELF_Engine,
)

# Define the device for computation
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# --- Path Setup ---
UAI_PATH = Path("networks") / "mn" / "Grids_17.uai"

print(f"Markov Network path: {UAI_PATH}")
assert UAI_PATH.exists(), f"File not found: {UAI_PATH}"
Using device: cpu
Markov Network path: networks/mn/Grids_17.uai

Step 1: Recap - Get a Pre-Trained Model¶

For this notebook to be self-contained, we’ll quickly train a model, incorporating the library’s recent updates (DiscreteEmbedder and the new data format). This will provide the trained_model we need for inference.

[2]:
# Load the PGM evaluator
mn_evaluator = MarkovNetwork(uai_file=str(UAI_PATH), device=DEVICE)
num_vars = mn_evaluator.num_variables

# Create a dummy training dataloader
num_samples_train = 64
evidence_data_train = torch.randint(
    0, 2, (num_samples_train, num_vars), device=DEVICE, dtype=torch.float32
)
evidence_mask_train = torch.rand(num_samples_train, num_vars, device=DEVICE) > 0.5
query_mask_train = ~evidence_mask_train
unobs_mask_train = torch.zeros_like(evidence_mask_train, dtype=torch.bool)
train_dataset = TensorDataset(
    evidence_data_train, evidence_mask_train, query_mask_train, unobs_mask_train
)
train_dataloader = DataLoader(train_dataset, batch_size=16)

# Setup model with the new DiscreteEmbedder
embedding = DiscreteEmbedder(num_vars)
model = MLP(hidden_sizes=[32, 16], output_size=num_vars, embedding=embedding).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Setup trainer
trainer = SelfSupervisedTrainer(
    model=model,
    pgm_evaluator=mn_evaluator,
    loss_fn=mpe_log_likelihood_loss,
    optimizer=optimizer,
    device=DEVICE,
)

# Train for a few epochs
print("Training a model for inference demonstration...")
trained_model = trainer.fit(train_dataloader, num_epochs=3)
print("Training complete. We now have a trained model.")
Using 1d factors: False
PGM is pairwise.
Training a model for inference demonstration...
Epoch 1/3, Average Loss: -1.7110
Epoch 2/3, Average Loss: -14.5424
Epoch 3/3, Average Loss: -22.0633
Training complete. We now have a trained model.

Step 2: Create an Inference DataLoader¶

Now we create a new, unseen set of inference queries. The model has not seen this data during training.

[3]:
num_samples_inf = 32

# The model takes evidence data and masks as input
evidence_data_inf = torch.randint(
    0, 2, (num_samples_inf, num_vars), device=DEVICE, dtype=torch.float32
)
evidence_mask_inf = torch.rand(num_samples_inf, num_vars, device=DEVICE) > 0.5
query_mask_inf = ~evidence_mask_inf
unobs_mask_inf = torch.zeros_like(evidence_mask_inf, dtype=torch.bool)

inf_dataset = TensorDataset(evidence_data_inf, evidence_mask_inf, query_mask_inf, unobs_mask_inf)
inf_dataloader = DataLoader(inf_dataset, batch_size=8)

print(f"Created an inference DataLoader with {len(inf_dataset)} samples.")
Created an inference DataLoader with 32 samples.

Step 3: Single-Pass Inference¶

We first use the SinglePassInferenceEngine. It runs the model once, applies a discretizer to get binary assignments, and returns the result. This is the fastest method.

[4]:
# A discretizer is needed to convert the model's continuous outputs (probabilities) into binary assignments.
discretizer = ThresholdDiscretizer(threshold=0.5)

simple_inference_engine = SinglePassInferenceEngine(
    model=trained_model, discretizer=discretizer, device=DEVICE
)

print("Running single-pass inference...")
initial_results = simple_inference_engine.run(inf_dataloader)
initial_assignments = initial_results["final_assignments"].to(DEVICE)

# Evaluate the quality of these assignments using the PGM
with torch.no_grad():
    initial_ll = mn_evaluator(initial_assignments).mean()

print(f"Single-Pass Avg Log-Likelihood: {initial_ll.item():.4f}")
Running single-pass inference...
Single-Pass Avg Log-Likelihood: 7.2980

Step 4: ITSELF Inference¶

Now, we use the ITSELF_Engine. For each batch of data, it performs several optimization steps, fine-tuning the model’s prediction specifically for that data. This test-time adaptation leverages the PGM evaluator to find better solutions.

[5]:
itself_engine = ITSELF_Engine(
    model=trained_model,
    pgm_evaluator=mn_evaluator,
    loss_fn=mpe_log_likelihood_loss,
    optimizer_cls=torch.optim.Adam,  # The optimizer to use for refinement
    discretizer=discretizer,
    refinement_lr=1e-3,  # Learning rate for the refinement steps
    refinement_steps=5,  # Number of refinement steps per instance
    device=DEVICE,
)

print("Running ITSELF inference with test-time refinement...")
refined_results = itself_engine.run(inf_dataloader)
refined_assignments = refined_results["final_assignments"].to(DEVICE)

# Evaluate the quality of the refined assignments
with torch.no_grad():
    refined_ll = mn_evaluator(refined_assignments).mean()

print(f"ITSELF Refined Avg Log-Likelihood: {refined_ll.item():.4f}")
Running ITSELF inference with test-time refinement...
ITSELF Refined Avg Log-Likelihood: 207.5446

Step 5: Comparison and Conclusion¶

Finally, we compare the average log-likelihoods. A higher (less negative) log-likelihood indicates a better solution to the MPE problem.

[6]:
print(f"Initial Avg Log-Likelihood (Single Pass): {initial_ll.item():.4f}")
print(f"Refined Avg Log-Likelihood (ITSELF):     {refined_ll.item():.4f}")

improvement = refined_ll - initial_ll
print(f"\nImprovement from ITSELF: {improvement.item():.4f}")

assert refined_ll > initial_ll, "ITSELF failed to improve the log-likelihood!"
print("\nSuccessfully demonstrated that ITSELF improves inference quality.")
Initial Avg Log-Likelihood (Single Pass): 7.2980
Refined Avg Log-Likelihood (ITSELF):     207.5446

Improvement from ITSELF: 200.2466

Successfully demonstrated that ITSELF improves inference quality.