Prune functions to simplify the neural model
torch_explain.nn.functional.prune
- torch_explain.nn.functional.prune.prune_equal_fanin(model: Module, epoch: int, prune_epoch: int, k: int = 2, device: device = device(type='cpu')) Module
Prune the linear layers of the network such that each neuron has the same fan-in.
- Parameters:
model – pytorch model.
epoch – current training epoch.
prune_epoch – training epoch when pruning needs to be applied.
k – fan-in.
device – cpu or cuda device.
- Returns:
Pruned model