0

I have trained an autoencoder and saved it using keras built in save() method. Now I want to split it into two parts: Encoder and decoder. I can successfully load the model and get the encoder part by creating a new model using the old model:

encoder_model = keras.models.Model(inputs=self.model.input, 
 outputs=self.model.get_layer(layer_of_activations).get_output_at(0))

However, if I try to do the alternative thing with decoder, I cannot. I tried it using various methods, none of which were correct. Then I found a similar issue here (Keras replacing input layer) and tried using this method using code below:

    for i, l in enumerate(self.model.layers[0:19]):
        self.model.layers.pop(0)
    newInput = Input(batch_shape=(None, None, None, 64))
    newOutputs = self.model(newInput)
    newModel = keras.models.Model(newInput, newOutputs)

The output shape of the last layer I remove is (None, None, None, 64), but this code produces the following error:

ValueError: number of input channels does not match corresponding dimension of filter, 64 != 3

I assume this is because the input dimensions of the model are not updated after popping original layers, which is noted in this question's first answer, second comment: Keras replacing input layer

Simply looping through the layers and recreating them in a new model does not work as my model is not sequential.

1 Answer 1

1

I resolved this by building a new model with the exact same architecture as the decoder part of the original autoencoder network and then just copied the weights.

Here's the code:

    # Looping through the old model and popping the encoder part + encoded layer
    for i, l in enumerate(self.model.layers[0:19]): 
        self.model.layers.pop(0)

    # Building a clean model that is the exact same architecture as the decoder part of the autoencoder
    new_model = nb.build_decoder()

    # Looping through both models and setting the weights on the new decoder
    for i, l in enumerate(self.model.layers):
        new_model.layers[i+1].set_weights(l.get_weights())

Not the answer you're looking for? Browse other questions tagged or ask your own question.