Fine-tuning

May 20, 2023

Fine-tuning is a technique in machine learning that involves adjusting the hyperparameters of a pre-trained model to optimize its performance on a new dataset. This is a common practice in transfer learning, where a pre-trained model trained on a large dataset is repurposed for a new task that has a smaller dataset. By fine-tuning the pre-trained model on the new task-specific dataset, the model can learn to recognize patterns and features relevant to the new task, thereby improving its accuracy.

Transfer Learning

Transfer learning is a technique that involves taking a pre-trained model trained on a large dataset and repurposing it for a new task that has a smaller dataset. The idea behind transfer learning is that the pre-trained model has already learned to recognize general patterns and features that are useful for many tasks. By leveraging this pre-trained model, we can save time and computational resources that would be required to train a new model from scratch.

For example, consider an image classification task that involves classifying different types of vehicles. A pre-trained model like VGG16 that was trained on the ImageNet dataset could be used as a starting point for this task. By fine-tuning the VGG16 model on the new vehicle dataset, the model can learn to recognize specific features and patterns relevant to the vehicle classification task, such as the shape of car headlights or the features of a truck’s cargo area.

How Fine-tuning Works

Fine-tuning involves taking a pre-trained model and adapting it to a new task by adjusting its hyperparameters. The hyperparameters that are typically adjusted during fine-tuning include the learning rate, the number of epochs, the batch size, and the number of layers that are frozen.

  1. Learning Rate: The learning rate controls the step size of the gradient descent algorithm used to optimize the model’s weights. During fine-tuning, the learning rate is often reduced to avoid overfitting to the new dataset.

  2. Number of Epochs: The number of epochs determines how many times the model sees the entire dataset during training. During fine-tuning, the number of epochs is often reduced to prevent overfitting to the new dataset.

  3. Batch Size: The batch size determines how many samples are processed at once during each iteration of training. During fine-tuning, the batch size is often reduced to allow for more frequent updates to the model’s weights.

  4. Number of Layers Frozen: The number of layers that are frozen determines how many layers of the pre-trained model are kept fixed during training. During fine-tuning, the early layers of the pre-trained model are typically frozen, while the later layers are fine-tuned to the new task-specific dataset.

Fine-tuning Example

Let’s consider an example of fine-tuning a pre-trained image classification model on a new dataset. We’ll use the VGG16 model as our pre-trained model and the CIFAR-10 dataset as our new task-specific dataset.

import tensorflow as tf
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import cifar10

# Load the pre-trained VGG16 model
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# Freeze the early layers of the pre-trained model
for layer in base_model.layers:
    layer.trainable = False

# Add new classification layers for the new task-specific dataset
x = Flatten()(base_model.output)
x = Dense(4096, activation='relu')(x)
x = Dense(4096, activation='relu')(x)
predictions = Dense(10, activation='softmax')(x)

# Create the new fine-tuned model
fine_tuned_model = Model(inputs=base_model.input, outputs=predictions)

# Compile the fine-tuned model
fine_tuned_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Load the CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Preprocess the CIFAR-10 dataset
x_train = tf.keras.applications.vgg16.preprocess_input(x_train)
x_test = tf.keras.applications.vgg16.preprocess_input(x_test)

# Train the new fine-tuned model on the CIFAR-10 dataset
fine_tuned_model.fit(x_train, tf.keras.utils.to_categorical(y_train),
                     validation_data=(x_test, tf.keras.utils.to_categorical(y_test)),
                     epochs=10, batch_size=32)

In this example, we first load the pre-trained VGG16 model and freeze the early layers to prevent them from being updated during fine-tuning. We then add new classification layers for the CIFAR-10 dataset, which has 10 classes. Finally, we compile the new fine-tuned model and train it on the CIFAR-10 dataset for 10 epochs with a batch size of 32.