diff --git a/.gitignore b/.gitignore index 2c9377b..e0074e6 100644 --- a/.gitignore +++ b/.gitignore @@ -169,3 +169,4 @@ examples/objective_function_onerun_twostream.pdf jaxincell/version.py *.xml .vscode/settings.json +*.npz \ No newline at end of file diff --git a/README.md b/README.md index e278006..e6aa9f5 100644 --- a/README.md +++ b/README.md @@ -182,8 +182,8 @@ pytest . ## Project Roadmap - [X] **`Task 1`**: Run PIC simulation using several field solvers. -- [ ] **`Task 2`**: Finalize example scripts and their documentation. -- [ ] **`Task 3`**: Implement a relativistic equation of motion. +- [X] **`Task 2`**: Finalize example scripts and their documentation. +- [X] **`Task 3`**: Implement a relativistic equation of motion. - [ ] **`Task 4`**: Implement collisions to allow the plasma to relax to a Maxwellian. - [ ] **`Task 5`**: Implement guiding-center equations of motion. - [ ] **`Task 6`**: Implement an implicit time-stepping algorithm. diff --git a/example_input.toml b/example_input.toml index 57cd09b..f13fcab 100644 --- a/example_input.toml +++ b/example_input.toml @@ -14,6 +14,7 @@ external_electric_field_wavenumber = 0 electron_charge_over_elementary_charge = -1 ion_charge_over_elementary_charge = 1 ion_mass_over_proton_mass = 1 +relativistic = false [solver_parameters] field_solver = 0 diff --git a/examples/Weibel_instability.py b/examples/Weibel_instability.py index 6766534..aaa1622 100644 --- a/examples/Weibel_instability.py +++ b/examples/Weibel_instability.py @@ -5,22 +5,23 @@ from jax import block_until_ready input_parameters = { -"length" : 3e-1, # dimensions of the simulation box in (x, y, z) -"amplitude_perturbation_x" : 0, # amplitude of sinusoidal perturbation in x -"wavenumber_electrons_x" : 0, # wavenumber of perturbation in z +"length" : 3e-1, # dimensions of the simulation box in (x, y, z) +"amplitude_perturbation_x" : 0, # amplitude of sinusoidal perturbation in x +"wavenumber_electrons_x" : 0, # wavenumber of perturbation in z "grid_points_per_Debye_length" : 0.6, # dx over Debye length -"velocity_plus_minus_electrons_z": False, # create two groups of electrons moving in opposite directions -"velocity_plus_minus_electrons_x": False, # create two groups of electrons moving in opposite directions +"velocity_plus_minus_electrons_z": False, # create two groups of electrons moving in opposite directions +"velocity_plus_minus_electrons_x": False, # create two groups of electrons moving in opposite directions "random_positions_x": True, # Use random positions in x for particles "random_positions_y": True, # Use random positions in y for particles "random_positions_z": True, # Use random positions in z for particles -"electron_drift_speed_x": 0, # Drift speed of electrons in x direction +"electron_drift_speed_x": 0, # Drift speed of electrons in x direction "electron_drift_speed_z": 0, # Drift speed of electrons in z direction "ion_temperature_over_electron_temperature_x": 10, # Temperature of ions over temperature of electrons -"print_info" : True, # print information about the simulation -"vth_electrons_over_c_x": 0.01, # Thermal velocity of electrons over speed of light -"vth_electrons_over_c_z": 0.10, # Thermal velocity of electrons over speed of light -"timestep_over_spatialstep_times_c": 0.5, # dt * speed_of_light / dx +"print_info" : True, # print information about the simulation +"vth_electrons_over_c_x": 0.01, # Thermal velocity of electrons over speed of light +"vth_electrons_over_c_z": 0.10, # Thermal velocity of electrons over speed of light +"timestep_over_spatialstep_times_c": 0.5, # dt * speed_of_light / dx +"relativistic": False, # Use relativistic equations of motion } solver_parameters = { diff --git a/examples/input.toml b/examples/input.toml index f4f89ae..ea7fe3f 100644 --- a/examples/input.toml +++ b/examples/input.toml @@ -1,20 +1,21 @@ [input_parameters] -length = 0.01 -amplitude_perturbation_x = 1e-7 -wavenumber_electrons_x = 1 +length = 1.0 +amplitude_perturbation_x = 0 +wavenumber_electrons_x = 0 wavenumber_ions_x = 0 -grid_points_per_Debye_length = 1.5 +grid_points_per_Debye_length = 2.0 vth_electrons_over_c_x = 0.05 ion_temperature_over_electron_temperature_x = 0.01 timestep_over_spatialstep_times_c = 1.0 -electron_drift_speed_x = 50000000.0 +electron_drift_speed_x = 150000000.0 velocity_plus_minus_electrons_x = true print_info = true external_electric_field_amplitude = 0 external_electric_field_wavenumber = 0 +relativistic = false [solver_parameters] field_solver = 0 -number_grid_points = 51 -number_pseudoelectrons = 2500 -total_steps = 1500 +number_grid_points = 81 +number_pseudoelectrons = 3000 +total_steps = 1000 diff --git a/examples/two-stream_instability.py b/examples/two-stream_instability.py index 6df61b8..354c651 100644 --- a/examples/two-stream_instability.py +++ b/examples/two-stream_instability.py @@ -2,6 +2,9 @@ import time from jax import block_until_ready from jaxincell import plot, simulation, load_parameters +import numpy as np +import pickle +import json # Read from input.toml input_parameters, solver_parameters = load_parameters('input.toml') @@ -17,3 +20,10 @@ # Plot the results plot(output) + +# # Save the output to a file +# np.savez("simulation_output.npz", **output) + +# # Load the output from the file +# data = np.load("simulation_output.npz", allow_pickle=True) +# output2 = dict(data) diff --git a/jaxincell/_particles.py b/jaxincell/_particles.py index db5d936..2d8c67d 100644 --- a/jaxincell/_particles.py +++ b/jaxincell/_particles.py @@ -1,8 +1,9 @@ from jax import vmap, jit import jax.numpy as jnp from ._boundary_conditions import field_2_ghost_cells +from ._constants import speed_of_light as c -__all__ = ['fields_to_particles_grid', 'rotation', 'boris_step'] +__all__ = ['fields_to_particles_grid', 'rotation', 'boris_step', 'boris_step_relativistic'] @jit def fields_to_particles_grid(x_n, field, dx, grid, grid_start, field_BC_left, field_BC_right): @@ -107,3 +108,73 @@ def boris_step(dt, xs_nplushalf, vs_n, q_ms, E_fields_at_x, B_fields_at_x): # vs_nplus1 = vs_n + (q_ms) * E_fields_at_x * dt # xs_nplus1 = xs_nplushalf + dt * vs_nplus1 # return xs_nplus1, vs_nplus1 + +@jit +def relativistic_rotation(dt, B, p_minus, q, m): + """ + Rotate momentum vector in magnetic field (relativistic Boris step). + """ + # gamma_minus from p_minus + gamma_minus = jnp.sqrt(1 + jnp.sum(p_minus ** 2) / (m ** 2 * c ** 2)) + + # t vector (rotation vector) + t = (q * dt) / (2 * m * gamma_minus) * B + p_dot_t = jnp.dot(p_minus, t) + p_cross_t = jnp.cross(p_minus, t) + t_squared = jnp.dot(t, t) + + p_plus = (p_minus*(1-t_squared) + 2*(p_dot_t * t + p_cross_t)) / (1 + t_squared) + + return p_plus + +@jit +def boris_step_relativistic(dt, xs_nplushalf, vs_n, q_s, m_s, E_fields_at_x, B_fields_at_x): + """ + Relativistic Boris pusher for N particles. + + Args: + dt: Time step + xs_nplushalf: Particle positions at t = n + 1/2, shape (N, 3) + vs_n: Velocities at time t = n, shape (N, 3) + q_s: Charges, shape (N,) + m_s: Masses, shape (N,) + E_fields_at_x: Electric fields at particle positions, shape (N, 3) + B_fields_at_x: Magnetic fields at particle positions, shape (N, 3) + c: Speed of light (default = 1.0 for normalized units) + Returns: + xs_nplus3_2: Updated positions at t = n + 3/2, shape (N, 3) + vs_nplus1: Updated velocities at t = n + 1, shape (N, 3) + """ + + def single_particle_step(x, v, q, m, E, B): + # Compute initial momentum + gamma_n = 1/jnp.sqrt(1.0 - jnp.sum((v / c) ** 2)) + + p_n = gamma_n * m * v + + # Half electric field acceleration + p_minus = p_n + q * E * dt / 2 + + # Magnetic rotation + p_plus = relativistic_rotation(dt, B, p_minus, q, m) + + # Second half electric field acceleration + p_nplus1 = p_plus + q * E * dt / 2 + + # Compute new gamma + gamma_nplus1 = jnp.sqrt(1.0 + jnp.sum((p_nplus1 / (m * c)) ** 2)) + + # Recover new velocity + v_nplus1 = p_nplus1 / (gamma_nplus1 * m) + + # Update position using new velocity + x_nplus3_2 = x + dt * v_nplus1 + + return x_nplus3_2, v_nplus1 + + # Vectorize over particles + xs_nplus3_2, vs_nplus1 = vmap(single_particle_step)( + xs_nplushalf, vs_n, q_s, m_s, E_fields_at_x, B_fields_at_x + ) + + return xs_nplus3_2, vs_nplus1 \ No newline at end of file diff --git a/jaxincell/_plot.py b/jaxincell/_plot.py index fd6735b..a6cb48d 100644 --- a/jaxincell/_plot.py +++ b/jaxincell/_plot.py @@ -1,9 +1,10 @@ +import numpy as np import jax.numpy as jnp import matplotlib.pyplot as plt -from matplotlib.animation import FuncAnimation from jax import vmap +from jax.debug import print as jprint from ._constants import speed_of_light -import numpy as np +from matplotlib.animation import FuncAnimation __all__ = ['plot'] @@ -122,7 +123,7 @@ def plot_field(ax, field_data, title, xlabel, ylabel, cbar_label): electron_plot = electron_ax.imshow( jnp.zeros((len(grid), bins_velocity)), aspect="auto", origin="lower", cmap="twilight", - extent=[-box_size_x / 2, box_size_x / 2, -max_velocity_ions_1, max_velocity_ions_1], + extent=[-box_size_x / 2, box_size_x / 2, -max_velocity_electrons_1, max_velocity_electrons_1], vmin=jnp.min(electron_phase_histograms), vmax=jnp.max(electron_phase_histograms)) electron_ax.set(xlabel=f"Electron Position {direction1} (m)", ylabel=f"Electron Velocity {direction1} (m/s)", @@ -158,7 +159,7 @@ def plot_field(ax, field_data, title, xlabel, ylabel, cbar_label): energy_ax.plot(time, output["electric_field_energy"], label="Electric field energy") if jnp.max(output["magnetic_field_energy"]) > 1e-10: energy_ax.plot(time, output["magnetic_field_energy"], label="Magnetic field energy") - energy_ax.plot(time[2:], jnp.abs(jnp.mean(output["charge_density"][2:], axis=-1))*1e12, label=r"Mean $\rho \times 10^{12}$") + energy_ax.plot(time[2:], jnp.abs(jnp.mean(output["charge_density"][2:], axis=-1))*1e15, label=r"Mean $\rho \times 10^{15}$") energy_ax.plot(time[2:], jnp.abs(output["total_energy"][2:] - output["total_energy"][2]) / output["total_energy"][2], label="Relative energy error") energy_ax.set(title="Energy", xlabel=r"Time ($\omega_{pe}^{-1}$)", ylabel="Energy (J)", yscale="log", ylim=[1e-7, None]) diff --git a/jaxincell/_simulation.py b/jaxincell/_simulation.py index a727602..22d1837 100644 --- a/jaxincell/_simulation.py +++ b/jaxincell/_simulation.py @@ -4,9 +4,9 @@ from jax.debug import print as jprint from jax import lax, jit, vmap, config from jax.random import PRNGKey, uniform, normal -from ._particles import fields_to_particles_grid, boris_step from ._sources import current_density, calculate_charge_density from ._boundary_conditions import set_BC_positions, set_BC_particles +from ._particles import fields_to_particles_grid, boris_step, boris_step_relativistic from ._fields import field_update, E_from_Gauss_1D_Cartesian, E_from_Gauss_1D_FFT, E_from_Poisson_1D_FFT, field_update1, field_update2 from ._constants import speed_of_light, epsilon_0, elementary_charge, mass_electron, mass_proton from ._diagnostics import diagnostics @@ -101,6 +101,7 @@ def initialize_simulation_parameters(user_parameters={}): "electron_charge_over_elementary_charge": -1, # Electron charge in units of the elementary charge "ion_charge_over_elementary_charge": 1, # Ion charge in units of the elementary charge "ion_mass_over_proton_mass": 1, # Ion mass in units of the proton mass + "relativistic": False, # Use relativistic Boris pusher # Boundary conditions "particle_BC_left": 0, # Left boundary condition for particles @@ -273,6 +274,9 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb # Combine electron and ion velocities velocities = jnp.concatenate((electron_velocities, ion_velocities)) + # Cap velocities at 99% the speed of light + speed_limit = 0.99 * speed_of_light + velocities = jnp.where(jnp.abs(velocities) >= speed_limit, jnp.sign(velocities) * speed_limit, velocities) # Grid setup dx = length / number_grid_points @@ -281,6 +285,7 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb # Print information about the simulation plasma_frequency = jnp.sqrt(number_pseudoelectrons * weight * charge_electrons**2)/jnp.sqrt(mass_electrons)/jnp.sqrt(epsilon_0)/jnp.sqrt(length) + relativistic_gamma_factor = 1 / jnp.sqrt(1 - jnp.sum(velocities**2, axis=1) / speed_of_light**2) cond(parameters["print_info"], lambda _: jprint(( @@ -294,11 +299,12 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb "Skin depth: {} m\n" "Wavenumber * Debye length: {}\n" "Pseudoparticles per cell: {}\n" + "Pseudoparticle weight: {}\n" "Steps at each plasma frequency: {}\n" "Total time: {} / plasma frequency\n" "Number of particles on a Debye cube: {}\n" + "Relativistic gamma factor: Maximum {}, Average {}\n" "Charge x External electric field x Debye Length / Temperature: {}\n" - "Pseudoparticle weight: {}\n" ),length/(Debye_length_per_dx*dx), length/(speed_of_light/plasma_frequency), number_pseudoelectrons * weight / length, @@ -308,11 +314,12 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb speed_of_light/plasma_frequency, wavenumber_perturbation_x_electrons*Debye_length_per_dx*dx, number_pseudoelectrons / number_grid_points, + weight, 1/(plasma_frequency * dt), dt * plasma_frequency * total_steps, number_pseudoelectrons * weight / length * (Debye_length_per_dx*dx)**3, + jnp.max(relativistic_gamma_factor), jnp.mean(relativistic_gamma_factor), -charge_electrons * parameters["external_electric_field_amplitude"] * Debye_length_per_dx*dx / (mass_electrons * vth_electrons**2 / 2), - weight, ), lambda _: None, operand=None) # **Fields Initialization** @@ -427,15 +434,20 @@ def simulation_step(carry, step_index): total_B = B_field + parameters["external_magnetic_field"] # Interpolate fields to particle positions - E_field_at_x = vmap(lambda x_n: fields_to_particles_grid( - x_n, total_E, dx, grid + dx / 2, grid[0], field_BC_left, field_BC_right))(positions_plus1_2) - - B_field_at_x = vmap(lambda x_n: fields_to_particles_grid( - x_n, total_B, dx, grid, grid[0] - dx / 2, field_BC_left, field_BC_right))(positions_plus1_2) + def interpolate_fields(x_n): + E = fields_to_particles_grid(x_n, total_E, dx, grid + dx/2, grid[0], field_BC_left, field_BC_right) + B = fields_to_particles_grid(x_n, total_B, dx, grid, grid[0] - dx/2, field_BC_left, field_BC_right) + return E, B + + E_field_at_x, B_field_at_x = vmap(interpolate_fields)(positions_plus1_2) # Particle update: Boris pusher - positions_plus3_2, velocities_plus1 = boris_step( - dt, positions_plus1_2, velocities, q_ms, E_field_at_x, B_field_at_x) + positions_plus3_2, velocities_plus1 = lax.cond( + parameters["relativistic"], + lambda _: boris_step_relativistic(dt, positions_plus1_2, velocities, qs, ms, E_field_at_x, B_field_at_x), + lambda _: boris_step(dt, positions_plus1_2, velocities, q_ms, E_field_at_x, B_field_at_x), + operand=None + ) # Apply boundary conditions positions_plus3_2, velocities_plus1, qs, ms, q_ms = set_BC_particles(