Suppose you want to fine-tune Keras' VGG16 model. Let's say we just want to replace the final dense (fully-connected) layer and keep the rest (include_top=True):
Here are the last few layers:
As expected, it's predicting among 1000 classes. Let's pop off the top:
Output:
input = Input(shape=input_shape, name='image_input')
vgg16 = VGG16(weights='imagenet', include_top=True)
vgg16.summary()
Here are the last few layers:
fc1 (Dense) (None, 4096) 102764544
_________________________________________________________________
fc2 (Dense) (None, 4096) 16781312
_________________________________________________________________
predictions (Dense) (None, 1000) 4097000
=================================================================
Total params: 138,357,544
As expected, it's predicting among 1000 classes. Let's pop off the top:
vgg16.layers.pop()
vgg16.summary()
Output:
…
fc1 (Dense) (None, 4096) 102764544
_________________________________________________________________
fc2 (Dense) (None, 4096) 16781312
=================================================================
Total params: 134,260,544
Great, the dense 1000 layer is gone! Now just add our dense layer for 2 categories:
vgg16_output = vgg16(input)
x = Dense(2, activation='softmax', name='predictions')(vgg16_output)
my_model = Model(inputs=input, outputs=x)
Now when you print the layers, it smashes all of vgg16 into one:
vgg16 (Model) (None, 1000) 134260544
_________________________________________________________________
predictions (Dense) (None, 2) 2002
=================================================================
Total params: 134,262,546
The workaround is to use the output of the last layer:
x = Dense(2, activation='softmax', name='predictions')(vgg16.layers[-1].output)
my_model = Model(inputs=vgg16.input, outputs=x)
Which gives you the correct model:
fc1 (Dense) (None, 4096) 102764544
_________________________________________________________________
fc2 (Dense) (None, 4096) 16781312
_________________________________________________________________
predictions (Dense) (None, 2) 8194
=================================================================
Total params: 134,268,738
No comments:
Post a Comment