import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import pandas as pd

# Function to get obstacle parameters from the user
def get_obstacle_parameters():
    obstacle_type = input("Enter the type of obstacle (square/circle): ").strip().lower()
    if obstacle_type == "square":
        return get_square_parameters()
    elif obstacle_type == "circle":
        return get_circle_parameters()
    else:
        raise ValueError("Invalid obstacle type. Please enter 'square' or 'circle'.")

# Function to get parameters for a square obstacle from the user
def get_square_parameters():
    while True:
        try:
            cx = int(input("Enter the x-coordinate of the square center (cx): "))
            cy = int(input("Enter the y-coordinate of the square center (cy): "))
            side = int(input("Enter the side length of the square: "))
            # Ensure the square is within the simulation domain
            if cx - side // 2 < 0 or cx + side // 2 >= nx or cy - side // 2 < 0 or cy + side // 2 >= ny:
                raise ValueError("Square parameters are outside the region or not feasible.")
            return "square", cx, cy, side
        except ValueError as e:
            print(e)

# Function to get parameters for a circular obstacle from the user
def get_circle_parameters():
    while True:
        try:
            cx = int(input("Enter the x-coordinate of the circle center (cx): "))
            cy = int(input("Enter the y-coordinate of the circle center (cy): "))
            r = int(input("Enter the radius of the circle (r): "))
            # Ensure the circle is within the simulation domain
            if cx - r < 0 or cx + r >= nx or cy - r < 0 or cy + r >= ny:
                raise ValueError("Circle parameters are outside the region or not feasible.")
            return "circle", cx, cy, r
        except ValueError as e:
            print(e)

###### Flow definition #########################################################
maxIter = 4000  # Total number of time iterations.
nx, ny = 400, 100  # Number of lattice nodes in the x and y directions.
ly = ny - 1  # Height of the domain in lattice units.
Re = 50.0  # Reynolds number.
uLB = 0.04  # Velocity in lattice units.

# Get obstacle parameters from the user
obstacle_type, cx, cy, size = get_obstacle_parameters()
# Calculate viscosity in lattice units based on Reynolds number and characteristic length (size)
nulb = uLB * size / Re
# Calculate relaxation parameter for the Lattice Boltzmann method
omega = 1. / (3. * nulb + 0.5)

###### Lattice Constants #######################################################
# Define the discrete velocity set for D2Q9 model (9 velocities)
v = np.array([[1, 1], [1, 0], [1, -1], [0, 1], [0, 0], [0, -1], [-1, 1], [-1, 0], [-1, -1]])
# Define the weights for each velocity direction
t = np.array([1./36., 1./9., 1./36., 1./9., 4./9., 1./9., 1./36., 1./9., 1./36.])

# Define columns for streaming step (used for shifting the distribution functions)
col1 = np.array([0, 1, 2])
col2 = np.array([3, 4, 5])
col3 = np.array([6, 7, 8])

###### Function Definitions ####################################################
# Function to compute macroscopic variables (density and velocity) from distribution functions
def macroscopic(fin):
    """
    Compute macroscopic variables (density rho and velocity u) from the distribution functions (fin).
    Args:
    fin (ndarray): Distribution functions.
    Returns:    rho (ndarray): Density field.
                u (ndarray): Velocity field."""
    # Compute density as the sum of distribution functions
    rho = np.sum(fin, axis=0)
    # Initialize velocity array with zeros
    u = np.zeros((2, nx, ny))
    # Compute velocity as the weighted sum of distribution functions
    for i in range(9):
        u[0, :, :] += v[i, 0] * fin[i, :, :]
        u[1, :, :] += v[i, 1] * fin[i, :, :]
    # Normalize velocity by density
    u = u / rho
    return rho, u

# Function to compute the equilibrium distribution function
def equilibrium(rho, u):
    """
    Compute the equilibrium distribution function based on density and velocity.
    Args:
    rho (ndarray): Density field.
    u (ndarray): Velocity field.
    Returns:  feq (ndarray): Equilibrium distribution functions.   """
    # Compute the square of the velocity magnitude
    usqr = 3. / 2. * (u[0]**2 + u[1]**2)
    # Initialize equilibrium distribution function array
    feq = np.zeros((9, nx, ny))
    # Compute equilibrium distribution function for each velocity direction
    for i in range(9):
        cu = 3. * (v[i, 0] * u[0, :, :] + v[i, 1] * u[1, :, :])
        feq[i, :, :] = rho * t[i] * (1. + cu + 0.5 * cu**2 - usqr)
    return feq

###### Setup: obstacle and velocity inlet with perturbation ########
# Function to create a mask for a square obstacle
def obstacle_fun_square(x, y):
    return np.logical_and(np.abs(x - cx) <= size // 2, np.abs(y - cy) <= size // 2)

# Function to create a mask for a circular obstacle
def obstacle_fun_circle(x, y):
    return (x - cx)**2 + (y - cy)**2 <= size**2

# Create the obstacle mask based on user input
if obstacle_type == "square":
    obstacle = np.fromfunction(obstacle_fun_square, (nx, ny))
else:
    obstacle = np.fromfunction(obstacle_fun_circle, (nx, ny))

# Function to initialize the velocity field with a slight perturbation
def inivel(d, x, y):
    """
    Initialize the velocity field with a slight perturbation to trigger instabilities.
    Args:
    d (int): Dimension index (0 for x-direction, 1 for y-direction).
    x (ndarray): x-coordinates of the grid points.
    y (ndarray): y-coordinates of the grid points.
    Returns:
    velocity (ndarray): Initial velocity field.
    """
    return (1. - d) * uLB * (1. + 1.e-4 * np.sin(y / ly * 2. * np.pi))

# Initialize the velocity field with the perturbation
vel = np.fromfunction(inivel, (2, nx, ny))

# Initialize the distribution functions at equilibrium with the initial velocity
fin = equilibrium(1., vel)

###### Main time loop ##########################################################
for time in range(maxIter):
    # Right wall: outflow condition (copy the values from the second last column to the last column)
    fin[col3, -1, :] = fin[col3, -2, :]

    # Compute macroscopic variables (density and velocity)
    rho, u = macroscopic(fin)

    # Left wall: inflow condition (impose velocity and compute density)
    u[:, 0, :] = vel[:, 0, :]
    rho[0, :] = 1. / (1. - u[0, 0, :]) * (np.sum(fin[col2, 0, :], axis=0) + 2. * np.sum(fin[col3, 0, :], axis=0))

    # Compute equilibrium distribution function based on updated macroscopic variables
    feq = equilibrium(rho, u)
    fin[[0, 1, 2], 0, :] = feq[[0, 1, 2], 0, :] + fin[[8, 7, 6], 0, :] - feq[[8, 7, 6], 0, :]

    # Collision step: relaxation towards equilibrium
    fout = fin - omega * (fin - feq)

    # Bounce-back condition for obstacle: reflect distribution functions
    for i in range(9):
        fout[i, obstacle] = fin[8 - i, obstacle]

    # Streaming step: propagate the distribution functions to neighboring nodes
    for i in range(9):
        fin[i, :, :] = np.roll(np.roll(fout[i, :, :], v[i, 0], axis=0), v[i, 1], axis=1)

    # Visualization every 100 iterations
    if time % 100 == 0:
        plt.clf()
        # Plot velocity magnitude
        plt.imshow(np.sqrt(u[0]**2 + u[1]**2).transpose(), cmap=cm.Reds)
        plt.colorbar()
        plt.title(f'Velocity magnitude at time step {time}')
        plt.pause(0.1)

###### Post-Processing #########################################################
# Compute final macroscopic variables (density and velocity) from distribution functions
rho, u = macroscopic(fin)

# Set velocity to zero inside the obstacle (no flow inside the obstacle)
u[0, obstacle] = 0
u[1, obstacle] = 0

# Pressure field calculation (assuming ideal gas law: p = rho * RT, with RT = 1)
pressure = rho
pressure[obstacle] = np.nan  # Mask the pressure inside the obstacle

###### Plotting ##########################################################
# Compute velocity magnitude field
velocity_magnitude = np.sqrt(u[0]**2 + u[1]**2)
velocity_magnitude[obstacle] = np.nan  # Mask the velocity magnitude inside the obstacle

plt.figure(figsize=(12, 6))
plt.subplot(121)
# Plot velocity magnitude
plt.imshow(velocity_magnitude.transpose(), cmap=cm.viridis)
cbar=plt.colorbar(location='bottom')
cbar.ax.tick_params(labelsize=13)
plt.title('Velocity Magnitude(m/s)',size='16')
plt.xticks(fontsize=13)
plt.yticks(fontsize=13)

plt.subplot(122)
# Plot pressure field
plt.imshow(pressure.transpose(), cmap=cm.viridis)
cbar=plt.colorbar(location='bottom')
cbar.ax.tick_params(labelsize=13)
plt.title('Pressure(Pa)',size='16')
plt.xticks(fontsize=13)
plt.yticks(fontsize=13)

plt.tight_layout()
plt.show()

###### Streamline Plot #########################################################
# Plot streamlines to visualize the flow direction
plt.figure(figsize=(12, 6))
plt.streamplot(np.arange(nx), np.arange(ny), u[0].transpose(), u[1].transpose(), color=velocity_magnitude.transpose(), density=2, cmap=cm.viridis)
cbar=plt.colorbar(location='bottom')
cbar.ax.tick_params(labelsize=18)
plt.title('Streamlines plot',size='23')
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.gca().set_aspect('equal')  # Set aspect ratio to be equal
plt.tight_layout()
plt.show()
