771832331010310310025210027821008666856384277341254771832331010310310025219080201021008649800872802731100701830903103002512390802010010101007在训练集、验证集和测试集三个txt文件夹中有如上格式的数据集每条样本前23位为特征值最后一位为分类标签标签共有8个类别
import torch import torch.nn as nn import torch.optim as optim import torch.utils.data as data import torchvision.transforms as transforms
class CNN(nn.Module): def init(self): super(CNN, self).init() self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1) self.relu1 = nn.ReLU(inplace=True) self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) self.relu2 = nn.ReLU(inplace=True) self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) self.fc1 = nn.Linear(32 * 7 * 7, 128) self.relu3 = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.pool2(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = self.relu3(x)
return x
class BiGRU(nn.Module): def init(self, input_size, hidden_size, num_layers, num_classes): super(BiGRU, self).init() self.hidden_size = hidden_size self.num_layers = num_layers self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True) self.fc = nn.Linear(hidden_size * 2, num_classes)
def forward(self, x):
h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)
out, _ = self.gru(x, h0)
out = self.fc(out[:, -1, :])
return out
class FusionModel(nn.Module): def init(self, cnn, bigru, num_classes): super(FusionModel, self).init() self.cnn = cnn self.bigru = bigru self.fc = nn.Linear(128 + bigru.hidden_size * 2, num_classes)
def forward(self, x1, x2):
x1 = self.cnn(x1)
x2 = self.bigru(x2)
x = torch.cat((x1, x2), dim=1)
x = self.fc(x)
return x
Hyperparameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') num_epochs = 10 batch_size = 32 learning_rate = 0.001
Load data
transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), ])
train_dataset = data.Dataset.from_folder('train_txt_folder', transform=transform) val_dataset = data.Dataset.from_folder('val_txt_folder', transform=transform) test_dataset = data.Dataset.from_folder('test_txt_folder', transform=transform)
train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False) test_loader = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
Initialize models
cnn = CNN().to(device) bigru = BiGRU(input_size=23, hidden_size=64, num_layers=2, num_classes=8).to(device) model = FusionModel(cnn, bigru, num_classes=8).to(device)
Loss and optimizer
criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate)
Training loop
total_step = len(train_loader) for epoch in range(num_epochs): model.train() for i, (images, features, labels) in enumerate(train_loader): images = images.to(device) features = features.to(device) labels = labels.to(device)
# Forward pass
outputs = model(images, features)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item():.4f}")
# Validation
model.eval()
with torch.no_grad():
correct = 0
total = 0
for images, features, labels in val_loader:
images = images.to(device)
features = features.to(device)
labels = labels.to(device)
outputs = model(images, features)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"Validation Accuracy: {accuracy:.2f}%")
Testing
model.eval() with torch.no_grad(): correct = 0 total = 0 for images, features, labels in test_loader: images = images.to(device) features = features.to(device) labels = labels.to(device)
outputs = model(images, features)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")
原文地址: https://www.cveoy.top/t/topic/i8D2 著作权归作者所有。请勿转载和采集!