import os
import random
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D, BatchNormalization, GlobalAveragePooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import VGG16

train_dir = 'C:/Users/chaofan/Desktop/class/class/archive/seg_train/seg_train' # 训练集目录
val_dir = 'C:/Users/chaofan/Desktop/class/class/archive/seg_test/seg_test'    #测试集目录
classes = os.listdir(train_dir)  # 获取类别列表
batch_size = 64  # 批大小
IMG_HEIGHT = 150  # 图像高度
IMG_WIDTH = 150  # 图像宽度
epochs = 10  # 训练轮数

train_image_generator = ImageDataGenerator(
    rescale=1./255,  # 归一化
    horizontal_flip=True  # 水平翻转
)

val_image_generator = ImageDataGenerator(
    rescale=1./255  # 归一化
)

train_data_gen = train_image_generator.flow_from_directory(
    batch_size=batch_size,  # 批大小
    directory=train_dir,  # 训练集目录
    shuffle=True,  # 是否打乱数据
    target_size=(IMG_HEIGHT, IMG_WIDTH),  # 图像大小
    class_mode='categorical'  # 分类方式
)

val_data_gen = val_image_generator.flow_from_directory(
    batch_size=batch_size,  # 批大小
    directory=val_dir,  # 验证集目录
    shuffle=True,  # 是否打乱数据
    target_size=(IMG_HEIGHT, IMG_WIDTH),  # 图像大小
    class_mode='categorical'  # 分类方式
)

total_train = len(train_data_gen)  # 训练集样本数
total_val = len(val_data_gen)  # 验证集样本数

print("总训练数据批次数:", total_train)
print("总验证数据批次数: ", total_val)

base_model = VGG16(weights='imagenet', include_top=False, input_shape=(IMG_HEIGHT, IMG_WIDTH, 3))

for layer in base_model.layers:
    layer.trainable = False

model = Sequential()
model.add(base_model)
model.add(Flatten())
model.add(Dense(1024, activation='relu'))
model.add(Dense(512, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(len(classes), activation='softmax'))

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

model.summary()

history = model.fit(
    train_data_gen,  # 训练集数据生成器
    steps_per_epoch=total_train,  # 每轮迭代的步数
    epochs=epochs,  # 训练轮数
    validation_data=val_data_gen,  # 验证集数据生成器
    validation_steps=total_val  # 验证集每轮迭代的步数
)
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()), 1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0, 1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')

plt.show()

# 从测试集中随机选择5张图片进行预测
random.seed(42)  # 设置随机种子,使每次运行结果相同
test_images = os.listdir(val_dir)  # 获取测试集中的图片列表
random.shuffle(test_images)  # 打乱图片列表顺序
selected_images = test_images[:5]  # 选择前5张图片
selected_images_path = [os.path.join(val_dir, img) for img in selected_images]  # 获取选择图片的完整路径

# 加载并预处理选择的图片
selected_images_data = []
for img_path in selected_images_path:
    img = image.load_img(img_path, target_size=(IMG_HEIGHT, IMG_WIDTH))  # 加载图片并调整大小
    img = image.img_to_array(img)  # 将图片转换为数组
    img = np.expand_dims(img, axis=0)  # 增加一个维度,使其符合模型输入的要求
    img = img / 255.0  # 归一化
    selected_images_data.append(img)

# 预测选择的图片的类别
predictions = model.predict(np.vstack(selected_images_data))  # 预测图片的类别概率
predicted_classes = np.argmax(predictions, axis=1)  # 获取预测类别的索引

# 显示选择的图片及其预测结果
plt.figure(figsize=(15, 5))
for i in range(len(selected_images)):
    plt.subplot(1, 5, i+1)
    img = image.load_img(selected_images_path[i])
    plt.imshow(img)
    plt.title('Predicted: {}'.format(classes[predicted_classes[i]]))
    plt.axis('off')
plt.show()

原文地址: https://www.cveoy.top/t/topic/pibA 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录