Source code for neupi.discretize.threshold

import torch

from neupi.core.discretizer import BaseDiscretizer
from neupi.registry import register


[docs] @register("discretizer") class ThresholdDiscretizer(BaseDiscretizer): """ Discretizes a tensor of probabilities to binary assignments based on a threshold. This class acts as a callable function. An instance can be passed to an inference engine and called to perform the discretization. Args: threshold (float): The threshold value. Values >= threshold will be 1, and values < threshold will be 0. Defaults to 0.5. References ---------- Arya, S., Rahman, T., & Gogate, V. (2024). Learning to Solve the Constrained Most Probable Explanation Task in Probabilistic Graphical Models. Proceedings of The 27th International Conference on Artificial Intelligence and Statistics. International Conference on Artificial Intelligence and Statistics, PMLR, pp. 2791–2799. https://proceedings.mlr.press/v238/arya24b.html Arya, S., Rahman, T., & Gogate, V. (2024). Neural Network Approximators for Marginal MAP in Probabilistic Circuits. Proceedings of the AAAI Conference on Artificial Intelligence, 38(10), 10918–10926. https://doi.org/10.1609/aaai.v38i10.28966 Arya, S., Rahman, T., & Gogate, V. G. (2024). A neural network approach for efficiently answering most probable explanation queries in probabilistic models. NeurIPS 2024. https://openreview.net/forum?id=ufPPf9ghzP """ def __init__(self, threshold: float = 0.5): super().__init__() if not 0.0 <= threshold <= 1.0: raise ValueError("Threshold must be between 0.0 and 1.0.") self.threshold = threshold def __call__( self, prob_outputs: torch.Tensor, query_mask: torch.Tensor = None, evidence_mask: torch.Tensor = None, unobs_mask: torch.Tensor = None, ) -> torch.Tensor: """ Converts a tensor of probabilities to binary assignments. Args: prob_outputs (torch.Tensor): A tensor of probabilities, typically the output of a sigmoid function. Returns: torch.Tensor: A tensor of binary assignments (0s and 1s) with the same shape as the input, on the same device. """ return (prob_outputs >= self.threshold).to( dtype=prob_outputs.dtype, device=prob_outputs.device )