import torch
import torch.utils.model_zoo as model_zoo
from segmentation_models_pytorch import UnetPlusPlus
from segmentation_models_pytorch.encoders import get_preprocessing_params

# Assuming get_pretrained_microscopynet_url function is already provided
# from custom_utils import get_pretrained_microscopynet_url

def get_pretrained_microscopynet_url(encoder, encoder_weights, version=1.1, 
                                     self_supervision=''):
    """Get the url to download the specified pretrained encoder.

    Args:
        encoder (str): pretrained encoder model name (e.g. resnet50)
        encoder_weights (str): pretraining dataset, either 'micronet' or 
            'imagenet-micronet' with the latter indicating the encoder
            was first pretrained on imagenet and then finetuned on microscopynet
        version (float): model version to use, defaults to latest. 
            Current options are 1.0 or 1.1.
        self_supervision (str): self-supervision method used. If self-supervision
            was not used set to '' (which is default).

    Returns:
        str: url to download the pretrained model
    """
    
    # there is an error with the name for resnext101_32x8d so catch and return
    # (currently there is only version 1.0 for this model so don't need to check version.)
    if encoder == 'resnext101_32x8d': 
        return 'https://nasa-public-data.s3.amazonaws.com/microscopy_segmentation_models/resnext101_pretrained_microscopynet_v1.0.pth.tar'

    # only resnet50/micronet has version 1.1 so I'm not going to overcomplicate this right now.
    if encoder != 'resnet50' or encoder_weights != 'micronet':
        version = 1.0

    # setup self-supervision
    if self_supervision != '':
        version = 1.0
        self_supervision = '_' + self_supervision

    # correct for name change for URL
    if encoder_weights == 'micronet':
        encoder_weights = 'microscopynet'
    elif encoder_weights == 'image-micronet':
        encoder_weights = 'imagenet-microscopynet'
    else:
        raise ValueError("encoder_weights must be 'micronet' or 'image-micronet'")

    # get url
    url_base = 'https://nasa-public-data.s3.amazonaws.com/microscopy_segmentation_models/'
    url_end = '_v%s.pth.tar' %str(version)
    return url_base + f'{encoder}{self_supervision}_pretrained_{encoder_weights}' + url_end

# Model Setup

def setup_segmentation_model(encoder_name='resnet50', class_values=None, encoder_weights='micronet'):
    if class_values is None:
        raise ValueError("class_values must be provided and should not be None")
    
    # Determine number of classes
    num_classes = len(class_values)
    
    # Define activation function based on number of classes
    activation = 'softmax2d' if num_classes > 1 else 'sigmoid'
    
    # Initialize U-Net++ model
    model = UnetPlusPlus(
        encoder_name=encoder_name,
        encoder_weights=None,  # Skip default weight loading
        in_channels=3,
        classes=num_classes,
        activation=activation
    )
    
    # Determine device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Load custom weights
    url = get_pretrained_microscopynet_url(encoder_name, encoder_weights)
    state_dict = model_zoo.load_url(url, map_location=map_location)
    model.encoder.load_state_dict(state_dict)
    
    # Move model to the appropriate device
    model = model.to(device)
    
    return model, device

# Loss and IoU metric

import torch
import torch.nn.functional as F
import numpy as np
import segmentation_models_pytorch as smp

# Combined Dice and BCE loss function
def dice_bce_loss(inputs, targets, bce_weight=0.5):
    # Apply sigmoid to inputs
    inputs = torch.sigmoid(inputs)
    
    # Flatten inputs and targets
    inputs = inputs.reshape(-1) #manual 
    targets = targets.reshape(-1) #manual
    
    # Compute Dice loss
    intersection = (inputs * targets).sum()
    dice_loss = 1 - (2. * intersection + 1) / (inputs.sum() + targets.sum() + 1)
    
    # Compute BCE loss using logits
    bce_loss = F.binary_cross_entropy_with_logits(inputs, targets.float())
    
    # Combine Dice and BCE losses
    combined_loss = dice_loss + bce_weight * bce_loss
    return combined_loss

# IoU metric function using smp
def compute_iou(output, target):
    # Get statistics for IoU calculation
    tp, fp, fn, tn = smp.metrics.get_stats(output, target, mode='multilabel', threshold=0.5)
    
    # Compute IoU score
    iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
    return iou_score