使用迁移学习和 TensorFlow 训练宝可梦分类模型
import tensorflow as tf import numpy as np import os from sklearn.model_selection import train_test_split
读取图像路径和标签
def load_data(): img_path = [] label = [] for root, dirs, files in os.walk('pokemon_data'): for file in files: if file.endswith('.jpg'): img_path.append(os.path.join(root, file)) if 'bulbasaur' in file: label.append(0) elif 'charmander' in file: label.append(1) elif 'squirtle' in file: label.append(2) elif 'pikachu' in file: label.append(3) elif 'mewtwo' in file: label.append(4) return img_path, label
随机打乱数据
def shuffle_data(img_path, label): index = np.arange(len(img_path)) np.random.shuffle(index) img_path = np.array(img_path)[index] label = np.array(label)[index] return img_path, label
数据预处理
def preprocess(img_path, label): img = tf.io.read_file(img_path) img = tf.image.decode_jpeg(img, channels=3) img = tf.image.resize(img, [224, 224]) img = tf.cast(img, dtype=tf.float32) / 255. return img, label
迁移学习模型
def transfer_model(): base_model = tf.keras.applications.ResNet50(include_top=False, weights='imagenet', input_shape=(224, 224, 3)) for layer in base_model.layers: layer.trainable = False model = tf.keras.Sequential([ base_model, tf.keras.layers.GlobalAveragePooling2D(), tf.keras.layers.Dense(256, activation='relu'), tf.keras.layers.Dropout(0.5), tf.keras.layers.Dense(5, activation='softmax') ]) return model
训练模型
def train_model(model, train_data, valid_data): model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['accuracy']) model.fit(train_data, epochs=10, validation_data=valid_data) return model
测试模型
def test_model(model, test_data): loss, acc = model.evaluate(test_data) print('Test loss: {}, Test accuracy: {}'.format(loss, acc))
保存模型
def save_model(model): model.save('pokemon_model.h5')
if name == 'main': img_path, label = load_data() img_path, label = shuffle_data(img_path, label) img_path_train, img_path_test, label_train, label_test = train_test_split(img_path, label, test_size=0.2, stratify=label) img_path_train, img_path_valid, label_train, label_valid = train_test_split(img_path_train, label_train, test_size=0.2, stratify=label_train) train_data = tf.data.Dataset.from_tensor_slices((img_path_train, label_train)) valid_data = tf.data.Dataset.from_tensor_slices((img_path_valid, label_valid)) test_data = tf.data.Dataset.from_tensor_slices((img_path_test, label_test)) train_data = train_data.shuffle(len(img_path_train)).map(preprocess).batch(32).prefetch(tf.data.experimental.AUTOTUNE) valid_data = valid_data.map(preprocess).batch(32).prefetch(tf.data.experimental.AUTOTUNE) test_data = test_data.map(preprocess).batch(32).prefetch(tf.data.experimental.AUTOTUNE) model = transfer_model() model = train_model(model, train_data, valid_data) test_model(model, test_data) save_model(model)
原文地址: https://www.cveoy.top/t/topic/nTKv 著作权归作者所有。请勿转载和采集!