Logic Explained Network (LENs) tutorial
Entropy-based LENs
For this simple tutorial, let’s solve the XOR problem (augmented with 100 dummy features):
import torch
import torch_explain as te
from torch.nn.functional import one_hot
x0 = torch.zeros((4, 100))
x_train = torch.tensor([
[0, 0],
[0, 1],
[1, 0],
[1, 1],
], dtype=torch.float)
x_train = torch.cat([x_train, x0], dim=1)
y_train = torch.tensor([0, 1, 1, 0], dtype=torch.long)
y_train_1h = one_hot(y_train).to(torch.float)
We can instantiate a simple feed-forward neural network
with 3 layers using the EntropyLayer
as the first one:
layers = [
te.nn.EntropyLinear(x_train.shape[1], 10, n_classes=y_train_1h.shape[1]),
torch.nn.LeakyReLU(),
torch.nn.Linear(10, 4),
torch.nn.LeakyReLU(),
torch.nn.Linear(4, 1),
]
model = torch.nn.Sequential(*layers)
We can now train the network by optimizing the cross entropy loss and the
entropy_logic_loss
loss function incorporating the human prior towards
simple explanations:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
loss_form = torch.nn.BCEWithLogitsLoss()
model.train()
for epoch in range(2001):
optimizer.zero_grad()
y_pred = model(x_train).squeeze(-1)
loss = loss_form(y_pred, y_train_1h) + 0.0001 * te.nn.functional.entropy_logic_loss(model)
loss.backward()
optimizer.step()
Once trained we can extract first-order logic formulas describing how the network composed the input features to obtain the predictions:
from torch_explain.logic.nn import entropy
from torch.nn.functional import one_hot
y1h = one_hot(y_train)
global_explanations, local_explanations = entropy.explain_classes(model, x_train, y_train, c_threshold=0.5, y_threshold=0.)
Explanations will be logic formulas in disjunctive normal form.
In this case, the explanation will be y=1
if and only if (f1 AND ~f2) OR (f2 AND ~f1)
corresponding to f1 XOR f2
.
The function automatically assesses the quality of logic explanations in terms of classification accuracy and rule complexity. In this case the accuracy is 100% and the complexity is 4.
\(\psi\) LENs
For this simple tutorial, let’s solve the XOR problem using a \(\psi\) LEN:
import torch
import torch_explain as te
x_train = torch.tensor([
[0, 0],
[0, 1],
[1, 0],
[1, 1],
], dtype=torch.float)
y_train = torch.tensor([0, 1, 1, 0], dtype=torch.float).unsqueeze(1)
We can instantiate a simple \(\psi\) network with 3 layers using sigmoid activation functions only:
layers = [
torch.nn.Linear(x_train.shape[1], 10),
torch.nn.Sigmoid(),
torch.nn.Linear(10, 5),
torch.nn.Sigmoid(),
torch.nn.Linear(5, 1),
torch.nn.Sigmoid(),
]
model = torch.nn.Sequential(*layers)
We can now train the network by optimizing the binary cross entropy loss and the
l1_loss
loss function incorporating the human prior towards
simple explanations. The \(\psi\) networks needs to be pruned during training
to simplify the internal architecture (here pruning happens at epoch 1000):
from torch_explain.nn.functional import prune_equal_fanin
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss_form = torch.nn.BCELoss()
model.train()
for epoch in range(6001):
optimizer.zero_grad()
y_pred = model(x_train)
loss = loss_form(y_pred, y_train) + 0.000001 * te.nn.functional.l1_loss(model)
loss.backward()
optimizer.step()
model = prune_equal_fanin(model, epoch, prune_epoch=1000, k=2)
Once trained we can extract first-order logic formulas describing how the network composed the input features to obtain the predictions:
from torch_explain.logic.nn import psi
from torch.nn.functional import one_hot
y1h = one_hot(y_train.squeeze().long())
explanation = psi.explain_class(model, x_train)
Explanations will be logic formulas in disjunctive normal form.
In this case, the explanation will be y=1 IFF (f1 AND ~f2) OR (f2 AND ~f1)
corresponding to y=1 IFF f1 XOR f2
.
The quality of the logic explanation can quantitatively assessed in terms of classification accuracy and rule complexity as follows:
from torch_explain.logic.metrics import test_explanation, complexity
accuracy, preds = test_explanation(explanation, x_train, y1h, target_class=1)
explanation_complexity = complexity(explanation)
In this case the accuracy is 100% and the complexity is 4.