基于VGG16的图像分类模型训练与评估
-- coding: utf-8 --
'使用VGG16预训练模型进行图像特征提取,并构建全连接神经网络进行分类'
import os import matplotlib.pyplot as plt import numpy as np from keras.applications import VGG16 from keras.preprocessing.image import ImageDataGenerator
data_dir = 'ss'+os.sep+'image' train_dir = os.path.join(data_dir, 'train') test_dir = os.path.join(data_dir, 'test')
vgg_model = VGG16(weights='imagenet', include_top=False, input_shape=(150,150,3)) vgg_model.summary()
datagen = ImageDataGenerator(rescale=1./255)
batch_size = 20
#定义提取特征函数 def extract_features(directory, sample_count): #VGG16最后池化层的输出特征初始化储存张量 features = np.zeros(shape=(sample_count,4,4,512)) labels = np.zeros(shape=(sample_count)) #从数据样本文件夹按batch_size数加载图像 generator = datagen.flow_from_directory(directory, target_size=(150,150), batch_size=batch_size, class_mode='binary') i = 0 for inputs_batch, labels_batch in generator: features_batch = vgg_model.predict(inputs_batch) features[ibatch_size : (i+1) * batch_size] = features_batch labels[ibatch_size : (i+1) * batch_size] = labels_batch i += 1 # if i *batch_size >= sample_count: break return features, labels
train_features, train_labels = extract_features(train_dir, 2000) #validation_features, validation_labels = extract_features(validation_dir, 1000) test_features,test_labels = extract_features(test_dir, 1000)
train_features = np.reshape(train_features, (2000, 44512)) #validation_features = np.reshape(validation_features, (1000, 44512)) test_features = np.reshape(test_features, (1000, 44512))
from keras import models from keras import layers from keras import optimizers
model = models.Sequential() model.add(layers.Dense(256,activation='relu',input_dim=44512)) model.add(layers.Dropout(0.5)) model.add(layers.Dense(1,activation='sigmoid')) model.compile(optimizer=optimizers.Adam(lr=2e-5), loss='binary_crossentropy', metrics=['acc'])
history = model.fit(train_features,train_labels, epochs=30, batch_size=batch_size, validation_data=(test_features, test_labels))
acc = history.history['acc'] val_acc = history.history['val_acc'] epochs = range(1, len(acc) + 1) plt.figure() plt.plot(epochs,acc,'bo',label='训练集准确率') plt.plot(epochs,val_acc,'b',label='验证集准确率') plt.title('训练集与验证集准确率曲线') plt.legend() plt.show()
acc = history.history['loss'] val_acc = history.history['val_loss'] epochs = range(1,len(acc) + 1) plt.figure() plt.plot(epochs,acc,'bo',label='训练集损失') plt.plot(epochs,val_acc,'b',label='验证集损失') plt.title('训练集与验证集损失曲线') plt.legend() plt.show()
原文地址: https://www.cveoy.top/t/topic/n9AR 著作权归作者所有。请勿转载和采集!