-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathCNN_testing.py
More file actions
28 lines (23 loc) · 954 Bytes
/
CNN_testing.py
File metadata and controls
28 lines (23 loc) · 954 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from data.data_loader import load_MNIST
from model.CNN_model import CNN, train
# Load data
train_loader, val_loader, test_loader = load_MNIST(
train_size=50000, test_size=10000, val_size=100, batch_size=32
)
# Set device (use GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Instantiate the model and move it to the device
model = CNN(32, 32, 3, 5, 2, 0.25, 128).to(device)
# Create an optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Train for a small number of epochs (adjust as needed for testing)
num_epochs = 10
test_accuracy = train(model, device, train_loader, test_loader, optimizer, num_epochs)
print(f"Final Test Accuracy: {test_accuracy * 100:.2f}%")