import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

# Parameters
Nx, Ny = 64, 64  # Grid size
dx, dy = 0.5, 0.5  # Grid spacing
ngrains = 3  # Number of grains
L = 5.0  # Mobility coefficient
A = B = 1  # Free energy parameters
kappa = 0.1  # Gradient energy coefficient
dt = 0.005  # Time step
nsteps = 1000  # Number of steps 
threshold = 0.5  # Threshold for area fraction
radius_range = (7, 14)  # Range of radii for the grains
volume_fraction_threshold = 0.001  # Threshold for grain existence

# Initialize the domain with circular grains
def initialize_grains(Nx, Ny, ngrains, radius_range):
    eta = np.zeros((Nx, Ny, ngrains))
    centers = np.random.rand(ngrains, 2) * np.array([Nx, Ny])
    radii = np.random.randint(radius_range[0], radius_range[1], ngrains)

    for i in range(ngrains):
        for x in range(Nx):
            for y in range(Ny):
                if (x - centers[i, 0])**2 + (y - centers[i, 1])**2 <= radii[i]**2:
                    eta[x, y, i] = 1.0
    return eta

# Function to approximate Laplacian using a five-point stencil
def laplacian(eta, dx, dy):
    laplace_eta = (np.roll(eta, -1, axis=0) + np.roll(eta, 1, axis=0) - 2 * eta) / dx**2 + \
                  (np.roll(eta, -1, axis=1) + np.roll(eta, 1, axis=1) - 2 * eta) / dy**2
    return laplace_eta

# Phase field model evolution function
def phase_field_evolve(Nx, Ny, ngrains, dt, nsteps, eta_initial):
    eta_hist = [eta_initial]
    volume_fractions = np.zeros((nsteps, ngrains))
    grain_status = np.ones((nsteps, ngrains))

    for step in range(nsteps):
        eta_next = eta_hist[-1].copy()
        for i in range(ngrains):
            if grain_status[step-1, i] == 0:  # Skip extinct grains
                continue
            eta_i = eta_hist[-1][:, :, i]
            sum_eta_squared = sum([eta_hist[-1][:, :, j]**2 for j in range(ngrains) if j != i])
            laplace_eta_i = laplacian(eta_i, dx, dy)

            # Discretized Allen-Cahn equation
            d_eta_dt = -L * (-A * eta_i + B * eta_i**3 + 2 * eta_i * sum_eta_squared - kappa * laplace_eta_i)
            eta_next[:, :, i] += d_eta_dt * dt
            
            # Enforce bounds on order parameters
            eta_next[:, :, i] = np.clip(eta_next[:, :, i], 0.0001, 0.999)

            # Calculate volume fraction
            volume_fraction = np.sum(eta_next[:, :, i]) / (Nx * Ny)
            volume_fractions[step, i] = volume_fraction
            if volume_fraction < volume_fraction_threshold:
                grain_status[step, i] = 0  # Mark grain as extinct
                

        eta_hist.append(eta_next)

    return eta_hist, volume_fractions, grain_status

# Function to compute area fraction for each grain at each time step
def compute_area_fraction(eta_hist, Nx, Ny, threshold):
    area_fractions = np.zeros((len(eta_hist), ngrains))

    for step, eta in enumerate(eta_hist):
        for i in range(ngrains):
            eta_i = eta[:, :, i]
            area_fraction = np.sum(eta_i > threshold) / (Nx * Ny)
            area_fractions[step, i] = area_fraction

    return area_fractions

# Initialize grains
eta_initial = initialize_grains(Nx, Ny, ngrains, radius_range)

# Run the phase field evolution
eta_hist, volume_fractions, grain_status = phase_field_evolve(Nx, Ny, ngrains, dt, nsteps, eta_initial)

# Compute area fractions
area_fractions = compute_area_fraction(eta_hist, Nx, Ny, threshold)

# Plotting area fraction vs. time for each grain
plt.figure(figsize=(10, 6))
for i in range(ngrains):
    plt.plot(np.arange(nsteps + 1), area_fractions[:, i], label=f'Grain {i+1}')
plt.xlabel('Time step')
plt.ylabel('Area fraction')
plt.title('Area fraction vs. Time for each grain')
plt.legend()
plt.show()

# create animation function was generated in a separate prompt response with minor modifications. 
from matplotlib.colors import Normalize

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.colors import Normalize

def create_animation_with_colorbar(eta_hist, interval=100, nprint=10):
    """
    Create an animation visualizing the grain growth over time with a colorbar,
    using the squared sum of eta values and ensuring safe normalization.

    Parameters:
    - eta_hist: A list of numpy arrays, each representing the state of the system at a given time step.
    - interval: Time interval between frames in milliseconds.
    - nprint: Interval of time steps to update the animation.

    Returns:
    - anim: The Matplotlib animation object.
    """
    fig, ax = plt.subplots()
    ax.set_title('Grain Growth Over Time')

    # Compute the squared sum of eta values for the initial frame and normalize
    data = np.sum(eta_hist[0]**2, axis=2)
    max_val = np.max(data) if np.max(data) != 0 else 1  # Avoid division by zero
    data_normalized = data / max_val
    im = ax.imshow(data_normalized, animated=True, cmap='viridis', norm=Normalize(vmin=0, vmax=1))
    
    # Create colorbar
    fig.colorbar(im, ax=ax)

    def update(frame):
        """
        Update the plot for the animation, normalizing the data for each frame.
        """
        new_data = np.sum(eta_hist[frame]**2, axis=2)
        max_val = np.max(new_data) if np.max(new_data) != 0 else 1  # Avoid division by zero
        new_data_normalized = new_data / max_val
        im.set_array(new_data_normalized)
        return (im,)

    anim = animation.FuncAnimation(fig, update, frames=range(0, len(eta_hist), nprint), interval=interval, blit=True)

    return anim

# This function is ready to be used as described in the comment at the bottom of the snippet.
# Usage example (assuming eta_hist is already computed):
anim = create_animation_with_colorbar(eta_hist, interval=50, nprint=10)
# To display in Jupyter Notebook:
from IPython.display import HTML
HTML(anim.to_html5_video())

# To save the animation as a file (uncomment and use in a local environment):
# anim.save('grain_growth_with_colorbar.mp4', writer='ffmpeg', dpi=300)

# To save the animation as a file (uncomment and use in a local environment):
# anim.save('grain_growth_with_colorbar.mp4', writer='ffmpeg', dpi=300)




