Prune functions to simplify the neural model

torch_explain.nn.functional.prune

torch_explain.nn.functional.prune.prune_equal_fanin(model: Module, epoch: int, prune_epoch: int, k: int = 2, device: device = device(type='cpu')) Module

Prune the linear layers of the network such that each neuron has the same fan-in.

Parameters:
  • model – pytorch model.

  • epoch – current training epoch.

  • prune_epoch – training epoch when pruning needs to be applied.

  • k – fan-in.

  • device – cpu or cuda device.

Returns:

Pruned model