Source code for neupi.discretize.kn
from typing import Callable
import torch
from neupi.core.discretizer import BaseDiscretizer
from neupi.registry import register
from .cython_kn.kn_binary_vectors import cython_process_assignments
[docs]
@register("discretizer_knn")
class KNearestDiscretizer(BaseDiscretizer):
"""
Finds k-nearest binary vectors for query variables using a scoring function.
This method generates candidate binary assignments for query variables that are
"close" to the continuous predictions, scores them using the PGM evaluator,
and selects the best one.
Args:
pgm_evaluator (Callable): The PGM evaluator (e.g., MarkovNetwork) which acts
as the scoring function.
k (int): Number of nearest binary vectors to consider.
batch_size (int): Batch size for scoring candidate assignments.
References
----------
Arya, Shivvrat, Rahman, Tahrima, and Gogate, Vibhav Giridhar. SINE: Scalable MPE Inference for Probabilistic Graphical Models Using Advanced Neural Embeddings. Proceedings of the 28th International Conference on Artificial Intelligence and Statistics (AISTATS), 2025.
"""
def __init__(self, pgm_evaluator: Callable, k: int, batch_size: int = 300):
super().__init__()
self.pgm_evaluator = pgm_evaluator
self.k = k
self.batch_size = batch_size
@torch.no_grad()
def __call__(
self,
prob_outputs: torch.Tensor,
query_mask: torch.Tensor,
evidence_mask: torch.Tensor = None,
unobs_mask: torch.Tensor = None,
) -> torch.Tensor:
"""
Selects k-nearest binary vectors for each continuous output from the NN and selects the best one.
Args:
prob_outputs (torch.Tensor): Continuous predictions from the network.
Shape: (batch_size, num_variables).
evidence_mask (torch.Tensor): Boolean mask for evidence variables.
query_mask (torch.Tensor): Boolean mask for query variables.
unobs_mask (torch.Tensor): Boolean mask for unobserved variables.
Returns:
torch.Tensor: The final discrete assignments.
Shape: (batch_size, num_variables).
"""
num_examples, num_vars = prob_outputs.shape
device = prob_outputs.device
dtype = prob_outputs.dtype
query_indices = [torch.where(qm)[0] for qm in query_mask]
best_assignments = prob_outputs.clone()
for i in range(num_examples):
query_probs_np = prob_outputs[i, query_indices[i]].detach().cpu().numpy()
# Get k candidate binary assignments from the Cython helper
_, candidate_np = cython_process_assignments(query_probs_np, self.k)
candidates = torch.tensor(candidate_np, dtype=dtype, device=device)
best_score = torch.tensor(float("-inf"), device=device)
best_local_assignment = None
# Score each candidate assignment in batches
for j in range(0, candidates.shape[0], self.batch_size):
batch_candidates = candidates[j : j + self.batch_size]
# Create a batch of full assignments, each with a different candidate
data_batch = prob_outputs[i].unsqueeze(0).repeat(len(batch_candidates), 1)
data_batch[:, query_indices[i]] = batch_candidates
scores = self.pgm_evaluator(data_batch)
max_score, max_idx = torch.max(scores, dim=0)
if max_score > best_score:
best_score = max_score
best_local_assignment = batch_candidates[max_idx]
best_assignments[i, query_indices[i]] = best_local_assignment
# Compare with simple thresholding and return the better result
thresholded_assignments = (prob_outputs >= 0.5).to(dtype)
thresholded_scores = self.pgm_evaluator(thresholded_assignments)
best_scores = self.pgm_evaluator(best_assignments)
final_assignments = torch.where(
best_scores.unsqueeze(1) > thresholded_scores.unsqueeze(1),
best_assignments,
thresholded_assignments,
)
return final_assignments