Fine-tuning models in Keras: a gotcha

Suppose you want to fine-tune Keras' VGG16 model. Let's say we just want to replace the final dense (fully-connected) layer and keep the rest (include_top=True):
 input = Input(shape=input_shape, name='image_input')  
 vgg16 = VGG16(weights='imagenet', include_top=True)  
 vgg16.summary()  


Here are the last few layers:

fc1 (Dense)                  (None, 4096)              102764544 
_________________________________________________________________
fc2 (Dense)                  (None, 4096)              16781312  
_________________________________________________________________
predictions (Dense)          (None, 1000)              4097000   
=================================================================
Total params: 138,357,544


As expected, it's predicting among 1000 classes. Let's pop off the top:
 vgg16.layers.pop()  
 vgg16.summary()  


Output:

fc1 (Dense)                  (None, 4096)              102764544 
_________________________________________________________________
fc2 (Dense)                  (None, 4096)              16781312  
=================================================================
Total params: 134,260,544



Great, the dense 1000 layer is gone! Now just add our dense layer for 2 categories:
 vgg16_output = vgg16(input)  
 x = Dense(2, activation='softmax', name='predictions')(vgg16_output)  
 my_model = Model(inputs=input, outputs=x)  



Now when you print the layers, it smashes all of vgg16 into one:

vgg16 (Model)                (None, 1000)              134260544 
_________________________________________________________________
predictions (Dense)          (None, 2)                 2002      
=================================================================
Total params: 134,262,546



The number of params is correct, but somehow the 1000 outputs are still there!

Turns out this issue has been known for over a year: https://github.com/fchollet/keras/issues/2371#issuecomment-211120172


The workaround is to use the output of the last layer:
 x = Dense(2, activation='softmax', name='predictions')(vgg16.layers[-1].output)
 my_model = Model(inputs=vgg16.input, outputs=x)  

Which gives you the correct model:

fc1 (Dense)                  (None, 4096)              102764544 
_________________________________________________________________
fc2 (Dense)                  (None, 4096)              16781312  
_________________________________________________________________
predictions (Dense)          (None, 2)                 8194      
=================================================================
Total params: 134,268,738



No comments:

Post a Comment

Maximum Likelihood Estimation for dummies

What is Maximum Likelihood Estimation (MLE)? It's simple, but there are some gotchas. First, let's recall what likelihood  is. ...