Source code for neupi.training.pm_ssl.pc.spn
import json
import torch
from neupi.core.model import BaseProbModel
from neupi.registry import register
from torch import nn
from .spn_utils import (
BernoulliLeaf,
ProductNode,
SumNode,
get_distributions,
preprocess_links,
)
[docs]
@register("prob_model")
class SumProductNetwork(BaseProbModel):
def __init__(self, json_file, device="cpu"):
"""
Initialize the SumProductNetwork with configuration from a JSON file.
:param json_file: Path to a JSON file containing the model configuration.
:param num_var: Number of variables in the model.
:param device: Device to perform computations ('cpu' or 'cuda:0').
:param depth_features: Depth of the features to be used for the input features of NN (defaults to 1).
:param approx: Boolean flag for approximation method usage (defaults to False).
Note: For now, models trained with DeeProb-kit are supported. (https://github.com/deeprob-org/deeprob-kit). Other libraries can be supported in by rewriting the get_distributions and preprocess_links functions.
Example:
```
from neupi.pm.pc.spn import SumProductNetwork
spn = SumProductNetwork("path/to/spn.json")
```
"""
super(SumProductNetwork, self).__init__()
with open(json_file) as f:
data = json.load(f)
self.json_file = json_file
self.num_var = 0
self.num_nodes_in_spn: int = len(data["nodes"])
self.all_nodes = torch.arange(self.num_nodes_in_spn, dtype=torch.long)
self.eps = 1e-6
self.device = device
self.node_types = {"Sum": [], "Product": [], "Bernoulli": []}
edge_parent, edge_child_start_index = preprocess_links(data)
self.edge_parent = edge_parent.to(device)
self.edge_child_start_index = edge_child_start_index.to(device)
distributions = get_distributions(data, self)
self.all_distributions = nn.ModuleList(list(distributions.values()))
# Convert node types to tensors
for each_node_type in self.node_types:
self.node_types[each_node_type] = torch.tensor(
self.node_types[each_node_type], dtype=torch.long, device=self.device
)
def _get_children_indices(self, node_idx: int):
start_idx = self.edge_child_start_index[node_idx]
end_idx = self.edge_child_start_index[node_idx + 1]
return self.all_nodes[start_idx + 1 : end_idx + 1]
def _check_is_binary(self, tensor):
# Check if all elements are either 0 or 1
return ((tensor == 0) | (tensor == 1)).all()
[docs]
def evaluate(self, x):
"""
Evaluate the SPN model on input data.
:param x: Tensor representing input data (batch_size, input_size).
:return: Function value at the root of the SPN.
"""
assert self.num_var == x.size(1), "Input size must match the number of variables"
# the shape of function_values_at_each_index is (num_var, batch_size)
function_values_at_each_index = torch.empty(
(self.num_nodes_in_spn, x.size(0)), device=self.device
)
# Make sure there are no -1s in the input
assert torch.sum(x == -1) == 0, "Input cannot contain -1s"
# Evaluate each node in reverse order (from leaves to root)
for node_idx in reversed(range(self.num_nodes_in_spn)):
func_value = self._evaluate_node(
node_idx,
x,
function_values_at_each_index,
)
function_values_at_each_index[node_idx] = func_value
return function_values_at_each_index[0]
def _get_node_type_and_children_indices(self, node_idx: int):
"""
Get the type and children indices of a node.
:param node_idx: Index of the node.
:return: Tuple of node type and tensor of children indices.
"""
children_indices = self._get_children_indices(node_idx)
return children_indices
def _evaluate_node(self, node_idx: int, x, function_values_at_each_index):
"""
Evaluate a single node in the SPN.
:param node_type: Type of the node (Sum, Product, Bernoulli).
:param children_indices: Indices of the children nodes.
:param node_idx: Index of the current node.
:param x: Input data tensor.
:param function_values_at_each_index: Tensor holding values for each node.
:return: Evaluated value of the node.
"""
node = self.all_distributions[node_idx]
if isinstance(node, ProductNode):
return node(function_values_at_each_index)
elif isinstance(node, SumNode):
return node(function_values_at_each_index)
elif isinstance(node, BernoulliLeaf):
return node(x)
else:
raise NotImplementedError("Unknown node type")
[docs]
def forward(self, x):
return self.evaluate(x)