基于VGG16的图像分类模型训练与预测
基于VGG16的图像分类模型训练与预测
本代码使用VGG16预训练模型进行图像分类,并展示训练过程和随机抽取测试集图片进行预测的代码示例,以及预测结果可视化。
import os
import random
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D, BatchNormalization, GlobalAveragePooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import VGG16
train_dir = 'C:/Users/chaofan/Desktop/class/class/archive/seg_train/seg_train' # 训练集目录
val_dir = 'C:/Users/chaofan/Desktop/class/class/archive/seg_test/seg_test' #测试集目录
classes = os.listdir(train_dir) # 获取类别列表
batch_size = 64 # 批大小
IMG_HEIGHT = 150 # 图像高度
IMG_WIDTH = 150 # 图像宽度
epochs = 10 # 训练轮数
train_image_generator = ImageDataGenerator(
rescale=1./255, # 归一化
horizontal_flip=True # 水平翻转
)
val_image_generator = ImageDataGenerator(
rescale=1./255 # 归一化
)
train_data_gen = train_image_generator.flow_from_directory(
batch_size=batch_size, # 批大小
directory=train_dir, # 训练集目录
shuffle=True, # 是否打乱数据
target_size=(IMG_HEIGHT, IMG_WIDTH), # 图像大小
class_mode='categorical' # 分类方式
)
val_data_gen = val_image_generator.flow_from_directory(
batch_size=batch_size, # 批大小
directory=val_dir, # 验证集目录
shuffle=True, # 是否打乱数据
target_size=(IMG_HEIGHT, IMG_WIDTH), # 图像大小
class_mode='categorical' # 分类方式
)
total_train = len(train_data_gen) # 训练集样本数
total_val = len(val_data_gen) # 验证集样本数
print("总训练数据批次数:", total_train)
print("总验证数据批次数: ", total_val)
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(IMG_HEIGHT, IMG_WIDTH, 3))
for layer in base_model.layers:
layer.trainable = False
model = Sequential()
model.add(base_model)
model.add(Flatten())
model.add(Dense(1024, activation='relu'))
model.add(Dense(512, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(len(classes), activation='softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
history = model.fit(
train_data_gen, # 训练集数据生成器
steps_per_epoch=total_train, # 每轮迭代的步数
epochs=epochs, # 训练轮数
validation_data=val_data_gen, # 验证集数据生成器
validation_steps=total_val # 验证集每轮迭代的步数
)
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()), 1])
plt.title('Training and Validation Accuracy')
plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0, 1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()
# 从测试集中随机抽取5张图片进行预测
test_images = []
test_labels = []
for class_name in classes:
class_path = os.path.join(val_dir, class_name)
img_list = os.listdir(class_path)
for i in range(5):
img_name = random.choice(img_list)
img_path = os.path.join(class_path, img_name)
img = image.load_img(img_path, target_size=(IMG_HEIGHT, IMG_WIDTH))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array /= 255.0
test_images.append(img_array)
test_labels.append(class_name)
test_images = np.concatenate(test_images, axis=0)
# 进行预测
predictions = model.predict(test_images)
# 显示预测结果
plt.figure(figsize=(10, 10))
for i in range(len(test_images)):
plt.subplot(5, 5, i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(test_images[i], cmap=plt.cm.binary)
predicted_label = classes[np.argmax(predictions[i])]
true_label = test_labels[i]
if predicted_label == true_label:
color = 'green'
else:
color = 'red'
plt.xlabel('{} ({})'.format(predicted_label, true_label), color=color)
plt.show()
代码说明:
- 导入必要的库
- 定义训练集和测试集路径
- 定义图像预处理参数
- 使用ImageDataGenerator进行数据增强
- 构建VGG16模型,冻结预训练层
- 添加新的全连接层
- 编译模型
- 训练模型
- 绘制训练过程中的精度和损失曲线
- 从测试集中随机抽取5张图片进行预测
- 显示预测结果,并用不同颜色标注正确和错误的预测
运行代码需要注意:
- 将代码中的训练集和测试集路径替换为自己的路径
- 确保已安装必要的库
- 运行代码后,将显示训练过程中的精度和损失曲线,以及5张随机抽取的测试图片的预测结果
原文地址: https://www.cveoy.top/t/topic/pibv 著作权归作者所有。请勿转载和采集!