0

I am following this tutorial to train an autoencoder.

The training has gone well. Next, I am interested to extract features from the hidden layer (between the encoder and decoder).

How should I do that?

1 Answer 1

1

The cleanest and most straight-forward way would be to add methods for creating partial outputs -- this can be even be done a posteriori on a trained model.

from torch import Tensor

class AE(nn.Module):
    def __init__(self, **kwargs):
        ...

    def encode(self, features: Tensor) -> Tensor:
        h = torch.relu(self.encoder_hidden_layer(features))
        return torch.relu(self.encoder_output_layer(h))

    def decode(self, encoded: Tensor) -> Tensor:
        h = torch.relu(self.decoder_hidden_layer(encoded))
        return torch.relu(self.decoder_output_layer(h))

    def forward(self, features: Tensor) -> Tensor:
        encoded = self.encode(features)
        return self.decode(encoded)

You can now query the model for encoder hidden states by simply calling encode with the corresponding input tensor.

If you'd rather not add any methods to the base class (I don't see why), you could alternatively write an external function:

def get_encoder_state(model: AE, features: Tensor) -> Tensor:
   return torch.relu(model.encoder_output_layer(torch.relu(model.encoder_hidden_layer(features))))

2
  • Thank you very much. Therefore, is it correct to add this function in the class: def encode(self, features): activation = self.encoder_hidden_layer(features) activation = torch.relu(activation) code = self.encoder_output_layer(activation) code = torch.relu(code) return code Then, in that medium tutorials, it is written that outputs = model(batch_features) After the training, how can I call my function and the trained model to extract the features? Is this one correct: hidden_features = model.encode(my_input)
    – Kadaj13
    Commented Mar 3, 2021 at 9:28
  • the outputs = model(batch_features) remains as it is (it still gives you the reconstructed input). for the hidden features, hidden_features = model.encode(input) should work! Commented Mar 3, 2021 at 9:59

Not the answer you're looking for? Browse other questions tagged or ask your own question.