Fashion MNIST Classification with ResNet in Keras
import datetime from keras.utils import to_categorical from keras.datasets import fashion_mnist import numpy as np from keras.layers import Input, Conv2D, AveragePooling2D, BatchNormalization, Activation, Add, Flatten, Dense, Dropout from keras.models import Model
def conv(channels, strides=1, kernel_size=(3, 3), padding='same'): return Conv2D(filters=channels, kernel_size=kernel_size, strides=strides, padding=padding,use_bias=False)
def res_block(inputs, base_channels):
residual = inputs
residual = BatchNormalization()(residual)
residual = Activation('relu')(residual)
residual = conv(channels=base_channels, kernel_size=(1, 1))(residual)
#直连部分
x = conv(channels=base_channels, kernel_size=(1, 1))(inputs)
x= BatchNormalization()(x)
x = Activation('relu')(x)
x= conv(channels=base_channels2, strides=1, kernel_size=(3,3))(x)
x= BatchNormalization()(x)
x = Activation('relu')(x)
x=conv(channels=base_channels, kernel_size=(1,1))(x)
outputs = Add()([x, residual])
return Activation('relu')(outputs)
def ResNet(input_shape, base_channels, classes):
inputs = Input(shape=input_shape)
x= conv(channels=base_channels,strides=2,kernel_size=(3,3))(inputs)
x = res_block(x, base_channels=base_channels)
x= res_block(x, base_channels=base_channels2)
x= res_block(x, base_channels=base_channels*2)
x= AveragePooling2D(pool_size=(4,4))(x)
x= Flatten()(x)
x= Dense(units=classes, activation='softmax')(x)
model= Model(inputs=inputs, outputs=x)
return model
加载 Fashion MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
数据预处理
train_images = np.expand_dims(train_images, axis=-1) test_images = np.expand_dims(test_images, axis=-1) train_images = train_images.astype('float32') / 255. test_images = test_images.astype('float32') / 255. train_labels = to_categorical(train_labels, num_classes=10) test_labels = to_categorical(test_labels, num_classes=10)
定义超参数
input_shape = (28, 28, 1) base_channels = 64 classes = 10 epochs = 10 batch_size = 64
构建 ResNet 模型
model = ResNet(input_shape=input_shape, base_channels=base_channels, classes=classes)
编译模型
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
训练模型
model.fit(train_images, train_labels, epochs=epochs, batch_size=batch_size, validation_data=(test_images, test_labels))
评估模型
loss, accuracy = model.evaluate(test_images, test_labels) print('Test accuracy: {:.2f}%'.format(accuracy*100))
原文地址: https://www.cveoy.top/t/topic/ogGf 著作权归作者所有。请勿转载和采集!