Code
Code
import os
import numpy as np
import torch
import torch.nn as nn
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from torch.optim import Adam
import torchvision
import pathlib
import glob
import matplotlib
import matplotlib.pyplot as plt
# Dataset directories
train_dir = r'c:\Users\User\Downloads\Weather\dataset2\dataset2\train'
test_dir = r'c:\Users\User\Downloads\Weather\dataset2\dataset2\test'
# Training loop
num_epochs = 10
best_accuracy = 0.0
# Lists to store training and testing accuracy values
train_accuracy_history = []
test_accuracy_history = []
# Calculate the size of training and testing images
train_count = len(glob.glob(train_dir + '/**/*.jpg'))
test_count = len(glob.glob(test_dir + '/**/*.jpg'))
training_accuracy /= train_count
training_loss /= train_count
train_accuracy_history.append(training_accuracy)
test_accuracy /= test_count
test_accuracy_history.append(test_accuracy)
print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {training_loss:.4f}, Train
Accuracy: {training_accuracy:.4f}, Test Accuracy: {test_accuracy:.4f}")
torch.save(Model.state_dict(), 'best_checkpoint.model')
# Plotting the accuracy
epochs = range(1, num_epochs+1)
plt.figure(figsize=(10, 5))
plt.plot(epochs, train_accuracy_history, label='Training Accuracy', marker='o')
plt.plot(epochs, test_accuracy_history, label='Testing Accuracy', marker='o')
plt.title('Training and Testing Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.xticks(epochs)
plt.legend()
plt.grid(True)
plt.show()
# Get a single batch of images and labels from the test dataset
images, labels = next(iter(test_loader))
# Plot the original and predicted images, along with their labels
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(np.transpose(original_image, (1, 2, 0)))
plt.title(f'Original: {classes[original_label]}')
# Since predicted_image is not the actual image data, plot the original image again
plt.subplot(1, 2, 2)
plt.imshow(np.transpose(original_image, (1, 2, 0)))
plt.title(f'Predicted: {classes[predicted_label]}')
plt.tight_layout()
plt.show()