linna.network#
Classes
|
Wraps a sequential PyTorch model. |
|
- class linna.network.Network(torch_model: torch.nn.Sequential)#
Wraps a sequential PyTorch model. A layer consists of linear layer and possibly a non-linear function. Suppose a sequential PyTorch model with 6 layers has the following structure:
nn.Linear, nn.ReLU, nn.Linear, nn.ReLU, nn.Linear, nn.Softmax
In our setting, we have 3 layers that are organized as follows:
Layer 0: nn.Linear, nn.ReLU Layer 1: nn.Linear, nn.ReLU Layer 2: nn.Linear, nn.Softmax
Throughout the class neurons are uniquely identified by their index, i.e. corresponding row index in the weight matrix.
- reset()#
- get_num_neurons()#
Returns the number of neurons
- Returns:
Number of neurons
- Return type:
int
- forward(X: torch.Tensor, layer_idx: Optional[int] = None, grad: bool = False)#
Computes the forward pass until the given layer (inclusive)
- Parameters:
X (torch.Tensor) –
layer_idx (int, optional) –
grad (bool, default=False) –
- Returns:
Output of the neural network
- Return type:
torch.Tensor
- classify(X: torch.Tensor)#
Classifies the given inputs
- Parameters:
X (torch.Tensor) – Inputs to be classified
- Returns:
Classification result
- Return type:
torch.Tensor
- delete_neuron(layer_idx: int, neuron: int)#
Removes a neuron from a given layer
- Parameters:
layer_idx (int) – Layer in which neuron should be deleted
neuron (int) – Neuron to be deleted
- restore_neuron(layer_idx: int, neuron: int)#
Restores the neuron in the given layer
- Parameters:
layer_idx (int) – Layer in which neuron should be restored
neuron (int) – Neuron to be restored
- set_basis(layer_idx: int, basis: List[int])#
Sets the basis for the specific layer. Note that due to
- Parameters:
layer_idx (int) – Layer
basis (List[int]) – List of basis neurons
- readjust_weights(layer_idx: int, neuron: int, coef: torch.Tensor)#
Readjust the outgoing weights for the neuron in the given layer. Effectively, the weight matrix of
layer_idx + 1
is modified.- Parameters:
layer_idx (int) – Layer whose outgoing weights are adjusted
neuron (int) – Neuron whose weight is adjusted
coef (torch.Tensor) – Linear coefficients
- get_io_matrix(layer_idx: int, loader, size=1000) ndarray #
Computes the IO matrix for the given layer
- Parameters:
layer_idx (int) – Layer for which IO matrix should be computed
loader (DataLoader) –
size (int) – Number of images to be considered
- Returns:
IO matrix
- Return type:
torch.Tensor
- export_to_nnet(filename: str)#
- update_torch_model()#
- class linna.network.NetworkLayer(torch_model: torch.nn.Sequential, layer_idx: int)#
- get_weight()#
Return the weight of the layer
- Returns:
Weight of the layer
- Return type:
torch.Tensor
- get_bias()#
Returns the bias of the layer
- Returns:
Return bias of layer
- Return type:
torch.Tensor
- set_weight(weight: torch.Tensor)#
Sets the weight of the layer
- Parameters:
weight (torch.Tensor) – Weight
- set_bias(bias: torch.Tensor)#
Sets the bias of the layer
- Parameters:
bias (torch.Tensor) –
- get_input_weight(neuron: int)#
Returns the input weight of the neuron
- Parameters:
neuron (int) – Neuron
- Returns:
Input weights of the neuron
- Return type:
torch.Tensor
- delete_output(neuron: int)#
Deletes the neuron from the layer
- Parameters:
neuron (int) – Neuron to be deleted
- delete_input(neuron: int)#
Deletes the input neuron
- Parameters:
neuron (int) – Neuron
- get_neuron_index(neuron: int)#
Return the index of the neuron
- Parameters:
neuron (int) – Neuron
- Returns:
Index of the neuron
- Return type:
int
- readjust_weights(neuron: int, coef: torch.Tensor)#
Readjust
- Parameters:
neuron (int) – Neuron
coef (torch.Tensor) – A tensor containing coefficients of linear combination
- restore_neuron(neuron: int)#
Restores a given neuron
- Parameters:
neuron (int) – Neuron to be restored
- restore_input(neuron)#
Restores an input neuron to the layer
- Parameters:
neuron (int) – Input neuron
- reset()#
Resets the layer
- restore_weights(neuron)#
Restores the weights of the layer w.r.t a previously removed neuron
- Parameters:
neuron (int) – Neuron whose weights should be restored