Source code for neupi.discretize.oauai

import itertools
from typing import Callable

import torch

from neupi.core.discretizer import BaseDiscretizer
from neupi.registry import register


[docs] @register("discretizer_oauai") class OAUAI(BaseDiscretizer): """ Discretizes by performing an exhaustive search over the k most uncertain variables. This method identifies the 'k' query variables with probabilities closest to 0.5, generates all 2^k possible binary assignments for this subset, scores each one using the PGM evaluator, and selects the best assignment. Note: This is a naive oracle that looks at all 2^k possible assignments and selects the best one. Other oracles can also be used (such as daoopt). Args: pgm_evaluator (Callable): The PGM evaluator (e.g., MarkovNetwork) which acts as the scoring function. k (int): The number of most uncertain query variables to search over. threshold (float): The baseline threshold for comparison. Defaults to 0.5. References ---------- Arya, Shivvrat, Rahman, Tahrima, and Gogate, Vibhav Giridhar. SINE: Scalable MPE Inference for Probabilistic Graphical Models Using Advanced Neural Embeddings. Proceedings of the 28th International Conference on Artificial Intelligence and Statistics (AISTATS), 2025. """ def __init__(self, pgm_evaluator: Callable, k: int, threshold: float = 0.5): super().__init__() if k > 10: print(f"Warning: k={k} may lead to slow performance (2^{k} evaluations per sample).") self.pgm_evaluator = pgm_evaluator self.k = k self.threshold = threshold @torch.no_grad() def __call__( self, prob_outputs: torch.Tensor, query_mask: torch.Tensor, evidence_mask: torch.Tensor, unobs_mask: torch.Tensor, ) -> torch.Tensor: """ Performs the OAUAI discretization for each instance in the batch. . """ num_examples, num_vars = prob_outputs.shape device = prob_outputs.device dtype = prob_outputs.dtype # Pre-generate all binary assignments for the k variables binary_assignments_k = torch.tensor( list(itertools.product([0, 1], repeat=self.k)), dtype=dtype, device=device ) num_candidates = binary_assignments_k.shape[0] # Get baseline assignments and scores from simple thresholding final_assignments = (prob_outputs >= self.threshold).to(dtype) final_scores = self.pgm_evaluator(final_assignments) for i in range(num_examples): # Isolate the i-th sample sample_probs = prob_outputs[i] sample_query_mask = query_mask[i] # Find the k most uncertain query variables certainty = torch.abs(sample_probs - 0.5) # Mask out non-query variables so they are not selected certainty[~sample_query_mask] = float("inf") _, top_k_indices = torch.topk(certainty, self.k, largest=False) # Create candidate assignments by modifying the thresholded base base_assignment = final_assignments[i].unsqueeze(0).repeat(num_candidates, 1) base_assignment[:, top_k_indices] = binary_assignments_k # Score all candidates and find the best one candidate_scores = self.pgm_evaluator(base_assignment) best_candidate_score, best_candidate_idx = torch.max(candidate_scores, dim=0) # If the best candidate is better than the simple thresholded result, use it if best_candidate_score > final_scores[i]: final_assignments[i] = base_assignment[best_candidate_idx] return final_assignments