Source code for neupi.embedding.identity

import torch

from neupi.core.embedding import Embedding
from neupi.registry import register


[docs] @register("embedding") class IdentityEmbedding(Embedding): """ An identity embedding that does not change the input. """ def __init__(self, num_vars: int): print("Initializing IdentityEmbedding") print( "This is not recommended; since it uses the complete assignment as input. Masking out query variables is recommended." ) self.num_vars = num_vars self.embedding_size = num_vars def __call__( self, evidence_data: torch.Tensor, evidence_mask: torch.Tensor, query_mask: torch.Tensor, unobs_mask: torch.Tensor, ) -> torch.Tensor: return evidence_data