使用 TensorFlow 构建卷积神经网络 (CNN) 进行图像分类
import tensorflow as tf
import numpy as np
import os
# 导入数据集
train_data = np.load('train_data.npy')
train_labels = np.load('train_labels.npy')
test_data = np.load('test_data.npy')
test_labels = np.load('test_labels.npy')
# 定义超参数
learning_rate = 0.001
num_epochs = 10
batch_size = 32
num_classes = 2
# 定义卷积神经网络的结构
def conv_net(x, num_classes):
# 定义卷积层和池化层
conv1 = tf.layers.conv2d(inputs=x, filters=32, kernel_size=[5, 5], activation=tf.nn.relu)
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
conv2 = tf.layers.conv2d(inputs=pool1, filters=64, kernel_size=[5, 5], activation=tf.nn.relu)
pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
# 将卷积层输出的特征图展平
pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])
# 定义全连接层
fc1 = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)
logits = tf.layers.dense(inputs=fc1, units=num_classes)
return logits
# 定义输入占位符和输出占位符
x = tf.placeholder(tf.float32, [None, 28, 28, 1])
y = tf.placeholder(tf.int32, [None])
# 定义损失函数和优化器
logits = conv_net(x, num_classes)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)
# 进行模型训练和测试
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 训练模型
for epoch in range(num_epochs):
num_batches = len(train_data) // batch_size
for batch in range(num_batches):
batch_data = train_data[batch * batch_size:(batch + 1) * batch_size]
batch_labels = train_labels[batch * batch_size:(batch + 1) * batch_size]
_, batch_loss = sess.run([optimizer, loss], feed_dict={x: batch_data, y: batch_labels})
print('Epoch:', epoch + 1, 'Batch:', batch + 1, 'Loss:', batch_loss)
# 测试模型
num_correct = 0
num_test = len(test_data)
for i in range(num_test):
test_image = test_data[i]
test_label = test_labels[i]
predicted_label = sess.run(logits, feed_dict={x: [test_image]})[0].argmax()
if predicted_label == test_label:
num_correct += 1
accuracy = num_correct / num_test
print('Accuracy:', accuracy)
由于数据集未提供,无法运行代码给出准确的运行结果。
原文地址: https://www.cveoy.top/t/topic/owd0 著作权归作者所有。请勿转载和采集!