Source code for neupi.embedding.discrete
import torch
from neupi.core.embedding import Embedding
from neupi.registry import register
[docs]
@register("embedding")
class DiscreteEmbedder(Embedding):
"""
Creates feature embeddings from variable assignments and bucket information.
This preprocessor converts a tensor of binary variable assignments into a
new feature space suitable for input to a neural network. It expands each
variable `v` into a two-dimensional representation `[v, 1-v]` and then
replaces the features for query and unobserved variables with specific
embedding values.
Args:
query_embedding (float): The value to use for query variables. Defaults to 0.0.
unobserved_embedding (float): The value to use for unobserved variables. Defaults to 1.0.
"""
def __init__(
self, num_vars: int, query_embedding: float = 0.0, unobserved_embedding: float = 1.0
):
self.num_vars = num_vars
self.embedding_size = num_vars * 2
self.query_embedding = query_embedding
self.unobserved_embedding = unobserved_embedding
def __call__(
self,
evidence_data: torch.Tensor,
evidence_mask: torch.Tensor,
query_mask: torch.Tensor,
unobs_mask: torch.Tensor,
) -> torch.Tensor:
"""
Processes the assignments to create embeddings.
Args:
assignments (torch.Tensor): A tensor of binary assignments.
Shape: (num_vars,) or (batch_size, num_vars).
buckets (Dict[str, torch.Tensor]): A dictionary containing boolean masks for different
variable types. Expected keys: 'query', 'unobs'.
Returns:
torch.Tensor: The embedded feature tensor.
Shape: (num_vars*2,) or (batch_size, num_vars*2).
"""
is_batch = evidence_data.dim() == 2
if not is_batch:
# Add a temporary batch dimension for consistent processing
evidence_data = evidence_data.unsqueeze(0)
num_samples, n_vars = evidence_data.size()
device = evidence_data.device
dtype = evidence_data.dtype
# Create the expanded feature tensor [v, 1-v]
embedded_features = torch.zeros(num_samples, n_vars * 2, dtype=dtype, device=device)
embedded_features[:, 0::2] = evidence_data # Even indices get the original value
embedded_features[:, 1::2] = 1 - evidence_data # Odd indices get 1 - value
# Apply embeddings for query and unobserved variables
if query_mask is not None:
embedded_features[:, 0::2][query_mask] = self.query_embedding
embedded_features[:, 1::2][query_mask] = self.query_embedding
if unobs_mask is not None:
embedded_features[:, 0::2][unobs_mask] = self.unobserved_embedding
embedded_features[:, 1::2][unobs_mask] = self.unobserved_embedding
if not is_batch:
# Remove the temporary batch dimension if the input was a single instance
return embedded_features.squeeze(0)
return embedded_features