以下代码使用 TensorFlow 构建了一个动物图像分类模型,并提供了详细的步骤和解释。

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout

# 2.1 数据集路径
train_dir = 'C:/Users/胖虎/Desktop/animals/train'
test_dir = 'C:/Users/胖虎/Desktop/animals/test'

# 2.2 数据预处理
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.1,
    zoom_range=0.1,
    horizontal_flip=True
)

test_datagen = ImageDataGenerator(rescale=1./255)

# 2.3 加载数据集
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(150, 150),
    batch_size=32,
    class_mode='categorical'
)

validation_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(150, 150),
    batch_size=32,
    class_mode='categorical'
)

# 构建模型
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(256, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(5, activation='softmax'))

# 3.2 模型编译
model.compile(loss='categorical_crossentropy',
              optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
              metrics=['accuracy'])

# 4. 模型训练
history = model.fit(train_generator,
                    steps_per_epoch=len(train_generator),
                    epochs=50,  # 增加训练轮数
                    validation_data=validation_generator,
                    validation_steps=len(validation_generator))

# 保存模型权重
model.save_weights('model.h5')

# 5. 模型评估
test_loss, test_accuracy = model.evaluate(validation_generator, steps=len(validation_generator))
print('Test Loss:', test_loss)
print('Test Accuracy:', test_accuracy)

# 4.2 训练集和测试集的准确率变化
import matplotlib.pyplot as plt

plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

# 5.1 模型预测
predictions = model.predict(validation_generator)

import matplotlib.pyplot as plt
import numpy as np

# 5.2 预测展示
class_names = train_generator.class_indices
class_names = list(class_names.keys())

plt.figure(figsize=(10, 10))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)

    img = validation_generator[i][0][0]  # 获取输入图像
    img = np.expand_dims(img, axis=0)  # 添加批次维度
    img = img / 255.0  # 归一化像素值

    prediction = model.predict(img)  # 预测图像
    predicted_label = class_names[np.argmax(prediction[0])]  # 获取预测标签

    true_label = class_names[np.argmax(validation_generator[i][1][0])]  # 获取真实标签

    plt.imshow(validation_generator[i][0][0], cmap=plt.cm.binary)

    if predicted_label == true_label:
        color = 'green'
    else:
        color = 'red'

    plt.xlabel('Predicted: '{} '({})''.format(predicted_label, true_label), color=color)

plt.show()

提高模型准确性的建议:

  1. 增加数据集大小:如果数据集较小,可以尝试通过数据增强技术生成更多的训练样本。
  2. 调整模型架构:可以尝试添加更多的卷积层、池化层或全连接层,或者调整它们的大小和数量。
  3. 调整超参数:可以尝试调整学习率、批次大小、迭代次数等超参数来优化模型。
  4. 尝试不同的优化器:除了Adam优化器外,还可以尝试其他优化器,如SGD、RMSprop等。
  5. 尝试使用预训练模型:可以尝试使用预训练的模型,如ResNet、VGG等,然后进行微调以适应特定的任务。
  6. 调整数据预处理:可以尝试使用不同的图像预处理技术,如直方图均衡化、对比度增强等,以增强图像特征。
  7. 增加训练轮数:增加训练轮数可以让模型更好地学习数据集的特征,但要注意防止过拟合。
  8. 尝试正则化技术:可以尝试使用正则化技术,如L1正则化、L2正则化或dropout,以减少模型的过拟合。

记住,不同的数据集和任务可能需要不同的方法来提高准确性。尝试不同的方法,并根据实际情况进行调整和优化。


原文地址: https://www.cveoy.top/t/topic/piiZ 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录