import torch
from torch_geometric.data import InMemoryDataset, Data

class My_dataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(My_dataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['edges.txt', 'features1.txt', 'features2.txt', 'label.txt']

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # Implement the download logic here
        pass

    def process(self):
        edges = self.read_edges(self.raw_paths[0])
        features1 = self.read_features(self.raw_paths[1])
        features2 = self.read_features(self.raw_paths[2])
        labels = self.read_labels(self.raw_paths[3])

        data_list = []
        for i in range(len(edges)):
            edge = edges[i]
            feature1 = features1[i]
            feature2 = features2[i]
            label = labels[i]

            x = torch.tensor([feature1, feature2], dtype=torch.float)
            edge_index = torch.tensor([edge], dtype=torch.long).t().contiguous()
            y = torch.tensor(label, dtype=torch.float)

            data = Data(x=x, edge_index=edge_index, y=y)
            data_list.append(data)

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

    def read_edges(self, path):
        with open(path, 'r') as f:
            lines = f.readlines()
        edges = []
        for line in lines:
            source, target = line.strip().split(' ')
            edges.append((int(source), int(target)))
        return edges

    def read_features(self, path):
        with open(path, 'r') as f:
            lines = f.readlines()
        features = []
        for line in lines:
            feature = [float(x) for x in line.strip().split(' ')]
            features.append(feature)
        return features

    def read_labels(self, path):
        with open(path, 'r') as f:
            lines = f.readlines()
        labels = []
        for line in lines:
            label = [int(x) for x in line.strip().split(' ')]
            labels.append(label)
        return labels

# Initialize the dataset
dataset = My_dataset('C:/Users/jh/Desktop/data')

# Mask for train/validation split
mask = torch.zeros(len(dataset), dtype=torch.bool)
mask[:16] = 1  # First 16 nodes as training set
mask[16:] = 0  # Remaining 4 nodes as validation set

train_dataset = dataset[mask]
val_dataset = dataset[~mask]

# --- GCN Model, Training, and Testing ---

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import DataLoader

# Define the GCN model
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

# Initialize the model and define the optimizer
model = GCN(in_channels=2, hidden_channels=16, out_channels=1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Create data loaders for training and validation sets
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# Training
model.train()
for epoch in range(100):
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)
        loss = F.binary_cross_entropy_with_logits(out, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.num_graphs
    avg_loss = total_loss / len(train_dataset)
    print('Epoch: {}, Loss: {:.4f}'.format(epoch, avg_loss))

# Validation
model.eval()
correct = 0
for batch in val_loader:
    with torch.no_grad():
        out = model(batch.x, batch.edge_index)
        pred = torch.round(torch.sigmoid(out))
        correct += (pred == batch.y).sum().item()
accuracy = correct / len(val_dataset)
print('Accuracy: {:.4f}'.format(accuracy))

This code will create a custom dataset, define a simple GCN model, train it on the dataset, and then evaluate its performance on the validation set. You can adjust the model, training parameters, and evaluation metrics based on your specific needs and dataset.

PyTorch Geometric InMemoryDataset: Custom Dataset for GCN Training

原文地址: https://www.cveoy.top/t/topic/mUgz 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录