数据在一个txt文件里类型如下所示每行最后一个值代表一种类别共有八种类别用pytorch写一个GRU模型的分类算法分类出这八种类别设置batch_size写出全部代码701830903103002512390802010010101017701830903103002512390802010010101017771832339103103100252190802010010110842021724
import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader
Define the model architecture
class GRUClassifier(nn.Module): def init(self, input_size, hidden_size, num_classes): super(GRUClassifier, self).init() self.hidden_size = hidden_size self.gru = nn.GRU(input_size, hidden_size, batch_first=True) self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
h0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)
out, _ = self.gru(x, h0)
out = self.fc(out[:, -1, :]) # Take the last hidden state
return out
Define the dataset class
class CustomDataset(Dataset): def init(self, data_file): self.data = [] with open(data_file, 'r') as file: for line in file: line = line.strip().split(',') features = [float(x) for x in line[:-1]] label = int(line[-1]) self.data.append((features, label))
def __getitem__(self, index):
return torch.tensor(self.data[index][0]), torch.tensor(self.data[index][1])
def __len__(self):
return len(self.data)
Set hyperparameters
input_size = 23 hidden_size = 64 num_classes = 8 batch_size = 4 num_epochs = 10
Create the model
model = GRUClassifier(input_size, hidden_size, num_classes)
Define the loss function and optimizer
criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters())
Create the dataset and data loader
dataset = CustomDataset('data.txt') dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
Training loop
for epoch in range(num_epochs): for batch_features, batch_labels in dataloader: optimizer.zero_grad() outputs = model(batch_features) loss = criterion(outputs, batch_labels) loss.backward() optimizer.step()
# Print the loss for each epoch
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')
Test the model
with torch.no_grad(): total_correct = 0 total_samples = 0 for batch_features, batch_labels in dataloader: outputs = model(batch_features) _, predicted = torch.max(outputs.data, 1) total_samples += batch_labels.size(0) total_correct += (predicted == batch_labels).sum().item()
accuracy = total_correct / total_samples
print(f'Test Accuracy: {accuracy}'
原文地址: http://www.cveoy.top/t/topic/hMpP 著作权归作者所有。请勿转载和采集!