# -*- coding: utf-8 -*-
"""
Created on Fri Jul 12 19:15:36 2024

@author: Saad Qureshi
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation, rc

# Constants and Material Properties
E1 = 280e9  # Young's modulus in Pascals for first half (Chromium)
nu1 = 0.2   # Poisson's ratio for first half
rho1 = 7100 # Density in kg/m^3 for first half
E2 = 70e9   # Young's modulus in Pascals for second half (Aluminum)
nu2 = 0.33  # Poisson's ratio for second half
rho2 = 2700 # Density in kg/m^3 for second half
c1 = np.sqrt(E1 / rho1)  # Wave speed for first half
c2 = np.sqrt(E2 / rho2)  # Wave speed for second half

domain_size = 200
dx = 1.0
dy = 1.0
dt = dx / (max(c1, c2)*(2))  # CFL condition for stability

calculate_principal_stresses = True
boundary_type = 'Mur'  # Can be 'Mur' or 'Dirichlet'
excitation_type = 'gaussian'  # Can be 'sinusoidal' or 'gaussian'

# Initialize fields
x = np.linspace(0, domain_size-1, domain_size)
y = np.linspace(0, domain_size-1, domain_size)
u = np.zeros((domain_size, domain_size))
u_prev = np.zeros((domain_size, domain_size))
u_next = np.zeros((domain_size, domain_size))
stress_xx = np.zeros((domain_size, domain_size))
stress_yy = np.zeros((domain_size, domain_size))
stress_xy = np.zeros((domain_size, domain_size))

# Source definitions
sources = [{"position": (100, 100), "frequency": 0.04, "amplitude": 80}]

def source(t, frequency, amplitude, excitation_type):
   
        if excitation_type == 'sinusoidal':
            omega = 2 * np.pi * frequency
            return amplitude * np.sin(omega * t)
        elif excitation_type == 'gaussian':
            return amplitude * np.exp(-((t - 5) ** 2) / (2 * (0.5 ** 2)))

# Function to get properties based on location
def get_properties(i):
    if i < domain_size // 2:
        return E1, nu1, c1
    else:
        return E2, nu2, c2

# Update function with Mur boundary conditions adjusted per wave speed
def update(t):
    global u, u_prev, u_next, stress_xx, stress_yy, stress_xy
    # Apply dynamic updates within the domain
    for i in range(1, domain_size-1):
        for j in range(1, domain_size-1):
            E, nu, c = get_properties(i)
            u_next[i, j] = (2 * u[i, j] - u_prev[i, j] +
                            c**2 * dt**2 / dx**2 * (u[i+1, j] + u[i-1, j] + u[i, j+1] + u[i, j-1] - 4 * u[i, j]))

    # Apply sources
    for source_info in sources:
        src_x, src_y = source_info["position"]
        u_next[src_x, src_y] += source(t, source_info["frequency"], source_info["amplitude"], excitation_type)

    # Apply Mur boundary conditions correctly for wave speed
    if boundary_type == 'Mur':
        for j in range(1, domain_size-1):
            c_left = get_properties(1)[2]
            c_right = get_properties(domain_size-2)[2]
            u_next[0, j] = u_prev[1, j] + (c_left * dt - dx) / (c_left * dt + dx) * (u_next[1, j] - u_prev[0, j])
            u_next[-1, j] = u_prev[-2, j] + (c_right * dt - dx) / (c_right * dt + dx) * (u_next[-2, j] - u_prev[-1, j])

        for i in range(1, domain_size-1):
            c_top = get_properties(i)[2]
            c_bottom = get_properties(i)[2]
            u_next[i, 0] = u_prev[i, 1] + (c_bottom * dt - dy) / (c_bottom * dt + dy) * (u_next[i, 1] - u_prev[i, 0])
            u_next[i, -1] = u_prev[i, -2] + (c_top * dt - dy) / (c_top * dt + dy) * (u_next[i, -2] - u_prev[i, -1])

    elif boundary_type == 'Dirichlet':
        u_next[0, :] = 0
        u_next[-1, :] = 0
        u_next[:, 0] = 0
        u_next[:, -1] = 0

    u_prev, u = u, u_next.copy()

    # Recalculate stresses
    grad_u_x = np.gradient(u, axis=0)  # Gradient along x-axis
    grad_u_y = np.gradient(u, axis=1)  # Gradient along y-axis
    for i in range(1, domain_size-1):
        E, nu, _ = get_properties(i)
        stress_xx[i, 1:-1] = E / (1 - nu**2) * (grad_u_x[i, 1:-1] + nu * grad_u_y[i, 1:-1])
        stress_yy[i, 1:-1] = E / (1 - nu**2) * (grad_u_y[i, 1:-1] + nu * grad_u_x[i, 1:-1])
        stress_xy[i, 1:-1] = E / (2 * (1 + nu)) * (grad_u_x[i, 1:-1] + grad_u_y[i, 1:-1])

    if calculate_principal_stresses:
        # Calculate principal stresses
        sigma_avg = (stress_xx + stress_yy) / 2
        sigma_diff = (stress_xx - stress_yy) / 2
        R = np.sqrt(sigma_diff**2 + stress_xy**2)
        principal_stress_1 = sigma_avg + R
        principal_stress_2 = sigma_avg - R
        return principal_stress_1, principal_stress_2
    else:
        return stress_xx, stress_yy, stress_xy

# Define the animation function and plot setup
# Define the animation function and plot setup
fig, ax = plt.subplots(1, 3, figsize=(18, 6))

# Initial settings for stress_xx plot
im1 = ax[0].imshow(stress_xx, origin='lower', extent=[0, domain_size * dx, 0, domain_size * dy], cmap='viridis', vmin = np.min(stress_xx), vmax = np.max(stress_xx))
cbar1 = fig.colorbar(im1, ax=ax[0])
cbar1.set_label('Stress')
ax[0].set_title('Stress 1')
ax[0].set_xlabel('x axis (mm)')
ax[0].set_ylabel('y axis (mm)')

# Initial settings for stress_yy plot
im2 = ax[1].imshow(stress_yy, origin='lower', extent=[0, domain_size * dx, 0, domain_size * dy], cmap='viridis', vmin = np.min(stress_yy), vmax = np.max(stress_yy))
cbar2 = fig.colorbar(im2, ax=ax[1])
cbar2.set_label('Stress')
ax[1].set_title('Stress 2')
ax[1].set_xlabel('x axis (mm)')
ax[1].set_ylabel('y axis (mm)')

# Initial settings for stress_xy plot
im3 = ax[2].imshow(stress_xy, origin='lower', extent=[0, domain_size * dx, 0, domain_size * dy], cmap='viridis', vmin = np.min(stress_xy), vmax = np.max(stress_xy))
cbar3 = fig.colorbar(im3, ax=ax[2])
cbar3.set_label('Stress')
ax[2].set_title('Stress 3')
ax[2].set_xlabel('x axis (mm)')
ax[2].set_ylabel('y axis(mm)')

# Main title for the figure
fig.suptitle('Stress Distribution Over Time', fontsize=16)

# Function to update plots
def animate(t):
    if calculate_principal_stresses:
        stress_1, stress_2 = update(t)
        im1.set_data(stress_1)
        im2.set_data(stress_2)
        im3.set_data(np.zeros_like(stress_1))  # Placeholder if only two plots needed
    else:
        stress_xx, stress_yy, stress_xy = update(t)
        im1.set_data(stress_xx)
        im2.set_data(stress_yy)
        im3.set_data(stress_xy)
    return im1, im2, im3

ani = animation.FuncAnimation(fig, animate, frames=200, interval=50, blit=True)
plt.tight_layout()
ani.save('stress_ani2.mp4', writer='ffmpeg')
plt.show()

