# Import necessary libraries
import os
import copy
import pandas as pd
import random
import numpy as np
from sklearn.model_selection import train_test_split
import torch
from torchvision import transforms, models
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.preprocessing import LabelEncoder

# Set random seeds for reproducibility
random_seed = 99
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)

# Data Preparation

# Directory path
base_path = '/Users/adityadeshmukh/Desktop/RJSC'

# Alloy folders
alloy_folders = {
    'hr_alloys20220208': 'hr',
    'CPJ_alloys': 'cpj',
    'P92 OTHER': 'p92'
}

# Creating a dataframe
data = {'alloy': [], 'image': [], 'alloy type': []}

# Looping through each main folder and its subfolders
for folder, alloy in alloy_folders.items():
    folder_path = os.path.join(base_path, folder)
    for subfolder in os.listdir(folder_path):
        subfolder_path = os.path.join(folder_path, subfolder)
        if os.path.isdir(subfolder_path):
            for file in os.listdir(subfolder_path):
                if file.endswith('.bmp'):
                    file_path = os.path.join(subfolder_path, file)
                    data['alloy'].append(alloy)
                    data['image'].append(file_path)
                    data['alloy type'].append(subfolder)

# Converting to pandas DataFrame
df = pd.DataFrame(data)
print(df.head())

# Label Encoding: Convert string labels to numerical labels (which is what the model expects)
label_encoder = LabelEncoder()
df['alloy type'] = label_encoder.fit_transform(df['alloy type'])

# Splitting data
# the first line splits data into training:test sets in the ratio 70:30.
# the second line splits the test set into two halves: validation set and test set.
# so that ultimately we have train:valid:test = 70:15:15
train_df, test_df = train_test_split(df, test_size=0.3, stratify=df['alloy type'], random_state=random_seed)
val_df, test_df = train_test_split(test_df, test_size=0.5, stratify=test_df['alloy type'], random_state=random_seed)

# Data augmentation and normalization.
# Two different sets of transformations are necessary because data augmentation transformations
# only apply to the training set, not test and validation.
data_transforms = {
    'train': transforms.Compose([
        transforms.ToTensor(),
        # Crop the top and bottom 25 pixels to remove miscellaneous SEM annotations
        transforms.Lambda(lambda x: x[:, 25:-25, :]),
        # Randomly crop and resize images to 224x224 (required for pre-trained models)
        transforms.RandomResizedCrop(224),
        # Randomly apply horizontal flipping and rotation for data augmentation:
        # Each time an image is loaded during training, the DataLoader applies the transformations randomly.
        # Essentially, this means the model sees slightly different versions of the training images throughout the
        # training process, which helps it generalize better by learning from a more diverse set of data representations.
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        # Normalize images based on pre-defined mean and standard deviation of ImageNet dataset
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x[:, 25:-25, :]),
        transforms.Resize((224, 224)),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Custom dataset class to load images and apply transformations
class SteelDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_name = self.dataframe.iloc[idx, 1]
        image = Image.open(img_name)
        label = self.dataframe.iloc[idx, 2]

        if self.transform:
            image = self.transform(image)

        return image, label


# Datasets
train_dataset = SteelDataset(train_df, transform=data_transforms['train'])
val_dataset = SteelDataset(val_df, transform=data_transforms['val'])
test_dataset = SteelDataset(test_df, transform=data_transforms['val'])

# DataLoaders in PyTorch are iterators that enable efficient loading of data during the training, validation,
# and testing phases of a machine learning model.
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, worker_init_fn=lambda _: np.random.seed(random_seed))
val_loader = DataLoader(val_dataset, batch_size=batch_size, worker_init_fn=lambda _: np.random.seed(random_seed))
test_loader = DataLoader(test_dataset, batch_size=batch_size, worker_init_fn=lambda _: np.random.seed(random_seed))

# Device configuration
# if your device has a gpu, it needs to be declared explicitly. For macbooks, the gpu is mps.
# if you use windows or linux, you should try cuda instead.
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Load and modify pretrained models
def load_model(model_name, num_classes):
    if model_name == "resnet":
        # When pretrained=True, it means that the model is initialized with weights that have already been learned.
        model = models.resnet50(pretrained=True)
        # This line extracts the number of input features to the last fully connected layer (fc)
        num_ftrs = model.fc.in_features
        # This line replaces the last fc with a new fc tailored for our task by setting its output to number of classes,
        # or number of steel alloys.
        model.fc = nn.Linear(num_ftrs, num_classes)
    elif model_name == "densenet":
        model = models.densenet121(pretrained=True)
        num_ftrs = model.classifier.in_features
        model.classifier = nn.Linear(num_ftrs, num_classes)
    return model


# Number of classes
num_classes = len(df['alloy type'].unique())

# Load models
resnet = load_model("resnet", num_classes).to(device)
densenet = load_model("densenet", num_classes).to(device)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)

# Loss and optimizer
# The loss quantifies how well the model's predictions match labels in the training data.
# The optimizer, here stochastic gradient descent, is an algorithm that adjusts the parameters (weights and biases)
# of the neural network during training to minimize the loss.
criterion = nn.CrossEntropyLoss()
# Momentum is generally set to 0.9. LR, the learning rate, will be tuned.
optimizer_resnet = optim.SGD(resnet.parameters(), lr=0.001, momentum=0.9)
optimizer_densenet = optim.SGD(densenet.parameters(), lr=0.001, momentum=0.9)

# Learning rate scheduler.
# LR is tuned in this way: if the loss plateaus for 5 epochs (iterations), LR is reduced by a factor of 0.1
scheduler_resnet = optim.lr_scheduler.ReduceLROnPlateau(optimizer_resnet, 'min', factor=0.1, patience=5, verbose=True)
scheduler_densenet = optim.lr_scheduler.ReduceLROnPlateau(optimizer_densenet,  'min', factor=0.1, patience=5, verbose=True)

# Training function
# in machine learning, data is fed to the model as dataloaders for a number of iterations, called epochs.
# In each training epoch, model tries to learn optimum weights that minimize the training loss. This model is evaluated
# on validation data in a validation epoch. Hyperparameters are tuned in order to obtain lower validation loss.
# Accuracies are also stored for each epoch. The model associated with the best validation
# accuracy is returned, along with the entire history of training loss, training accuracy, validation loss, and
# validation accuracy.
def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=25):
    best_acc = 0.0
    # Initialize arrays to store metrics
    train_losses, val_losses, train_accs, val_accs = [], [], [], []

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode. Here weights are learned.
            else:
                model.eval()   # Set model to evaluate mode. Here we only use the learned weights to make predictions.

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.float() / len(dataloaders[phase].dataset)

            # Record metrics
            if phase == 'train':
                train_losses.append(epoch_loss)
                train_accs.append(epoch_acc.item())
            else:
                val_losses.append(epoch_loss)
                val_accs.append(epoch_acc.item())
                scheduler.step(epoch_loss)  # Adjust learning rate based on validation loss

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # Deep copy the best-performing model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    print('Best val Acc: {:4f}'.format(best_acc))

    # Load best model weights
    model.load_state_dict(best_model_wts)

    return model, train_losses, val_losses, train_accs, val_accs

# this function plots train_losses, val_losses and train_accs, val_accs in two graphs.
def plot_performance(train_losses, val_losses, train_accs, val_accs, file_name):
    np.random.seed(random_seed)
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training loss')
    plt.plot(val_losses, label='Validation loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Training accuracy')
    plt.plot(val_accs, label='Validation accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.savefig(file_name)


# Train models
dataloaders = {'train': train_loader, 'val': val_loader}
print('\n RESNET-50________________________________ \n')
resnet, resnet_train_loss, resnet_val_loss, resnet_train_acc, resnet_val_acc  = train_model(resnet, dataloaders, criterion, optimizer_resnet, scheduler_resnet, num_epochs=500)
print('\n DENSENET-121______________________________ \n')
densenet, densenet_train_loss, densenet_val_loss, densenet_train_acc, densenet_val_acc = train_model(densenet, dataloaders, criterion, optimizer_densenet, scheduler_densenet, num_epochs=500)

torch.save(resnet.state_dict(), 'resnet_model.pth')
torch.save(densenet.state_dict(), 'densenet_model.pth')

# Ensemble
# we average the outputs of our two models and use this averaged output to make predictions.
class AveragingEnsemble(nn.Module):
    def __init__(self, modelA, modelB):
        super(AveragingEnsemble, self).__init__()
        self.modelA = modelA
        self.modelB = modelB

    def forward(self, x):
        # Get predictions from both models
        outputA = self.modelA(x)
        outputB = self.modelB(x)

        # Average the predictions
        average_output = (outputA + outputB) / 2
        return average_output

ensemble_model = AveragingEnsemble(resnet, densenet)


# Evaluate Ensemble
def evaluate_model(model, dataloader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.tolist())
            all_labels.extend(labels.tolist())

    accuracy = accuracy_score(all_labels, all_preds)
    return accuracy


test_accuracy = evaluate_model(ensemble_model, test_loader)
print('Test Accuracy of Ensemble: {:.4f}%'.format(test_accuracy * 100))

# Plotting functions
plot_performance(resnet_train_loss, resnet_val_loss, resnet_train_acc, resnet_val_acc, "res1.png")
plot_performance(densenet_train_loss, densenet_val_loss, densenet_train_acc, densenet_val_acc, "res2.png")

