From df5f3e27ed572f95dc72c33bea7a92ad30a338d0 Mon Sep 17 00:00:00 2001 From: Aaron Tran Date: Thu, 24 Jul 2025 12:37:57 -0500 Subject: [PATCH 01/19] Stub bump-on-tail example, Maxwellian electrons No bump distribution yet --- examples/bump-on-tail.py | 29 ++++++++++++++++ examples/bump-on-tail.toml | 69 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+) create mode 100644 examples/bump-on-tail.py create mode 100644 examples/bump-on-tail.toml diff --git a/examples/bump-on-tail.py b/examples/bump-on-tail.py new file mode 100644 index 0000000..2e868bd --- /dev/null +++ b/examples/bump-on-tail.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python +""" +bump-on-tail.py +Weak electron beam on tail of bulk Maxwellian electron distribution drives +slowly-growing Langmuir waves with Im(omega) << omega_pe. +""" + +import numpy as np +from datetime import datetime + +from jax import block_until_ready +from jaxincell import plot, simulation, load_parameters + +input_parameters, solver_parameters = load_parameters('bump-on-tail.toml') + +# Run the simulation +started = datetime.now() +output = block_until_ready(simulation(input_parameters, **solver_parameters)) +print("Simulation done, elapsed:", datetime.now()-started) + +# Plot the results +plot(output) + +# # Save the output to a file +# np.savez("simulation_output.npz", **output) + +# # Load the output from the file +# data = np.load("simulation_output.npz", allow_pickle=True) +# output2 = dict(data) diff --git a/examples/bump-on-tail.toml b/examples/bump-on-tail.toml new file mode 100644 index 0000000..41292a3 --- /dev/null +++ b/examples/bump-on-tail.toml @@ -0,0 +1,69 @@ +[input_parameters] +# ----------------------------- +# species intrinsic properties +# ----------------------------- +ion_mass_over_proton_mass = 1e9 +ion_charge_over_elementary_charge = 1 +electron_charge_over_elementary_charge = -1 +# ----------------------------- +# particle spatial distribution +# ----------------------------- +random_positions_x = true +random_positions_y = false +random_positions_z = false +amplitude_perturbation_x = 0 # Amplitude of sinusoidal perturbation +amplitude_perturbation_y = 0 +amplitude_perturbation_z = 0 +wavenumber_electrons_x = 0 # Wavenumber of sinusoidal density perturbation (factor of 2pi/length) +wavenumber_electrons_y = 0 +wavenumber_electrons_z = 0 +wavenumber_ions_x = 0 +wavenumber_ions_y = 0 +wavenumber_ions_z = 0 +# ------------------------------ +# electron velocity distribution +# ------------------------------ +vth_electrons_over_c_x = 0.014142135624 # sqrt(2*T/m)/c +vth_electrons_over_c_y = 0 +vth_electrons_over_c_z = 0 +electron_drift_speed_x = 0 # Drift speed of electrons in the x direction +electron_drift_speed_y = 0 +electron_drift_speed_z = 0 +velocity_plus_minus_electrons_x = false # create two groups of electrons moving in opposite directions in the x direction +velocity_plus_minus_electrons_y = false +velocity_plus_minus_electrons_z = false +# ------------------------------ +# ion velocity distribution +# ------------------------------ +ion_drift_speed_x = 0 # Drift speed of ions in the x direction +ion_drift_speed_y = 0 +ion_drift_speed_z = 0 +ion_temperature_over_electron_temperature_x = 1e-9 # Temperature ratio of ions to electrons in the x direction +ion_temperature_over_electron_temperature_y = 1e-9 +ion_temperature_over_electron_temperature_z = 1e-9 +velocity_plus_minus_ions_x = false # create two groups of ions moving in opposite directions in the x direction +velocity_plus_minus_ions_y = false +velocity_plus_minus_ions_z = false +# ------------------------------ +# domain, spacetime gridding +# ------------------------------ +length = 1 # Simulation box size in meters (sets dimensionful normalization, but does not change dimensionless ratios?) +grid_points_per_Debye_length = 5 # dx over Debye length (ignore the conflicting variable name) +timestep_over_spatialstep_times_c = 1.0 # dt * speed_of_light / dx +particle_BC_left = 0 # 0: periodic, 1: reflective, 2: absorbing +particle_BC_right = 0 +field_BC_left = 0 # 0: periodic, 1: reflective, 2: absorbing +field_BC_right = 0 +# ------------------------------ +# other controls +# ------------------------------ +print_info = true +relativistic = false # Use relativistic Boris pusher +seed = 250724 # Random seed for reproducibility +tolerance_Picard_iterations_implicit_CN = 1e-6 # Tolerance for Picard iterations in implicit Crank-Nicholson method + +[solver_parameters] +number_grid_points = 60 # number of grid CELLS, not edges/vertices +field_solver = 0 # 0: Curl_EB, 1: Gauss_1D_FFT, 2: Gauss_1D_Cartesian, 3: Poisson_1D_FFT, +number_pseudoelectrons = 1000 # Total in entire domain +total_steps = 10000 # Total number of time steps From e91a3d4c4cf233e3ed5325c83452fe1337fbf22b Mon Sep 17 00:00:00 2001 From: Aaron Tran Date: Thu, 24 Jul 2025 12:43:02 -0500 Subject: [PATCH 02/19] rm trailing whitespace across codebase --- examples/Landau_damping.py | 2 +- examples/Langmuir_wave.py | 2 +- examples/Weibel_instability.py | 2 +- examples/scaling_energy_time.py | 4 +- jaxincell/__init__.py | 2 +- jaxincell/__main__.py | 2 +- jaxincell/_algorithms.py | 22 +++++------ jaxincell/_boundary_conditions.py | 22 +++++------ jaxincell/_constants.py | 2 +- jaxincell/_diagnostics.py | 16 ++++---- jaxincell/_fields.py | 22 +++++------ jaxincell/_particles.py | 42 ++++++++++----------- jaxincell/_plot.py | 16 ++++---- jaxincell/_simulation.py | 62 +++++++++++++++---------------- jaxincell/_sources.py | 8 ++-- 15 files changed, 113 insertions(+), 113 deletions(-) diff --git a/examples/Landau_damping.py b/examples/Landau_damping.py index 9c58eef..5b3a42e 100644 --- a/examples/Landau_damping.py +++ b/examples/Landau_damping.py @@ -19,7 +19,7 @@ } solver_parameters = { - "field_solver" : 0, # Algorithm to solve E and B fields - 0: Curl_EB, 1: Gauss_1D_FFT, 2: Gauss_1D_Cartesian, 3: Poisson_1D_FFT, + "field_solver" : 0, # Algorithm to solve E and B fields - 0: Curl_EB, 1: Gauss_1D_FFT, 2: Gauss_1D_Cartesian, 3: Poisson_1D_FFT, "number_grid_points" : 81, # Number of grid points "number_pseudoelectrons" : 3000, # Number of pseudoelectrons "total_steps" : 200, # Total number of time steps diff --git a/examples/Langmuir_wave.py b/examples/Langmuir_wave.py index d1ec5a6..ae441f6 100644 --- a/examples/Langmuir_wave.py +++ b/examples/Langmuir_wave.py @@ -18,7 +18,7 @@ } solver_parameters = { - "field_solver" : 0, # Algorithm to solve E and B fields - 0: Curl_EB, 1: Gauss_1D_FFT, 2: Gauss_1D_Cartesian, 3: Poisson_1D_FFT, + "field_solver" : 0, # Algorithm to solve E and B fields - 0: Curl_EB, 1: Gauss_1D_FFT, 2: Gauss_1D_Cartesian, 3: Poisson_1D_FFT, "number_grid_points" : 33, # Number of grid points "number_pseudoelectrons" : 3000, # Number of pseudoelectrons "total_steps" : 1000, # Total number of time steps diff --git a/examples/Weibel_instability.py b/examples/Weibel_instability.py index d539267..f2ba71b 100644 --- a/examples/Weibel_instability.py +++ b/examples/Weibel_instability.py @@ -25,7 +25,7 @@ } solver_parameters = { - "field_solver" : 0, # Algorithm to solve E and B fields - 0: Curl_EB, 1: Gauss_1D_FFT, 2: Gauss_1D_Cartesian, 3: Poisson_1D_FFT, + "field_solver" : 0, # Algorithm to solve E and B fields - 0: Curl_EB, 1: Gauss_1D_FFT, 2: Gauss_1D_Cartesian, 3: Poisson_1D_FFT, "number_grid_points" : 201, # Number of grid points "number_pseudoelectrons" : 3000, # Number of pseudoelectrons "total_steps" : 6000, # Total number of time steps diff --git a/examples/scaling_energy_time.py b/examples/scaling_energy_time.py index f793fa3..edf6809 100644 --- a/examples/scaling_energy_time.py +++ b/examples/scaling_energy_time.py @@ -28,7 +28,7 @@ } solver_parameters = { - "field_solver" : 1, # Algorithm to solve E and B fields - 0: Curl_EB, 1: Gauss_1D_FFT, 2: Gauss_1D_Cartesian, 3: Poisson_1D_FFT, + "field_solver" : 1, # Algorithm to solve E and B fields - 0: Curl_EB, 1: Gauss_1D_FFT, 2: Gauss_1D_Cartesian, 3: Poisson_1D_FFT, "number_grid_points" : 50, # Number of grid points "number_pseudoelectrons" : 1500, # Number of pseudoelectrons "total_steps" : 2000, # Total number of time steps @@ -147,4 +147,4 @@ def plot_and_save_results(x_data_list, y_data_list, x_labels, y_labels, file_nam # Plotting the relative energy error plot_and_save_results(x_data_list, error_y_data_list, x_labels, [error_y_label] * 4, 'scaling_energy_error.pdf') -plt.show() \ No newline at end of file +plt.show() diff --git a/jaxincell/__init__.py b/jaxincell/__init__.py index 0c82ced..18a18f1 100644 --- a/jaxincell/__init__.py +++ b/jaxincell/__init__.py @@ -5,4 +5,4 @@ from ._particles import * from ._plot import * from ._simulation import * -from ._sources import * \ No newline at end of file +from ._sources import * diff --git a/jaxincell/__main__.py b/jaxincell/__main__.py index d71db8a..19a9698 100644 --- a/jaxincell/__main__.py +++ b/jaxincell/__main__.py @@ -19,4 +19,4 @@ def main(cl_args=sys.argv[1:]): plot(output) if __name__ == "__main__": - main(sys.argv[1:]) \ No newline at end of file + main(sys.argv[1:]) diff --git a/jaxincell/_algorithms.py b/jaxincell/_algorithms.py index 03a8a85..a23f94f 100644 --- a/jaxincell/_algorithms.py +++ b/jaxincell/_algorithms.py @@ -22,11 +22,11 @@ def Boris_step(carry, step_index, parameters, dx, dt, grid, box_size, (E_field, B_field, positions_minus1_2, positions, positions_plus1_2, velocities, qs, ms, q_ms) = carry - + J = current_density(positions_minus1_2, positions, positions_plus1_2, velocities, qs, dx, dt, grid, grid[0] - dx / 2, particle_BC_left, particle_BC_right) E_field, B_field = field_update1(E_field, B_field, dx, dt/2, J, field_BC_left, field_BC_right) - + # Add external fields total_E = E_field + parameters["external_electric_field"] total_B = B_field + parameters["external_magnetic_field"] @@ -51,14 +51,14 @@ def interpolate_fields(x_n): positions_plus3_2, velocities_plus1, qs, ms, q_ms = set_BC_particles( positions_plus3_2, velocities_plus1, qs, ms, q_ms, dx, grid, *box_size, particle_BC_left, particle_BC_right) - + positions_plus1 = set_BC_positions(positions_plus3_2 - (dt / 2) * velocities_plus1, qs, dx, grid, *box_size, particle_BC_left, particle_BC_right) J = current_density(positions_plus1_2, positions_plus1, positions_plus3_2, velocities_plus1, qs, dx, dt, grid, grid[0] - dx / 2, particle_BC_left, particle_BC_right) E_field, B_field = field_update2(E_field, B_field, dx, dt/2, J, field_BC_left, field_BC_right) - + if field_solver != 0: charge_density = calculate_charge_density(positions, qs, dx, grid + dx / 2, particle_BC_left, particle_BC_right) switcher = { @@ -80,7 +80,7 @@ def interpolate_fields(x_n): # Collect data for storage charge_density = calculate_charge_density(positions, qs, dx, grid, particle_BC_left, particle_BC_right) step_data = (positions, velocities, E_field, B_field, J, charge_density) - + return carry, step_data # Implicit Crank-Nicolson step @@ -141,7 +141,7 @@ def substep_loop(sub_carry, step_idx): return (pos_new, vel_new, qs_new, ms_new, q_ms_new, pos_stag_arr), J_sub * dtau - # initial substep carry + # initial substep carry sub_init = ( pos_fix, vel_fix, qs_prev, ms_prev, q_ms_prev, @@ -178,14 +178,14 @@ def substep_loop(sub_carry, step_idx): picard_init = (E_old, E_new, positions, positions_new, velocities, velocities_new, qs, ms, q_ms, positions_sub1_2_all_init) state0 = (picard_init, jnp.zeros_like(E_new), delta_E0, iter_idx0) - + def cond_fn(state): _, _, delta_E, i = state return jnp.logical_and(delta_E > tol, i < max_iter) def body_fn(state): carry, _, _, i = state - + E_old = carry[0] new_carry, J_iter = picard_step(carry, None) @@ -202,12 +202,12 @@ def body_fn(state): B_field = B_new positions_plus1= positions_new velocities_plus1 = velocities_new - + charge_density = calculate_charge_density(positions_new, qs, dx, grid, particle_BC_left, particle_BC_right) carry = (E_field, B_field, positions_plus1, velocities_plus1, qs, ms, q_ms) - + # Collect data step_data = (positions_plus1, velocities_plus1, E_field, B_field, J, charge_density) - + return carry, step_data diff --git a/jaxincell/_boundary_conditions.py b/jaxincell/_boundary_conditions.py index e794fb6..5ac0564 100644 --- a/jaxincell/_boundary_conditions.py +++ b/jaxincell/_boundary_conditions.py @@ -117,7 +117,7 @@ def set_BC_single_particle_positions(x_n, dx, grid, box_size_x, box_size_y, box_ x_n0 = jnp.where( x_n[0] < -box_size_x / 2, - jnp.where(BC_left == 0, (x_n[0] + box_size_x / 2) % box_size_x - box_size_x / 2, + jnp.where(BC_left == 0, (x_n[0] + box_size_x / 2) % box_size_x - box_size_x / 2, jnp.where(BC_left == 1, -box_size_x - x_n[0], grid[0] - 1.5 * dx)), # Absorbing jnp.where( x_n[0] > box_size_x / 2, @@ -148,8 +148,8 @@ def set_BC_positions(xs_n, qs, dx, grid, box_size_x, box_size_y, box_size_z, BC_ @jit def field_ghost_cells_E(field_BC_left, field_BC_right, E_field, B_field): """ - Set the ghost cells for the electric field at the boundaries of the simulation grid. - The ghost cells are used to apply boundary conditions and extend the field in the + Set the ghost cells for the electric field at the boundaries of the simulation grid. + The ghost cells are used to apply boundary conditions and extend the field in the simulation domain based on the selected boundary conditions. Args: @@ -181,8 +181,8 @@ def field_ghost_cells_E(field_BC_left, field_BC_right, E_field, B_field): @jit def field_ghost_cells_B(field_BC_left, field_BC_right, B_field, E_field): """ - Set the ghost cells for the magnetic field at the boundaries of the simulation grid. - The ghost cells are used to apply boundary conditions and extend the magnetic field + Set the ghost cells for the magnetic field at the boundaries of the simulation grid. + The ghost cells are used to apply boundary conditions and extend the magnetic field in the simulation domain based on the selected boundary conditions. Args: @@ -209,13 +209,13 @@ def field_ghost_cells_B(field_BC_left, field_BC_right, B_field, E_field): @jit def field_2_ghost_cells(field_BC_left, field_BC_right, field): """ - This function adds ghost cells to the field array, which is used for interpolation when - accessing field values at particle positions. Ghost cells are added to the left and + This function adds ghost cells to the field array, which is used for interpolation when + accessing field values at particle positions. Ghost cells are added to the left and right boundaries based on the specified boundary conditions for the particles. - Ghost cells are needed for simulations to handle boundary effects by using the appropriate - field values at the boundaries. This is especially important in simulations where particles - can cross boundary regions, and the electric and magnetic fields must be extended beyond + Ghost cells are needed for simulations to handle boundary effects by using the appropriate + field values at the boundaries. This is especially important in simulations where particles + can cross boundary regions, and the electric and magnetic fields must be extended beyond the simulation domain. Args: @@ -238,7 +238,7 @@ def field_2_ghost_cells(field_BC_left, field_BC_right, field): jnp.where(field_BC_left==1,field[0], jnp.where(field_BC_left==2,jnp.array([0,0,0]), jnp.array([0,0,0])))) - + field_ghost_cell_R = jnp.where(field_BC_right==0,field[0], jnp.where(field_BC_right==1,field[-1], jnp.where(field_BC_right==2,jnp.array([0,0,0]), diff --git a/jaxincell/_constants.py b/jaxincell/_constants.py index 44277d0..043a7a3 100644 --- a/jaxincell/_constants.py +++ b/jaxincell/_constants.py @@ -4,4 +4,4 @@ elementary_charge = 1.60217663e-19 # Elementary charge mass_electron = 9.10938371e-31 # Electron mass mass_proton = 1.67262193e-27 # Proton mass -boltzmann_constant = 1.380649e-23 # Boltzmann constant \ No newline at end of file +boltzmann_constant = 1.380649e-23 # Boltzmann constant diff --git a/jaxincell/_diagnostics.py b/jaxincell/_diagnostics.py index 182cc23..ae550d1 100644 --- a/jaxincell/_diagnostics.py +++ b/jaxincell/_diagnostics.py @@ -12,7 +12,7 @@ def diagnostics(output): total_steps = output['total_steps'] mass_electrons = output["mass_electrons"][0] mass_ions = output["mass_ions"][0] - + # array_to_do_fft_on = charge_density_over_time[:,len(grid)//2] array_to_do_fft_on = E_field_over_time[:,len(grid)//2,0] array_to_do_fft_on = (array_to_do_fft_on-jnp.mean(array_to_do_fft_on))/jnp.max(array_to_do_fft_on) @@ -26,20 +26,20 @@ def diagnostics(output): def integrate(y, dx): return 0.5 * (jnp.asarray(dx) * (y[..., 1:] + y[..., :-1])).sum(-1) # def integrate(y, dx): return jnp.sum(y, axis=-1) * dx - + abs_E_squared = jnp.sum(output['electric_field']**2, axis=-1) abs_externalE_squared = jnp.sum(output['external_electric_field']**2, axis=-1) integral_E_squared = integrate(abs_E_squared, dx=output['dx']) integral_externalE_squared = integrate(abs_externalE_squared, dx=output['dx']) - + abs_B_squared = jnp.sum(output['magnetic_field']**2, axis=-1) abs_externalB_squared = jnp.sum(output['external_magnetic_field']**2, axis=-1) integral_B_squared = integrate(abs_B_squared, dx=output['dx']) integral_externalB_squared = integrate(abs_externalB_squared, dx=output['dx']) - + v_electrons_squared = jnp.sum(jnp.sum(output['velocity_electrons']**2, axis=-1), axis=-1) v_ions_squared = jnp.sum(jnp.sum(output['velocity_ions']**2 , axis=-1), axis=-1) - + output.update({ 'electric_field_energy_density': (epsilon_0/2) * abs_E_squared, @@ -56,9 +56,9 @@ def integrate(y, dx): return 0.5 * (jnp.asarray(dx) * (y[..., 1:] + y[..., :-1]) 'external_magnetic_field_energy_density': 1/(2*mu_0) * abs_externalB_squared, 'external_magnetic_field_energy': 1/(2*mu_0) * integral_externalB_squared }) - + total_energy = (output["electric_field_energy"] + output["external_electric_field_energy"] + output["magnetic_field_energy"] + output["external_magnetic_field_energy"] + output["kinetic_energy"]) - - output.update({'total_energy': total_energy}) \ No newline at end of file + + output.update({'total_energy': total_energy}) diff --git a/jaxincell/_fields.py b/jaxincell/_fields.py index be98a46..ec66539 100644 --- a/jaxincell/_fields.py +++ b/jaxincell/_fields.py @@ -9,7 +9,7 @@ @jit def E_from_Gauss_1D_FFT(charge_density, dx): """ - Solve for the electric field E = -d(phi)/dx using FFT, + Solve for the electric field E = -d(phi)/dx using FFT, where phi is derived from the 1D Gauss' law equation. Parameters: charge_density : 1D numpy array, source term (right-hand side of Poisson equation) @@ -34,7 +34,7 @@ def E_from_Gauss_1D_FFT(charge_density, dx): @jit def E_from_Poisson_1D_FFT(charge_density, dx): """ - Solve for the electric field E = -d(phi)/dx using FFT, + Solve for the electric field E = -d(phi)/dx using FFT, where phi is derived from the 1D Poisson equation. Parameters: charge_density : 1D numpy array, source term (right-hand side of Poisson equation) @@ -63,20 +63,20 @@ def E_from_Poisson_1D_FFT(charge_density, dx): @jit def E_from_Gauss_1D_Cartesian(charge_density, dx): """ - Solve for the electric field at t=0 (E0) using the charge density distribution + Solve for the electric field at t=0 (E0) using the charge density distribution and applying Gauss's law in a 1D system. Args: charge_density : 1D numpy array, source term (right-hand side of Gauss equation) dx : float, grid spacing in the x-direction - + Returns: array: The electric field at each grid point due to the particles, shape (G,). """ # Construct divergence matrix for solving Gauss' Law divergence_matrix = jnp.diag(jnp.ones(len(charge_density)))-jnp.diag(jnp.ones(len(charge_density)-1),k=-1) divergence_matrix.at[0,-1].set(-1) - + # Solve for the electric field using Gauss' law in the 1D case E_field_from_Gauss = (dx / epsilon_0) * jnp.linalg.solve(divergence_matrix, charge_density) return E_field_from_Gauss @@ -85,7 +85,7 @@ def E_from_Gauss_1D_Cartesian(charge_density, dx): @jit def curlE(E_field, B_field, dx, dt, field_BC_left, field_BC_right): """ - Compute the curl of the electric field, which is related to the time derivative of + Compute the curl of the electric field, which is related to the time derivative of the magnetic field in Maxwell's equations (Faraday's law). Args: @@ -103,7 +103,7 @@ def curlE(E_field, B_field, dx, dt, field_BC_left, field_BC_right): ghost_cell_L, ghost_cell_R = field_ghost_cells_E(field_BC_left, field_BC_right, E_field, B_field) E_field = jnp.insert(E_field, 0, ghost_cell_L, axis=0) E_field = jnp.append(E_field, jnp.array([ghost_cell_R]), axis=0) - + # Compute the curl using the finite difference approximation for 1D (only d/dx) dFz_dx = (E_field[1:-1, 2] - E_field[0:-2, 2]) / dx dFy_dx = (E_field[1:-1, 1] - E_field[0:-2, 1]) / dx @@ -115,7 +115,7 @@ def curlE(E_field, B_field, dx, dt, field_BC_left, field_BC_right): @jit def curlB(B_field, E_field, dx, dt, field_BC_left, field_BC_right): """ - Compute the curl of the magnetic field, which is related to the time derivative of + Compute the curl of the magnetic field, which is related to the time derivative of the electric field in Maxwell's equations (Ampère's law with Maxwell correction). Args: @@ -134,7 +134,7 @@ def curlB(B_field, E_field, dx, dt, field_BC_left, field_BC_right): B_field = jnp.insert(B_field, 0, ghost_cell_L, axis=0) B_field = jnp.append(B_field, jnp.array([ghost_cell_R]), axis=0) - #If taking E_i = B_(i+1) - B_i (since B-fields defined on centers), roll by -1 first. + #If taking E_i = B_(i+1) - B_i (since B-fields defined on centers), roll by -1 first. B_field = jnp.roll(B_field, -1, axis=0) # Compute the curl using the finite difference approximation for 1D (only d/dx) @@ -166,7 +166,7 @@ def field_update(E_fields, B_fields, dx, dt, j, field_BC_left, field_BC_right): # Faraday's law curl_B = curlB(B_fields, E_fields, dx, dt, field_BC_left, field_BC_right) - + # Update the Fields B_fields -= dt*curl_E E_fields += dt*((speed_of_light**2)*curl_B-(j/epsilon_0)) @@ -191,4 +191,4 @@ def field_update2(E_fields, B_fields, dx, dt, j, field_BC_left, field_BC_right): #Then, update E (Ampere's) curl_B = curlB(B_fields, E_fields, dx, dt, field_BC_left, field_BC_right) E_fields += dt*((speed_of_light**2)*curl_B-(j/epsilon_0)) - return E_fields,B_fields \ No newline at end of file + return E_fields,B_fields diff --git a/jaxincell/_particles.py b/jaxincell/_particles.py index 2d8c67d..eafd279 100644 --- a/jaxincell/_particles.py +++ b/jaxincell/_particles.py @@ -8,9 +8,9 @@ @jit def fields_to_particles_grid(x_n, field, dx, grid, grid_start, field_BC_left, field_BC_right): """ - This function retrieves the electric or magnetic field values at particle positions - using a field interpolation scheme. The function first adds ghost cells to the field - array to handle boundary conditions, then interpolates the field based on the + This function retrieves the electric or magnetic field values at particle positions + using a field interpolation scheme. The function first adds ghost cells to the field + array to handle boundary conditions, then interpolates the field based on the particle's position in the grid. Args: @@ -31,25 +31,25 @@ def fields_to_particles_grid(x_n, field, dx, grid, grid_start, field_BC_left, fi field = jnp.insert(field,0,ghost_cell_L1,axis=0) field = jnp.append(field,jnp.array([ghost_cell_R]),axis=0) x = x_n[0] - + # Adjust the grid to accommodate particles at the first half grid cell (staggered grid) #If using a staggered grid, particles at first half cell will be out of grid, so add extra cell - grid = jnp.insert(grid,0,grid[0]-dx,axis=0) - + grid = jnp.insert(grid,0,grid[0]-dx,axis=0) + # Calculate the index of the field grid corresponding to the particle position i = ((x-grid_start+dx)//dx).astype(int) #new grid_start = grid_start-dx due to extra cell - + # Interpolate the field at the particle position using a quadratic interpolation fields_n = 0.5*field[i]*(0.5+(grid[i]-x)/dx)**2 + field[i+1]*(0.75-(grid[i]-x)**2/dx**2) + 0.5*field[i+2]*(0.5-(grid[i]-x)/dx)**2 - + return fields_n @jit def rotation(dt, B, vsub, q_m): """ - This function implements the Boris algorithm to rotate the particle velocity vector - in the magnetic field for one time step. This step is part of the numerical solution + This function implements the Boris algorithm to rotate the particle velocity vector + in the magnetic field for one time step. This step is part of the numerical solution of the Lorentz force equation. Args: @@ -63,20 +63,20 @@ def rotation(dt, B, vsub, q_m): """ # First part of the Boris algorithm: calculate intermediate velocity Rvec = vsub + 0.5 * dt * q_m * jnp.cross(vsub, B) - + # Magnetic field vector term for the rotation step Bvec = 0.5 * q_m * dt * B - + # Apply the Boris rotation step to the velocity vector vplus = (jnp.cross(Rvec, Bvec) + jnp.dot(Rvec, Bvec) * Bvec + Rvec) / (1 + jnp.dot(Bvec, Bvec)) - + return vplus @jit def boris_step(dt, xs_nplushalf, vs_n, q_ms, E_fields_at_x, B_fields_at_x): """ - This function performs one step of the Boris algorithm for particle motion. - The particle velocity is updated using the electric and magnetic fields at its position, + This function performs one step of the Boris algorithm for particle motion. + The particle velocity is updated using the electric and magnetic fields at its position, and the particle position is updated using the new velocity. Args: @@ -94,16 +94,16 @@ def boris_step(dt, xs_nplushalf, vs_n, q_ms, E_fields_at_x, B_fields_at_x): """ # First half step update for velocity due to electric field vs_n_int = vs_n + (q_ms) * E_fields_at_x * dt / 2 - + # Apply the Boris rotation step for the magnetic field vs_n_rot = vmap(lambda B_n, v_n, q_m: rotation(dt, B_n, v_n, q_m))(B_fields_at_x, vs_n_int, q_ms[:, 0]) - + # Second half step update for velocity due to electric field vs_nplus1 = vs_n_rot + (q_ms) * E_fields_at_x * dt / 2 - + # Update the particle positions using the new velocities xs_nplus3_2 = xs_nplushalf + dt * vs_nplus1 - + return xs_nplus3_2, vs_nplus1 # vs_nplus1 = vs_n + (q_ms) * E_fields_at_x * dt # xs_nplus1 = xs_nplushalf + dt * vs_nplus1 @@ -131,7 +131,7 @@ def relativistic_rotation(dt, B, p_minus, q, m): def boris_step_relativistic(dt, xs_nplushalf, vs_n, q_s, m_s, E_fields_at_x, B_fields_at_x): """ Relativistic Boris pusher for N particles. - + Args: dt: Time step xs_nplushalf: Particle positions at t = n + 1/2, shape (N, 3) @@ -177,4 +177,4 @@ def single_particle_step(x, v, q, m, E, B): xs_nplushalf, vs_n, q_s, m_s, E_fields_at_x, B_fields_at_x ) - return xs_nplus3_2, vs_nplus1 \ No newline at end of file + return xs_nplus3_2, vs_nplus1 diff --git a/jaxincell/_plot.py b/jaxincell/_plot.py index 7f52d66..bf2bea2 100644 --- a/jaxincell/_plot.py +++ b/jaxincell/_plot.py @@ -164,19 +164,19 @@ def plot_field(ax, field_data, title, xlabel, ylabel, cbar_label): energy_ax.set(title="Energy", xlabel=r"Time ($\omega_{pe}^{-1}$)", ylabel="Energy (J)", yscale="log", ylim=[1e-7, None]) energy_ax.legend(fontsize=7) - + if second_direction: electron_ax2 = axes_flat[used_axes + 1] ion_ax2 = axes_flat[used_axes + 2] positions_ax = axes_flat[used_axes + 3] - + # Phase space in second direction max_velocity_electrons_2 = max(1.0 * jnp.max(output["velocity_electrons"][:, :, direction_index2]), 2.5 * jnp.abs(vth_e_2) + jnp.abs(output[f"electron_drift_speed_{direction2}"])) max_velocity_ions_2 = max(1.0 * jnp.max(output["velocity_ions"][:, :, direction_index2]), sqrtmemi * 0.3 * jnp.abs(vth_i_2) * jnp.sqrt(output[f"ion_temperature_over_electron_temperature_{direction2}"]) + jnp.abs(output[f"ion_drift_speed_{direction2}"])) - + max_velocity_electrons_12 = max(max_velocity_electrons_1, max_velocity_electrons_2) max_velocity_ions_12 = max(max_velocity_ions_1, max_velocity_ions_2) electron_phase_histograms2 = vmap(lambda pos, vel: jnp.histogram2d( @@ -204,7 +204,7 @@ def plot_field(ax, field_data, title, xlabel, ylabel, cbar_label): ion_ax2.set(xlabel=f"Ion Velocity {direction1} (m/s)", ylabel=f"Ion Velocity {direction2} (m/s)", title=f"Ion Phase Space v{direction1} vs v{direction2}") - + B_field_densities = output["magnetic_field_energy_density"] B0 = np.asarray(B_field_densities[0]) global_max = np.asarray(B_field_densities).max() @@ -226,12 +226,12 @@ def plot_field(ax, field_data, title, xlabel, ylabel, cbar_label): scat_r.set_animated(True) scat_b.set_animated(True) im.set_animated(True) - + # Time label animated_time_text2 = positions_ax.text(0.5, 0.9, "", transform=positions_ax.transAxes, ha="center", va="top", fontsize=12, bbox=dict(facecolor='white', alpha=0.7, edgecolor='none')) - + # Add marker explanation text positions_ax.text(0.02, 0.98, "Electrons: '<'", color="red", transform=positions_ax.transAxes, ha="left", va="top", @@ -239,7 +239,7 @@ def plot_field(ax, field_data, title, xlabel, ylabel, cbar_label): positions_ax.text(0.02, 0.92, "Ions: '>'", color="blue", transform=positions_ax.transAxes, ha="left", va="top", fontsize=10, bbox=dict(facecolor='white', alpha=0.7, edgecolor='none')) - + def update(frame): electron_plot.set_array(electron_phase_histograms[frame].T) ion_plot.set_array(ion_phase_histograms[frame].T) @@ -247,7 +247,7 @@ def update(frame): if second_direction: electron_plot2.set_array(electron_phase_histograms2[frame].T) ion_plot2.set_array(ion_phase_histograms2[frame].T) - + x_electrons = np.asarray(output["position_electrons"][frame, subset_electrons, direction_index2]) z_electrons = np.asarray(output["position_electrons"][frame, subset_electrons, direction_index1]) x_ions = np.asarray(output["position_ions"][frame, subset_ions, direction_index2]) diff --git a/jaxincell/_simulation.py b/jaxincell/_simulation.py index da5d9c2..dbe0281 100644 --- a/jaxincell/_simulation.py +++ b/jaxincell/_simulation.py @@ -44,14 +44,14 @@ def load_parameters(input_file): def initialize_simulation_parameters(user_parameters={}): """ - Initialize the simulation parameters for a particle-in-cell simulation, - combining user-provided values with predefined defaults. This function - ensures all required parameters are set and automatically calculates + Initialize the simulation parameters for a particle-in-cell simulation, + combining user-provided values with predefined defaults. This function + ensures all required parameters are set and automatically calculates derived parameters based on the inputs. - The function uses lambda functions to define derived parameters that - depend on other parameters. These lambda functions are evaluated after - merging user-provided parameters with the defaults, ensuring derived + The function uses lambda functions to define derived parameters that + depend on other parameters. These lambda functions are evaluated after + merging user-provided parameters with the defaults, ensuring derived parameters are consistent with any overrides. Parameters: @@ -115,19 +115,19 @@ def initialize_simulation_parameters(user_parameters={}): "particle_BC_right": 0, # Right boundary condition for particles "field_BC_left": 0, # Left boundary condition for fields "field_BC_right": 0, # Right boundary condition for fields - + # External fields (initialized to zero) "external_electric_field_amplitude": 0, # Amplitude of sinusoidal (cos) perturbation in x "external_electric_field_wavenumber": 0, # Wavenumber of sinusoidal (cos) perturbation in x (factor of 2pi/length) "external_magnetic_field_amplitude": 0, # Amplitude of sinusoidal (cos) perturbation in x "external_magnetic_field_wavenumber": 0, # Wavenumber of sinusoidal (cos) perturbation in x (factor of 2pi/length) - + "weight": 0, } # Merge user-provided parameters into the default dictionary parameters = {**default_parameters, **user_parameters} - + # Compute derived parameters based on user-provided or default values for key, value in parameters.items(): if callable(value): # If the value is a lambda function, compute it @@ -139,9 +139,9 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb ,max_number_of_Picard_iterations_implicit_CN=7, number_of_particle_substeps_implicit_CN=2): """ Initialize particles and electromagnetic fields for a Particle-in-Cell simulation. - - This function generates particle positions, velocities, charges, masses, and - charge-to-mass ratios, as well as the initial electric and magnetic fields. It + + This function generates particle positions, velocities, charges, masses, and + charge-to-mass ratios, as well as the initial electric and magnetic fields. It combines user-provided parameters with default values and calculates derived quantities. Parameters: @@ -169,7 +169,7 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb # Random key generator for reproducibility seed = parameters["seed"] - + # **Particle Positions** # Use random positions in x if requested, otherwise linspace electron_xs = lax.cond(parameters["random_positions_x"], @@ -247,7 +247,7 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb / number_grid_points / (vth_electrons_over_c) )) - + charges = jnp.concatenate(( charge_electrons * weight * jnp.ones((number_pseudoelectrons, 1)), charge_ions * weight * jnp.ones((number_pseudoelectrons, 1)) @@ -267,7 +267,7 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb v_electrons_z = parameters["vth_electrons_over_c_z"] * speed_of_light / jnp.sqrt(2) * normal(PRNGKey(seed+9), shape=(number_pseudoelectrons, )) + parameters["electron_drift_speed_z"] v_electrons_z = jnp.where(parameters["velocity_plus_minus_electrons_z"], v_electrons_z * (-1) ** jnp.arange(0, number_pseudoelectrons), v_electrons_z) electron_velocities = jnp.stack((v_electrons_x, v_electrons_y, v_electrons_z), axis=1) - + # Ion thermal velocities and drift speeds vth_ions_x = jnp.sqrt(jnp.abs(parameters["ion_temperature_over_electron_temperature_x"])) * parameters["vth_electrons_over_c_x"] * speed_of_light * jnp.sqrt(jnp.abs(mass_electrons / mass_ions)) vth_ions_y = jnp.sqrt(jnp.abs(parameters["ion_temperature_over_electron_temperature_y"])) * parameters["vth_electrons_over_c_y"] * speed_of_light * jnp.sqrt(jnp.abs(mass_electrons / mass_ions)) @@ -279,7 +279,7 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb v_ions_z = vth_ions_z / jnp.sqrt(2) * normal(PRNGKey(seed+12), shape=(number_pseudoelectrons, )) + parameters["ion_drift_speed_z"] v_ions_z = jnp.where(parameters["velocity_plus_minus_ions_z"], v_ions_z * (-1) ** jnp.arange(0, number_pseudoelectrons), v_ions_z) ion_velocities = jnp.stack((v_ions_x, v_ions_y, v_ions_z), axis=1) - + # Combine electron and ion velocities velocities = jnp.concatenate((electron_velocities, ion_velocities)) # Cap velocities at 99% the speed of light @@ -305,7 +305,7 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb "Ion temperature / Electron temperature: {}\n" "Debye length: {} m\n" "Skin depth: {} m\n" - "Wavenumber * Debye length: {}\n" + "Wavenumber * Debye length: {}\n" "Pseudoparticles per cell: {}\n" "Pseudoparticle weight: {}\n" "Steps at each plasma frequency: {}\n" @@ -329,13 +329,13 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb jnp.max(relativistic_gamma_factor), jnp.mean(relativistic_gamma_factor), -charge_electrons * parameters["external_electric_field_amplitude"] * Debye_length_per_dx*dx / (mass_electrons * vth_electrons**2 / 2), ), lambda _: None, operand=None) - + # **Fields Initialization** B_field = jnp.zeros((grid.size, 3)) E_field = jnp.zeros((grid.size, 3)) - + fields = (E_field, B_field) - + external_E_field_x = parameters["external_electric_field_amplitude"] * jnp.cos(parameters["external_electric_field_wavenumber"] * jnp.linspace(-jnp.pi, jnp.pi, number_grid_points)) external_B_field_x = parameters["external_magnetic_field_amplitude"] * jnp.cos(parameters["external_magnetic_field_wavenumber"] * jnp.linspace(-jnp.pi, jnp.pi, number_grid_points)) @@ -358,22 +358,22 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb "number_grid_points": number_grid_points, "plasma_frequency": plasma_frequency, "max_initial_vth_electrons": vth_electrons, - "max_number_of_Picard_iterations_implicit_CN": max_number_of_Picard_iterations_implicit_CN, + "max_number_of_Picard_iterations_implicit_CN": max_number_of_Picard_iterations_implicit_CN, "number_of_particle_substeps_implicit_CN": number_of_particle_substeps_implicit_CN, }) - + return parameters @partial(jit, static_argnames=['number_grid_points', 'number_pseudoelectrons', 'total_steps', 'field_solver', "time_evolution_algorithm", "max_number_of_Picard_iterations_implicit_CN","number_of_particle_substeps_implicit_CN"]) -def simulation(input_parameters={}, number_grid_points=100, number_pseudoelectrons=3000, total_steps=1000, +def simulation(input_parameters={}, number_grid_points=100, number_pseudoelectrons=3000, total_steps=1000, field_solver=0,positions=None, velocities=None,time_evolution_algorithm=0,max_number_of_Picard_iterations_implicit_CN=7, number_of_particle_substeps_implicit_CN=2): """ Run a plasma physics simulation using a Particle-In-Cell (PIC) method in JAX. This function simulates the evolution of a plasma system by solving for particle motion - (electrons and ions) and self-consistent electromagnetic fields on a grid. It uses the + (electrons and ions) and self-consistent electromagnetic fields on a grid. It uses the Boris algorithm for particle updates and a leapfrog scheme for field updates. Parameters: @@ -423,12 +423,12 @@ def simulation(input_parameters={}, number_grid_points=100, number_pseudoelectro positions + (dt / 2) * velocities, velocities, parameters["charges"], parameters["masses"], parameters["charge_to_mass_ratios"], dx, grid, *box_size, particle_BC_left, particle_BC_right) - + positions_minus1_2 = set_BC_positions( positions - (dt / 2) * velocities, parameters["charges"], dx, grid, *box_size, particle_BC_left, particle_BC_right) - + if time_evolution_algorithm == 0: initial_carry = ( E_field, B_field, positions_minus1_2, positions, @@ -452,7 +452,7 @@ def simulation(input_parameters={}, number_grid_points=100, number_pseudoelectro @scan_tqdm(total_steps) def simulation_step(carry, step_index): return step_func(carry, step_index) - + # Run simulation _, results = lax.scan(simulation_step, initial_carry, jnp.arange(total_steps)) @@ -460,7 +460,7 @@ def simulation_step(carry, step_index): # Unpack results positions_over_time, velocities_over_time, electric_field_over_time, \ magnetic_field_over_time, current_density_over_time, charge_density_over_time = results - + # **Output results** temporary_output = { "position_electrons": positions_over_time[ :, :number_pseudoelectrons, :], @@ -480,9 +480,9 @@ def simulation_step(carry, step_index): "total_steps": total_steps, "time_array": jnp.linspace(0, total_steps * dt, total_steps), } - + output = {**temporary_output, **parameters} diagnostics(output) - - return output \ No newline at end of file + + return output diff --git a/jaxincell/_sources.py b/jaxincell/_sources.py index e4a441f..0653074 100644 --- a/jaxincell/_sources.py +++ b/jaxincell/_sources.py @@ -47,7 +47,7 @@ def charge_density_BCs(particle_BC_left, particle_BC_right, position, dx, grid, @jit def single_particle_charge_density(x, q, dx, grid, particle_BC_left, particle_BC_right): """ - Computes the charge density contribution of a single particle to the grid using a + Computes the charge density contribution of a single particle to the grid using a quadratic particle shape function. Args: @@ -91,7 +91,7 @@ def calculate_charge_density(xs_n, qs, dx, grid, particle_BC_left, particle_BC_r """ # Vectorize over particles chargedens_contrib = vmap(single_particle_charge_density, in_axes=(0, 0, None, None, None, None)) - + # Compute charge density for all particles chargedens = chargedens_contrib(xs_n[:, 0], qs[:, 0], dx, grid, particle_BC_left, particle_BC_right) @@ -155,10 +155,10 @@ def compute_current(i): j_grid_z = chargedens * vz_n return j_grid_x, j_grid_y, j_grid_z # Each output has shape (grid_size,) - + current_dens_x, current_dens_y, current_dens_z = vmap(compute_current)(jnp.arange(len(xs_nminushalf))) current_dens_x = jnp.sum(current_dens_x, axis=0) current_dens_y = jnp.sum(current_dens_y, axis=0) current_dens_z = jnp.sum(current_dens_z, axis=0) - return jnp.stack([current_dens_x, current_dens_y, current_dens_z], axis=0).T \ No newline at end of file + return jnp.stack([current_dens_x, current_dens_y, current_dens_z], axis=0).T From 04bedeaadeb9fed6a22806cc50a6d97480e09ce7 Mon Sep 17 00:00:00 2001 From: Aaron Tran Date: Fri, 25 Jul 2025 10:42:21 -0500 Subject: [PATCH 03/19] Additional particle populations Allow user to specify and add any number of particles with a drifting maxwellian distribution, optionally with sinusoidal density perturbation --- examples/bump-on-tail.toml | 75 ++++++++++++++++++++- jaxincell/_simulation.py | 133 +++++++++++++++++++++++++++++++++++-- 2 files changed, 201 insertions(+), 7 deletions(-) diff --git a/examples/bump-on-tail.toml b/examples/bump-on-tail.toml index 41292a3..1e50182 100644 --- a/examples/bump-on-tail.toml +++ b/examples/bump-on-tail.toml @@ -26,7 +26,7 @@ wavenumber_ions_z = 0 vth_electrons_over_c_x = 0.014142135624 # sqrt(2*T/m)/c vth_electrons_over_c_y = 0 vth_electrons_over_c_z = 0 -electron_drift_speed_x = 0 # Drift speed of electrons in the x direction +electron_drift_speed_x = -1.125e5 # -0.125c * 3e-3 # drift in m/s; compensate bump drift so that j(t)=0 electron_drift_speed_y = 0 electron_drift_speed_z = 0 velocity_plus_minus_electrons_x = false # create two groups of electrons moving in opposite directions in the x direction @@ -62,8 +62,79 @@ relativistic = false # Use relativistic Boris pusher seed = 250724 # Random seed for reproducibility tolerance_Picard_iterations_implicit_CN = 1e-6 # Tolerance for Picard iterations in implicit Crank-Nicholson method +# ------------------------------ +# Additional Maxwellian particle populations with bulk drift. +# To add new populations, copy this block and modify parameter values, +# but don't change the header "[[species]]" or parameter names. +# To only use default ion/electron plasma, comment out all [[species]] blocks. +# All parameters must be specified; internal code doesn't set default values. +# ------------------------------ +[[species]] +mass_over_proton_mass = 0.0005446170199382714 # me/mp using values from jaxincell/_constants.py +charge_over_elementary_charge = -1 +weight_ratio = 3e-3 # sets density ratio between bulk/beam, in combination with "number_pseudoparticles_species" +# spatial distribution +random_positions_x = true +random_positions_y = false +random_positions_z = false +amplitude_perturbation_x = 0 # Amplitude of sinusoidal perturbation +amplitude_perturbation_y = 0 +amplitude_perturbation_z = 0 +wavenumber_perturbation_x = 0 # Wavenumber of sinusoidal density perturbation (factor of 2pi/length) +wavenumber_perturbation_y = 0 +wavenumber_perturbation_z = 0 +# velocity distribution +vth_over_c_x = 0.014142135624 # sqrt(2*T/m)/c +vth_over_c_y = 0 +vth_over_c_z = 0 +drift_speed_x = 3.75e7 # drift speed 0.125c converted to m/s +drift_speed_y = 0 +drift_speed_z = 0 +# RNG control +# "seed_position", "seed_position_override" parameters allow you to put +# different particle species at identical positions, so as to ensure exact +# charge neutrality at t=0 for multiple ion/electron populations. +# WARNING: to avoid unphysically correlated coordinates, choose the position +# seed to not coincide with RNG seeds used elsewhere in the program. +seed_position_override = false +seed_position = 10 + +# enforce charge neutrality for the bump distribution +[[species]] +mass_over_proton_mass = 1e9 +charge_over_elementary_charge = 1 +weight_ratio = 3e-3 +# spatial distribution +random_positions_x = true +random_positions_y = false +random_positions_z = false +amplitude_perturbation_x = 0 # Amplitude of sinusoidal perturbation +amplitude_perturbation_y = 0 +amplitude_perturbation_z = 0 +wavenumber_perturbation_x = 0 # Wavenumber of sinusoidal density perturbation (factor of 2pi/length) +wavenumber_perturbation_y = 0 +wavenumber_perturbation_z = 0 +# velocity distribution +vth_over_c_x = 1e-100 +vth_over_c_y = 1e-100 +vth_over_c_z = 1e-100 +drift_speed_x = 0 +drift_speed_y = 0 +drift_speed_z = 0 +# RNG control +# "seed_position", "seed_position_override" parameters allow you to put +# different particle species at identical positions, so as to ensure exact +# charge neutrality at t=0 for multiple ion/electron populations. +# WARNING: to avoid unphysically correlated coordinates, choose the position +# seed to not coincide with RNG seeds used elsewhere in the program. +seed_position_override = false +seed_position = 10 + [solver_parameters] number_grid_points = 60 # number of grid CELLS, not edges/vertices field_solver = 0 # 0: Curl_EB, 1: Gauss_1D_FFT, 2: Gauss_1D_Cartesian, 3: Poisson_1D_FFT, number_pseudoelectrons = 1000 # Total in entire domain -total_steps = 10000 # Total number of time steps +total_steps = 5000 # Total number of time steps +# number of particles for each additional species +# provide one list value per [[species]] block +number_pseudoparticles_species = [1000,1000,] diff --git a/jaxincell/_simulation.py b/jaxincell/_simulation.py index dbe0281..99c7938 100644 --- a/jaxincell/_simulation.py +++ b/jaxincell/_simulation.py @@ -40,6 +40,19 @@ def load_parameters(input_file): parameters = tomllib.load(open(input_file, "rb")) input_parameters = parameters['input_parameters'] solver_parameters = parameters['solver_parameters'] + # Interface for additional species and/or particle populations + try: + # Nest within main struct to avoid changing top-level internal API + input_parameters['species'] = parameters['species'] + except: + input_parameters['species'] = [] + # Convert TOML array -> Python tuple to make hashable static argument, as + # required by Jax + try: + solver_parameters['number_pseudoparticles_species'] = tuple(solver_parameters['number_pseudoparticles_species']) + except: + solver_parameters['number_pseudoparticles_species'] = () + assert len(solver_parameters['number_pseudoparticles_species']) == len(input_parameters['species']) return input_parameters, solver_parameters def initialize_simulation_parameters(user_parameters={}): @@ -135,7 +148,8 @@ def initialize_simulation_parameters(user_parameters={}): return parameters -def initialize_particles_fields(input_parameters={}, number_grid_points=50, number_pseudoelectrons=500, total_steps=350 +def initialize_particles_fields(input_parameters={}, number_grid_points=50, number_pseudoelectrons=500, + number_pseudoparticles_species=None, total_steps=350 ,max_number_of_Picard_iterations_implicit_CN=7, number_of_particle_substeps_implicit_CN=2): """ Initialize particles and electromagnetic fields for a Particle-in-Cell simulation. @@ -256,7 +270,6 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb mass_electrons * weight * jnp.ones((number_pseudoelectrons, 1)), mass_ions * weight * jnp.ones((number_pseudoelectrons, 1)) ), axis=0) - charge_to_mass_ratios = charges / masses # **Particle Velocities** # Electron thermal velocities and drift speeds @@ -282,6 +295,21 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb # Combine electron and ion velocities velocities = jnp.concatenate((electron_velocities, ion_velocities)) + + # Introduce additional particle species + for ii, species in enumerate(parameters['species']): + plists = make_particles(species_parameters=species, + Nprt=number_pseudoparticles_species[ii], + box_size=box_size, weight=weight, seed=seed, + rng_index=ii) + positions = jnp.concatenate((positions, plists['positions'])) + velocities = jnp.concatenate((velocities, plists['velocities'])) + charges = jnp.concatenate((charges, plists['charges']), axis=0) + masses = jnp.concatenate((masses, plists['masses']), axis=0) + + # After done adding all species + charge_to_mass_ratios = charges / masses + # Cap velocities at 99% the speed of light speed_limit = 0.99 * speed_of_light velocities = jnp.where(jnp.abs(velocities) >= speed_limit, jnp.sign(velocities) * speed_limit, velocities) @@ -364,10 +392,103 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb return parameters +def make_particles(species_parameters, Nprt, box_size, weight, seed, rng_index): + """ + Generate Nprt total particles of a user-requested species with specified + charge, mass, and space/velocity distribution. + + Parameters: + ---------- + species_parameters : dict + Dictionary of user-specified species parameters. + Nprt : int + Total number of pseudoparticles in the domain + box_size : tuple + Domain size in x,y,z + weight : float + Top-level pseudoelectron weight + seed : int + Top-level random number generator seed used for entire simulation + rng_index : int + Species or particle population index in [0,1,2,3,...] + Use a unique index value for each population. + This index is used to advance the random seed and so avoid spurious + correlation between different particle positions and velocities. + See https://docs.jax.dev/en/latest/random-numbers.html + + Returns: + ------- + plist : dict + Dictionary with lists of positions, velocities, charges, masses. + """ + _p = species_parameters + charge = _p["charge_over_elementary_charge"] * elementary_charge + mass = _p["mass_over_proton_mass"] * mass_proton + vth_x = _p["vth_over_c_x"] * speed_of_light + vth_y = _p["vth_over_c_y"] * speed_of_light + vth_z = _p["vth_over_c_z"] * speed_of_light + + # This code is brittle; it depends on hard-coded offsets to the RNG seed + # within initialize_particles_fields(...) + assert rng_index >= 0 + local_seed = seed+12 + rng_index*6 + # Separate position/velocity seeds allow different ion and electron + # populations to be inited with identical space positions, but + # uncorrelated velocity distributions + seed_pos = lax.cond(_p['seed_position_override'], + lambda _: _p['seed_position'], + lambda _: local_seed, operand=None) + seed_vel = local_seed + + out = dict() + + # **Particle Positions** + + xs = lax.cond(_p["random_positions_x"], + lambda _: uniform(PRNGKey(seed_pos+1), shape=(Nprt,), + minval=-box_size[0] / 2, maxval=box_size[0] / 2), + lambda _: jnp.linspace(-box_size[0] / 2, box_size[0] / 2, Nprt), operand=None) + wavenumber_perturbation_x = _p["wavenumber_perturbation_x"] * 2 * jnp.pi / box_size[0] + xs += _p["amplitude_perturbation_x"] * jnp.sin(wavenumber_perturbation_x * xs) + + ys = lax.cond(_p["random_positions_y"], + lambda _: uniform(PRNGKey(seed_pos+2), shape=(Nprt,), + minval=-box_size[1] / 2, maxval=box_size[1] / 2), + lambda _: jnp.linspace(-box_size[1] / 2, box_size[1] / 2, Nprt), operand=None) + wavenumber_perturbation_y = _p["wavenumber_perturbation_y"] * 2 * jnp.pi / box_size[1] + ys += _p["amplitude_perturbation_y"] * jnp.sin(wavenumber_perturbation_y * ys) + + zs = lax.cond(_p["random_positions_z"], + lambda _: uniform(PRNGKey(seed_pos+3), shape=(Nprt,), + minval=-box_size[2] / 2, maxval=box_size[2] / 2), + lambda _: jnp.linspace(-box_size[2] / 2, box_size[2] / 2, Nprt), operand=None) + wavenumber_perturbation_z = _p["wavenumber_perturbation_z"] * 2 * jnp.pi / box_size[2] + zs += _p["amplitude_perturbation_z"] * jnp.sin(wavenumber_perturbation_z * zs) + + out['positions'] = jnp.stack((xs, ys, zs), axis=1) + + # **Particle Charges and Masses** + + out['charges'] = charge * weight * _p['weight_ratio'] * jnp.ones((Nprt, 1)) + out['masses'] = mass * weight * _p['weight_ratio'] * jnp.ones((Nprt, 1)) + + # **Particle Velocities** + + v_x = vth_x/jnp.sqrt(2) * normal(PRNGKey(seed_vel+4), shape=(Nprt,)) + v_y = vth_y/jnp.sqrt(2) * normal(PRNGKey(seed_vel+5), shape=(Nprt,)) + v_z = vth_z/jnp.sqrt(2) * normal(PRNGKey(seed_vel+6), shape=(Nprt,)) + v_x += _p["drift_speed_x"] + v_y += _p["drift_speed_y"] + v_z += _p["drift_speed_z"] + + out['velocities'] = jnp.stack((v_x, v_y, v_z), axis=1) + + return out -@partial(jit, static_argnames=['number_grid_points', 'number_pseudoelectrons', 'total_steps', 'field_solver', "time_evolution_algorithm", +@partial(jit, static_argnames=['number_grid_points', 'number_pseudoelectrons', 'number_pseudoparticles_species', 'total_steps', 'field_solver', "time_evolution_algorithm", "max_number_of_Picard_iterations_implicit_CN","number_of_particle_substeps_implicit_CN"]) -def simulation(input_parameters={}, number_grid_points=100, number_pseudoelectrons=3000, total_steps=1000, +def simulation(input_parameters={}, number_grid_points=100, number_pseudoelectrons=3000, + number_pseudoparticles_species=None, total_steps=1000, field_solver=0,positions=None, velocities=None,time_evolution_algorithm=0,max_number_of_Picard_iterations_implicit_CN=7, number_of_particle_substeps_implicit_CN=2): """ Run a plasma physics simulation using a Particle-In-Cell (PIC) method in JAX. @@ -391,7 +512,9 @@ def simulation(input_parameters={}, number_grid_points=100, number_pseudoelectro """ # **Initialize simulation parameters** parameters = initialize_particles_fields(input_parameters, number_grid_points=number_grid_points, - number_pseudoelectrons=number_pseudoelectrons, total_steps=total_steps, + number_pseudoelectrons=number_pseudoelectrons, + number_pseudoparticles_species=number_pseudoparticles_species, + total_steps=total_steps, max_number_of_Picard_iterations_implicit_CN=max_number_of_Picard_iterations_implicit_CN, number_of_particle_substeps_implicit_CN=number_of_particle_substeps_implicit_CN) From 1d747e61cb47baf84ef1953d74fa108bf7e89a6a Mon Sep 17 00:00:00 2001 From: Aaron Tran Date: Fri, 25 Jul 2025 11:49:23 -0500 Subject: [PATCH 04/19] Separate diagnostics(...) from simulation(...) to separate ion/electron populations of unknown location in the particle array, it's useful to construct masked arrays which must be done in a non-jitted subroutine --- examples/Landau_damping.py | 5 ++++- examples/Langmuir_wave.py | 5 ++++- examples/Weibel_instability.py | 5 ++++- examples/auto-differentiability.py | 4 +++- examples/bump-on-tail.py | 5 ++++- examples/optimize_two_stream_saturation.py | 6 +++++- examples/scaling_energy_time.py | 6 ++++-- examples/two-stream_instability.py | 5 ++++- jaxincell/_diagnostics.py | 19 +++++++++++++++++ jaxincell/_simulation.py | 24 ++++++++++++++-------- 10 files changed, 66 insertions(+), 18 deletions(-) diff --git a/examples/Landau_damping.py b/examples/Landau_damping.py index 5b3a42e..b2c31ed 100644 --- a/examples/Landau_damping.py +++ b/examples/Landau_damping.py @@ -1,7 +1,7 @@ ## Landau_damping.py # Example of electric field damping in a plasma from jaxincell import plot -from jaxincell import simulation +from jaxincell import simulation, diagnostics import jax.numpy as jnp from jax import block_until_ready @@ -30,6 +30,9 @@ output = block_until_ready(simulation(input_parameters, **solver_parameters)) +# Post-process: segregate ions/electrons, compute energies, compute FFT +diagnostics(output) + print(f"Dominant FFT frequency (f): {output['dominant_frequency']} Hz") print(f"Plasma frequency (w_p): {output['plasma_frequency']} Hz") print(f"Error: {jnp.abs(output['dominant_frequency'] - output['plasma_frequency']) / output['plasma_frequency'] * 100:.2f}%") diff --git a/examples/Langmuir_wave.py b/examples/Langmuir_wave.py index ae441f6..6cbcbd8 100644 --- a/examples/Langmuir_wave.py +++ b/examples/Langmuir_wave.py @@ -1,7 +1,7 @@ ## Langmuir_wave.py # Example of plasma oscillations of electrons from jaxincell import plot -from jaxincell import simulation +from jaxincell import simulation, diagnostics import jax.numpy as jnp from jax import block_until_ready @@ -26,6 +26,9 @@ output = block_until_ready(simulation(input_parameters, **solver_parameters)) +# Post-process: segregate ions/electrons, compute energies, compute FFT +diagnostics(output) + print(f"Dominant FFT frequency (f): {output['dominant_frequency']} Hz") print(f"Plasma frequency (w_p): {output['plasma_frequency']} Hz") print(f"Error: {jnp.abs(output['dominant_frequency'] - output['plasma_frequency']) / output['plasma_frequency'] * 100:.2f}%") diff --git a/examples/Weibel_instability.py b/examples/Weibel_instability.py index f2ba71b..b29ab59 100644 --- a/examples/Weibel_instability.py +++ b/examples/Weibel_instability.py @@ -1,7 +1,7 @@ ## Weibel_instability.py # Example of plasma oscillations of electrons from jaxincell import plot -from jaxincell import simulation +from jaxincell import simulation, diagnostics from jax import block_until_ready input_parameters = { @@ -34,4 +34,7 @@ output = block_until_ready(simulation(input_parameters, **solver_parameters)) +# Post-process: segregate ions/electrons, compute energies, compute FFT +diagnostics(output) + plot(output, direction="xz") # Plot the results in x and z direction diff --git a/examples/auto-differentiability.py b/examples/auto-differentiability.py index 10b6a05..74a0950 100644 --- a/examples/auto-differentiability.py +++ b/examples/auto-differentiability.py @@ -6,7 +6,7 @@ import jax.numpy as jnp import matplotlib.pyplot as plt from jax import jit, grad, lax, block_until_ready, debug -from jaxincell import simulation, load_parameters +from jaxincell import simulation, load_parameters, diagnostics # Read from input.toml input_parameters, solver_parameters = load_parameters('input.toml') @@ -15,6 +15,8 @@ def mean_electric_field(electron_drift_speed): input_parameters["electron_drift_speed_x"] = electron_drift_speed output = block_until_ready(simulation(input_parameters, **solver_parameters)) + # Post-process: segregate ions/electrons, compute energies, compute FFT + diagnostics(output) electric_field = jnp.mean(output['electric_field_x'][:, :, 0], axis=1) mean_E = jnp.mean(lax.slice(electric_field, [solver_parameters["total_steps"]//2], [solver_parameters["total_steps"]])) return mean_E diff --git a/examples/bump-on-tail.py b/examples/bump-on-tail.py index 2e868bd..d692d6d 100644 --- a/examples/bump-on-tail.py +++ b/examples/bump-on-tail.py @@ -9,7 +9,7 @@ from datetime import datetime from jax import block_until_ready -from jaxincell import plot, simulation, load_parameters +from jaxincell import plot, simulation, load_parameters, diagnostics input_parameters, solver_parameters = load_parameters('bump-on-tail.toml') @@ -18,6 +18,9 @@ output = block_until_ready(simulation(input_parameters, **solver_parameters)) print("Simulation done, elapsed:", datetime.now()-started) +# Post-process: segregate ions/electrons, compute energies, compute FFT +diagnostics(output) + # Plot the results plot(output) diff --git a/examples/optimize_two_stream_saturation.py b/examples/optimize_two_stream_saturation.py index 3451fca..3e36b56 100644 --- a/examples/optimize_two_stream_saturation.py +++ b/examples/optimize_two_stream_saturation.py @@ -4,7 +4,7 @@ from jax import jit, grad import matplotlib.pyplot as plt from scipy.optimize import least_squares -from jaxincell import simulation, epsilon_0, load_parameters +from jaxincell import simulation, epsilon_0, load_parameters, diagnostics # Read from input.toml input_parameters, solver_parameters = load_parameters('input.toml') @@ -24,6 +24,8 @@ def objective_function(Ti): Ti = jnp.atleast_1d(Ti)[0] params["ion_temperature_over_electron_temperature_x"] = Ti output = simulation(params, **solver_parameters) + # Post-process: segregate ions/electrons, compute energies, compute FFT + diagnostics(output) abs_E_squared = jnp.sum(output['electric_field']**2, axis=-1) integral_E_squared = jnp.trapezoid(abs_E_squared, dx=output['dx'], axis=-1) energy = (epsilon_0/2) * integral_E_squared @@ -33,6 +35,8 @@ def objective_function(Ti): print(f'Perform a first run to see one objective function') input_parameters["ion_temperature_over_electron_temperature_x"] = x0_optimization output = simulation(input_parameters, **solver_parameters) +# Post-process: segregate ions/electrons, compute energies, compute FFT +diagnostics(output) objective = objective_function(x0_optimization) plt.figure(figsize=(8,6)) plt.plot(output['time_array']*output['plasma_frequency'],output['electric_field_energy'], label='Electric Field Energy') diff --git a/examples/scaling_energy_time.py b/examples/scaling_energy_time.py index edf6809..28368e5 100644 --- a/examples/scaling_energy_time.py +++ b/examples/scaling_energy_time.py @@ -1,7 +1,7 @@ ## scaling_time.py import time from jaxincell import simulation -from jaxincell import diagnostics +from jaxincell import diagnostics, diagnostics from jax import block_until_ready import matplotlib.pyplot as plt import jax.numpy as jnp @@ -46,7 +46,7 @@ # Function to compute the maximum relative energy error def max_relative_energy_error(output): - diagnostics(output) + #diagnostics(output) # moved to outer caller level to be more consistent with other jaxincell example problems relative_energy_error = jnp.abs((output["total_energy"] - output["total_energy"][0]) / output["total_energy"][0]) return jnp.max(relative_energy_error) @@ -78,6 +78,8 @@ def measure_time_and_error(parameter_list, param_name): # solver_parameters['number_grid_points'] = old_grid_points elapsed_time = time.time() - start times.append(elapsed_time) + # Post-process: segregate ions/electrons, compute energies, compute FFT + diagnostics(output) max_relative_errors.append(max_relative_energy_error(output)) print(f"{param_name.capitalize()}: {param}, Time: {elapsed_time}s") return times, max_relative_errors diff --git a/examples/two-stream_instability.py b/examples/two-stream_instability.py index 354c651..1c9fa52 100644 --- a/examples/two-stream_instability.py +++ b/examples/two-stream_instability.py @@ -1,7 +1,7 @@ # Example script to run the simulation and plot the results import time from jax import block_until_ready -from jaxincell import plot, simulation, load_parameters +from jaxincell import plot, simulation, load_parameters, diagnostics import numpy as np import pickle import json @@ -18,6 +18,9 @@ output = block_until_ready(simulation(input_parameters, **solver_parameters)) print(f"Run #{i+1}: Wall clock time: {time.time()-start}s") +# Post-process: segregate ions/electrons, compute energies, compute FFT +diagnostics(output) + # Plot the results plot(output) diff --git a/jaxincell/_diagnostics.py b/jaxincell/_diagnostics.py index ae550d1..2349e5b 100644 --- a/jaxincell/_diagnostics.py +++ b/jaxincell/_diagnostics.py @@ -6,6 +6,25 @@ __all__ = ['diagnostics'] def diagnostics(output): + + isel = (output["charges"] >= 0)[:,0] # cannot use masks in jitted functions + esel = (output["charges"] < 0)[:,0] + segregated = { + "position_electrons": output["positions"] [:, esel, :], + "velocity_electrons": output["velocities"][:, esel, :], + "mass_electrons": output["masses"] [ esel], + "charge_electrons": output["charges"] [ esel], + "position_ions": output["positions"] [:, isel, :], + "velocity_ions": output["velocities"][:, isel, :], + "mass_ions": output["masses"] [ isel], + "charge_ions": output["charges"] [ isel], + } + output.update(**segregated) + del output["positions"] + del output["velocities"] + del output["masses"] + del output["charges"] + E_field_over_time = output['electric_field'] grid = output['grid'] dt = output['dt'] diff --git a/jaxincell/_simulation.py b/jaxincell/_simulation.py index 99c7938..9c0fb3d 100644 --- a/jaxincell/_simulation.py +++ b/jaxincell/_simulation.py @@ -586,14 +586,20 @@ def simulation_step(carry, step_index): # **Output results** temporary_output = { - "position_electrons": positions_over_time[ :, :number_pseudoelectrons, :], - "velocity_electrons": velocities_over_time[:, :number_pseudoelectrons, :], - "mass_electrons": parameters["masses"][ :number_pseudoelectrons], - "charge_electrons": parameters["charges"][ :number_pseudoelectrons], - "position_ions": positions_over_time[ :, number_pseudoelectrons:, :], - "velocity_ions": velocities_over_time[:, number_pseudoelectrons:, :], - "mass_ions": parameters["masses"][ number_pseudoelectrons:], - "charge_ions": parameters["charges"][ number_pseudoelectrons:], + ## segregate ions/electrons in non-jitted method outside simulation(...) + ## so we can make use of dynamically constructed arrays + #"position_electrons": positions_over_time[ :, :number_pseudoelectrons, :], + #"velocity_electrons": velocities_over_time[:, :number_pseudoelectrons, :], + #"mass_electrons": parameters["masses"][ :number_pseudoelectrons], + #"charge_electrons": parameters["charges"][ :number_pseudoelectrons], + #"position_ions": positions_over_time[ :, number_pseudoelectrons:, :], + #"velocity_ions": velocities_over_time[:, number_pseudoelectrons:, :], + #"mass_ions": parameters["masses"][ number_pseudoelectrons:], + #"charge_ions": parameters["charges"][ number_pseudoelectrons:], + "positions": positions_over_time, + "velocities": velocities_over_time, + "masses": parameters["masses"], + "charges": parameters["charges"], "electric_field": electric_field_over_time, "magnetic_field": magnetic_field_over_time, "current_density": current_density_over_time, @@ -606,6 +612,6 @@ def simulation_step(carry, step_index): output = {**temporary_output, **parameters} - diagnostics(output) + #diagnostics(output) return output From f3bb6a4de46168afef98485dad4a0d90be4386f7 Mon Sep 17 00:00:00 2001 From: Aaron Tran Date: Fri, 25 Jul 2025 11:58:19 -0500 Subject: [PATCH 05/19] symmetrize interactive plot colorbars --- jaxincell/_plot.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jaxincell/_plot.py b/jaxincell/_plot.py index bf2bea2..2ed922e 100644 --- a/jaxincell/_plot.py +++ b/jaxincell/_plot.py @@ -100,8 +100,11 @@ def add_field_components(field, unit, label_prefix): fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 2.5 * nrows), squeeze=False) def plot_field(ax, field_data, title, xlabel, ylabel, cbar_label): + print('plotting', title, 'field_data.shape', field_data.shape) + vbnd = np.max(np.abs(field_data[np.isfinite(field_data)])) # enforce symmetric colormap im = ax.imshow(field_data, aspect="auto", cmap="RdBu", origin="lower", - extent=[grid[0], grid[-1], time[0], time[-1]]) + extent=[grid[0], grid[-1], time[0], time[-1]], + vmin=-vbnd, vmax=vbnd) ax.set(title=title, xlabel=xlabel, ylabel=ylabel) fig.colorbar(im, ax=ax, label=cbar_label) return im From 9fb878f95099cb9fe3223dca926961987ba55c52 Mon Sep 17 00:00:00 2001 From: Aaron Tran Date: Fri, 25 Jul 2025 12:30:55 -0500 Subject: [PATCH 06/19] Handle non-TOML params with no additional species --- jaxincell/_simulation.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/jaxincell/_simulation.py b/jaxincell/_simulation.py index 9c0fb3d..7f81e60 100644 --- a/jaxincell/_simulation.py +++ b/jaxincell/_simulation.py @@ -52,7 +52,6 @@ def load_parameters(input_file): solver_parameters['number_pseudoparticles_species'] = tuple(solver_parameters['number_pseudoparticles_species']) except: solver_parameters['number_pseudoparticles_species'] = () - assert len(solver_parameters['number_pseudoparticles_species']) == len(input_parameters['species']) return input_parameters, solver_parameters def initialize_simulation_parameters(user_parameters={}): @@ -510,6 +509,16 @@ def simulation(input_parameters={}, number_grid_points=100, number_pseudoelectro ------- output : dict """ + + # For simulation(...) parameters specified via Python, not parsed TOML file + if 'species' not in input_parameters: + input_parameters['species'] = [] + if not number_pseudoparticles_species: + number_pseudoparticles_species = () + else: + number_pseudoparticles_species = tuple(number_pseudoparticles_species) + assert len(number_pseudoparticles_species) == len(input_parameters['species']) + # **Initialize simulation parameters** parameters = initialize_particles_fields(input_parameters, number_grid_points=number_grid_points, number_pseudoelectrons=number_pseudoelectrons, From f4759ef837f83d28b4910e31ef67b1320caf24f0 Mon Sep 17 00:00:00 2001 From: Aaron Tran Date: Fri, 25 Jul 2025 13:29:28 -0500 Subject: [PATCH 07/19] Hotter bump-on-tail params Try to overcome PIC noise limit so that we can see E-field growth Four wavelengths across box very visible in e- phase mixing --- examples/bump-on-tail.toml | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/bump-on-tail.toml b/examples/bump-on-tail.toml index 1e50182..d5e3bc5 100644 --- a/examples/bump-on-tail.toml +++ b/examples/bump-on-tail.toml @@ -23,10 +23,10 @@ wavenumber_ions_z = 0 # ------------------------------ # electron velocity distribution # ------------------------------ -vth_electrons_over_c_x = 0.014142135624 # sqrt(2*T/m)/c +vth_electrons_over_c_x = 0.07071067812 # sqrt(2*T/m)/c = sqrt(2)*0.05 vth_electrons_over_c_y = 0 vth_electrons_over_c_z = 0 -electron_drift_speed_x = -1.125e5 # -0.125c * 3e-3 # drift in m/s; compensate bump drift so that j(t)=0 +electron_drift_speed_x = -2.25e6 # -0.25c * 3e-2 # drift in m/s; compensate bump drift so that j(t)=0 electron_drift_speed_y = 0 electron_drift_speed_z = 0 velocity_plus_minus_electrons_x = false # create two groups of electrons moving in opposite directions in the x direction @@ -48,7 +48,7 @@ velocity_plus_minus_ions_z = false # domain, spacetime gridding # ------------------------------ length = 1 # Simulation box size in meters (sets dimensionful normalization, but does not change dimensionless ratios?) -grid_points_per_Debye_length = 5 # dx over Debye length (ignore the conflicting variable name) +grid_points_per_Debye_length = 2.565 # dx over Debye length (ignore the conflicting variable name) timestep_over_spatialstep_times_c = 1.0 # dt * speed_of_light / dx particle_BC_left = 0 # 0: periodic, 1: reflective, 2: absorbing particle_BC_right = 0 @@ -72,7 +72,7 @@ tolerance_Picard_iterations_implicit_CN = 1e-6 # Tolerance for Picard iteration [[species]] mass_over_proton_mass = 0.0005446170199382714 # me/mp using values from jaxincell/_constants.py charge_over_elementary_charge = -1 -weight_ratio = 3e-3 # sets density ratio between bulk/beam, in combination with "number_pseudoparticles_species" +weight_ratio = 3e-2 # sets density ratio between bulk/beam, in combination with "number_pseudoparticles_species" # spatial distribution random_positions_x = true random_positions_y = false @@ -84,10 +84,10 @@ wavenumber_perturbation_x = 0 # Wavenumber of sinusoidal density perturbation ( wavenumber_perturbation_y = 0 wavenumber_perturbation_z = 0 # velocity distribution -vth_over_c_x = 0.014142135624 # sqrt(2*T/m)/c +vth_over_c_x = 0.07071067812 # sqrt(2*T/m)/c = sqrt(2)*0.05 vth_over_c_y = 0 vth_over_c_z = 0 -drift_speed_x = 3.75e7 # drift speed 0.125c converted to m/s +drift_speed_x = 7.5e7 # drift speed 0.25c converted to m/s drift_speed_y = 0 drift_speed_z = 0 # RNG control @@ -96,14 +96,14 @@ drift_speed_z = 0 # charge neutrality at t=0 for multiple ion/electron populations. # WARNING: to avoid unphysically correlated coordinates, choose the position # seed to not coincide with RNG seeds used elsewhere in the program. -seed_position_override = false +seed_position_override = true seed_position = 10 # enforce charge neutrality for the bump distribution [[species]] mass_over_proton_mass = 1e9 charge_over_elementary_charge = 1 -weight_ratio = 3e-3 +weight_ratio = 3e-2 # spatial distribution random_positions_x = true random_positions_y = false @@ -127,14 +127,14 @@ drift_speed_z = 0 # charge neutrality at t=0 for multiple ion/electron populations. # WARNING: to avoid unphysically correlated coordinates, choose the position # seed to not coincide with RNG seeds used elsewhere in the program. -seed_position_override = false +seed_position_override = true seed_position = 10 [solver_parameters] -number_grid_points = 60 # number of grid CELLS, not edges/vertices +number_grid_points = 40 # number of grid CELLS, not edges/vertices field_solver = 0 # 0: Curl_EB, 1: Gauss_1D_FFT, 2: Gauss_1D_Cartesian, 3: Poisson_1D_FFT, -number_pseudoelectrons = 1000 # Total in entire domain -total_steps = 5000 # Total number of time steps +number_pseudoelectrons = 40000 # Total in entire domain +total_steps = 2000 # Total number of time steps # number of particles for each additional species # provide one list value per [[species]] block -number_pseudoparticles_species = [1000,1000,] +number_pseudoparticles_species = [40000,40000,] From ee98631cc310ec5b95bbd7c56fbd412c65948a09 Mon Sep 17 00:00:00 2001 From: Aaron Tran Date: Fri, 25 Jul 2025 13:37:26 -0500 Subject: [PATCH 08/19] Minor bump-on-tail prob tweaks --- examples/bump-on-tail.py | 4 ++-- examples/bump-on-tail.toml | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/bump-on-tail.py b/examples/bump-on-tail.py index d692d6d..58e5867 100644 --- a/examples/bump-on-tail.py +++ b/examples/bump-on-tail.py @@ -24,8 +24,8 @@ # Plot the results plot(output) -# # Save the output to a file -# np.savez("simulation_output.npz", **output) +# Save the output to a file +np.savez("simulation_output.npz", **output) # # Load the output from the file # data = np.load("simulation_output.npz", allow_pickle=True) diff --git a/examples/bump-on-tail.toml b/examples/bump-on-tail.toml index d5e3bc5..fce7ae3 100644 --- a/examples/bump-on-tail.toml +++ b/examples/bump-on-tail.toml @@ -131,6 +131,9 @@ seed_position_override = true seed_position = 10 [solver_parameters] +time_evolution_algorithm = 0 # 0: Boris solver, 1: implicit Crank-Nicholson +max_number_of_Picard_iterations_implicit_CN = 30 +number_of_particle_substeps_implicit_CN = 2 number_grid_points = 40 # number of grid CELLS, not edges/vertices field_solver = 0 # 0: Curl_EB, 1: Gauss_1D_FFT, 2: Gauss_1D_Cartesian, 3: Poisson_1D_FFT, number_pseudoelectrons = 40000 # Total in entire domain From dc025278bda3843e2ed64159bba0f15dae34429a Mon Sep 17 00:00:00 2001 From: Aaron Tran Date: Fri, 25 Jul 2025 14:11:08 -0500 Subject: [PATCH 09/19] bump-on-tail document expected linear theory --- examples/bump-on-tail.toml | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/examples/bump-on-tail.toml b/examples/bump-on-tail.toml index fce7ae3..1d4f9f3 100644 --- a/examples/bump-on-tail.toml +++ b/examples/bump-on-tail.toml @@ -1,3 +1,20 @@ +# --------------------------------------------------- +# bump-on-tail.toml +# --------------------------------------------------- +# For the following plasma: +# nbeam/n0 = 0.03 +# vbeam = 0.25*c +# sqrt(Te/me) = 0.05 for both bulk and beam populations +# Fastest-growing mode predicted by non-relativistic electrostatic dispersion +# has parameters: +# Re(omega) = 1.0*omega_pe +# Im(omega) = 0.075*omega_pe +# k = 4.9*omega_pe/c +# where omega_pe is the electron plasma frequency, k = 2*pi/wavelength is the +# wavenumber, c is speed of light, vbeam is (non-relativistic) three-velocity, +# nbeam is beam density, n0 = background (bulk) density. +# --------------------------------------------------- + [input_parameters] # ----------------------------- # species intrinsic properties From 7a8fea2192389de6c6f4bdf91c5541b7bf732904 Mon Sep 17 00:00:00 2001 From: Aaron Tran Date: Mon, 11 Aug 2025 16:20:31 -0400 Subject: [PATCH 10/19] Fix redundant import flagged by copilot AI Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- examples/scaling_energy_time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/scaling_energy_time.py b/examples/scaling_energy_time.py index 28368e5..b5b3ca9 100644 --- a/examples/scaling_energy_time.py +++ b/examples/scaling_energy_time.py @@ -1,7 +1,7 @@ ## scaling_time.py import time from jaxincell import simulation -from jaxincell import diagnostics, diagnostics +from jaxincell import diagnostics from jax import block_until_ready import matplotlib.pyplot as plt import jax.numpy as jnp From d79fd993ff4bb8abf695d44e1a7a871c4603e304 Mon Sep 17 00:00:00 2001 From: Aaron Tran Date: Mon, 11 Aug 2025 15:27:28 -0500 Subject: [PATCH 11/19] rm debug print statement (thx copilot) --- jaxincell/_plot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jaxincell/_plot.py b/jaxincell/_plot.py index 2ed922e..5f0ac99 100644 --- a/jaxincell/_plot.py +++ b/jaxincell/_plot.py @@ -100,7 +100,6 @@ def add_field_components(field, unit, label_prefix): fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 2.5 * nrows), squeeze=False) def plot_field(ax, field_data, title, xlabel, ylabel, cbar_label): - print('plotting', title, 'field_data.shape', field_data.shape) vbnd = np.max(np.abs(field_data[np.isfinite(field_data)])) # enforce symmetric colormap im = ax.imshow(field_data, aspect="auto", cmap="RdBu", origin="lower", extent=[grid[0], grid[-1], time[0], time[-1]], From b695715c548188bc697ccd7a9938291d015da947 Mon Sep 17 00:00:00 2001 From: Aaron Tran Date: Mon, 11 Aug 2025 16:32:11 -0400 Subject: [PATCH 12/19] cleanup bare except clauses Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- jaxincell/_simulation.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/jaxincell/_simulation.py b/jaxincell/_simulation.py index 7f81e60..e6bbcb1 100644 --- a/jaxincell/_simulation.py +++ b/jaxincell/_simulation.py @@ -50,7 +50,13 @@ def load_parameters(input_file): # required by Jax try: solver_parameters['number_pseudoparticles_species'] = tuple(solver_parameters['number_pseudoparticles_species']) - except: + except KeyError: + input_parameters['species'] = [] + # Convert TOML array -> Python tuple to make hashable static argument, as + # required by Jax + try: + solver_parameters['number_pseudoparticles_species'] = tuple(solver_parameters['number_pseudoparticles_species']) + except KeyError: solver_parameters['number_pseudoparticles_species'] = () return input_parameters, solver_parameters From 6f2f8dab3fbd3e1217dbac820047107730a9b5e1 Mon Sep 17 00:00:00 2001 From: Aaron Tran Date: Mon, 29 Sep 2025 15:42:29 -0500 Subject: [PATCH 13/19] cleanup bare except clauses better github copilot edit introduced garbage that i didn't catch... --- jaxincell/_simulation.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/jaxincell/_simulation.py b/jaxincell/_simulation.py index e6bbcb1..91aca7c 100644 --- a/jaxincell/_simulation.py +++ b/jaxincell/_simulation.py @@ -44,12 +44,6 @@ def load_parameters(input_file): try: # Nest within main struct to avoid changing top-level internal API input_parameters['species'] = parameters['species'] - except: - input_parameters['species'] = [] - # Convert TOML array -> Python tuple to make hashable static argument, as - # required by Jax - try: - solver_parameters['number_pseudoparticles_species'] = tuple(solver_parameters['number_pseudoparticles_species']) except KeyError: input_parameters['species'] = [] # Convert TOML array -> Python tuple to make hashable static argument, as From de7a032a49c7324ab6b8e7f00b784d364ff6b1bd Mon Sep 17 00:00:00 2001 From: Aaron Tran Date: Mon, 29 Sep 2025 17:44:55 -0500 Subject: [PATCH 14/19] Add species ID to particle data structures --- jaxincell/_simulation.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/jaxincell/_simulation.py b/jaxincell/_simulation.py index 91aca7c..92d11ca 100644 --- a/jaxincell/_simulation.py +++ b/jaxincell/_simulation.py @@ -269,6 +269,10 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb mass_electrons * weight * jnp.ones((number_pseudoelectrons, 1)), mass_ions * weight * jnp.ones((number_pseudoelectrons, 1)) ), axis=0) + species_ids = jnp.concatenate(( + 0. * jnp.ones((number_pseudoelectrons, 1)), # "default" electrons hardcoded ID = 0 + 1. * jnp.ones((number_pseudoelectrons, 1)) # "default" ions harcoded ID = 1 + ), axis=0) # **Particle Velocities** # Electron thermal velocities and drift speeds @@ -296,15 +300,18 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb velocities = jnp.concatenate((electron_velocities, ion_velocities)) # Introduce additional particle species + # starting from id=2 (range inclusive); id=0,1 already used by "default" + # ion and electron populations for ii, species in enumerate(parameters['species']): plists = make_particles(species_parameters=species, Nprt=number_pseudoparticles_species[ii], box_size=box_size, weight=weight, seed=seed, - rng_index=ii) + species_id=ii+2) positions = jnp.concatenate((positions, plists['positions'])) velocities = jnp.concatenate((velocities, plists['velocities'])) charges = jnp.concatenate((charges, plists['charges']), axis=0) masses = jnp.concatenate((masses, plists['masses']), axis=0) + species_ids= jnp.concatenate((species_ids,plists['species_ids']), axis=0) # After done adding all species charge_to_mass_ratios = charges / masses @@ -373,6 +380,7 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb "initial_velocities": velocities, "charges": charges, "masses": masses, + "species_ids": species_ids, "charge_to_mass_ratios": charge_to_mass_ratios, "fields": fields, "grid": grid, @@ -391,7 +399,7 @@ def initialize_particles_fields(input_parameters={}, number_grid_points=50, numb return parameters -def make_particles(species_parameters, Nprt, box_size, weight, seed, rng_index): +def make_particles(species_parameters, Nprt, box_size, weight, seed, species_id): """ Generate Nprt total particles of a user-requested species with specified charge, mass, and space/velocity distribution. @@ -408,9 +416,11 @@ def make_particles(species_parameters, Nprt, box_size, weight, seed, rng_index): Top-level pseudoelectron weight seed : int Top-level random number generator seed used for entire simulation - rng_index : int - Species or particle population index in [0,1,2,3,...] - Use a unique index value for each population. + species_id : int + Species ID (index) in [2,3,4,...]; note that 0,1 are reserved for the + default electron/ion population that is created separately from the + [[species]] TOML blocks. + Use a unique ID for each population. This index is used to advance the random seed and so avoid spurious correlation between different particle positions and velocities. See https://docs.jax.dev/en/latest/random-numbers.html @@ -429,8 +439,8 @@ def make_particles(species_parameters, Nprt, box_size, weight, seed, rng_index): # This code is brittle; it depends on hard-coded offsets to the RNG seed # within initialize_particles_fields(...) - assert rng_index >= 0 - local_seed = seed+12 + rng_index*6 + assert species_id >= 2 + local_seed = seed + species_id*6 # Separate position/velocity seeds allow different ion and electron # populations to be inited with identical space positions, but # uncorrelated velocity distributions @@ -470,6 +480,7 @@ def make_particles(species_parameters, Nprt, box_size, weight, seed, rng_index): out['charges'] = charge * weight * _p['weight_ratio'] * jnp.ones((Nprt, 1)) out['masses'] = mass * weight * _p['weight_ratio'] * jnp.ones((Nprt, 1)) + out['species_ids'] = species_id * jnp.ones((Nprt, 1)) # **Particle Velocities** @@ -609,6 +620,7 @@ def simulation_step(carry, step_index): "velocities": velocities_over_time, "masses": parameters["masses"], "charges": parameters["charges"], + "species_ids": parameters["species_ids"], "electric_field": electric_field_over_time, "magnetic_field": magnetic_field_over_time, "current_density": current_density_over_time, From cc75b7d1db15e62a8248751c5de4d77924b0c60f Mon Sep 17 00:00:00 2001 From: Aaron Tran Date: Mon, 29 Sep 2025 18:17:03 -0500 Subject: [PATCH 15/19] Diagnostic KE, hist2d handle variable prtl weights --- jaxincell/_diagnostics.py | 24 ++++++++++++++++-------- jaxincell/_plot.py | 35 +++++++++++++++++++++-------------- 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/jaxincell/_diagnostics.py b/jaxincell/_diagnostics.py index 2349e5b..8048139 100644 --- a/jaxincell/_diagnostics.py +++ b/jaxincell/_diagnostics.py @@ -7,6 +7,13 @@ def diagnostics(output): + # weight arrays are redundant w.r.t. charge/mass arrays, + # but they are convenient for histogram plotting + output["weights"] = jnp.ones_like(output["charges"]) + for ii, species in enumerate(output["species"]): + ssel = (output['species_ids'] == ii+2) # hardcoded offset here is not great!! + output["weights"] = output["weights"].at[ssel].set( species["weight_ratio"] ) + isel = (output["charges"] >= 0)[:,0] # cannot use masks in jitted functions esel = (output["charges"] < 0)[:,0] segregated = { @@ -14,10 +21,12 @@ def diagnostics(output): "velocity_electrons": output["velocities"][:, esel, :], "mass_electrons": output["masses"] [ esel], "charge_electrons": output["charges"] [ esel], + "weight_electrons": output["weights"] [ esel], "position_ions": output["positions"] [:, isel, :], "velocity_ions": output["velocities"][:, isel, :], "mass_ions": output["masses"] [ isel], "charge_ions": output["charges"] [ isel], + "weight_ions": output["weights"] [ isel], } output.update(**segregated) del output["positions"] @@ -29,8 +38,6 @@ def diagnostics(output): grid = output['grid'] dt = output['dt'] total_steps = output['total_steps'] - mass_electrons = output["mass_electrons"][0] - mass_ions = output["mass_ions"][0] # array_to_do_fft_on = charge_density_over_time[:,len(grid)//2] array_to_do_fft_on = E_field_over_time[:,len(grid)//2,0] @@ -56,9 +63,10 @@ def integrate(y, dx): return 0.5 * (jnp.asarray(dx) * (y[..., 1:] + y[..., :-1]) integral_B_squared = integrate(abs_B_squared, dx=output['dx']) integral_externalB_squared = integrate(abs_externalB_squared, dx=output['dx']) - v_electrons_squared = jnp.sum(jnp.sum(output['velocity_electrons']**2, axis=-1), axis=-1) - v_ions_squared = jnp.sum(jnp.sum(output['velocity_ions']**2 , axis=-1), axis=-1) - + KE_electrons = (1/2) * jnp.expand_dims(output['mass_electrons'][...,0], 0) * jnp.sum(output['velocity_electrons']**2, axis=-1) + KE_ions = (1/2) * jnp.expand_dims(output['mass_ions'] [...,0], 0) * jnp.sum(output['velocity_ions']**2, axis=-1) + KE_electrons = jnp.sum(KE_electrons, axis=-1) + KE_ions = jnp.sum(KE_ions, axis=-1) output.update({ 'electric_field_energy_density': (epsilon_0/2) * abs_E_squared, @@ -67,9 +75,9 @@ def integrate(y, dx): return 0.5 * (jnp.asarray(dx) * (y[..., 1:] + y[..., :-1]) 'magnetic_field_energy': 1/(2*mu_0) * integral_B_squared, 'dominant_frequency': dominant_frequency, 'plasma_frequency': plasma_frequency, - 'kinetic_energy': (1/2) * mass_electrons * v_electrons_squared + (1/2) * mass_ions * v_ions_squared, - 'kinetic_energy_electrons': (1/2) * mass_electrons * v_electrons_squared, - 'kinetic_energy_ions': (1/2) * mass_ions * v_ions_squared, + 'kinetic_energy': KE_electrons + KE_ions, + 'kinetic_energy_electrons': KE_electrons, + 'kinetic_energy_ions': KE_ions, 'external_electric_field_energy_density': (epsilon_0/2) * abs_externalE_squared, 'external_electric_field_energy': (epsilon_0/2) * integral_externalE_squared, 'external_magnetic_field_energy_density': 1/(2*mu_0) * abs_externalB_squared, diff --git a/jaxincell/_plot.py b/jaxincell/_plot.py index 5f0ac99..be6b414 100644 --- a/jaxincell/_plot.py +++ b/jaxincell/_plot.py @@ -3,7 +3,7 @@ import matplotlib.pyplot as plt from jax import vmap from jax.debug import print as jprint -from ._constants import speed_of_light +from ._constants import speed_of_light, mass_electron, mass_proton from matplotlib.animation import FuncAnimation __all__ = ['plot'] @@ -69,7 +69,10 @@ def add_field_components(field, unit, label_prefix): }) # Compute phase space histograms - sqrtmemi = jnp.sqrt(output["mass_electrons"][0] / output["mass_ions"][0]) + #sqrtmemi = jnp.sqrt(output["mass_electrons"][0] / output["mass_ions"][0]) + # we cannot assume constant mass for all pseudoparticles with varying weights; hardcode + # using the "default" ion species + sqrtmemi = jnp.sqrt( mass_electron / (mass_proton * output["ion_mass_over_proton_mass"]) ) max_velocity_electrons_1 = max(1.0 * jnp.max(output["velocity_electrons"][:, :, direction_index1]), 2.5 * jnp.abs(vth_e_1) + jnp.abs(output[f"electron_drift_speed_{direction1}"])) @@ -79,15 +82,17 @@ def add_field_components(field, unit, label_prefix): max_velocity_ions_1 = float(jnp.asarray(max_velocity_ions_1)) bins_velocity = max(min(len(grid), 111), 71) - electron_phase_histograms = vmap(lambda pos, vel: jnp.histogram2d( - pos, vel, bins=[len(grid), bins_velocity], + electron_phase_histograms = vmap(lambda pos, vel, wt: jnp.histogram2d( + pos, vel, weights=wt, bins=[len(grid), bins_velocity], range=[[-box_size_x / 2, box_size_x / 2], [-max_velocity_electrons_1, max_velocity_electrons_1]])[0] - )(output["position_electrons"][:, :, direction_index1], output["velocity_electrons"][:, :, direction_index1]) + )(output["position_electrons"][:, :, direction_index1], output["velocity_electrons"][:, :, direction_index1], + jnp.tile(jnp.expand_dims(output["weight_electrons"][...,0],0), (output["position_electrons"].shape[0], 1))) - ion_phase_histograms = vmap(lambda pos, vel: jnp.histogram2d( - pos, vel, bins=[len(grid), bins_velocity], + ion_phase_histograms = vmap(lambda pos, vel, wt: jnp.histogram2d( + pos, vel, weights=wt, bins=[len(grid), bins_velocity], range=[[-box_size_x / 2, box_size_x / 2], [-max_velocity_ions_1, max_velocity_ions_1]])[0] - )(output["position_ions"][:, :, direction_index1], output["velocity_ions"][:, :, direction_index1]) + )(output["position_ions"][:, :, direction_index1], output["velocity_ions"][:, :, direction_index1], + jnp.tile(jnp.expand_dims(output["weight_ions"][...,0],0), (output["position_ions"].shape[0], 1))) # Grid layout ncols = 3 @@ -181,15 +186,17 @@ def plot_field(ax, field_data, title, xlabel, ylabel, cbar_label): max_velocity_electrons_12 = max(max_velocity_electrons_1, max_velocity_electrons_2) max_velocity_ions_12 = max(max_velocity_ions_1, max_velocity_ions_2) - electron_phase_histograms2 = vmap(lambda pos, vel: jnp.histogram2d( - pos, vel, bins=[len(grid), bins_velocity], + electron_phase_histograms2 = vmap(lambda pos, vel, wt: jnp.histogram2d( + pos, vel, weights=wt, bins=[len(grid), bins_velocity], range=[[-max_velocity_electrons_12, max_velocity_electrons_12], [-max_velocity_electrons_12, max_velocity_electrons_12]])[0] - )(output["velocity_electrons"][:, :, direction_index1], output["velocity_electrons"][:, :, direction_index2]) + )(output["velocity_electrons"][:, :, direction_index1], output["velocity_electrons"][:, :, direction_index2], + jnp.expand_dims(output["weight_electrons"][...,0],0)) - ion_phase_histograms2 = vmap(lambda pos, vel: jnp.histogram2d( - pos, vel, bins=[len(grid), bins_velocity], + ion_phase_histograms2 = vmap(lambda pos, vel, wt: jnp.histogram2d( + pos, vel, weights=wt, bins=[len(grid), bins_velocity], range=[[-max_velocity_ions_12, max_velocity_ions_12], [-max_velocity_ions_12, max_velocity_ions_12]])[0] - )(output["velocity_ions"][:, :, direction_index1], output["velocity_ions"][:, :, direction_index2]) + )(output["velocity_ions"][:, :, direction_index1], output["velocity_ions"][:, :, direction_index2], + jnp.expand_dims(output["weight_ions"][...,0],0)) electron_plot2 = electron_ax2.imshow( jnp.zeros((len(grid), bins_velocity)), aspect="auto", origin="lower", cmap="twilight", From 5f01ca8d148a013d85a8aa1125d900149c038d12 Mon Sep 17 00:00:00 2001 From: Aaron Tran Date: Tue, 30 Sep 2025 03:10:29 -0500 Subject: [PATCH 16/19] Plot individual species IDs histograms --- examples/bump-on-tail.py | 1 + jaxincell/_diagnostics.py | 2 ++ jaxincell/_plot.py | 56 +++++++++++++++++++++++++-------------- 3 files changed, 39 insertions(+), 20 deletions(-) diff --git a/examples/bump-on-tail.py b/examples/bump-on-tail.py index 58e5867..f315495 100644 --- a/examples/bump-on-tail.py +++ b/examples/bump-on-tail.py @@ -23,6 +23,7 @@ # Plot the results plot(output) +plot(output, species_id=2) # Save the output to a file np.savez("simulation_output.npz", **output) diff --git a/jaxincell/_diagnostics.py b/jaxincell/_diagnostics.py index 8048139..26c3ce7 100644 --- a/jaxincell/_diagnostics.py +++ b/jaxincell/_diagnostics.py @@ -22,11 +22,13 @@ def diagnostics(output): "mass_electrons": output["masses"] [ esel], "charge_electrons": output["charges"] [ esel], "weight_electrons": output["weights"] [ esel], + "species_id_electrons": output["species_ids"][esel], "position_ions": output["positions"] [:, isel, :], "velocity_ions": output["velocities"][:, isel, :], "mass_ions": output["masses"] [ isel], "charge_ions": output["charges"] [ isel], "weight_ions": output["weights"] [ isel], + "species_id_ions": output["species_ids"][ isel], } output.update(**segregated) del output["positions"] diff --git a/jaxincell/_plot.py b/jaxincell/_plot.py index be6b414..5c378f3 100644 --- a/jaxincell/_plot.py +++ b/jaxincell/_plot.py @@ -5,10 +5,11 @@ from jax.debug import print as jprint from ._constants import speed_of_light, mass_electron, mass_proton from matplotlib.animation import FuncAnimation +import matplotlib as mpl __all__ = ['plot'] -def plot(output, direction="x", threshold=1e-12): +def plot(output, direction="x", threshold=1e-12, species_id=None): def is_nonzero(field): return jnp.max(jnp.abs(field)) > threshold @@ -39,6 +40,14 @@ def is_nonzero(field): raise ValueError("direction must be one or two of 'x', 'y', or 'z'") # direction_index = {"x": 0, "y": 1, "z": 2}[direction] + if species_id is None: # show all particles + # is_finite(...) is just a lazy way to get boolean array of all True + ssele = jnp.isfinite(output['species_id_electrons'])[:,0] + sseli = jnp.isfinite(output['species_id_ions'])[:,0] + else: # only show one species + ssele = (output['species_id_electrons'] == species_id)[:,0] + sseli = (output['species_id_ions'] == species_id)[:,0] + # Determine which vector fields have nonzero components def add_field_components(field, unit, label_prefix): components = [] @@ -73,26 +82,29 @@ def add_field_components(field, unit, label_prefix): # we cannot assume constant mass for all pseudoparticles with varying weights; hardcode # using the "default" ion species sqrtmemi = jnp.sqrt( mass_electron / (mass_proton * output["ion_mass_over_proton_mass"]) ) - max_velocity_electrons_1 = max(1.0 * jnp.max(output["velocity_electrons"][:, :, direction_index1]), - 2.5 * jnp.abs(vth_e_1) - + jnp.abs(output[f"electron_drift_speed_{direction1}"])) - max_velocity_ions_1 = max(1.0 * jnp.max(output["velocity_ions"][:, :, direction_index1]), - sqrtmemi * 0.3 * jnp.abs(vth_i_1) * jnp.sqrt(output[f"ion_temperature_over_electron_temperature_{direction1}"]) - + jnp.abs(output[f"ion_drift_speed_{direction1}"])) + max_velocity_electrons_1 = 2.5 * jnp.abs(vth_e_1) + jnp.abs(output[f"electron_drift_speed_{direction1}"]) + if np.any(ssele): + max_velocity_electrons_1 = max(1.0 * jnp.max(output["velocity_electrons"][:, ssele, direction_index1]), + max_velocity_electrons_1) + max_velocity_ions_1 = (sqrtmemi * 0.3 * jnp.abs(vth_i_1) * jnp.sqrt(output[f"ion_temperature_over_electron_temperature_{direction1}"]) + + jnp.abs(output[f"ion_drift_speed_{direction1}"]) ) + if np.any(sseli): + max_velocity_ions_1 = max(1.0 * jnp.max(output["velocity_ions"][:, sseli, direction_index1]), + max_velocity_ions_1) max_velocity_ions_1 = float(jnp.asarray(max_velocity_ions_1)) bins_velocity = max(min(len(grid), 111), 71) electron_phase_histograms = vmap(lambda pos, vel, wt: jnp.histogram2d( pos, vel, weights=wt, bins=[len(grid), bins_velocity], range=[[-box_size_x / 2, box_size_x / 2], [-max_velocity_electrons_1, max_velocity_electrons_1]])[0] - )(output["position_electrons"][:, :, direction_index1], output["velocity_electrons"][:, :, direction_index1], - jnp.tile(jnp.expand_dims(output["weight_electrons"][...,0],0), (output["position_electrons"].shape[0], 1))) + )(output["position_electrons"][:, ssele, direction_index1], output["velocity_electrons"][:, ssele, direction_index1], + jnp.tile(jnp.expand_dims(output["weight_electrons"][ssele,0],0), (output["position_electrons"].shape[0], 1))) ion_phase_histograms = vmap(lambda pos, vel, wt: jnp.histogram2d( pos, vel, weights=wt, bins=[len(grid), bins_velocity], range=[[-box_size_x / 2, box_size_x / 2], [-max_velocity_ions_1, max_velocity_ions_1]])[0] - )(output["position_ions"][:, :, direction_index1], output["velocity_ions"][:, :, direction_index1], - jnp.tile(jnp.expand_dims(output["weight_ions"][...,0],0), (output["position_ions"].shape[0], 1))) + )(output["position_ions"][:, sseli, direction_index1], output["velocity_ions"][:, sseli, direction_index1], + jnp.tile(jnp.expand_dims(output["weight_ions"][sseli,0],0), (output["position_ions"].shape[0], 1))) # Grid layout ncols = 3 @@ -178,25 +190,29 @@ def plot_field(ax, field_data, title, xlabel, ylabel, cbar_label): positions_ax = axes_flat[used_axes + 3] # Phase space in second direction - max_velocity_electrons_2 = max(1.0 * jnp.max(output["velocity_electrons"][:, :, direction_index2]), - 2.5 * jnp.abs(vth_e_2) + jnp.abs(output[f"electron_drift_speed_{direction2}"])) - max_velocity_ions_2 = max(1.0 * jnp.max(output["velocity_ions"][:, :, direction_index2]), - sqrtmemi * 0.3 * jnp.abs(vth_i_2) * jnp.sqrt(output[f"ion_temperature_over_electron_temperature_{direction2}"]) + - jnp.abs(output[f"ion_drift_speed_{direction2}"])) + max_velocity_electrons_2 = 2.5 * jnp.abs(vth_e_2) + jnp.abs(output[f"electron_drift_speed_{direction2}"]) + if np.any(ssele): + max_velocity_electrons_2 = max(1.0 * jnp.max(output["velocity_electrons"][:, ssele, direction_index2]), + max_velocity_electrons_2) + max_velocity_ions_2 = (sqrtmemi * 0.3 * jnp.abs(vth_i_2) * jnp.sqrt(output[f"ion_temperature_over_electron_temperature_{direction2}"]) + + jnp.abs(output[f"ion_drift_speed_{direction2}"]) ) + if np.any(sseli): + max_velocity_ions_2 = max(1.0 * jnp.max(output["velocity_ions"][:, sseli, direction_index2]), + max_velocity_ions_2) max_velocity_electrons_12 = max(max_velocity_electrons_1, max_velocity_electrons_2) max_velocity_ions_12 = max(max_velocity_ions_1, max_velocity_ions_2) electron_phase_histograms2 = vmap(lambda pos, vel, wt: jnp.histogram2d( pos, vel, weights=wt, bins=[len(grid), bins_velocity], range=[[-max_velocity_electrons_12, max_velocity_electrons_12], [-max_velocity_electrons_12, max_velocity_electrons_12]])[0] - )(output["velocity_electrons"][:, :, direction_index1], output["velocity_electrons"][:, :, direction_index2], - jnp.expand_dims(output["weight_electrons"][...,0],0)) + )(output["velocity_electrons"][:, ssele, direction_index1], output["velocity_electrons"][:, ssele, direction_index2], + jnp.tile(jnp.expand_dims(output["weight_electrons"][ssele,0],0), (output["position_electrons"].shape[0], 1))) ion_phase_histograms2 = vmap(lambda pos, vel, wt: jnp.histogram2d( pos, vel, weights=wt, bins=[len(grid), bins_velocity], range=[[-max_velocity_ions_12, max_velocity_ions_12], [-max_velocity_ions_12, max_velocity_ions_12]])[0] - )(output["velocity_ions"][:, :, direction_index1], output["velocity_ions"][:, :, direction_index2], - jnp.expand_dims(output["weight_ions"][...,0],0)) + )(output["velocity_ions"][:, sseli, direction_index1], output["velocity_ions"][:, sseli, direction_index2], + jnp.tile(jnp.expand_dims(output["weight_ions"][sseli,0],0), (output["position_ions"].shape[0], 1))) electron_plot2 = electron_ax2.imshow( jnp.zeros((len(grid), bins_velocity)), aspect="auto", origin="lower", cmap="twilight", From 26373c8a9cca3eb5f0396753ead6f81c80d267c7 Mon Sep 17 00:00:00 2001 From: Rogerio Jorge Date: Tue, 30 Sep 2025 10:56:47 -0500 Subject: [PATCH 17/19] Updated _diagnostics.py to run inside _simulation.py --- examples/Landau_damping.py | 3 - examples/Langmuir_wave.py | 3 - examples/two-stream_instability.py | 10 +- jaxincell/_diagnostics.py | 194 +++++++++++++++++------------ jaxincell/_simulation.py | 2 +- 5 files changed, 118 insertions(+), 94 deletions(-) diff --git a/examples/Landau_damping.py b/examples/Landau_damping.py index b2c31ed..4e9f3b8 100644 --- a/examples/Landau_damping.py +++ b/examples/Landau_damping.py @@ -30,9 +30,6 @@ output = block_until_ready(simulation(input_parameters, **solver_parameters)) -# Post-process: segregate ions/electrons, compute energies, compute FFT -diagnostics(output) - print(f"Dominant FFT frequency (f): {output['dominant_frequency']} Hz") print(f"Plasma frequency (w_p): {output['plasma_frequency']} Hz") print(f"Error: {jnp.abs(output['dominant_frequency'] - output['plasma_frequency']) / output['plasma_frequency'] * 100:.2f}%") diff --git a/examples/Langmuir_wave.py b/examples/Langmuir_wave.py index 6cbcbd8..7095557 100644 --- a/examples/Langmuir_wave.py +++ b/examples/Langmuir_wave.py @@ -26,9 +26,6 @@ output = block_until_ready(simulation(input_parameters, **solver_parameters)) -# Post-process: segregate ions/electrons, compute energies, compute FFT -diagnostics(output) - print(f"Dominant FFT frequency (f): {output['dominant_frequency']} Hz") print(f"Plasma frequency (w_p): {output['plasma_frequency']} Hz") print(f"Error: {jnp.abs(output['dominant_frequency'] - output['plasma_frequency']) / output['plasma_frequency'] * 100:.2f}%") diff --git a/examples/two-stream_instability.py b/examples/two-stream_instability.py index 1c9fa52..43301d0 100644 --- a/examples/two-stream_instability.py +++ b/examples/two-stream_instability.py @@ -1,10 +1,7 @@ # Example script to run the simulation and plot the results import time from jax import block_until_ready -from jaxincell import plot, simulation, load_parameters, diagnostics -import numpy as np -import pickle -import json +from jaxincell import plot, simulation, load_parameters # Read from input.toml input_parameters, solver_parameters = load_parameters('input.toml') @@ -18,15 +15,14 @@ output = block_until_ready(simulation(input_parameters, **solver_parameters)) print(f"Run #{i+1}: Wall clock time: {time.time()-start}s") -# Post-process: segregate ions/electrons, compute energies, compute FFT -diagnostics(output) - # Plot the results plot(output) # # Save the output to a file +# import numpy as np # np.savez("simulation_output.npz", **output) # # Load the output from the file +# import numpy as np # data = np.load("simulation_output.npz", allow_pickle=True) # output2 = dict(data) diff --git a/jaxincell/_diagnostics.py b/jaxincell/_diagnostics.py index 26c3ce7..5e8bce9 100644 --- a/jaxincell/_diagnostics.py +++ b/jaxincell/_diagnostics.py @@ -5,89 +5,123 @@ __all__ = ['diagnostics'] -def diagnostics(output): - - # weight arrays are redundant w.r.t. charge/mass arrays, - # but they are convenient for histogram plotting - output["weights"] = jnp.ones_like(output["charges"]) - for ii, species in enumerate(output["species"]): - ssel = (output['species_ids'] == ii+2) # hardcoded offset here is not great!! - output["weights"] = output["weights"].at[ssel].set( species["weight_ratio"] ) - - isel = (output["charges"] >= 0)[:,0] # cannot use masks in jitted functions - esel = (output["charges"] < 0)[:,0] - segregated = { - "position_electrons": output["positions"] [:, esel, :], - "velocity_electrons": output["velocities"][:, esel, :], - "mass_electrons": output["masses"] [ esel], - "charge_electrons": output["charges"] [ esel], - "weight_electrons": output["weights"] [ esel], - "species_id_electrons": output["species_ids"][esel], - "position_ions": output["positions"] [:, isel, :], - "velocity_ions": output["velocities"][:, isel, :], - "mass_ions": output["masses"] [ isel], - "charge_ions": output["charges"] [ isel], - "weight_ions": output["weights"] [ isel], - "species_id_ions": output["species_ids"][ isel], - } - output.update(**segregated) - del output["positions"] - del output["velocities"] - del output["masses"] - del output["charges"] - - E_field_over_time = output['electric_field'] - grid = output['grid'] - dt = output['dt'] - total_steps = output['total_steps'] - - # array_to_do_fft_on = charge_density_over_time[:,len(grid)//2] - array_to_do_fft_on = E_field_over_time[:,len(grid)//2,0] - array_to_do_fft_on = (array_to_do_fft_on-jnp.mean(array_to_do_fft_on))/jnp.max(array_to_do_fft_on) - plasma_frequency = output['plasma_frequency'] - - fft_values = lax.slice(fft(array_to_do_fft_on), (0,), (total_steps//2,)) - freqs = fftfreq(total_steps, d=dt)[:total_steps//2]*2*jnp.pi # d=dt specifies the time step - magnitude = jnp.abs(fft_values) - peak_index = jnp.argmax(magnitude) - dominant_frequency = jnp.abs(freqs[peak_index]) - - def integrate(y, dx): return 0.5 * (jnp.asarray(dx) * (y[..., 1:] + y[..., :-1])).sum(-1) - # def integrate(y, dx): return jnp.sum(y, axis=-1) * dx - - abs_E_squared = jnp.sum(output['electric_field']**2, axis=-1) - abs_externalE_squared = jnp.sum(output['external_electric_field']**2, axis=-1) - integral_E_squared = integrate(abs_E_squared, dx=output['dx']) - integral_externalE_squared = integrate(abs_externalE_squared, dx=output['dx']) - - abs_B_squared = jnp.sum(output['magnetic_field']**2, axis=-1) - abs_externalB_squared = jnp.sum(output['external_magnetic_field']**2, axis=-1) - integral_B_squared = integrate(abs_B_squared, dx=output['dx']) - integral_externalB_squared = integrate(abs_externalB_squared, dx=output['dx']) - - KE_electrons = (1/2) * jnp.expand_dims(output['mass_electrons'][...,0], 0) * jnp.sum(output['velocity_electrons']**2, axis=-1) - KE_ions = (1/2) * jnp.expand_dims(output['mass_ions'] [...,0], 0) * jnp.sum(output['velocity_ions']**2, axis=-1) - KE_electrons = jnp.sum(KE_electrons, axis=-1) - KE_ions = jnp.sum(KE_ions, axis=-1) +def diagnostics(output, *, jittable: bool = False): + """ + If jittable=True: avoid any boolean-mask indexing that changes shapes. + Compute diagnostics via mask-weighted reductions and keep arrays intact. + If jittable=False: (old behavior) you may split arrays into electrons/ions. + """ + # ---------- locals ---------- + E_t = output['electric_field'] # (T, Ng, 3) + B_t = output['magnetic_field'] # (T, Ng, 3) + grid = output['grid'] # (Ng,) + dt = output['dt'] # scalar + T = output['total_steps'] # int + dx = output['dx'] # scalar + N = output['number_pseudoelectrons'] # electrons per default pop + # ---------- FFT-based dominant frequency ---------- + arr = E_t[:, len(grid)//2, 0] + arr = (arr - jnp.mean(arr)) / jnp.max(arr) + fft_vals = lax.slice(fft(arr), (0,), (T//2,)) + freqs = fftfreq(T, d=dt)[:T//2] * 2*jnp.pi + mag = jnp.abs(fft_vals) + idx = jnp.argmax(mag) + dom_omega = jnp.abs(freqs[idx]) + + # ---------- trapz integrate ---------- + def integrate_trap(y, dx): + # y: (..., Ng) + return 0.5 * (jnp.asarray(dx) * (y[..., 1:] + y[..., :-1])).sum(-1) + + # ---------- field energies ---------- + absE2 = jnp.sum(E_t**2, axis=-1) # (T, Ng) + absB2 = jnp.sum(B_t**2, axis=-1) # (T, Ng) + + # external fields are static in time; make (T, Ng) for energy time series + absE2_ext = jnp.sum(output['external_electric_field']**2, axis=-1) # (Ng,) + absB2_ext = jnp.sum(output['external_magnetic_field']**2, axis=-1) # (Ng,) + absE2_ext_T = jnp.broadcast_to(absE2_ext, (absE2.shape[0], absE2.shape[1])) + absB2_ext_T = jnp.broadcast_to(absB2_ext, (absB2.shape[0], absB2.shape[1])) + + intE2 = integrate_trap(absE2, dx) + intB2 = integrate_trap(absB2, dx) + intE2_ext = integrate_trap(absE2_ext_T, dx) + intB2_ext = integrate_trap(absB2_ext_T, dx) + + # ---------- kinetic energies via mask-weighted reductions ---------- + # unified arrays + pos = output['positions'] # (T, Ntot, 3) + vel = output['velocities'] # (T, Ntot, 3) + m = output['masses'][...,0] # (Ntot,) + q = output['charges'][...,0] # (Ntot,) + + # masks (elementwise use only; no slicing by mask) + is_e = (q < 0) # (Ntot,) + is_i = ~is_e + me = is_e.astype(m.dtype) + mi = is_i.astype(m.dtype) + + v2 = jnp.sum(vel**2, axis=-1) # (T, Ntot) + KE_particle = 0.5 * v2 * m[None,:] # (T, Ntot) + KE_e = jnp.sum(KE_particle * me[None,:], axis=-1) # (T,) + KE_i = jnp.sum(KE_particle * mi[None,:], axis=-1) # (T,) + KE = KE_e + KE_i + + # ---------- pack scalars/time series ---------- output.update({ - 'electric_field_energy_density': (epsilon_0/2) * abs_E_squared, - 'electric_field_energy': (epsilon_0/2) * integral_E_squared, - 'magnetic_field_energy_density': 1/(2*mu_0) * abs_B_squared, - 'magnetic_field_energy': 1/(2*mu_0) * integral_B_squared, - 'dominant_frequency': dominant_frequency, - 'plasma_frequency': plasma_frequency, - 'kinetic_energy': KE_electrons + KE_ions, - 'kinetic_energy_electrons': KE_electrons, - 'kinetic_energy_ions': KE_ions, - 'external_electric_field_energy_density': (epsilon_0/2) * abs_externalE_squared, - 'external_electric_field_energy': (epsilon_0/2) * integral_externalE_squared, - 'external_magnetic_field_energy_density': 1/(2*mu_0) * abs_externalB_squared, - 'external_magnetic_field_energy': 1/(2*mu_0) * integral_externalB_squared + 'electric_field_energy_density': (epsilon_0/2) * absE2, # (T, Ng) + 'electric_field_energy': (epsilon_0/2) * intE2, # (T,) + 'magnetic_field_energy_density': 1/(2*mu_0) * absB2, # (T, Ng) + 'magnetic_field_energy': 1/(2*mu_0) * intB2, # (T,) + 'external_electric_field_energy_density': (epsilon_0/2) * absE2_ext, # (Ng,) + 'external_electric_field_energy': (epsilon_0/2) * intE2_ext, # (T,) + 'external_magnetic_field_energy_density': 1/(2*mu_0) * absB2_ext, # (Ng,) + 'external_magnetic_field_energy': 1/(2*mu_0) * intB2_ext, # (T,) + 'dominant_frequency': dom_omega, + 'kinetic_energy': KE, # (T,) + 'kinetic_energy_electrons': KE_e, # (T,) + 'kinetic_energy_ions': KE_i, # (T,) }) - total_energy = (output["electric_field_energy"] + output["external_electric_field_energy"] + - output["magnetic_field_energy"] + output["external_magnetic_field_energy"] + - output["kinetic_energy"]) + # ---------- JIT-safe default species split via static slices ---------- + # Your particle ordering is: [electrons (N), ions (N), then any extra species...] + # We only split the *default* two populations for plotting. + e_slice = slice(0, N) + i_slice = slice(N, 2*N) + output.update({ + "position_electrons": pos[:, e_slice, :], + "velocity_electrons": vel[:, e_slice, :], + "mass_electrons": output["masses"][e_slice, :], + "charge_electrons": output["charges"][e_slice, :], + "species_id_electrons": output["species_ids"][e_slice, :], + + "position_ions": pos[:, i_slice, :], + "velocity_ions": vel[:, i_slice, :], + "mass_ions": output["masses"][i_slice, :], + "charge_ions": output["charges"][i_slice, :], + "species_id_ions": output["species_ids"][i_slice, :], + }) + + # weights for histograms (default pops → 1.0 is fine; keeps plot working) + # If later you want real per-species weight ratios, emit a parallel array + # from initialize and slice here similarly. + ones_N = jnp.ones((N,1), dtype=output["charges"].dtype) + output.update({ + "weight_electrons": ones_N, + "weight_ions": ones_N, + }) + + # ---------- total energy ---------- + total_energy = (output["electric_field_energy"] + + output["external_electric_field_energy"] + + output["magnetic_field_energy"] + + output["external_magnetic_field_energy"] + + output["kinetic_energy"]) output.update({'total_energy': total_energy}) + + # In jittable=True, keep unified arrays; do NOT delete anything. + if not jittable: + # (optional non-JIT path: you could delete unified arrays here) + pass \ No newline at end of file diff --git a/jaxincell/_simulation.py b/jaxincell/_simulation.py index 92d11ca..b32fc99 100644 --- a/jaxincell/_simulation.py +++ b/jaxincell/_simulation.py @@ -633,6 +633,6 @@ def simulation_step(carry, step_index): output = {**temporary_output, **parameters} - #diagnostics(output) + diagnostics(output, jittable=True) return output From 0df3dca7f9c8ad19f9c97a520332488473a70b81 Mon Sep 17 00:00:00 2001 From: Aaron Tran Date: Tue, 30 Sep 2025 13:06:45 -0500 Subject: [PATCH 18/19] Disable creating species-specific histograms Partly revert commit 5f01ca8 as the kwarg species_id no longer works with Rogerio's jittable diagnostics(...) function --- examples/bump-on-tail.py | 1 - jaxincell/_plot.py | 56 ++++++++++++++-------------------------- 2 files changed, 20 insertions(+), 37 deletions(-) diff --git a/examples/bump-on-tail.py b/examples/bump-on-tail.py index f315495..58e5867 100644 --- a/examples/bump-on-tail.py +++ b/examples/bump-on-tail.py @@ -23,7 +23,6 @@ # Plot the results plot(output) -plot(output, species_id=2) # Save the output to a file np.savez("simulation_output.npz", **output) diff --git a/jaxincell/_plot.py b/jaxincell/_plot.py index 5c378f3..ffe109a 100644 --- a/jaxincell/_plot.py +++ b/jaxincell/_plot.py @@ -5,11 +5,10 @@ from jax.debug import print as jprint from ._constants import speed_of_light, mass_electron, mass_proton from matplotlib.animation import FuncAnimation -import matplotlib as mpl __all__ = ['plot'] -def plot(output, direction="x", threshold=1e-12, species_id=None): +def plot(output, direction="x", threshold=1e-12): def is_nonzero(field): return jnp.max(jnp.abs(field)) > threshold @@ -40,14 +39,6 @@ def is_nonzero(field): raise ValueError("direction must be one or two of 'x', 'y', or 'z'") # direction_index = {"x": 0, "y": 1, "z": 2}[direction] - if species_id is None: # show all particles - # is_finite(...) is just a lazy way to get boolean array of all True - ssele = jnp.isfinite(output['species_id_electrons'])[:,0] - sseli = jnp.isfinite(output['species_id_ions'])[:,0] - else: # only show one species - ssele = (output['species_id_electrons'] == species_id)[:,0] - sseli = (output['species_id_ions'] == species_id)[:,0] - # Determine which vector fields have nonzero components def add_field_components(field, unit, label_prefix): components = [] @@ -82,29 +73,26 @@ def add_field_components(field, unit, label_prefix): # we cannot assume constant mass for all pseudoparticles with varying weights; hardcode # using the "default" ion species sqrtmemi = jnp.sqrt( mass_electron / (mass_proton * output["ion_mass_over_proton_mass"]) ) - max_velocity_electrons_1 = 2.5 * jnp.abs(vth_e_1) + jnp.abs(output[f"electron_drift_speed_{direction1}"]) - if np.any(ssele): - max_velocity_electrons_1 = max(1.0 * jnp.max(output["velocity_electrons"][:, ssele, direction_index1]), - max_velocity_electrons_1) - max_velocity_ions_1 = (sqrtmemi * 0.3 * jnp.abs(vth_i_1) * jnp.sqrt(output[f"ion_temperature_over_electron_temperature_{direction1}"]) - + jnp.abs(output[f"ion_drift_speed_{direction1}"]) ) - if np.any(sseli): - max_velocity_ions_1 = max(1.0 * jnp.max(output["velocity_ions"][:, sseli, direction_index1]), - max_velocity_ions_1) + max_velocity_electrons_1 = max(1.0 * jnp.max(output["velocity_electrons"][:, :, direction_index1]), + 2.5 * jnp.abs(vth_e_1) + + jnp.abs(output[f"electron_drift_speed_{direction1}"])) + max_velocity_ions_1 = max(1.0 * jnp.max(output["velocity_ions"][:, :, direction_index1]), + sqrtmemi * 0.3 * jnp.abs(vth_i_1) * jnp.sqrt(output[f"ion_temperature_over_electron_temperature_{direction1}"]) + + jnp.abs(output[f"ion_drift_speed_{direction1}"])) max_velocity_ions_1 = float(jnp.asarray(max_velocity_ions_1)) bins_velocity = max(min(len(grid), 111), 71) electron_phase_histograms = vmap(lambda pos, vel, wt: jnp.histogram2d( pos, vel, weights=wt, bins=[len(grid), bins_velocity], range=[[-box_size_x / 2, box_size_x / 2], [-max_velocity_electrons_1, max_velocity_electrons_1]])[0] - )(output["position_electrons"][:, ssele, direction_index1], output["velocity_electrons"][:, ssele, direction_index1], - jnp.tile(jnp.expand_dims(output["weight_electrons"][ssele,0],0), (output["position_electrons"].shape[0], 1))) + )(output["position_electrons"][:, :, direction_index1], output["velocity_electrons"][:, :, direction_index1], + jnp.tile(jnp.expand_dims(output["weight_electrons"][:,0],0), (output["position_electrons"].shape[0], 1))) ion_phase_histograms = vmap(lambda pos, vel, wt: jnp.histogram2d( pos, vel, weights=wt, bins=[len(grid), bins_velocity], range=[[-box_size_x / 2, box_size_x / 2], [-max_velocity_ions_1, max_velocity_ions_1]])[0] - )(output["position_ions"][:, sseli, direction_index1], output["velocity_ions"][:, sseli, direction_index1], - jnp.tile(jnp.expand_dims(output["weight_ions"][sseli,0],0), (output["position_ions"].shape[0], 1))) + )(output["position_ions"][:, :, direction_index1], output["velocity_ions"][:, :, direction_index1], + jnp.tile(jnp.expand_dims(output["weight_ions"][:,0],0), (output["position_ions"].shape[0], 1))) # Grid layout ncols = 3 @@ -190,29 +178,25 @@ def plot_field(ax, field_data, title, xlabel, ylabel, cbar_label): positions_ax = axes_flat[used_axes + 3] # Phase space in second direction - max_velocity_electrons_2 = 2.5 * jnp.abs(vth_e_2) + jnp.abs(output[f"electron_drift_speed_{direction2}"]) - if np.any(ssele): - max_velocity_electrons_2 = max(1.0 * jnp.max(output["velocity_electrons"][:, ssele, direction_index2]), - max_velocity_electrons_2) - max_velocity_ions_2 = (sqrtmemi * 0.3 * jnp.abs(vth_i_2) * jnp.sqrt(output[f"ion_temperature_over_electron_temperature_{direction2}"]) - + jnp.abs(output[f"ion_drift_speed_{direction2}"]) ) - if np.any(sseli): - max_velocity_ions_2 = max(1.0 * jnp.max(output["velocity_ions"][:, sseli, direction_index2]), - max_velocity_ions_2) + max_velocity_electrons_2 = max(1.0 * jnp.max(output["velocity_electrons"][:, :, direction_index2]), + 2.5 * jnp.abs(vth_e_2) + jnp.abs(output[f"electron_drift_speed_{direction2}"])) + max_velocity_ions_2 = max(1.0 * jnp.max(output["velocity_ions"][:, :, direction_index2]), + sqrtmemi * 0.3 * jnp.abs(vth_i_2) * jnp.sqrt(output[f"ion_temperature_over_electron_temperature_{direction2}"]) + + jnp.abs(output[f"ion_drift_speed_{direction2}"])) max_velocity_electrons_12 = max(max_velocity_electrons_1, max_velocity_electrons_2) max_velocity_ions_12 = max(max_velocity_ions_1, max_velocity_ions_2) electron_phase_histograms2 = vmap(lambda pos, vel, wt: jnp.histogram2d( pos, vel, weights=wt, bins=[len(grid), bins_velocity], range=[[-max_velocity_electrons_12, max_velocity_electrons_12], [-max_velocity_electrons_12, max_velocity_electrons_12]])[0] - )(output["velocity_electrons"][:, ssele, direction_index1], output["velocity_electrons"][:, ssele, direction_index2], - jnp.tile(jnp.expand_dims(output["weight_electrons"][ssele,0],0), (output["position_electrons"].shape[0], 1))) + )(output["velocity_electrons"][:, :, direction_index1], output["velocity_electrons"][:, :, direction_index2], + jnp.tile(jnp.expand_dims(output["weight_electrons"][:,0],0), (output["position_electrons"].shape[0], 1))) ion_phase_histograms2 = vmap(lambda pos, vel, wt: jnp.histogram2d( pos, vel, weights=wt, bins=[len(grid), bins_velocity], range=[[-max_velocity_ions_12, max_velocity_ions_12], [-max_velocity_ions_12, max_velocity_ions_12]])[0] - )(output["velocity_ions"][:, sseli, direction_index1], output["velocity_ions"][:, sseli, direction_index2], - jnp.tile(jnp.expand_dims(output["weight_ions"][sseli,0],0), (output["position_ions"].shape[0], 1))) + )(output["velocity_ions"][:, :, direction_index1], output["velocity_ions"][:, :, direction_index2], + jnp.tile(jnp.expand_dims(output["weight_ions"][:,0],0), (output["position_ions"].shape[0], 1))) electron_plot2 = electron_ax2.imshow( jnp.zeros((len(grid), bins_velocity)), aspect="auto", origin="lower", cmap="twilight", From a2a13ac2079a949fd6c208c96b49c8c993eb464d Mon Sep 17 00:00:00 2001 From: Aaron Tran Date: Tue, 30 Sep 2025 13:16:18 -0500 Subject: [PATCH 19/19] Fix failing test for unified prtl array output --- tests/test_diagnostics.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_diagnostics.py b/tests/test_diagnostics.py index d19edfd..d6e7dc9 100644 --- a/tests/test_diagnostics.py +++ b/tests/test_diagnostics.py @@ -9,16 +9,20 @@ def test_diagnostics(): 'external_electric_field': jnp.array([[[0.5, 0.0, 0.0], [0.0, 0.5, 0.0]], [[0.0, 0.0, 0.5], [0.5, 0.5, 0.5]]]), 'magnetic_field': jnp.array([[[0.1, 0.0, 0.0], [0.0, 0.1, 0.0]], [[0.0, 0.0, 0.1], [0.1, 0.1, 0.1]]]), 'external_magnetic_field': jnp.array([[[0.05, 0.0, 0.0], [0.0, 0.05, 0.0]], [[0.0, 0.0, 0.05], [0.05, 0.05, 0.05]]]), - 'velocity_electrons': jnp.array([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], [[0.7, 0.8, 0.9], [1.0, 1.1, 1.2]]]), - 'velocity_ions': jnp.array([[[0.2, 0.3, 0.4], [0.5, 0.6, 0.7]], [[0.8, 0.9, 1.0], [1.1, 1.2, 1.3]]]), + 'velocities': jnp.array([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.2, 0.3, 0.4], [0.5, 0.6, 0.7]], + [[0.7, 0.8, 0.9], [1.0, 1.1, 1.2], [0.8, 0.9, 1.0], [1.1, 1.2, 1.3]]]), + 'positions': jnp.array([[[0.5, 0.0, 0.0], [0.5, 0.0, 0.0], [0.5, 0.0, 0.0], [0.5, 0.0, 0.0]], + [[0.5, 0.0, 0.0], [0.5, 0.0, 0.0], [0.5, 0.0, 0.0], [0.5, 0.0, 0.0]]]), 'grid': jnp.array([0.0, 1.0]), 'dt': 0.1, 'total_steps': 2, 'plasma_frequency': 1.0, 'dx': 0.1, 'weight': 1.0, - 'mass_electrons': jnp.array([1.0]), - 'mass_ions': jnp.array([1.0]) + 'masses': jnp.array([[1.0], [1.0], [1836.], [1836.]]), + 'charges': jnp.array([[-1.0], [-1.0], [1.], [1.]]), + 'species_ids': jnp.array([[0], [0], [1], [1]]), + 'number_pseudoelectrons': 2, } diagnostics(output)