LENs tutorial

For this simple tutorial, let’s solve the XOR problem (augmented with 100 dummy features):

import torch
import torch_explain as te

x0 = torch.zeros((4, 100))
x_train = torch.tensor([
    [0, 0],
    [0, 1],
    [1, 0],
    [1, 1],
], dtype=torch.float)
x_train = torch.cat([x_train, x0], dim=1)
y_train = torch.tensor([0, 1, 1, 0], dtype=torch.long)

We can instantiate a simple feed-forward neural network with 3 layers using the EntropyLayer as the first one:

layers = [
    te.nn.EntropyLinear(x_train.shape[1], 10, n_classes=2),
    torch.nn.Linear(10, 4),
    torch.nn.Linear(4, 1),
model = torch.nn.Sequential(*layers)

We can now train the network by optimizing the cross entropy loss and the entropy_logic_loss loss function incorporating the human prior towards simple explanations:

optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss_form = torch.nn.CrossEntropyLoss()
for epoch in range(1001):
    y_pred = model(x_train).squeeze(-1)
    loss = loss_form(y_pred, y_train) + /
           0.00001 * te.nn.functional.entropy_logic_loss(model)

Once trained we can extract first-order logic formulas describing how the network composed the input features to obtain the predictions:

from torch_explain.logic import explain_class
from torch.nn.functional import one_hot

y1h = one_hot(y_train)
explanation, _ = explain_class(model, x_train, y1h, x_train, y1h, target_class=1)

Explanations will be logic formulas in disjunctive normal form. In this case, the explanation will be y=1 IFF (f1 AND ~f2) OR (f2  AND ~f1) corresponding to y=1 IFF f1 XOR f2.