import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# 定义训练集和测试集路径
train_dir = 'C:/Users/28938/Desktop/image/image/train'
test_dir = 'C:/Users/28938/Desktop/image/image/test'

# 定义类别标签
class_names = ['cats', 'dogs']

# 定义图像尺寸和批次大小
img_height = 224
img_width = 224
batch_size = 32

# 从目录中读取训练集和测试集
train_ds = keras.preprocessing.image_dataset_from_directory(
    train_dir,
    validation_split=0.2,
    subset='training',
    seed=42,
    image_size=(img_height, img_width),
    batch_size=batch_size
)
val_ds = keras.preprocessing.image_dataset_from_directory(
    train_dir,
    validation_split=0.2,
    subset='validation',
    seed=42,
    image_size=(img_height, img_width),
    batch_size=batch_size
)
test_ds = keras.preprocessing.image_dataset_from_directory(
    test_dir,
    seed=42,
    image_size=(img_height, img_width),
    batch_size=batch_size
)

# 定义数据增强器
data_augmentation = keras.Sequential(
    [
        layers.experimental.preprocessing.RandomFlip('horizontal', input_shape=(img_height, img_width, 3)),
        layers.experimental.preprocessing.RandomRotation(0.1),
        layers.experimental.preprocessing.RandomZoom(0.1),
        layers.experimental.preprocessing.RandomCrop(img_height, img_width),
        layers.experimental.preprocessing.Rescaling(1./255),
        layers.experimental.preprocessing.RandomContrast(0.1),
        layers.experimental.preprocessing.RandomSaturation(0.1)
    ]
)

# 定义模型输入
input_shape = (img_height, img_width, 3)
model = keras.Sequential([
    data_augmentation,
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
    layers.MaxPooling2D(pool_size=(2, 2)),
    layers.Dropout(0.2),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D(pool_size=(2, 2)),
    layers.Dropout(0.2),
    layers.Conv2D(128, (3, 3), activation='relu'),
    layers.MaxPooling2D(pool_size=(2, 2)),
    layers.Dropout(0.2),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(len(class_names), activation='softmax')
])

# 编译模型
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

# 设定训练参数
epochs = 30

# 定义模型checkpoint
checkpoint_path = 'model_checkpoint/cp.ckpt'
checkpoint_dir = os.path.dirname(checkpoint_path)
checkpoint = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, 
                                                save_weights_only=True, 
                                                save_best_only=True, 
                                                monitor='val_accuracy', 
                                                mode='max', 
                                                verbose=1)

# 定义学习率衰减策略
lr_decay = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', 
                                                factor=0.1, 
                                                patience=5, 
                                                verbose=1, 
                                                mode='auto', 
                                                min_delta=0.0001, 
                                                cooldown=0, 
                                                min_lr=0)

# 开始训练模型
history = model.fit(train_ds, 
                    validation_data=val_ds, 
                    epochs=epochs, 
                    callbacks=[checkpoint, lr_decay])

# 加载最佳模型权重
model.load_weights(checkpoint_path)

# 对模型进行评估
test_loss, test_acc = model.evaluate(test_ds)
print('Test accuracy:', test_acc)

# 对模型进行预测
predictions = model.predict(test_ds)

# 混淆矩阵和分类报告
from sklearn.metrics import confusion_matrix, classification_report

# 获取测试集真实标签
test_labels = []
for images, labels in test_ds:
    test_labels.append(labels.numpy())

test_labels = np.concatenate(test_labels)

# 获取预测标签
predicted_labels = np.argmax(predictions, axis=1)

# 计算混淆矩阵和分类报告
cm = confusion_matrix(test_labels, predicted_labels)
report = classification_report(test_labels, predicted_labels, target_names=class_names)

# 打印混淆矩阵和分类报告
print('Confusion Matrix:')
print(cm)
print('
Classification Report:')
print(report)

# 绘制模型准确度和损失随时间变化的曲线
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(2, 1, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

# 选择一张测试图片
test_image_path = 'C:/Users/28938/Desktop/image/image/test/cats/cat.4004.jpg'

# 读取并预处理图片
img = keras.preprocessing.image.load_img(
    test_image_path, target_size=(img_height, img_width)
)
img_array = keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0)  # 创建一个批次维度

# 预测图片类别
predictions = model.predict(img_array)
score = tf.nn.softmax(predictions[0])

# 显示图片和预测结果
plt.imshow(img)
plt.axis('off')
plt.show()
print('预测结果: {}, 置信度: {:.2f}%'.format(class_names[np.argmax(score)], 100 * np.max(score)))

提高模型准确率的建议:

  1. 调整模型结构: 可以尝试增加更多的卷积层和全连接层,增加模型的复杂度和表达能力。
  2. 调整数据增强器: 可以尝试调整数据增强器的参数,如随机翻转、旋转和缩放的幅度,以增加数据的多样性和模型的鲁棒性。
  3. 调整训练参数: 可以尝试增加训练的轮数(epochs),以便模型更好地学习训练集的特征。
  4. 调整优化器: 可以尝试使用其他优化器,如SGD、RMSprop或AdamW,并调整学习率和其他超参数。
  5. 使用预训练模型: 可以尝试使用预训练模型,如ResNet、VGG或Inception等,将其作为特征提取器或调整其权重进行微调。
  6. 增加数据量: 可以尝试增加训练集的数量,通过数据增强器生成更多的训练样本,或者使用迁移学习和数据增强技术结合,从其他数据集中获取更多训练样本。
  7. 调整模型超参数: 可以尝试调整模型的超参数,如卷积核大小、卷积层的数量和大小、全连接层的大小等,以找到更适合数据集的模型结构。
  8. 进行模型集成: 可以尝试使用多个模型进行集成,例如使用不同的初始化、不同的数据增强方式或不同的模型结构,然后将它们的预测结果进行投票或平均。
  9. 使用更大的图像尺寸: 可以尝试使用更大的图像尺寸进行训练,以提供更多的细节和上下文信息。
  10. 尝试其他技术: 可以尝试其他一些技术,如Dropout、Batch Normalization、L1/L2正则化等,以减少过拟合并提高模型的泛化能力。

原文地址: https://www.cveoy.top/t/topic/pieB 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录