import os import random import numpy as np import pandas as pd import tensorflow as tf import matplotlib.pyplot as plt from tensorflow.keras.models import Sequential, Model from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D, BatchNormalization, GlobalAveragePooling2D from tensorflow.keras.preprocessing.image import ImageDataGenerator from tensorflow.keras.preprocessing import image from tensorflow.keras.applications.vgg16 import VGG16

train_dir = 'C:/Users/chaofan/Desktop/class/class/archive/seg_train/seg_train' # 训练集目录 val_dir = 'C:/Users/chaofan/Desktop/class/class/archive/seg_test/seg_test' #测试集目录 classes = os.listdir(train_dir) # 获取类别列表 batch_size = 64 # 批大小 IMG_HEIGHT = 150 # 图像高度 IMG_WIDTH = 150 # 图像宽度 epochs = 10 # 训练轮数

train_image_generator = ImageDataGenerator( rescale=1./255, # 归一化 horizontal_flip=True # 水平翻转 )

val_image_generator = ImageDataGenerator( rescale=1./255 # 归一化 )

train_data_gen = train_image_generator.flow_from_directory( batch_size=batch_size, # 批大小 directory=train_dir, # 训练集目录 shuffle=True, # 是否打乱数据 target_size=(IMG_HEIGHT, IMG_WIDTH), # 图像大小 class_mode='categorical' # 分类方式 )

val_data_gen = val_image_generator.flow_from_directory( batch_size=batch_size, # 批大小 directory=val_dir, # 验证集目录 shuffle=True, # 是否打乱数据 target_size=(IMG_HEIGHT, IMG_WIDTH), # 图像大小 class_mode='categorical' # 分类方式 )

total_train = len(train_data_gen) # 训练集样本数 total_val = len(val_data_gen) # 验证集样本数

print('总训练数据批次数:', total_train) print('总验证数据批次数: ', total_val)

base_model = VGG16(weights='imagenet', include_top=False, input_shape=(IMG_HEIGHT, IMG_WIDTH, 3))

for layer in base_model.layers: layer.trainable = False

model = Sequential() model.add(base_model) model.add(Flatten()) model.add(Dense(1024, activation='relu')) model.add(Dense(512, activation='relu')) model.add(Dense(64, activation='relu')) model.add(Dense(len(classes), activation='softmax'))

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

model.summary()

history = model.fit( train_data_gen, # 训练集数据生成器 steps_per_epoch=total_train, # 每轮迭代的步数 epochs=epochs, # 训练轮数 validation_data=val_data_gen, # 验证集数据生成器 validation_steps=total_val # 验证集每轮迭代的步数 )

从测试集随机抽取5张图片进行预测

test_images = os.listdir(val_dir) random.shuffle(test_images) test_images = test_images[:5]

for img_name in test_images: img_path = os.path.join(val_dir, img_name) img = image.load_img(img_path, target_size=(IMG_HEIGHT, IMG_WIDTH)) img_array = image.img_to_array(img) img_array = np.expand_dims(img_array, axis=0) / 255.0

prediction = model.predict(img_array)
predicted_class = np.argmax(prediction)
predicted_label = classes[predicted_class]

plt.imshow(img)
plt.title(predicted_label)
plt.axis('off')
plt.show()

原文地址: https://www.cveoy.top/t/topic/pibT 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录