# Make a model with 2 layers layer1 = keras.layers.Dense(3, activation="relu") layer2 = keras.layers.Dense(3, activation="sigmoid") model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])
# Freeze the first layer layer1.trainable = False
# Keep a copy of the weights of layer1 for later reference initial_layer1_weights_values = layer1.get_weights()
# Train the model model.compile(optimizer="adam", loss="mse") model.fit(np.random.random((2, 3)), np.random.random((2, 3)))
# Check that the weights of layer1 have not changed during training final_layer1_weights_values = layer1.get_weights() np.testing.assert_allclose( initial_layer1_weights_values[0], final_layer1_weights_values[0] ) np.testing.assert_allclose( initial_layer1_weights_values[1], final_layer1_weights_values[1] )
model = keras.Sequential( [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),] )
model.trainable = False# Freeze the outer model
assert inner_model.trainable == False# All layers in `model` are now frozen assert inner_model.layers[0].trainable == False# `trainable` is propagated recursively
base_model = keras.applications.Xception( weights='imagenet', # Load weights pre-trained on ImageNet. input_shape=(150, 150, 3), include_top=False) # Do not include the ImageNet classifier at the top.
随后,冻结该基础模型。
1
base_model.trainable = False
根据基础模型创建一个新模型。
1 2 3 4 5 6 7 8 9 10
inputs = keras.Input(shape=(150, 150, 3)) # We make sure that the base_model is running in inference mode here, # by passing `training=False`. This is important for fine-tuning, as you will # learn in a few paragraphs. x = base_model(inputs, training=False) # Convert features of shape `base_model.output_shape[1:]` to vectors x = keras.layers.GlobalAveragePooling2D()(x) # A Dense classifier with a single unit (binary classification) outputs = keras.layers.Dense(1)(x) model = keras.Model(inputs, outputs)
# Unfreeze the base model base_model.trainable = True# 先把模型解冻!!!
# It's important to recompile your model after you make any changes # to the `trainable` attribute of any inner layer, so that your changes # are take into account model.compile(optimizer=keras.optimizers.Adam(1e-5), # Very low learning rate loss=keras.losses.BinaryCrossentropy(from_logits=True), metrics=[keras.metrics.BinaryAccuracy()])
# Train end-to-end. Be careful to stop before you overfit! model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)
# Create base model base_model = keras.applications.Xception( weights='imagenet', input_shape=(150, 150, 3), include_top=False) # Freeze base model base_model.trainable = False
# Create new model on top. inputs = keras.Input(shape=(150, 150, 3)) x = base_model(inputs, training=False) x = keras.layers.GlobalAveragePooling2D()(x) outputs = keras.layers.Dense(1)(x) model = keras.Model(inputs, outputs)
# Iterate over the batches of a dataset. for inputs, targets in new_dataset: # Open a GradientTape. with tf.GradientTape() as tape: # Forward pass. predictions = model(inputs) # Compute the loss value for this batch. loss_value = loss_fn(targets, predictions) # Get gradients of loss wrt the *trainable* weights. gradients = tape.gradient(loss_value, model.trainable_weights) # Update the weights of the model. optimizer.apply_gradients(zip(gradients, model.trainable_weights))
对于微调同样如此。
端到端示例:基于 Dogs vs. Cats 数据集微调图像分类模型
为了巩固这些概念,我们先介绍一个具体的端到端迁移学习和微调示例。我们将加载在 ImageNet 上预训练的 Xception 模型,并将其用于 Kaggle Dogs vs. Cats 分类数据集。
train_ds, validation_ds, test_ds = tfds.load( "cats_vs_dogs", # Reserve 10% for validation and 10% for test split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"], as_supervised=True, # Include labels )
print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds)) print( "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds) ) print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
下面是训练数据集中的前 9 个图像。如您所见,它们具有不同的大小。
1 2 3 4 5 6 7 8
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10)) for i, (image, label) inenumerate(train_ds.take(9)): ax = plt.subplot(3, 3, i + 1) plt.imshow(image) plt.title(int(label)) plt.axis("off")
base_model = keras.applications.Xception( weights="imagenet", # Load weights pre-trained on ImageNet. input_shape=(150, 150, 3), include_top=False, ) # Do not include the ImageNet classifier at the top.
# Freeze the base_model base_model.trainable = False
# Create new model on top inputs = keras.Input(shape=(150, 150, 3)) x = data_augmentation(inputs) # Apply random data augmentation
# Pre-trained Xception weights requires that input be normalized # from (0, 255) to a range (-1., +1.), the normalization layer # does the following, outputs = (inputs - mean) / sqrt(var) norm_layer = keras.layers.experimental.preprocessing.Normalization() mean = np.array([127.5] * 3) var = mean ** 2 # Scale inputs to [-1, +1] x = norm_layer(x) norm_layer.set_weights([mean, var])
# The base model contains batchnorm layers. We want to keep them in inference mode # when we unfreeze the base model for fine-tuning, so we make sure that the # base_model is running in inference mode here. x = base_model(x, training=False) x = keras.layers.GlobalAveragePooling2D()(x) x = keras.layers.Dropout(0.2)(x) # Regularize with dropout outputs = keras.layers.Dense(1)(x) model = keras.Model(inputs, outputs)
# Unfreeze the base_model. Note that it keeps running in inference mode # since we passed `training=False` when calling it. This means that # the batchnorm layers will not update their batch statistics. # This prevents the batchnorm layers from undoing all the training # we've done so far. base_model.trainable = True model.summary()