Source code for neupi.training.pm_ssl.nam.made

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# ------------------------------------------------------------------------------
from neupi.core.model import BaseProbModel
from neupi.registry import register


@register("nn_model")
class MaskedLinear(BaseProbModel):
    """same as Linear except has a configurable mask on the weights"""

    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features, bias)
        self.register_buffer("mask", torch.ones(out_features, in_features))

    def set_mask(self, mask):
        self.mask.data.copy_(torch.from_numpy(mask.astype(np.uint8).T))

    def forward(self, input):
        return F.linear(input, self.mask * self.weight, self.bias)


[docs] class MADE(nn.Module): def __init__(self, nin, hidden_sizes, nout, num_masks=1, natural_ordering=False): """ nin: integer; number of inputs hidden sizes: a list of integers; number of units in hidden layers nout: integer; number of outputs, which usually collectively parameterize some kind of 1D distribution note: if nout is e.g. 2x larger than nin (perhaps the mean and std), then the first nin will be all the means and the second nin will be stds. i.e. output dimensions depend on the same input dimensions in "chunks" and should be carefully decoded downstream appropriately. the output of running the tests for this file makes this a bit more clear with examples. num_masks: can be used to train ensemble over orderings/connections natural_ordering: force natural ordering of dimensions, don't use random permutations """ super().__init__() self.nin = nin self.nout = nout self.hidden_sizes = hidden_sizes self.sigmoid = nn.Sigmoid() assert self.nout % self.nin == 0, "nout must be integer multiple of nin" # define a simple MLP neural net self.net = [] hs = [nin] + hidden_sizes + [nout] for h0, h1 in zip(hs, hs[1:]): self.net.extend( [ MaskedLinear(h0, h1), nn.ReLU(), ] ) self.net.pop() # pop the last ReLU for the output layer self.net = nn.Sequential(*self.net) # seeds for orders/connectivities of the model ensemble self.natural_ordering = natural_ordering self.num_masks = num_masks self.seed = 0 # for cycling through num_masks orderings self.m = {} self.update_masks()
[docs] def update_masks(self): if self.m and self.num_masks == 1: return L = len(self.hidden_sizes) # fetch the next seed and construct a random stream rng = np.random.RandomState(self.seed) self.seed = (self.seed + 1) % self.num_masks # sample the order of the inputs and the connectivity of all neurons self.m[-1] = np.arange(self.nin) if self.natural_ordering else rng.permutation(self.nin) for l in range(L): self.m[l] = rng.randint(self.m[l - 1].min(), self.nin - 1, size=self.hidden_sizes[l]) # construct the mask matrices masks = [self.m[l - 1][:, None] <= self.m[l][None, :] for l in range(L)] masks.append(self.m[L - 1][:, None] < self.m[-1][None, :]) # handle the case where nout = nin * k, for integer k > 1 if self.nout > self.nin: k = int(self.nout / self.nin) # replicate the mask across the other outputs masks[-1] = np.concatenate([masks[-1]] * k, axis=1) # set the masks in all MaskedLinear layers layers = [l for l in self.net.modules() if isinstance(l, MaskedLinear)] for l, m in zip(layers, masks): l.set_mask(m)
[docs] def forward(self, x): return self.net(x)
[docs] def evaluate(self, x): distributions = self(x) # This is a more numerically stable implementation of the log likelihood - uses logsumexp trick log_likelihoods = -(F.binary_cross_entropy_with_logits(distributions, x, reduction="none")) log_likelihoods = torch.sum(log_likelihoods, dim=1) return log_likelihoods