import torch
from torch import nn
from neupi.core.model import BaseProbModel
from neupi.registry import register
from ..io.uai_reader_cython import UAIParser
[docs]
@register("prob_model")
class MarkovNetwork(BaseProbModel):
"""
A PyTorch module to evaluate the log-likelihood of assignments in a binary Markov Network.
This class loads a Markov Network from a .uai file and provides an efficient,
vectorized method to compute the log-likelihood for a batch of variable assignments.
It supports both pairwise and higher-order factor models.
Args:
uai_file (str): Path to the .uai file defining the Markov Network.
device (str or torch.device): The device to perform computations on ('cpu' or 'cuda').
Note: For now, models from the UAI format are supported. Other formats can be supported in by rewriting the UAIParser class.
Example:
```
from neupi.pm.pgm.mn import MarkovNetwork
mn = MarkovNetwork("path/to/mn.uai")
```
"""
def __init__(self, uai_file: str, device: str = "cpu"):
super().__init__()
self.device = torch.device(device)
if not uai_file.endswith(".uai"):
raise ValueError("Only .uai format is supported for Markov Networks.")
self.uai_file = uai_file
with open(uai_file, "r") as f:
file_content = f.read()
self.pgm = UAIParser(model_str=file_content, one_d_factors=0, device=self.device)
self.num_variables = self.pgm.num_vars
if self.pgm.network_type != "MARKOV":
raise ValueError("The .uai file must define a MARKOV network.")
# Select the appropriate evaluation method based on factor complexity
if self.pgm.pairwise_only:
self.evaluate = self._evaluate_pairwise
else:
self._precompute_for_higher_order()
self.evaluate = self._evaluate_higher_order
def _precompute_for_higher_order(self):
"""Precomputes binary combinations for efficient higher-order evaluation."""
self.precomputed_data = {}
for size, clique_class in self.pgm.clique_dict_class.items():
binary_combinations = torch.tensor(
[[(j >> k) & 1 for k in range(size - 1, -1, -1)] for j in range(2**size)],
dtype=torch.float32,
device=self.device,
)
self.precomputed_data[size] = {
"binary_combinations": binary_combinations,
"all_vars": clique_class.variables,
"all_factors": clique_class.tables,
}
def _compute_clique_scores(self, x, binary_combinations, all_vars, all_factors):
"""Computes scores for a batch of assignments for a given clique size."""
# Select the values of variables involved in the cliques
all_values = x[:, all_vars.flatten()].view(x.shape[0], all_vars.shape[0], all_vars.shape[1])
# Match assignments with binary combinations to find the right factor entry
selected_values = all_values.unsqueeze(1) * binary_combinations.unsqueeze(0).unsqueeze(
2
) + (1 - all_values.unsqueeze(1)) * (1 - binary_combinations.unsqueeze(0).unsqueeze(2))
product_term = torch.prod(selected_values, dim=3)
all_factors_flat = all_factors.view(all_factors.shape[0], -1)
# Sum the log-potentials from the correct factor entries
scores = torch.sum(product_term * all_factors_flat.permute(1, 0).unsqueeze(0), dim=1)
return scores
def _evaluate_higher_order(self, x: torch.Tensor) -> torch.Tensor:
"""
Evaluates log-likelihood for models with higher-order factors.
Args:
x (torch.Tensor): A binary tensor of assignments of shape (batch_size, num_variables).
Returns:
torch.Tensor: A tensor of log-likelihood scores of shape (batch_size,).
"""
x = x.to(self.device)
ll_scores = torch.zeros(x.shape[0], device=self.device)
if x.shape[1] != self.pgm.num_vars:
raise ValueError("Input dimension does not match the number of variables in the model.")
for size, data in self.precomputed_data.items():
clique_scores = self._compute_clique_scores(
x, data["binary_combinations"], data["all_vars"], data["all_factors"]
)
ll_scores += torch.sum(clique_scores, dim=1)
return ll_scores
def _evaluate_pairwise(self, x: torch.Tensor) -> torch.Tensor:
"""
Evaluates log-likelihood for models with only pairwise and unary factors.
Args:
x (torch.Tensor): A binary tensor of assignments of shape (batch_size, num_variables).
Returns:
torch.Tensor: A tensor of log-likelihood scores of shape (batch_size,).
"""
x = x.to(self.device)
# Unary factor contributions
univariate_contrib = (1 - x[:, self.pgm.univariate_vars]) * self.pgm.univariate_tables[
:, 0
] + x[:, self.pgm.univariate_vars] * self.pgm.univariate_tables[:, 1]
# Pairwise factor contributions
x_biv_0 = x[:, self.pgm.bivariate_vars[:, 0]]
x_biv_1 = x[:, self.pgm.bivariate_vars[:, 1]]
bivariate_contrib = (
(1 - x_biv_0) * (1 - x_biv_1) * self.pgm.bivariate_tables[:, 0, 0].unsqueeze(0)
+ (1 - x_biv_0) * x_biv_1 * self.pgm.bivariate_tables[:, 0, 1].unsqueeze(0)
+ x_biv_0 * (1 - x_biv_1) * self.pgm.bivariate_tables[:, 1, 0].unsqueeze(0)
+ x_biv_0 * x_biv_1 * self.pgm.bivariate_tables[:, 1, 1].unsqueeze(0)
)
total_log_likelihood = torch.sum(univariate_contrib, dim=1) + torch.sum(
bivariate_contrib, dim=1
)
return total_log_likelihood
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Alias for the evaluate method.
Args:
x (torch.Tensor): A binary tensor of assignments of shape (batch_size, num_variables).
Returns:
torch.Tensor: A tensor of log-likelihood scores of shape (batch_size,).
"""
return self.evaluate(x)