import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.animation import PillowWriter

def get_desktop_path():
    home = os.path.expanduser("~")  # Gets the home directory
    desktop = os.path.join(home, 'Desktop')  # Appends Desktop to the home path
    return desktop

# Constants
kB = 1.38064852e-23  # Boltzmann constant (J/K)
epsilon = 1.65e-21  # Lennard-Jones potential well depth (J)
sigma = 3.4e-10  # Lennard-Jones potential distance parameter (m)
mass = 6.63e-26  # Mass of argon atom (kg)
T = 300  # Temperature (K)
L = 10e-9  # Length of the simulation box (m)
num_particles = 100  # Number of particles
dt = 2e-15  # Time step (s)
num_steps = 5000  # Number of simulation steps

def initialize_positions(num_particles, L):
    """Initialize positions of particles in an FCC lattice"""
    positions = np.zeros((num_particles, 2))
    num_cells_per_side = int(np.ceil(np.sqrt(num_particles)))
    cell_size = L / num_cells_per_side
    particle_idx = 0

    for i in range(num_cells_per_side):
        for j in range(num_cells_per_side):
            if particle_idx >= num_particles:
                break
            positions[particle_idx] = np.array([i * cell_size, j * cell_size])
            particle_idx += 1

    # Ensure particles are not on the edge of the box
    positions += cell_size / 2
    return positions

def initialize_velocities(num_particles, T, mass):
    """Initialize velocities of particles to follow Maxwell-Boltzmann distribution"""
    velocities = np.random.normal(0, np.sqrt(kB * T / mass), (num_particles, 2))
    velocities -= np.mean(velocities, axis=0)  # Ensure zero net momentum
    return velocities

def compute_forces(positions):
    """Compute forces and potential energy using Lennard-Jones potential"""
    forces = np.zeros_like(positions)
    potential_energy = 0.0

    for i in range(len(positions)):
        for j in range(i + 1, len(positions)):
            r_ij = positions[i] - positions[j]
            r_ij -= L * np.round(r_ij / L)  # Apply periodic boundary conditions
            r2 = np.dot(r_ij, r_ij)
            if r2 < (3 * sigma) ** 2:  # Apply a cutoff to save computation
                r2_inv = sigma ** 2 / r2
                r6_inv = r2_inv ** 3
                r12_inv = r6_inv ** 2
                force_magnitude = 24 * epsilon * (2 * r12_inv - r6_inv) / r2
                forces[i] += force_magnitude * r_ij
                forces[j] -= force_magnitude * r_ij
                potential_energy += 4 * epsilon * (r12_inv - r6_inv)

    return forces, potential_energy

def velocity_verlet(positions, velocities, forces, dt):
    """Perform a single step of Velocity Verlet integration"""
    positions += velocities * dt + 0.5 * forces * dt ** 2 / mass
    positions = positions % L  # Apply periodic boundary conditions
    new_forces, potential_energy = compute_forces(positions)
    velocities += 0.5 * (forces + new_forces) * dt / mass
    return positions, velocities, new_forces, potential_energy

def run_simulation(num_particles, L, T, mass, dt, num_steps):
    positions = initialize_positions(num_particles, L)
    velocities = initialize_velocities(num_particles, T, mass)
    forces, potential_energy = compute_forces(positions)
    kinetic_energy = 0.5 * mass * np.sum(velocities ** 2)

    kinetic_energies = []
    potential_energies = []
    total_energies = []

    for step in range(num_steps):
        positions, velocities, forces, potential_energy = velocity_verlet(positions, velocities, forces, dt)
        kinetic_energy = 0.5 * mass * np.sum(velocities ** 2)
        total_energy = kinetic_energy + potential_energy

        kinetic_energies.append(kinetic_energy)
        potential_energies.append(potential_energy)
        total_energies.append(total_energy)

    return positions, kinetic_energies, potential_energies, total_energies

def plot_energies(kinetic_energies, potential_energies, total_energies, filename):
    plt.figure()
    plt.plot(kinetic_energies, label='Kinetic Energy')
    plt.plot(potential_energies, label='Potential Energy')
    plt.plot(total_energies, label='Total Energy')
    plt.xlabel('Time Step')
    plt.ylabel('Energy (J)')
    plt.legend()
    plt.savefig(filename)
    plt.show()

def save_snapshot(positions, filename):
    plt.figure()
    plt.scatter(positions[:, 0], positions[:, 1])
    plt.xlim(0, L)
    plt.ylim(0, L)
    plt.xlabel('X Position (m)')
    plt.ylabel('Y Position (m)')
    plt.title('Particle Positions')
    plt.savefig(filename)
    plt.show()

# Test 1: Periodicity Test with Animation
def test_periodicity(L):
    """Test for periodicity by checking if a single atom reenters the box with animation"""
    positions = np.array([[0, L/2]])  # Start the particle at the left edge
    velocities = np.array([[1e4, 0]])  # Increased velocity for faster movement
    fig, ax = plt.subplots()
    ax.set_xlim(0, L)
    ax.set_ylim(0, L)
    ax.set_xlabel('X Position (m)')
    ax.set_ylabel('Y Position (m)')
    particle, = ax.plot([], [], 'ro', markersize=12)

    def init():
        particle.set_data([], [])
        return particle,

    def update(frame):
        nonlocal positions, velocities
        forces = np.zeros_like(positions)  # No other particles, so no forces
        positions, velocities, _, _ = velocity_verlet(positions, velocities, forces, dt)
        particle.set_data([positions[0, 0]], [positions[0, 1]])  # Pass as lists
        return particle,

    num_frames = int(L / (velocities[0, 0] * dt)) + 1  # Ensure enough frames to cover the entire path
    ani = animation.FuncAnimation(fig, update, frames=num_frames, init_func=init, blit=True)
    gif_path = os.path.join(get_desktop_path(), 'periodicity_test.gif')
    ani.save(gif_path, writer=PillowWriter(fps=30))
    plt.close(fig)
    print(f"Periodicity Test animation saved as {gif_path}")

# Test 2: Energy Conservation Test with Animation
def test_energy_conservation():
    """Test for energy conservation with two particles colliding elastically with animation"""
    v_initial = 1e5  # Initial velocity of the atoms (m/s)

    # Initial conditions
    x1 = L / 4  # Initial position of atom 1 (m)
    x2 = 3 * L / 4  # Initial position of atom 2 (m)
    v1 = v_initial  # Initial velocity of atom 1 (m/s)
    v2 = -v_initial  # Initial velocity of atom 2 (m/s)

    # Lists to store positions and energies
    positions1 = []
    positions2 = []
    kinetic_energies = []

    # Simulation loop
    for step in range(num_steps):
        # Update positions
        x1 += v1 * dt
        x2 += v2 * dt
        
        # Check for collision and update velocities
        if x1 >= x2:
            v1, v2 = v2, v1
        
        # Save positions and energies
        positions1.append(x1)
        positions2.append(x2)
        kinetic_energy = 0.5 * mass * (v1**2 + v2**2)
        kinetic_energies.append(kinetic_energy)
        
        # Reflect atoms at the boundaries (elastic collision with the wall)
        if x1 < 0 or x1 > L:
            v1 = -v1
        if x2 < 0 or x2 > L:
            v2 = -v2

    # Create animation
    fig, ax = plt.subplots()
    ax.set_xlim(0, L)
    ax.set_ylim(0, L)
    ax.set_xlabel('X Position (m)')
    ax.set_ylabel('Y Position (m)')

    line1, = ax.plot([], [], 'ro', label='Atom 1')
    line2, = ax.plot([], [], 'bo', label='Atom 2')

    def init():
        line1.set_data([], [])
        line2.set_data([], [])
        return line1, line2

    def update(frame):
        # Update with sequences instead of single values
        line1.set_data([positions1[frame]], [L / 2])
        line2.set_data([positions2[frame]], [L / 2])
        return line1, line2

    ani = animation.FuncAnimation(fig, update, frames=num_steps, init_func=init, blit=True)

    # Save the animation as a GIF using PillowWriter
    desktop_path = os.path.join(os.path.expanduser("~"), "Desktop")
    gif_path = os.path.join(desktop_path, "atomic_collision_simulation.gif")
    ani.save(gif_path, writer='pillow', fps=60)

    plt.show()

    print(f"Simulation complete. GIF saved to: {gif_path}")

# Main simulation loop
positions = initialize_positions(num_particles, L)
velocities = initialize_velocities(num_particles, T, mass)
forces, potential_energy = compute_forces(positions)

# Save initial snapshot
save_snapshot(positions, os.path.join(get_desktop_path(), 'initial_snapshot.png'))

kinetic_energies = []
potential_energies = []
total_energies = []

for step in range(num_steps):
    positions, velocities, forces, potential_energy = velocity_verlet(positions, velocities, forces, dt)
    kinetic_energy = 0.5 * mass * np.sum(velocities ** 2)
    total_energy = kinetic_energy + potential_energy

    kinetic_energies.append(kinetic_energy)
    potential_energies.append(potential_energy)
    total_energies.append(total_energy)

# Plot energies
plot_energies(kinetic_energies, potential_energies, total_energies, os.path.join(get_desktop_path(), 'energy_plot.png'))

# Save final snapshot
save_snapshot(positions, os.path.join(get_desktop_path(), 'final_snapshot.png'))

# Print final energies
print(f"Final Kinetic Energy: {kinetic_energies[-1]}")
print(f"Final Potential Energy: {potential_energies[-1]}")
print(f"Final Total Energy: {total_energies[-1]}")

# Run the additional tests with animations
test_periodicity(L)
test_energy_conservation()