Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,4 @@ examples/objective_function_onerun_twostream.pdf
jaxincell/version.py
*.xml
.vscode/settings.json
*.npz
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ pytest .
## Project Roadmap

- [X] **`Task 1`**: <strike>Run PIC simulation using several field solvers.</strike>
- [ ] **`Task 2`**: Finalize example scripts and their documentation.
- [ ] **`Task 3`**: Implement a relativistic equation of motion.
- [X] **`Task 2`**: <strike>Finalize example scripts and their documentation.</strike>
- [X] **`Task 3`**: <strike>Implement a relativistic equation of motion.</strike>
- [ ] **`Task 4`**: Implement collisions to allow the plasma to relax to a Maxwellian.
- [ ] **`Task 5`**: Implement guiding-center equations of motion.
- [ ] **`Task 6`**: Implement an implicit time-stepping algorithm.
Expand Down
1 change: 1 addition & 0 deletions example_input.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ external_electric_field_wavenumber = 0
electron_charge_over_elementary_charge = -1
ion_charge_over_elementary_charge = 1
ion_mass_over_proton_mass = 1
relativistic = false

[solver_parameters]
field_solver = 0
Expand Down
21 changes: 11 additions & 10 deletions examples/Weibel_instability.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,23 @@
from jax import block_until_ready

input_parameters = {
"length" : 3e-1, # dimensions of the simulation box in (x, y, z)
"amplitude_perturbation_x" : 0, # amplitude of sinusoidal perturbation in x
"wavenumber_electrons_x" : 0, # wavenumber of perturbation in z
"length" : 3e-1, # dimensions of the simulation box in (x, y, z)
"amplitude_perturbation_x" : 0, # amplitude of sinusoidal perturbation in x
"wavenumber_electrons_x" : 0, # wavenumber of perturbation in z
"grid_points_per_Debye_length" : 0.6, # dx over Debye length
"velocity_plus_minus_electrons_z": False, # create two groups of electrons moving in opposite directions
"velocity_plus_minus_electrons_x": False, # create two groups of electrons moving in opposite directions
"velocity_plus_minus_electrons_z": False, # create two groups of electrons moving in opposite directions
"velocity_plus_minus_electrons_x": False, # create two groups of electrons moving in opposite directions
"random_positions_x": True, # Use random positions in x for particles
"random_positions_y": True, # Use random positions in y for particles
"random_positions_z": True, # Use random positions in z for particles
"electron_drift_speed_x": 0, # Drift speed of electrons in x direction
"electron_drift_speed_x": 0, # Drift speed of electrons in x direction
"electron_drift_speed_z": 0, # Drift speed of electrons in z direction
"ion_temperature_over_electron_temperature_x": 10, # Temperature of ions over temperature of electrons
"print_info" : True, # print information about the simulation
"vth_electrons_over_c_x": 0.01, # Thermal velocity of electrons over speed of light
"vth_electrons_over_c_z": 0.10, # Thermal velocity of electrons over speed of light
"timestep_over_spatialstep_times_c": 0.5, # dt * speed_of_light / dx
"print_info" : True, # print information about the simulation
"vth_electrons_over_c_x": 0.01, # Thermal velocity of electrons over speed of light
"vth_electrons_over_c_z": 0.10, # Thermal velocity of electrons over speed of light
"timestep_over_spatialstep_times_c": 0.5, # dt * speed_of_light / dx
"relativistic": False, # Use relativistic equations of motion
}

solver_parameters = {
Expand Down
17 changes: 9 additions & 8 deletions examples/input.toml
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
[input_parameters]
length = 0.01
amplitude_perturbation_x = 1e-7
wavenumber_electrons_x = 1
length = 1.0
amplitude_perturbation_x = 0
wavenumber_electrons_x = 0
wavenumber_ions_x = 0
grid_points_per_Debye_length = 1.5
grid_points_per_Debye_length = 2.0
vth_electrons_over_c_x = 0.05
ion_temperature_over_electron_temperature_x = 0.01
timestep_over_spatialstep_times_c = 1.0
electron_drift_speed_x = 50000000.0
electron_drift_speed_x = 150000000.0
velocity_plus_minus_electrons_x = true
print_info = true
external_electric_field_amplitude = 0
external_electric_field_wavenumber = 0
relativistic = false

[solver_parameters]
field_solver = 0
number_grid_points = 51
number_pseudoelectrons = 2500
total_steps = 1500
number_grid_points = 81
number_pseudoelectrons = 3000
total_steps = 1000
10 changes: 10 additions & 0 deletions examples/two-stream_instability.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import time
from jax import block_until_ready
from jaxincell import plot, simulation, load_parameters
import numpy as np
import pickle
import json

# Read from input.toml
input_parameters, solver_parameters = load_parameters('input.toml')
Expand All @@ -17,3 +20,10 @@

# Plot the results
plot(output)

# # Save the output to a file
# np.savez("simulation_output.npz", **output)

# # Load the output from the file
# data = np.load("simulation_output.npz", allow_pickle=True)
# output2 = dict(data)
73 changes: 72 additions & 1 deletion jaxincell/_particles.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from jax import vmap, jit
import jax.numpy as jnp
from ._boundary_conditions import field_2_ghost_cells
from ._constants import speed_of_light as c

__all__ = ['fields_to_particles_grid', 'rotation', 'boris_step']
__all__ = ['fields_to_particles_grid', 'rotation', 'boris_step', 'boris_step_relativistic']

@jit
def fields_to_particles_grid(x_n, field, dx, grid, grid_start, field_BC_left, field_BC_right):
Expand Down Expand Up @@ -107,3 +108,73 @@ def boris_step(dt, xs_nplushalf, vs_n, q_ms, E_fields_at_x, B_fields_at_x):
# vs_nplus1 = vs_n + (q_ms) * E_fields_at_x * dt
# xs_nplus1 = xs_nplushalf + dt * vs_nplus1
# return xs_nplus1, vs_nplus1

@jit
def relativistic_rotation(dt, B, p_minus, q, m):
"""
Rotate momentum vector in magnetic field (relativistic Boris step).
"""
# gamma_minus from p_minus
gamma_minus = jnp.sqrt(1 + jnp.sum(p_minus ** 2) / (m ** 2 * c ** 2))

# t vector (rotation vector)
t = (q * dt) / (2 * m * gamma_minus) * B
p_dot_t = jnp.dot(p_minus, t)
p_cross_t = jnp.cross(p_minus, t)
t_squared = jnp.dot(t, t)

p_plus = (p_minus*(1-t_squared) + 2*(p_dot_t * t + p_cross_t)) / (1 + t_squared)

return p_plus

@jit
def boris_step_relativistic(dt, xs_nplushalf, vs_n, q_s, m_s, E_fields_at_x, B_fields_at_x):
"""
Relativistic Boris pusher for N particles.

Args:
dt: Time step
xs_nplushalf: Particle positions at t = n + 1/2, shape (N, 3)
vs_n: Velocities at time t = n, shape (N, 3)
q_s: Charges, shape (N,)
m_s: Masses, shape (N,)
E_fields_at_x: Electric fields at particle positions, shape (N, 3)
B_fields_at_x: Magnetic fields at particle positions, shape (N, 3)
c: Speed of light (default = 1.0 for normalized units)
Returns:
xs_nplus3_2: Updated positions at t = n + 3/2, shape (N, 3)
vs_nplus1: Updated velocities at t = n + 1, shape (N, 3)
"""

def single_particle_step(x, v, q, m, E, B):
# Compute initial momentum
gamma_n = 1/jnp.sqrt(1.0 - jnp.sum((v / c) ** 2))

p_n = gamma_n * m * v

# Half electric field acceleration
p_minus = p_n + q * E * dt / 2

# Magnetic rotation
p_plus = relativistic_rotation(dt, B, p_minus, q, m)

# Second half electric field acceleration
p_nplus1 = p_plus + q * E * dt / 2

# Compute new gamma
gamma_nplus1 = jnp.sqrt(1.0 + jnp.sum((p_nplus1 / (m * c)) ** 2))

# Recover new velocity
v_nplus1 = p_nplus1 / (gamma_nplus1 * m)

# Update position using new velocity
x_nplus3_2 = x + dt * v_nplus1

return x_nplus3_2, v_nplus1

# Vectorize over particles
xs_nplus3_2, vs_nplus1 = vmap(single_particle_step)(
xs_nplushalf, vs_n, q_s, m_s, E_fields_at_x, B_fields_at_x
)

return xs_nplus3_2, vs_nplus1
9 changes: 5 additions & 4 deletions jaxincell/_plot.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from jax import vmap
from jax.debug import print as jprint
from ._constants import speed_of_light
import numpy as np
from matplotlib.animation import FuncAnimation

__all__ = ['plot']

Expand Down Expand Up @@ -122,7 +123,7 @@

electron_plot = electron_ax.imshow(
jnp.zeros((len(grid), bins_velocity)), aspect="auto", origin="lower", cmap="twilight",
extent=[-box_size_x / 2, box_size_x / 2, -max_velocity_ions_1, max_velocity_ions_1],
extent=[-box_size_x / 2, box_size_x / 2, -max_velocity_electrons_1, max_velocity_electrons_1],
vmin=jnp.min(electron_phase_histograms), vmax=jnp.max(electron_phase_histograms))
electron_ax.set(xlabel=f"Electron Position {direction1} (m)",
ylabel=f"Electron Velocity {direction1} (m/s)",
Expand Down Expand Up @@ -158,7 +159,7 @@
energy_ax.plot(time, output["electric_field_energy"], label="Electric field energy")
if jnp.max(output["magnetic_field_energy"]) > 1e-10:
energy_ax.plot(time, output["magnetic_field_energy"], label="Magnetic field energy")
energy_ax.plot(time[2:], jnp.abs(jnp.mean(output["charge_density"][2:], axis=-1))*1e12, label=r"Mean $\rho \times 10^{12}$")
energy_ax.plot(time[2:], jnp.abs(jnp.mean(output["charge_density"][2:], axis=-1))*1e15, label=r"Mean $\rho \times 10^{15}$")

Check warning on line 162 in jaxincell/_plot.py

View check run for this annotation

Codecov / codecov/patch

jaxincell/_plot.py#L162

Added line #L162 was not covered by tests
energy_ax.plot(time[2:], jnp.abs(output["total_energy"][2:] - output["total_energy"][2]) / output["total_energy"][2], label="Relative energy error")
energy_ax.set(title="Energy", xlabel=r"Time ($\omega_{pe}^{-1}$)",
ylabel="Energy (J)", yscale="log", ylim=[1e-7, None])
Expand Down
32 changes: 22 additions & 10 deletions jaxincell/_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from jax.debug import print as jprint
from jax import lax, jit, vmap, config
from jax.random import PRNGKey, uniform, normal
from ._particles import fields_to_particles_grid, boris_step
from ._sources import current_density, calculate_charge_density
from ._boundary_conditions import set_BC_positions, set_BC_particles
from ._particles import fields_to_particles_grid, boris_step, boris_step_relativistic
from ._fields import field_update, E_from_Gauss_1D_Cartesian, E_from_Gauss_1D_FFT, E_from_Poisson_1D_FFT, field_update1, field_update2
from ._constants import speed_of_light, epsilon_0, elementary_charge, mass_electron, mass_proton
from ._diagnostics import diagnostics
Expand Down Expand Up @@ -101,6 +101,7 @@ def initialize_simulation_parameters(user_parameters={}):
"electron_charge_over_elementary_charge": -1, # Electron charge in units of the elementary charge
"ion_charge_over_elementary_charge": 1, # Ion charge in units of the elementary charge
"ion_mass_over_proton_mass": 1, # Ion mass in units of the proton mass
"relativistic": False, # Use relativistic Boris pusher

# Boundary conditions
"particle_BC_left": 0, # Left boundary condition for particles
Expand Down Expand Up @@ -273,6 +274,9 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb

# Combine electron and ion velocities
velocities = jnp.concatenate((electron_velocities, ion_velocities))
# Cap velocities at 99% the speed of light
speed_limit = 0.99 * speed_of_light
velocities = jnp.where(jnp.abs(velocities) >= speed_limit, jnp.sign(velocities) * speed_limit, velocities)

# Grid setup
dx = length / number_grid_points
Expand All @@ -281,6 +285,7 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb

# Print information about the simulation
plasma_frequency = jnp.sqrt(number_pseudoelectrons * weight * charge_electrons**2)/jnp.sqrt(mass_electrons)/jnp.sqrt(epsilon_0)/jnp.sqrt(length)
relativistic_gamma_factor = 1 / jnp.sqrt(1 - jnp.sum(velocities**2, axis=1) / speed_of_light**2)

cond(parameters["print_info"],
lambda _: jprint((
Expand All @@ -294,11 +299,12 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb
"Skin depth: {} m\n"
"Wavenumber * Debye length: {}\n"
"Pseudoparticles per cell: {}\n"
"Pseudoparticle weight: {}\n"
"Steps at each plasma frequency: {}\n"
"Total time: {} / plasma frequency\n"
"Number of particles on a Debye cube: {}\n"
"Relativistic gamma factor: Maximum {}, Average {}\n"
"Charge x External electric field x Debye Length / Temperature: {}\n"
"Pseudoparticle weight: {}\n"
),length/(Debye_length_per_dx*dx),
length/(speed_of_light/plasma_frequency),
number_pseudoelectrons * weight / length,
Expand All @@ -308,11 +314,12 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb
speed_of_light/plasma_frequency,
wavenumber_perturbation_x_electrons*Debye_length_per_dx*dx,
number_pseudoelectrons / number_grid_points,
weight,
1/(plasma_frequency * dt),
dt * plasma_frequency * total_steps,
number_pseudoelectrons * weight / length * (Debye_length_per_dx*dx)**3,
jnp.max(relativistic_gamma_factor), jnp.mean(relativistic_gamma_factor),
-charge_electrons * parameters["external_electric_field_amplitude"] * Debye_length_per_dx*dx / (mass_electrons * vth_electrons**2 / 2),
weight,
), lambda _: None, operand=None)

# **Fields Initialization**
Expand Down Expand Up @@ -427,15 +434,20 @@ def simulation_step(carry, step_index):
total_B = B_field + parameters["external_magnetic_field"]

# Interpolate fields to particle positions
E_field_at_x = vmap(lambda x_n: fields_to_particles_grid(
x_n, total_E, dx, grid + dx / 2, grid[0], field_BC_left, field_BC_right))(positions_plus1_2)

B_field_at_x = vmap(lambda x_n: fields_to_particles_grid(
x_n, total_B, dx, grid, grid[0] - dx / 2, field_BC_left, field_BC_right))(positions_plus1_2)
def interpolate_fields(x_n):
E = fields_to_particles_grid(x_n, total_E, dx, grid + dx/2, grid[0], field_BC_left, field_BC_right)
B = fields_to_particles_grid(x_n, total_B, dx, grid, grid[0] - dx/2, field_BC_left, field_BC_right)
return E, B

E_field_at_x, B_field_at_x = vmap(interpolate_fields)(positions_plus1_2)

# Particle update: Boris pusher
positions_plus3_2, velocities_plus1 = boris_step(
dt, positions_plus1_2, velocities, q_ms, E_field_at_x, B_field_at_x)
positions_plus3_2, velocities_plus1 = lax.cond(
parameters["relativistic"],
lambda _: boris_step_relativistic(dt, positions_plus1_2, velocities, qs, ms, E_field_at_x, B_field_at_x),
lambda _: boris_step(dt, positions_plus1_2, velocities, q_ms, E_field_at_x, B_field_at_x),
operand=None
)

# Apply boundary conditions
positions_plus3_2, velocities_plus1, qs, ms, q_ms = set_BC_particles(
Expand Down