diff --git a/.gitignore b/.gitignore
index 2c9377b..e0074e6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -169,3 +169,4 @@ examples/objective_function_onerun_twostream.pdf
jaxincell/version.py
*.xml
.vscode/settings.json
+*.npz
\ No newline at end of file
diff --git a/README.md b/README.md
index e278006..e6aa9f5 100644
--- a/README.md
+++ b/README.md
@@ -182,8 +182,8 @@ pytest .
## Project Roadmap
- [X] **`Task 1`**: Run PIC simulation using several field solvers.
-- [ ] **`Task 2`**: Finalize example scripts and their documentation.
-- [ ] **`Task 3`**: Implement a relativistic equation of motion.
+- [X] **`Task 2`**: Finalize example scripts and their documentation.
+- [X] **`Task 3`**: Implement a relativistic equation of motion.
- [ ] **`Task 4`**: Implement collisions to allow the plasma to relax to a Maxwellian.
- [ ] **`Task 5`**: Implement guiding-center equations of motion.
- [ ] **`Task 6`**: Implement an implicit time-stepping algorithm.
diff --git a/example_input.toml b/example_input.toml
index 57cd09b..f13fcab 100644
--- a/example_input.toml
+++ b/example_input.toml
@@ -14,6 +14,7 @@ external_electric_field_wavenumber = 0
electron_charge_over_elementary_charge = -1
ion_charge_over_elementary_charge = 1
ion_mass_over_proton_mass = 1
+relativistic = false
[solver_parameters]
field_solver = 0
diff --git a/examples/Weibel_instability.py b/examples/Weibel_instability.py
index 6766534..aaa1622 100644
--- a/examples/Weibel_instability.py
+++ b/examples/Weibel_instability.py
@@ -5,22 +5,23 @@
from jax import block_until_ready
input_parameters = {
-"length" : 3e-1, # dimensions of the simulation box in (x, y, z)
-"amplitude_perturbation_x" : 0, # amplitude of sinusoidal perturbation in x
-"wavenumber_electrons_x" : 0, # wavenumber of perturbation in z
+"length" : 3e-1, # dimensions of the simulation box in (x, y, z)
+"amplitude_perturbation_x" : 0, # amplitude of sinusoidal perturbation in x
+"wavenumber_electrons_x" : 0, # wavenumber of perturbation in z
"grid_points_per_Debye_length" : 0.6, # dx over Debye length
-"velocity_plus_minus_electrons_z": False, # create two groups of electrons moving in opposite directions
-"velocity_plus_minus_electrons_x": False, # create two groups of electrons moving in opposite directions
+"velocity_plus_minus_electrons_z": False, # create two groups of electrons moving in opposite directions
+"velocity_plus_minus_electrons_x": False, # create two groups of electrons moving in opposite directions
"random_positions_x": True, # Use random positions in x for particles
"random_positions_y": True, # Use random positions in y for particles
"random_positions_z": True, # Use random positions in z for particles
-"electron_drift_speed_x": 0, # Drift speed of electrons in x direction
+"electron_drift_speed_x": 0, # Drift speed of electrons in x direction
"electron_drift_speed_z": 0, # Drift speed of electrons in z direction
"ion_temperature_over_electron_temperature_x": 10, # Temperature of ions over temperature of electrons
-"print_info" : True, # print information about the simulation
-"vth_electrons_over_c_x": 0.01, # Thermal velocity of electrons over speed of light
-"vth_electrons_over_c_z": 0.10, # Thermal velocity of electrons over speed of light
-"timestep_over_spatialstep_times_c": 0.5, # dt * speed_of_light / dx
+"print_info" : True, # print information about the simulation
+"vth_electrons_over_c_x": 0.01, # Thermal velocity of electrons over speed of light
+"vth_electrons_over_c_z": 0.10, # Thermal velocity of electrons over speed of light
+"timestep_over_spatialstep_times_c": 0.5, # dt * speed_of_light / dx
+"relativistic": False, # Use relativistic equations of motion
}
solver_parameters = {
diff --git a/examples/input.toml b/examples/input.toml
index f4f89ae..ea7fe3f 100644
--- a/examples/input.toml
+++ b/examples/input.toml
@@ -1,20 +1,21 @@
[input_parameters]
-length = 0.01
-amplitude_perturbation_x = 1e-7
-wavenumber_electrons_x = 1
+length = 1.0
+amplitude_perturbation_x = 0
+wavenumber_electrons_x = 0
wavenumber_ions_x = 0
-grid_points_per_Debye_length = 1.5
+grid_points_per_Debye_length = 2.0
vth_electrons_over_c_x = 0.05
ion_temperature_over_electron_temperature_x = 0.01
timestep_over_spatialstep_times_c = 1.0
-electron_drift_speed_x = 50000000.0
+electron_drift_speed_x = 150000000.0
velocity_plus_minus_electrons_x = true
print_info = true
external_electric_field_amplitude = 0
external_electric_field_wavenumber = 0
+relativistic = false
[solver_parameters]
field_solver = 0
-number_grid_points = 51
-number_pseudoelectrons = 2500
-total_steps = 1500
+number_grid_points = 81
+number_pseudoelectrons = 3000
+total_steps = 1000
diff --git a/examples/two-stream_instability.py b/examples/two-stream_instability.py
index 6df61b8..354c651 100644
--- a/examples/two-stream_instability.py
+++ b/examples/two-stream_instability.py
@@ -2,6 +2,9 @@
import time
from jax import block_until_ready
from jaxincell import plot, simulation, load_parameters
+import numpy as np
+import pickle
+import json
# Read from input.toml
input_parameters, solver_parameters = load_parameters('input.toml')
@@ -17,3 +20,10 @@
# Plot the results
plot(output)
+
+# # Save the output to a file
+# np.savez("simulation_output.npz", **output)
+
+# # Load the output from the file
+# data = np.load("simulation_output.npz", allow_pickle=True)
+# output2 = dict(data)
diff --git a/jaxincell/_particles.py b/jaxincell/_particles.py
index db5d936..2d8c67d 100644
--- a/jaxincell/_particles.py
+++ b/jaxincell/_particles.py
@@ -1,8 +1,9 @@
from jax import vmap, jit
import jax.numpy as jnp
from ._boundary_conditions import field_2_ghost_cells
+from ._constants import speed_of_light as c
-__all__ = ['fields_to_particles_grid', 'rotation', 'boris_step']
+__all__ = ['fields_to_particles_grid', 'rotation', 'boris_step', 'boris_step_relativistic']
@jit
def fields_to_particles_grid(x_n, field, dx, grid, grid_start, field_BC_left, field_BC_right):
@@ -107,3 +108,73 @@ def boris_step(dt, xs_nplushalf, vs_n, q_ms, E_fields_at_x, B_fields_at_x):
# vs_nplus1 = vs_n + (q_ms) * E_fields_at_x * dt
# xs_nplus1 = xs_nplushalf + dt * vs_nplus1
# return xs_nplus1, vs_nplus1
+
+@jit
+def relativistic_rotation(dt, B, p_minus, q, m):
+ """
+ Rotate momentum vector in magnetic field (relativistic Boris step).
+ """
+ # gamma_minus from p_minus
+ gamma_minus = jnp.sqrt(1 + jnp.sum(p_minus ** 2) / (m ** 2 * c ** 2))
+
+ # t vector (rotation vector)
+ t = (q * dt) / (2 * m * gamma_minus) * B
+ p_dot_t = jnp.dot(p_minus, t)
+ p_cross_t = jnp.cross(p_minus, t)
+ t_squared = jnp.dot(t, t)
+
+ p_plus = (p_minus*(1-t_squared) + 2*(p_dot_t * t + p_cross_t)) / (1 + t_squared)
+
+ return p_plus
+
+@jit
+def boris_step_relativistic(dt, xs_nplushalf, vs_n, q_s, m_s, E_fields_at_x, B_fields_at_x):
+ """
+ Relativistic Boris pusher for N particles.
+
+ Args:
+ dt: Time step
+ xs_nplushalf: Particle positions at t = n + 1/2, shape (N, 3)
+ vs_n: Velocities at time t = n, shape (N, 3)
+ q_s: Charges, shape (N,)
+ m_s: Masses, shape (N,)
+ E_fields_at_x: Electric fields at particle positions, shape (N, 3)
+ B_fields_at_x: Magnetic fields at particle positions, shape (N, 3)
+ c: Speed of light (default = 1.0 for normalized units)
+ Returns:
+ xs_nplus3_2: Updated positions at t = n + 3/2, shape (N, 3)
+ vs_nplus1: Updated velocities at t = n + 1, shape (N, 3)
+ """
+
+ def single_particle_step(x, v, q, m, E, B):
+ # Compute initial momentum
+ gamma_n = 1/jnp.sqrt(1.0 - jnp.sum((v / c) ** 2))
+
+ p_n = gamma_n * m * v
+
+ # Half electric field acceleration
+ p_minus = p_n + q * E * dt / 2
+
+ # Magnetic rotation
+ p_plus = relativistic_rotation(dt, B, p_minus, q, m)
+
+ # Second half electric field acceleration
+ p_nplus1 = p_plus + q * E * dt / 2
+
+ # Compute new gamma
+ gamma_nplus1 = jnp.sqrt(1.0 + jnp.sum((p_nplus1 / (m * c)) ** 2))
+
+ # Recover new velocity
+ v_nplus1 = p_nplus1 / (gamma_nplus1 * m)
+
+ # Update position using new velocity
+ x_nplus3_2 = x + dt * v_nplus1
+
+ return x_nplus3_2, v_nplus1
+
+ # Vectorize over particles
+ xs_nplus3_2, vs_nplus1 = vmap(single_particle_step)(
+ xs_nplushalf, vs_n, q_s, m_s, E_fields_at_x, B_fields_at_x
+ )
+
+ return xs_nplus3_2, vs_nplus1
\ No newline at end of file
diff --git a/jaxincell/_plot.py b/jaxincell/_plot.py
index fd6735b..a6cb48d 100644
--- a/jaxincell/_plot.py
+++ b/jaxincell/_plot.py
@@ -1,9 +1,10 @@
+import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
-from matplotlib.animation import FuncAnimation
from jax import vmap
+from jax.debug import print as jprint
from ._constants import speed_of_light
-import numpy as np
+from matplotlib.animation import FuncAnimation
__all__ = ['plot']
@@ -122,7 +123,7 @@ def plot_field(ax, field_data, title, xlabel, ylabel, cbar_label):
electron_plot = electron_ax.imshow(
jnp.zeros((len(grid), bins_velocity)), aspect="auto", origin="lower", cmap="twilight",
- extent=[-box_size_x / 2, box_size_x / 2, -max_velocity_ions_1, max_velocity_ions_1],
+ extent=[-box_size_x / 2, box_size_x / 2, -max_velocity_electrons_1, max_velocity_electrons_1],
vmin=jnp.min(electron_phase_histograms), vmax=jnp.max(electron_phase_histograms))
electron_ax.set(xlabel=f"Electron Position {direction1} (m)",
ylabel=f"Electron Velocity {direction1} (m/s)",
@@ -158,7 +159,7 @@ def plot_field(ax, field_data, title, xlabel, ylabel, cbar_label):
energy_ax.plot(time, output["electric_field_energy"], label="Electric field energy")
if jnp.max(output["magnetic_field_energy"]) > 1e-10:
energy_ax.plot(time, output["magnetic_field_energy"], label="Magnetic field energy")
- energy_ax.plot(time[2:], jnp.abs(jnp.mean(output["charge_density"][2:], axis=-1))*1e12, label=r"Mean $\rho \times 10^{12}$")
+ energy_ax.plot(time[2:], jnp.abs(jnp.mean(output["charge_density"][2:], axis=-1))*1e15, label=r"Mean $\rho \times 10^{15}$")
energy_ax.plot(time[2:], jnp.abs(output["total_energy"][2:] - output["total_energy"][2]) / output["total_energy"][2], label="Relative energy error")
energy_ax.set(title="Energy", xlabel=r"Time ($\omega_{pe}^{-1}$)",
ylabel="Energy (J)", yscale="log", ylim=[1e-7, None])
diff --git a/jaxincell/_simulation.py b/jaxincell/_simulation.py
index a727602..22d1837 100644
--- a/jaxincell/_simulation.py
+++ b/jaxincell/_simulation.py
@@ -4,9 +4,9 @@
from jax.debug import print as jprint
from jax import lax, jit, vmap, config
from jax.random import PRNGKey, uniform, normal
-from ._particles import fields_to_particles_grid, boris_step
from ._sources import current_density, calculate_charge_density
from ._boundary_conditions import set_BC_positions, set_BC_particles
+from ._particles import fields_to_particles_grid, boris_step, boris_step_relativistic
from ._fields import field_update, E_from_Gauss_1D_Cartesian, E_from_Gauss_1D_FFT, E_from_Poisson_1D_FFT, field_update1, field_update2
from ._constants import speed_of_light, epsilon_0, elementary_charge, mass_electron, mass_proton
from ._diagnostics import diagnostics
@@ -101,6 +101,7 @@ def initialize_simulation_parameters(user_parameters={}):
"electron_charge_over_elementary_charge": -1, # Electron charge in units of the elementary charge
"ion_charge_over_elementary_charge": 1, # Ion charge in units of the elementary charge
"ion_mass_over_proton_mass": 1, # Ion mass in units of the proton mass
+ "relativistic": False, # Use relativistic Boris pusher
# Boundary conditions
"particle_BC_left": 0, # Left boundary condition for particles
@@ -273,6 +274,9 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb
# Combine electron and ion velocities
velocities = jnp.concatenate((electron_velocities, ion_velocities))
+ # Cap velocities at 99% the speed of light
+ speed_limit = 0.99 * speed_of_light
+ velocities = jnp.where(jnp.abs(velocities) >= speed_limit, jnp.sign(velocities) * speed_limit, velocities)
# Grid setup
dx = length / number_grid_points
@@ -281,6 +285,7 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb
# Print information about the simulation
plasma_frequency = jnp.sqrt(number_pseudoelectrons * weight * charge_electrons**2)/jnp.sqrt(mass_electrons)/jnp.sqrt(epsilon_0)/jnp.sqrt(length)
+ relativistic_gamma_factor = 1 / jnp.sqrt(1 - jnp.sum(velocities**2, axis=1) / speed_of_light**2)
cond(parameters["print_info"],
lambda _: jprint((
@@ -294,11 +299,12 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb
"Skin depth: {} m\n"
"Wavenumber * Debye length: {}\n"
"Pseudoparticles per cell: {}\n"
+ "Pseudoparticle weight: {}\n"
"Steps at each plasma frequency: {}\n"
"Total time: {} / plasma frequency\n"
"Number of particles on a Debye cube: {}\n"
+ "Relativistic gamma factor: Maximum {}, Average {}\n"
"Charge x External electric field x Debye Length / Temperature: {}\n"
- "Pseudoparticle weight: {}\n"
),length/(Debye_length_per_dx*dx),
length/(speed_of_light/plasma_frequency),
number_pseudoelectrons * weight / length,
@@ -308,11 +314,12 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb
speed_of_light/plasma_frequency,
wavenumber_perturbation_x_electrons*Debye_length_per_dx*dx,
number_pseudoelectrons / number_grid_points,
+ weight,
1/(plasma_frequency * dt),
dt * plasma_frequency * total_steps,
number_pseudoelectrons * weight / length * (Debye_length_per_dx*dx)**3,
+ jnp.max(relativistic_gamma_factor), jnp.mean(relativistic_gamma_factor),
-charge_electrons * parameters["external_electric_field_amplitude"] * Debye_length_per_dx*dx / (mass_electrons * vth_electrons**2 / 2),
- weight,
), lambda _: None, operand=None)
# **Fields Initialization**
@@ -427,15 +434,20 @@ def simulation_step(carry, step_index):
total_B = B_field + parameters["external_magnetic_field"]
# Interpolate fields to particle positions
- E_field_at_x = vmap(lambda x_n: fields_to_particles_grid(
- x_n, total_E, dx, grid + dx / 2, grid[0], field_BC_left, field_BC_right))(positions_plus1_2)
-
- B_field_at_x = vmap(lambda x_n: fields_to_particles_grid(
- x_n, total_B, dx, grid, grid[0] - dx / 2, field_BC_left, field_BC_right))(positions_plus1_2)
+ def interpolate_fields(x_n):
+ E = fields_to_particles_grid(x_n, total_E, dx, grid + dx/2, grid[0], field_BC_left, field_BC_right)
+ B = fields_to_particles_grid(x_n, total_B, dx, grid, grid[0] - dx/2, field_BC_left, field_BC_right)
+ return E, B
+
+ E_field_at_x, B_field_at_x = vmap(interpolate_fields)(positions_plus1_2)
# Particle update: Boris pusher
- positions_plus3_2, velocities_plus1 = boris_step(
- dt, positions_plus1_2, velocities, q_ms, E_field_at_x, B_field_at_x)
+ positions_plus3_2, velocities_plus1 = lax.cond(
+ parameters["relativistic"],
+ lambda _: boris_step_relativistic(dt, positions_plus1_2, velocities, qs, ms, E_field_at_x, B_field_at_x),
+ lambda _: boris_step(dt, positions_plus1_2, velocities, q_ms, E_field_at_x, B_field_at_x),
+ operand=None
+ )
# Apply boundary conditions
positions_plus3_2, velocities_plus1, qs, ms, q_ms = set_BC_particles(