diff --git a/PyPIC3D/J.py b/PyPIC3D/J.py index b283f38..d056092 100644 --- a/PyPIC3D/J.py +++ b/PyPIC3D/J.py @@ -6,6 +6,7 @@ # import external libraries from PyPIC3D.utils import digital_filter, wrap_around, bilinear_filter +from PyPIC3D.shapes import get_first_order_weights, get_second_order_weights @partial(jit, static_argnames=("filter",)) def J_from_rhov(particles, J, constants, world, grid, filter='bilinear'): @@ -540,66 +541,4 @@ def z_active(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_ # determine which dimension is active and calculate weights accordingly - return Wx_, Wy_, Wz_ - -@jit -def get_second_order_weights(deltax, deltay, deltaz, dx, dy, dz): - """ - Calculate the second-order weights for particle current distribution. - - Args: - deltax, deltay, deltaz (float): Particle position offsets from grid points. - dx, dy, dz (float): Grid spacings in x, y, and z directions. - - Returns: - tuple: Weights for x, y, and z directions. - """ - Sx0 = (3/4) - (deltax/dx)**2 - Sy0 = (3/4) - (deltay/dy)**2 - Sz0 = (3/4) - (deltaz/dz)**2 - - Sx1 = (1/2) * ((1/2) + (deltax/dx))**2 - Sy1 = (1/2) * ((1/2) + (deltay/dy))**2 - Sz1 = (1/2) * ((1/2) + (deltaz/dz))**2 - - Sx_minus1 = (1/2) * ((1/2) - (deltax/dx))**2 - Sy_minus1 = (1/2) * ((1/2) - (deltay/dy))**2 - Sz_minus1 = (1/2) * ((1/2) - (deltaz/dz))**2 - # second order weights - - x_weights = [Sx_minus1, Sx0, Sx1] - y_weights = [Sy_minus1, Sy0, Sy1] - z_weights = [Sz_minus1, Sz0, Sz1] - - return x_weights, y_weights, z_weights - -@jit -def get_first_order_weights(deltax, deltay, deltaz, dx, dy, dz): - """ - Calculate the first-order weights for particle current distribution. - - Args: - deltax, deltay, deltaz (float): Particle position offsets from grid points. - dx, dy, dz (float): Grid spacings in x, y, and z directions. - - Returns: - tuple: Weights for x, y, and z directions. - """ - Sx0 = jnp.asarray(1 - deltax / dx) - Sy0 = jnp.asarray(1 - deltay / dy) - Sz0 = jnp.asarray(1 - deltaz / dz) - - Sx1 = jnp.asarray(deltax / dx) - Sy1 = jnp.asarray(deltay / dy) - Sz1 = jnp.asarray(deltaz / dz) - - Sx_minus1 = jnp.zeros_like(Sx0) - Sy_minus1 = jnp.zeros_like(Sy0) - Sz_minus1 = jnp.zeros_like(Sz0) - # No second-order weights for first-order weighting - - x_weights = [Sx_minus1, Sx0, Sx1] - y_weights = [Sy_minus1, Sy0, Sy1] - z_weights = [Sz_minus1, Sz0, Sz1] - - return x_weights, y_weights, z_weights \ No newline at end of file + return Wx_, Wy_, Wz_ \ No newline at end of file diff --git a/PyPIC3D/__init__.py b/PyPIC3D/__init__.py index 124f937..555aa2d 100644 --- a/PyPIC3D/__init__.py +++ b/PyPIC3D/__init__.py @@ -12,7 +12,7 @@ from . import boundaryconditions from . import initialization from . import particle -from . import plotting +from .diagnostics import plotting from . import utils from .solvers import pstd from .solvers import fdtd diff --git a/PyPIC3D/__main__.py b/PyPIC3D/__main__.py index 61fa4b5..bf35c8d 100644 --- a/PyPIC3D/__main__.py +++ b/PyPIC3D/__main__.py @@ -11,23 +11,36 @@ from jax import block_until_ready import jax.numpy as jnp from tqdm import tqdm + #from memory_profiler import profile # Importing relevant libraries -from PyPIC3D.plotting import ( - write_particles_phase_space, write_data, plot_vtk_particles, plot_field_slice_vtk, - plot_vectorfield_slice_vtk +from PyPIC3D.diagnostics.plotting import ( + write_particles_phase_space, write_data +) + +from PyPIC3D.diagnostics.openPMD import ( + write_openpmd_particles, write_openpmd_fields +) + +from PyPIC3D.diagnostics.vtk import ( + plot_field_slice_vtk, plot_vectorfield_slice_vtk, plot_vtk_particles ) from PyPIC3D.utils import ( - dump_parameters_to_toml, load_config_file, compute_energy + dump_parameters_to_toml, load_config_file, compute_energy, + setup_pmd_files ) from PyPIC3D.initialization import ( initialize_simulation ) -from PyPIC3D.rho import compute_rho, compute_mass_density, compute_velocity_field +from PyPIC3D.diagnostics.fluid_quantities import ( + compute_mass_density +) + +from PyPIC3D.rho import compute_rho # Importing functions from the PyPIC3D package @@ -57,6 +70,10 @@ def run_PyPIC3D(config_file): # Compute the energy of the system initial_energy = e_energy + b_energy + kinetic_energy + if plotting_parameters['plot_openpmd_fields']: setup_pmd_files( os.path.join(output_dir, "data"), "fields", ".h5") + if plotting_parameters['plot_openpmd_particles']: setup_pmd_files( os.path.join(output_dir, "data"), "particles", ".h5") + # setup the openPMD files if needed + ############################################################################################################ ###################################################### SIMULATION LOOP ##################################### @@ -66,6 +83,9 @@ def run_PyPIC3D(config_file): # plot the data if t % plotting_parameters['plotting_interval'] == 0: + plot_num = t // plotting_parameters['plotting_interval'] + # determine the plot number + E, B, J, rho, *rest = fields # unpack the fields @@ -82,34 +102,45 @@ def run_PyPIC3D(config_file): write_data(f"{output_dir}/data/total_momentum.txt", t * dt, total_momentum) # Write the total momentum to a file - for species in particles: - write_data(f"{output_dir}/data/{species.name}_kinetic_energy.txt", t * dt, species.kinetic_energy()) + # for species in particles: + # write_data(f"{output_dir}/data/{species.name}_kinetic_energy.txt", t * dt, species.kinetic_energy()) if plotting_parameters['plot_phasespace']: write_particles_phase_space(particles, t, output_dir) - rho = compute_rho(particles, rho, world, constants) - # calculate the charge density based on the particle positions - mass_density = compute_mass_density(particles, rho, world) - # calculate the mass density based on the particle positions - fields_mag = [rho[:,world['Ny']//2,:], mass_density[:,world['Ny']//2,:]] - plot_field_slice_vtk(fields_mag, scalar_field_names, 1, E_grid, t, "scalar_field", output_dir, world) - # Plot the scalar fields in VTK format + if plotting_parameters['plot_vtk_scalars']: + rho = compute_rho(particles, rho, world, constants) + # calculate the charge density based on the particle positions + mass_density = compute_mass_density(particles, rho, world) + # calculate the mass density based on the particle positions + fields_mag = [rho[:,world['Ny']//2,:], mass_density[:,world['Ny']//2,:]] + plot_field_slice_vtk(fields_mag, scalar_field_names, 1, E_grid, t, "scalar_field", output_dir, world) + # Plot the scalar fields in VTK format - vector_field_slices = [ [E[0][:,world['Ny']//2,:], E[1][:,world['Ny']//2,:], E[2][:,world['Ny']//2,:]], - [B[0][:,world['Ny']//2,:], B[1][:,world['Ny']//2,:], B[2][:,world['Ny']//2,:]], - [J[0][:,world['Ny']//2,:], J[1][:,world['Ny']//2,:], J[2][:,world['Ny']//2,:]]] - plot_vectorfield_slice_vtk(vector_field_slices, vector_field_names, 1, E_grid, t, 'vector_field', output_dir, world) - # Plot the vector fields in VTK format + + if plotting_parameters['plot_vtk_vectors']: + vector_field_slices = [ [E[0][:,world['Ny']//2,:], E[1][:,world['Ny']//2,:], E[2][:,world['Ny']//2,:]], + [B[0][:,world['Ny']//2,:], B[1][:,world['Ny']//2,:], B[2][:,world['Ny']//2,:]], + [J[0][:,world['Ny']//2,:], J[1][:,world['Ny']//2,:], J[2][:,world['Ny']//2,:]]] + plot_vectorfield_slice_vtk(vector_field_slices, vector_field_names, 1, E_grid, t, 'vector_field', output_dir, world) + # Plot the vector fields in VTK format if plotting_parameters['plot_vtk_particles']: - plot_vtk_particles(particles, t, output_dir) + plot_vtk_particles(particles, plot_num, output_dir) # Plot the particles in VTK format + if plotting_parameters['plot_openpmd_particles']: + write_openpmd_particles(particles, world, constants, os.path.join(output_dir, "data"), plot_num, t, "particles", ".h5") + # Write the particles in openPMD format + + if plotting_parameters['plot_openpmd_fields']: + write_openpmd_fields(fields, world, os.path.join(output_dir, "data"), plot_num, t, "fields", ".h5") + # Write the fields in openPMD format + fields = (E, B, J, rho, *rest) # repack the fields @@ -172,4 +203,4 @@ def main(): if __name__ == "__main__": main() - # run the main function \ No newline at end of file + # run the main function diff --git a/PyPIC3D/diagnostics/fluid_quantities.py b/PyPIC3D/diagnostics/fluid_quantities.py new file mode 100644 index 0000000..781f275 --- /dev/null +++ b/PyPIC3D/diagnostics/fluid_quantities.py @@ -0,0 +1,249 @@ + +from PyPIC3D.shapes import get_first_order_weights, get_second_order_weights +from PyPIC3D.utils import wrap_around + +import jax +import jax.numpy as jnp +from jax import jit + + +@jit +def compute_mass_density(particles, rho, world): + """ + Compute the mass density (rho) for a given set of particles in a simulation world. + Parameters: + particles (list): A list of particle species, each containing methods to get the number of particles, + their positions, and their mass. + rho (ndarray): The initial mass density array to be updated. + world (dict): A dictionary containing the simulation world parameters, including: + - 'dx': Grid spacing in the x-direction. + - 'dy': Grid spacing in the y-direction. + - 'dz': Grid spacing in the z-direction. + - 'x_wind': Window size in the x-direction. + - 'y_wind': Window size in the y-direction. + - 'z_wind': Window size in the z-direction. + Returns: + ndarray: The updated charge density array. + """ + dx = world['dx'] + dy = world['dy'] + dz = world['dz'] + x_wind = world['x_wind'] + y_wind = world['y_wind'] + z_wind = world['z_wind'] + Nx, Ny, Nz = rho.shape + # get the shape of the charge density array + + rho = jnp.zeros_like(rho) + # reset rho to zero + + for species in particles: + shape_factor = species.get_shape() + # get the shape factor of the species, which determines the weighting function + N_particles = species.get_number_of_particles() + mass = species.get_mass() + # get the number of particles and their mass + dm = mass / dx / dy / dz + # calculate the mass per unit volume + x, y, z = species.get_position() + # get the position of the particles in the species + + x0 = jnp.floor((x + x_wind / 2) / dx).astype(int) + y0 = jnp.floor((y + y_wind / 2) / dy).astype(int) + z0 = jnp.floor((z + z_wind / 2) / dz).astype(int) + # Calculate the nearest grid points + + deltax = x - jnp.floor(x / dx) * dx + deltay = y - jnp.floor(y / dy) * dy + deltaz = z - jnp.floor(z / dz) * dz + # Calculate the difference between the particle position and the nearest grid point + + x1 = wrap_around(x0 + 1, Nx) + y1 = wrap_around(y0 + 1, Ny) + z1 = wrap_around(z0 + 1, Nz) + # Calculate the index of the next grid point + + x_minus1 = x0 - 1 + y_minus1 = y0 - 1 + z_minus1 = z0 - 1 + # Calculate the index of the previous grid point + + xpts = [x_minus1, x0, x1] + ypts = [y_minus1, y0, y1] + zpts = [z_minus1, z0, z1] + # place all the points in a list + + x_weights, y_weights, z_weights = jax.lax.cond( + shape_factor == 1, + lambda _: get_first_order_weights(deltax, deltay, deltaz, dx, dy, dz), + lambda _: get_second_order_weights(deltax, deltay, deltaz, dx, dy, dz), + operand=None + ) + # get the weighting factors based on the shape factor + + + for i in range(3): + for j in range(3): + for k in range(3): + rho = rho.at[xpts[i], ypts[j], zpts[k]].add( dm * x_weights[i] * y_weights[j] * z_weights[k], mode='drop') + # distribute the mass of the particles to the grid points using the weighting factors + + return rho + +@jit +def compute_velocity_field(particles, field, direction, world): + """ + Compute the velocity field (v) for a given set of particles in a simulation world. + Parameters: + particles (list): A list of particle species, each containing methods to get the number of particles, + their positions, and their mass. + field (ndarray): The initial velocity field array to be updated. + direction (int): The direction along which to compute the velocity field (0: x, 1: y, 2: z). + world (dict): A dictionary containing the simulation world parameters, including: + - 'dx': Grid spacing in the x-direction. + - 'dy': Grid spacing in the y-direction. + - 'dz': Grid spacing in the z-direction. + - 'x_wind': Window size in the x-direction. + - 'y_wind': Window size in the y-direction. + - 'z_wind': Window size in the z-direction. + Returns: + ndarray: The updated velocity field array. + """ + dx = world['dx'] + dy = world['dy'] + dz = world['dz'] + x_wind = world['x_wind'] + y_wind = world['y_wind'] + z_wind = world['z_wind'] + Nx, Ny, Nz = field.shape + # get the shape of the velocity field array + + field = jnp.zeros_like(field) + # reset field to zero + + for species in particles: + shape_factor = species.get_shape() + # get the shape factor of the species, which determines the weighting function + N_particles = species.get_number_of_particles() + # get the number of particles + x, y, z = species.get_position() + # get the position of the particles in the species + vx, vy, vz = species.get_velocity() + # get the velocity of the particles in the species + + dv = jnp.array([vx, vy, vz])[direction] / N_particles + # select the velocity component based on the direction + + x0 = jnp.floor((x + x_wind / 2) / dx).astype(int) + y0 = jnp.floor((y + y_wind / 2) / dy).astype(int) + z0 = jnp.floor((z + z_wind / 2) / dz).astype(int) + # Calculate the nearest grid points + + deltax = x - jnp.floor(x / dx) * dx + deltay = y - jnp.floor(y / dy) * dy + deltaz = z - jnp.floor(z / dz) * dz + # Calculate the difference between the particle position and the nearest grid point + + x1 = wrap_around(x0 + 1, Nx) + y1 = wrap_around(y0 + 1, Ny) + z1 = wrap_around(z0 + 1, Nz) + # Calculate the index of the next grid point + + x_minus1 = x0 - 1 + y_minus1 = y0 - 1 + z_minus1 = z0 - 1 + # Calculate the index of the previous grid point + + xpts = [x_minus1, x0, x1] + ypts = [y_minus1, y0, y1] + zpts = [z_minus1, z0, z1] + # place all the points in a list + + x_weights, y_weights, z_weights = jax.lax.cond( + shape_factor == 1, + lambda _: get_first_order_weights(deltax, deltay, deltaz, dx, dy, dz), + lambda _: get_second_order_weights(deltax, deltay, deltaz, dx, dy, dz), + operand=None + ) + # get the weighting factors based on the shape factor + + for i in range(3): + for j in range(3): + for k in range(3): + field = field.at[xpts[i], ypts[j], zpts[k]].add( dv * x_weights[i] * y_weights[j] * z_weights[k], mode='drop') + # distribute the velocity of the particles to the grid points using the weighting factors + + return field + + + + +@jit +def compute_pressure_field(particles, field, velocity_field, direction, world): + + dx = world['dx'] + dy = world['dy'] + dz = world['dz'] + x_wind = world['x_wind'] + y_wind = world['y_wind'] + z_wind = world['z_wind'] + Nx, Ny, Nz = field.shape + # get the shape of the velocity field array + + field = jnp.zeros_like(field) + # reset field to zero + + for species in particles: + shape_factor = species.get_shape() + # get the shape factor of the species, which determines the weighting function + x, y, z = species.get_position() + # get the position of the particles in the species + vx, vy, vz = species.get_velocity() + # get the velocity of the particles in the species + + + v = jnp.array([vx, vy, vz])[direction] + # select the velocity component based on the direction + + x0 = jnp.floor((x + x_wind / 2) / dx).astype(int) + y0 = jnp.floor((y + y_wind / 2) / dy).astype(int) + z0 = jnp.floor((z + z_wind / 2) / dz).astype(int) + # Calculate the nearest grid points + + deltax = x - jnp.floor(x / dx) * dx + deltay = y - jnp.floor(y / dy) * dy + deltaz = z - jnp.floor(z / dz) * dz + # Calculate the difference between the particle position and the nearest grid point + + x1 = wrap_around(x0 + 1, Nx) + y1 = wrap_around(y0 + 1, Ny) + z1 = wrap_around(z0 + 1, Nz) + # Calculate the index of the next grid point + + x_minus1 = x0 - 1 + y_minus1 = y0 - 1 + z_minus1 = z0 - 1 + # Calculate the index of the previous grid point + + xpts = [x_minus1, x0, x1] + ypts = [y_minus1, y0, y1] + zpts = [z_minus1, z0, z1] + # place all the points in a list + + x_weights, y_weights, z_weights = jax.lax.cond( + shape_factor == 1, + lambda _: get_first_order_weights(deltax, deltay, deltaz, dx, dy, dz), + lambda _: get_second_order_weights(deltax, deltay, deltaz, dx, dy, dz), + operand=None + ) + # get the weighting factors based on the shape factor + + for i in range(3): + for j in range(3): + for k in range(3): + vbar = v - velocity_field.at[xpts[i], ypts[j], zpts[k]].get() + + field = field.at[xpts[i], ypts[j], zpts[k]].add( vbar**2 * x_weights[i] * y_weights[j] * z_weights[k], mode='drop') + # distribute the pressure moment of the particles to the grid points using the weighting factors + + return field \ No newline at end of file diff --git a/PyPIC3D/diagnostics/openPMD.py b/PyPIC3D/diagnostics/openPMD.py new file mode 100644 index 0000000..a76c9fd --- /dev/null +++ b/PyPIC3D/diagnostics/openPMD.py @@ -0,0 +1,386 @@ + +import openpmd_api as io +import jax.numpy as jnp +import os +import numpy as np +import importlib.metadata + +def _ensure_openpmd_array(data, dtype=np.float64): + arr = np.squeeze(np.asarray(data, dtype=dtype)) + if not arr.flags.c_contiguous or not arr.flags.writeable: + arr = np.array(arr, dtype=dtype, copy=True, order="C") + return arr + + +def _open_openpmd_series(output_path, filename, file_extension=".bp"): + filename = "_".join(filename.split()) + file_extension + # add file extension + series_path = os.path.join(output_path, filename) + access_mode = io.Access.append if os.path.exists(series_path) else io.Access.create + series = io.Series(series_path, access_mode) + series.set_attribute("software", "PyPIC3D") + series.set_attribute("softwareVersion", importlib.metadata.version("PyPIC3D")) + return series + +def _configure_openpmd_mesh(mesh, world, active_dims=(1,1,1)): + mesh.geometry = io.Geometry.cartesian + # openpmd-api 0.16+ removed io.Data_Order; mesh.data_order accepts a string. + mesh.data_order = io.Data_Order.C if hasattr(io, "Data_Order") else "C" + + axes = [] + ds = [] + offsets = [] + # initialize lists for axes, spacings, and offsets + if active_dims[0]: + axes.append("x") + ds.append(float(world["dx"])) + offsets.append(-float(world["x_wind"]) / 2.0) + if active_dims[1]: + axes.append("y") + ds.append(float(world["dy"])) + offsets.append(-float(world["y_wind"]) / 2.0) + if active_dims[2]: + axes.append("z") + ds.append(float(world["dz"])) + offsets.append(-float(world["z_wind"]) / 2.0) + # determine the active axes being used and set them + + mesh.axis_labels = axes + mesh.grid_spacing = ds + mesh.grid_global_offset = offsets + + + # mesh.grid_spacing = [float(world["dx"]), float(world["dy"]), float(world["dz"])] + # mesh.grid_global_offset = [ + # -float(world["x_wind"]) / 2.0, + # -float(world["y_wind"]) / 2.0, + # -float(world["z_wind"]) / 2.0, + # ] + # mesh.axis_labels = ["x", "y", "z"] + mesh.unit_SI = 1.0 + + +def _write_openpmd_scalar_mesh(iteration, name, data, world, active_dims=(1,1,1)): + mesh = iteration.meshes[name] + _configure_openpmd_mesh(mesh, world, active_dims) + array = _ensure_openpmd_array(data) + record = mesh[io.Mesh_Record_Component.SCALAR] + record.reset_dataset(io.Dataset(array.dtype, array.shape)) + record.store_chunk(array, [0] * array.ndim, array.shape) + record.unit_SI = 1.0 + + +def _write_openpmd_vector_mesh(iteration, name, components, world, active_dims=(1,1,1)): + mesh = iteration.meshes[name] + _configure_openpmd_mesh(mesh, world, active_dims) + for component_name, component_data in zip(("x", "y", "z"), components): + array = _ensure_openpmd_array(component_data) + record = mesh[component_name] + record.reset_dataset(io.Dataset(array.dtype, array.shape)) + record.store_chunk(array, [0] * array.ndim, array.shape) + record.unit_SI = 1.0 + + +def write_openpmd_fields_to_iteration(iteration, field_map, world, active_dims=(1,1,1)): + for name, data in field_map.items(): + is_vector = isinstance(data, (list, tuple)) and len(data) == 3 + if is_vector: + _write_openpmd_vector_mesh(iteration, name, data, world, active_dims) + else: + _write_openpmd_scalar_mesh(iteration, name, data, world, active_dims) + + +def write_openpmd_particles_to_iteration(iteration, particles, constants): + if not particles: + return + + C = float(constants["C"]) + + for species in particles: + species_name = species.get_name().replace(" ", "_") + species_group = iteration.particles[species_name] + + x, y, z = species.get_position() + vx, vy, vz = species.get_velocity() + gamma = 1 / jnp.sqrt(1.0 - (vx**2 + vy**2 + vz**2) / C**2) + + x = _ensure_openpmd_array(x) + y = _ensure_openpmd_array(y) + z = _ensure_openpmd_array(z) + vx = _ensure_openpmd_array(vx) + vy = _ensure_openpmd_array(vy) + vz = _ensure_openpmd_array(vz) + gamma = _ensure_openpmd_array(gamma) + + num_particles = x.shape[0] + particle_mass = float(species.mass) + particle_charge = float(species.charge) + + position = species_group["position"] + for component, data in zip(("x", "y", "z"), (x, y, z)): + record_component = position[component] + record_component.reset_dataset(io.Dataset(data.dtype, [num_particles])) + record_component.store_chunk(data, [0], [num_particles]) + record_component.unit_SI = 1.0 + + pos_off = species_group["positionOffset"] + zeros = np.zeros(num_particles, dtype=np.float64) + for comp in ("x", "y", "z"): + rc = pos_off[comp] + rc.reset_dataset(io.Dataset(zeros.dtype, [num_particles])) + rc.store_chunk(zeros, [0], [num_particles]) + rc.unit_SI = 1.0 + + momentum = species_group["momentum"] + for component, data in zip(("x", "y", "z"), (vx, vy, vz)): + record_component = momentum[component] + record_component.reset_dataset(io.Dataset(data.dtype, [num_particles])) + record_component.store_chunk(data * particle_mass * gamma, [0], [num_particles]) + record_component.unit_SI = 1.0 + + weighting = species_group["weighting"] + weights = np.full(num_particles, float(species.weight), dtype=np.float64) + weighting.reset_dataset(io.Dataset(weights.dtype, [num_particles])) + weighting.store_chunk(weights, [0], [num_particles]) + weighting.unit_SI = 1.0 + + charge = species_group["charge"] + charges = np.full(num_particles, particle_charge, dtype=np.float64) + charge.reset_dataset(io.Dataset(charges.dtype, [num_particles])) + charge.store_chunk(charges, [0], [num_particles]) + charge.unit_SI = 1.0 + + mass = species_group["mass"] + masses = np.full(num_particles, particle_mass, dtype=np.float64) + mass.reset_dataset(io.Dataset(masses.dtype, [num_particles])) + mass.store_chunk(masses, [0], [num_particles]) + mass.unit_SI = 1.0 + + +def write_openpmd_fields(fields, world, output_dir, plot_t, t, filename="fields", file_extension=".bp"): + """ + Write all field data to an openPMD file for visualization in ParaView/VisIt. + + Args: + fields (tuple): Field tuple from the solver (E, B, J, rho, ...). + world (dict): Simulation world parameters. + output_dir (str): Base output directory for the simulation. + plot_t (int): openPMD iteration number/index used when writing this step. + t (int): Simulation step index used to compute the physical time (t * world["dt"]). + filename (str): Base name for the openPMD file. + file_extension (str): File extension for the openPMD series (for example, ".bp"). + """ + E, B, J, rho, *rest = fields + field_map = { + "E": E, + "B": B, + "J": J, + "rho": rho, + } + # map field names to their data + + if rest: + field_map["phi"] = rest[0] + for idx, extra in enumerate(rest[1:], start=1): + field_map[f"field_{idx}"] = extra + # add extra fields if present + + Nx, Ny, Nz = rho.shape + active_dims = (Nx > 1, Ny > 1, Nz > 1) + # determine active dimensions + + + series = _open_openpmd_series(output_dir, filename, file_extension=file_extension) + # open or create the openPMD series + iteration = series.iterations[int(plot_t)] + # specify the iteration using the plot number + iteration.time = float(t * world["dt"]) + # set the physical time + iteration.dt = float(world["dt"]) + # set the time step + iteration.time_unit_SI = 1.0 + # set the time unit + write_openpmd_fields_to_iteration(iteration, field_map, world, active_dims) + # write the field data to the iteration + series.flush() + series.close() + # flush and close the series + + +def write_openpmd_particles(particles, world, constants, output_dir, plot_t, t, filename="particles", file_extension=".bp"): + """ + Write all particle data to an openPMD file for visualization in ParaView/VisIt. + + Args: + particles (list): Particle species list. + world (dict): Simulation world parameters. + constants (dict): Physical constants (must include key 'C'). + output_dir (str): Base output directory for the simulation. + t (int): Iteration index. + filename (str): openPMD file name. + """ + series = _open_openpmd_series(output_dir, filename, file_extension=file_extension) + # open or create the openPMD series + iteration = series.iterations[int(plot_t)] + # specify the iteration using the plot number + iteration.time = float(t * world["dt"]) + # set the physical time + iteration.dt = float(world["dt"]) + # set the time step + iteration.time_unit_SI = 1.0 + # set the time unit + write_openpmd_particles_to_iteration(iteration, particles, constants) + # write the particle data to the iteration + series.flush() + series.close() + # flush and close the series + + + +def write_openpmd_initial_particles(particles, world, constants, output_dir, filename="initial_particles.h5"): + """ + Write the initial particle states to separate openPMD files, one per species. + + Args: + particles (list): List of particle species. + world (dict): Dictionary containing the simulation world parameters. + constants (dict): Dictionary of physical constants (must include key 'C' for the speed of light). + output_dir (str): Base output directory for the simulation. + filename (str): Base name of the openPMD output file (species name is prepended). + """ + if not particles: + return + + C = constants['C'] + # speed of light + + output_path = os.path.join(output_dir, "data", "initial_particles") + os.makedirs(output_path, exist_ok=True) + + def make_array_writable(arr): + arr = np.array(arr, dtype=np.float64, copy=True, order="C") + arr.setflags(write=True) + return arr + + for species in particles: + species_name = species.get_name().replace(" ", "_") + series_filename = f"{species_name}_{filename}" + series_path = os.path.join(output_path, series_filename) + + series = io.Series(series_path, io.Access.create) + series.set_attribute("software", "PyPIC3D") + series.set_attribute("softwareVersion", importlib.metadata.version("PyPIC3D")) + + iteration = series.iterations[0] + iteration.time = 0.0 + iteration.dt = float(world["dt"]) + iteration.time_unit_SI = 1.0 + + species_group = iteration.particles[species_name] + + x, y, z = species.get_forward_position() + vx, vy, vz = species.get_velocity() + gamma = 1 / jnp.sqrt(1.0 - (vx**2 + vy**2 + vz**2) / C**2) + # compute the Lorentz factor + + x = make_array_writable(x) + y = make_array_writable(y) + z = make_array_writable(z) + vx = make_array_writable(vx) + vy = make_array_writable(vy) + vz = make_array_writable(vz) + gamma = make_array_writable(gamma) + + num_particles = x.shape[0] + particle_mass = float(species.mass) + particle_charge = float(species.charge) + + position = species_group["position"] + for component, data in zip(("x", "y", "z"), (x, y, z)): + record_component = position[component] + record_component.reset_dataset(io.Dataset(data.dtype, [num_particles])) + record_component.store_chunk(data, [0], [num_particles]) + record_component.unit_SI = 1.0 + + # positionOffset: required by openPMD consumers (WarpX expects it) + pos_off = species_group["positionOffset"] + zeros = np.zeros(num_particles, dtype=np.float64) + for comp in ("x", "y", "z"): + rc = pos_off[comp] + rc.reset_dataset(io.Dataset(zeros.dtype, [num_particles])) + rc.store_chunk(zeros, [0], [num_particles]) + rc.unit_SI = 1.0 + + momentum = species_group["momentum"] + for component, data in zip(("x", "y", "z"), (vx, vy, vz)): + record_component = momentum[component] + record_component.reset_dataset(io.Dataset(data.dtype, [num_particles])) + record_component.store_chunk(data * particle_mass * gamma , [0], [num_particles]) + record_component.unit_SI = 1.0 + + weighting = species_group["weighting"] + weights = np.full(num_particles, float(species.weight), dtype=np.float64) + weighting.reset_dataset(io.Dataset(weights.dtype, [num_particles])) + weighting.store_chunk(weights, [0], [num_particles]) + weighting.unit_SI = 1.0 + + charge = species_group["charge"] + charges = np.full(num_particles, particle_charge, dtype=np.float64) + charge.reset_dataset(io.Dataset(charges.dtype, [num_particles])) + charge.store_chunk(charges, [0], [num_particles]) + charge.unit_SI = 1.0 + + mass = species_group["mass"] + masses = np.full(num_particles, particle_mass, dtype=np.float64) + mass.reset_dataset(io.Dataset(masses.dtype, [num_particles])) + mass.store_chunk(masses, [0], [num_particles]) + mass.unit_SI = 1.0 + + series.flush() + series.close() + +def write_openpmd_initial_fields(fields, world, output_dir, filename="initial_fields.h5"): + """ + Write the initial field states to an openPMD file. + + Args: + fields (tuple): Field tuple from the solver (E, B, J, rho, ...). + world (dict): Simulation world parameters. + output_dir (str): Base output directory for the simulation. + filename (str): openPMD file name. + """ + E, B, J, rho, *rest = fields + field_map = { + "E": E, + "B": B, + "J": J, + "rho": rho, + } + # map field names to their data + + if rest: + field_map["phi"] = rest[0] + for idx, extra in enumerate(rest[1:], start=1): + field_map[f"field_{idx}"] = extra + # add extra fields if present + + Nx, Ny, Nz = rho.shape + active_dims = (Nx > 1, Ny > 1, Nz > 1) + # determine active dimensions + + + output_path = os.path.join(output_dir, "data", "initial_fields") + os.makedirs(output_path, exist_ok=True) + series_path = os.path.join(output_path, filename) + series = io.Series(series_path, io.Access.create) + series.set_attribute("software", "PyPIC3D") + series.set_attribute("softwareVersion", importlib.metadata.version("PyPIC3D")) + # create the openPMD series + + iteration = series.iterations[0] + iteration.time = 0.0 + iteration.dt = float(world["dt"]) + iteration.time_unit_SI = 1.0 + write_openpmd_fields_to_iteration(iteration, field_map, world, active_dims) + series.flush() + series.close() \ No newline at end of file diff --git a/PyPIC3D/diagnostics/plotting.py b/PyPIC3D/diagnostics/plotting.py new file mode 100644 index 0000000..1236fae --- /dev/null +++ b/PyPIC3D/diagnostics/plotting.py @@ -0,0 +1,246 @@ +import matplotlib +matplotlib.use('agg') +import matplotlib.pyplot as plt +import jax.numpy as jnp +import os +import plotly.graph_objects as go +import jax +from functools import partial + +def plot_positions(particles, t, x_wind, y_wind, z_wind, path): + """ + Makes an interactive 3D plot of the positions of the particles using Plotly. + + Args: + particles (list): A list of ParticleSpecies objects containing positions. + t (float): The time value. + x_wind (float): The x-axis wind limit. + y_wind (float): The y-axis wind limit. + z_wind (float): The z-axis wind limit. + + Returns: + None + """ + fig = go.Figure() + + colors = ['red', 'blue', 'green', 'purple', 'orange', 'yellow', 'cyan'] + for idx, species in enumerate(particles): + x, y, z = species.get_position() + fig.add_trace(go.Scatter3d( + x=x, y=y, z=z, mode='markers', + marker=dict(size=2, color=colors[idx % len(colors)]), + name=f'Species {idx + 1}' + )) + + fig.update_layout( + scene=dict( + xaxis=dict(range=[-(2/3)*x_wind, (2/3)*x_wind]), + yaxis=dict(range=[-(2/3)*y_wind, (2/3)*y_wind]), + zaxis=dict(range=[-(2/3)*z_wind, (2/3)*z_wind]), + xaxis_title='X (m)', + yaxis_title='Y (m)', + zaxis_title='Z (m)' + ), + title="Particle Positions" + ) + + if not os.path.exists(f"{path}/data/positions"): + os.makedirs(f"{path}/data/positions") + + fig.write_html(f"{path}/data/positions/particles.{t:09}.html") + +def write_particles_phase_space(particles, t, path): + """ + Write the phase space of the particles to a file. + + Args: + particles (Particles): The particles to be written. + t (ndarray): The time values. + name (str): The name of the plot. + + Returns: + None + """ + if not os.path.exists(f"{path}/data/phase_space/x"): + os.makedirs(f"{path}/data/phase_space/x") + if not os.path.exists(f"{path}/data/phase_space/y"): + os.makedirs(f"{path}/data/phase_space/y") + if not os.path.exists(f"{path}/data/phase_space/z"): + os.makedirs(f"{path}/data/phase_space/z") + # Create directory if it doesn't exist + + for species in particles: + x, y, z = species.get_position() + vx, vy, vz = species.get_velocity() + name = species.get_name().replace(" ", "") + + x_phase_space = jnp.stack((x, vx), axis=-1) + y_phase_space = jnp.stack((y, vy), axis=-1) + z_phase_space = jnp.stack((z, vz), axis=-1) + + jnp.save(f"{path}/data/phase_space/x/{name}_phase_space.{t:09}.npy", x_phase_space) + jnp.save(f"{path}/data/phase_space/y/{name}_phase_space.{t:09}.npy", y_phase_space) + jnp.save(f"{path}/data/phase_space/z/{name}_phase_space.{t:09}.npy", z_phase_space) + # write the phase space of the particles to a file + +def particles_phase_space(particles, world, t, name, path): + """ + Plot the phase space of the particles. + + Args: + particles (Particles): The particles to be plotted. + t (ndarray): The time values. + name (str): The name of the plot. + + Returns: + None + """ + + x_wind = world['x_wind'] + y_wind = world['y_wind'] + z_wind = world['z_wind'] + + colors = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] + idx = 0 + order = 10 + for species in particles: + x, y, z = species.get_position() + vx, vy, vz = species.get_velocity() + plt.scatter(x, vx, c=colors[idx], zorder=order, s=1) + idx += 1 + order -= 1 + plt.xlabel("Position") + plt.ylabel("Velocity") + #plt.ylim(-1e10, 1e10) + plt.xlim(-(2/3)*x_wind, (2/3)*x_wind) + plt.title(f"{name} Phase Space") + plt.savefig(f"{path}/data/phase_space/x/{name}_phase_space.{t:09}.png", dpi=300) + plt.close() + + idx = 0 + for species in particles: + x, y, z = species.get_position() + vx, vy, vz = species.get_velocity() + plt.scatter(y, vy, c=colors[idx]) + idx += 1 + plt.xlabel("Position") + plt.ylabel("Velocity") + plt.xlim(-(2/3)*y_wind, (2/3)*y_wind) + plt.title(f"{name} Phase Space") + plt.savefig(f"{path}/data/phase_space/y/{name}_phase_space.{t:09}.png", dpi=150) + plt.close() + + idx = 0 + for species in particles: + x, y, z = species.get_position() + vx, vy, vz = species.get_velocity() + plt.scatter(z, vz, c=colors[idx]) + idx += 1 + plt.xlabel("Position") + plt.ylabel("Velocity") + plt.xlim(-(2/3)*z_wind, (2/3)*z_wind) + plt.title(f"{name} Phase Space") + plt.savefig(f"{path}/data/phase_space/z/{name}_phase_space.{t:09}.png", dpi=150) + plt.close() + + +def plot_initial_histograms(particle_species, world, name, path): + """ + Generates and saves histograms for the initial positions and velocities + of particles in a simulation. + + Parameters: + particle_species (object): An object representing the particle species, + which provides methods `get_position()` and + `get_velocity()` to retrieve particle positions + (x, y, z) and velocities (vx, vy, vz). + world (dict): A dictionary containing the simulation world parameters, + specifically the wind dimensions with keys 'x_wind', + 'y_wind', and 'z_wind'. + name (str): A string representing the name of the particle species or + simulation, used in the titles of the histograms and filenames. + path (str): The directory path where the histogram images will be saved. + + Saves: + Six histogram images: + - Initial X, Y, and Z position histograms. + - Initial X, Y, and Z velocity histograms. + The images are saved in the specified `path` directory with filenames + formatted as "{name}_initial__histogram.png". + """ + + x, y, z = particle_species.get_position() + vx, vy, vz = particle_species.get_velocity() + + x_wind = world['x_wind'] + y_wind = world['y_wind'] + z_wind = world['z_wind'] + + + plt.hist(x, bins=50) + plt.xlabel("X") + plt.ylabel("Number of Particles") + plt.xlim(-(2/3)*x_wind, (2/3)*x_wind) + plt.title(f"{name} Initial X Position Histogram") + plt.savefig(f"{path}/{name}_initial_x_histogram.png", dpi=150) + plt.close() + + plt.hist(y, bins=50) + plt.xlabel("Y") + plt.ylabel("Number of Particles") + plt.xlim(-(2/3)*y_wind, (2/3)*y_wind) + plt.title(f"{name} Initial Y Position Histogram") + plt.savefig(f"{path}/{name}_initial_y_histogram.png", dpi=150) + plt.close() + + plt.hist(z, bins=50) + plt.xlabel("Z") + plt.ylabel("Number of Particles") + plt.xlim(-(2/3)*z_wind, (2/3)*z_wind) + plt.title(f"{name} Initial Z Position Histogram") + plt.savefig(f"{path}/{name}_initial_z_histogram.png", dpi=150) + plt.close() + + plt.hist(vx, bins=50) + plt.xlabel("X Velocity") + plt.ylabel("Number of Particles") + plt.title(f"{name} Initial X Velocity Histogram") + plt.savefig(f"{path}/{name}_initial_x_velocity_histogram.png", dpi=150) + plt.close() + + plt.hist(vy, bins=50) + plt.xlabel("Y Velocity") + plt.ylabel("Number of Particles") + plt.title(f"{name} Initial Y Velocity Histogram") + plt.savefig(f"{path}/{name}_initial_y_velocity_histogram.png", dpi=150) + plt.close() + + plt.hist(vz, bins=50) + plt.xlabel("Z Velocity") + plt.ylabel("Number of Particles") + plt.title(f"{name} Initial Z Velocity Histogram") + plt.savefig(f"{path}/{name}_initial_z_velocity_histogram.png", dpi=150) + plt.close() + + +@partial(jax.jit, static_argnums=(0)) +def write_data(filename, time, data): + """ + Write the given time and data to a file using JAX's callback mechanism. + This function is designed to be used with JAX's just-in-time compilation (jit) to optimize performance. + + Args: + filename (str): The name of the file to write to. + time (float): The time value to write. + data (any): The data to write. + + Returns: + None + """ + + def write_to_file(filename, time, data): + with open(filename, "a") as f: + f.write(f"{time}, {data}\n") + + return jax.debug.callback(write_to_file, filename, time, data, ordered=True) + diff --git a/PyPIC3D/diagnostics/vtk.py b/PyPIC3D/diagnostics/vtk.py new file mode 100644 index 0000000..f7ff14b --- /dev/null +++ b/PyPIC3D/diagnostics/vtk.py @@ -0,0 +1,231 @@ +import vtk +from vtk.util import numpy_support +import numpy as np +import os +from pyevtk.hl import gridToVTK, pointsToVTK +import jax.numpy as jnp + +def plot_fields(fieldx, fieldy, fieldz, t, name, dx, dy, dz): + """ + Plot the fields in a 3D grid. + + Args: + fieldx (ndarray): Array representing the x-component of the field. + fieldy (ndarray): Array representing the y-component of the field. + fieldz (ndarray): Array representing the z-component of the field. + t (float): Time value. + name (str): Name of the field. + dx (float): Spacing between grid points in the x-direction. + dy (float): Spacing between grid points in the y-direction. + dz (float): Spacing between grid points in the z-direction. + + Returns: + None + """ + Nx = fieldx.shape[0] + Ny = fieldx.shape[1] + Nz = fieldx.shape[2] + x = np.linspace(0, Nx, Nx) * dx + y = np.linspace(0, Ny, Ny) * dy + z = np.linspace(0, Nz, Nz) * dz + + # Create directory if it doesn't exist + directory = "./plots/fields" + if not os.path.exists(directory): + os.makedirs(directory) + + gridToVTK(f"./plots/fields/{name}_{t:09}", x, y, z, \ + cellData = {f"{name}_x" : np.asarray(fieldx), \ + f"{name}_y" : np.asarray(fieldy), f"{name}_z" : np.asarray(fieldz)}) +# plot the electric fields in the vtk file format + + +def plot_vtk_particles(particles, t, path): + """ + Plot the particles in VTK format. + + Args: + particles (Particles): The particles to be plotted. + t (ndarray): The time values. + path (str): The path to save the plot. + + Returns: + None + """ + if not os.path.exists(f"{path}/data/particles"): + os.makedirs(f"{path}/data/particles") + + particle_names = [species.get_name().replace(" ", "") for species in particles] + for species in particles: + name = species.get_name().replace(" ", "") + x, y, z = map(np.asarray, species.get_position()) + vx, vy, vz = map(np.asarray, species.get_velocity()) + # Get the position and velocity of the particles + q = np.asarray( species.get_charge() * np.ones_like(vx) ) + # Get the charge of the particles + + pointsToVTK(f"{path}/data/particles/{name}.{t:09}", x, y, z, \ + {"v": (vx, vy, vz), "q": q}) + # save the particles in the vtk file format + + +def plot_field_slice_vtk(field_slices, field_names, slice, grid, t, name, path, world): + """ + Plot a slice of a field in VTK format using Python VTK library. + + Args: + field_slices (list): List of 2D field slices to be plotted. + field_names (list): List of field names corresponding to the slices. + slice (int): Slice direction (0=x-slice, 1=y-slice, 2=z-slice, 3=full 3D). + grid (tuple): The grid dimensions (x, y, z). + t (int): The time step. + name (str): The name of the field. + path (str): The path to save the plot. + world (dict): World parameters containing grid information. + + Returns: + None + """ + + x, y, z = grid + nx, ny, nz = world['Nx'], world['Ny'], world['Nz'] + dx, dy, dz = world['dx'], world['dy'], world['dz'] + + if not os.path.exists(f"{path}/data/{name}_slice"): + os.makedirs(f"{path}/data/{name}_slice") + # Create directory if it doesn't exist + + # Create VTK structured grid + structured_grid = vtk.vtkStructuredGrid() + + if slice == 0: + x = np.asarray([x[nx//2]]) + elif slice == 1: + y = np.asarray([y[ny//2]]) + elif slice == 2: + z = np.asarray([z[nz//2]]) + + structured_grid.SetDimensions(x.shape[0], y.shape[0], z.shape[0]) + # Set the dimensions of the structured grid based on the slice type + + # Efficiently create all grid points using numpy meshgrid and bulk insert + X, Y, Z = np.meshgrid(x, y, z, indexing='ij') + coords = np.column_stack((Z.ravel(), Y.ravel(), X.ravel())) + points = vtk.vtkPoints() + points.SetData(numpy_support.numpy_to_vtk(coords, deep=True)) + structured_grid.SetPoints(points) + # Create points for the structured grid based on the slice type + + for idx, (field_slice, field_name) in enumerate(zip(field_slices, field_names)): + # Ensure field_slice is 2D and handle VTK ordering + field_data = np.asarray(field_slice) + if field_data.ndim == 2: + # VTK expects data in (k,j,i) order, but our slice is typically (j,k) + field_data = field_data.T.flatten() + else: + field_data = field_data.flatten() + + vtk_array = numpy_support.numpy_to_vtk(field_data) + vtk_array.SetName(field_name) + structured_grid.GetPointData().AddArray(vtk_array) + + # Write the VTK file + writer = vtk.vtkStructuredGridWriter() + writer.SetFileName(f"{path}/data/{name}_slice/{name}_slice_{t:09}.vtk") + writer.SetInputData(structured_grid) + writer.Write() + + +def plot_vectorfield_slice_vtk(field_slices, field_names, slice, grid, t, name, path, world): + """ + Plot a slice of a field in VTK format as vector data using Python VTK library. + + Args: + field_slices (list): List of 2D field slices to be plotted. Should be [Fx, Fy, Fz] for vector fields. + field_names (list): List of field names corresponding to the slices (e.g., ['Ex', 'Ey', 'Ez']). + slice (int): Slice direction (0=x-slice, 1=y-slice, 2=z-slice, 3=full 3D). + grid (tuple): The grid dimensions (x, y, z). + t (int): The time step. + name (str): The name of the field. + path (str): The path to save the plot. + world (dict): World parameters containing grid information. + + Returns: + None + """ + + x, y, z = grid + nx, ny, nz = world['Nx'], world['Ny'], world['Nz'] + dx, dy, dz = world['dx'], world['dy'], world['dz'] + + if not os.path.exists(f"{path}/data/{name}_slice"): + os.makedirs(f"{path}/data/{name}_slice") + + # Handle slicing + if slice == 0: + x = np.asarray([x[nx//2]]) + elif slice == 1: + y = np.asarray([y[ny//2]]) + elif slice == 2: + z = np.asarray([z[nz//2]]) + + # Create VTK structured grid + structured_grid = vtk.vtkStructuredGrid() + structured_grid.SetDimensions(x.shape[0], y.shape[0], z.shape[0]) + + # Create grid points + X, Y, Z = np.meshgrid(x, y, z, indexing='ij') + coords = np.column_stack((Z.ravel(), Y.ravel(), X.ravel())) + points = vtk.vtkPoints() + points.SetData(numpy_support.numpy_to_vtk(coords, deep=True)) + structured_grid.SetPoints(points) + + # Stack field slices as vector data + # Each field_slice should be 2D, shape (len(x), len(y)) or similar + # We flatten and stack them as (N, 3) for VTK vector + for field_slice, field_name in zip(field_slices, field_names): + field_arrays = [] + for comp in field_slice: + field_data = np.asarray(comp) + # convert to np array + if field_data.ndim == 2: + field_data = field_data.T.flatten() # VTK expects Fortran order + else: + field_data = field_data.flatten() + field_arrays.append(field_data) + # Stack as (N, 3) + vector_data = np.stack(field_arrays, axis=-1) + # If shape is (N, 3), convert to VTK + vtk_vector_array = numpy_support.numpy_to_vtk(vector_data, deep=True) + vtk_vector_array.SetName(f"{field_name}_vector") + structured_grid.GetPointData().AddArray(vtk_vector_array) + + # Write the VTK file + writer = vtk.vtkStructuredGridWriter() + writer.SetFileName(f"{path}/data/{name}_slice/{name}_slice_{t:09}.vtk") + writer.SetInputData(structured_grid) + writer.Write() + + +def write_slice(field_slice, x1, x2, t, name, path, dt): + """ + Plots a slice of a field and saves it in VTK format. + Parameters: + field_slice (numpy.ndarray): The 2D slice of the field to be plotted. + x1 (numpy.ndarray): The x-coordinates of the field slice. + x2 (numpy.ndarray): The y-coordinates of the field slice. + t (int): The time step or index for the slice. + name (str): The name of the field or slice. + path (str): The directory path where the VTK file will be saved. + dt (float): The time step size (not used in the function but included in parameters). + Returns: + None + """ + + x3 = np.zeros(1) + + field_slice = jnp.expand_dims(field_slice, axis=-1) + + gridToVTK(f"{path}/data/{name}_slice/{name}_slice_{t:09}", x1, x2, x3, \ + cellData = {f"{name}" : field_slice}) + # plot the slice of the field in the vtk file format \ No newline at end of file diff --git a/PyPIC3D/initialization.py b/PyPIC3D/initialization.py index 4f9f4df..f421070 100644 --- a/PyPIC3D/initialization.py +++ b/PyPIC3D/initialization.py @@ -35,8 +35,12 @@ ) -from PyPIC3D.plotting import ( - plot_initial_histograms, write_openpmd_initial_particles +from PyPIC3D.diagnostics.plotting import ( + plot_initial_histograms +) + +from PyPIC3D.diagnostics.openPMD import ( + write_openpmd_initial_particles, write_openpmd_initial_fields ) @@ -71,9 +75,14 @@ def default_parameters(): "plot_errors": False, "plot_dispersion": False, 'plot_chargeconservation': False, - "plot_vtk_particles": True, + "plot_vtk_particles": False, + "plot_vtk_scalars" : False, + "plot_vtk_vectors" : False, + "plot_openpmd_particles": False, + "plot_openpmd_fields": False, "plotting_interval": 10, "dump_particles": False, + "dump_fields": False, } # dictionary for plotting/saving data @@ -334,6 +343,10 @@ def initialize_simulation(toml_file): fields = (E, B, J, rho, phi) # define the fields tuple for the electrodynamic and electrostatic solvers + if plotting_parameters['dump_fields']: + write_openpmd_initial_fields(fields, world, simulation_parameters['output_dir'], filename="initial_fields.h5") + # write the initial fields to an openPMD file + if GPUs: print(f"GPUs Detected! Using GPUs for simulation\n") @@ -384,4 +397,4 @@ def initialize_fields(Nx, Ny, Nz): rho = jax.numpy.zeros(shape = (Nx, Ny, Nz) ) # initialize the electric potential and charge density arrays as 0 - return (Ex, Ey, Ez), (Bx, By, Bz), (Jx, Jy, Jz), phi, rho \ No newline at end of file + return (Ex, Ey, Ez), (Bx, By, Bz), (Jx, Jy, Jz), phi, rho diff --git a/PyPIC3D/plotting.py b/PyPIC3D/plotting.py deleted file mode 100644 index ce43936..0000000 --- a/PyPIC3D/plotting.py +++ /dev/null @@ -1,699 +0,0 @@ -import numpy as np -import matplotlib -matplotlib.use('agg') -import matplotlib.pyplot as plt -from jax import jit -import jax.numpy as jnp -from pyevtk.hl import gridToVTK, pointsToVTK -import os -import plotly.graph_objects as go -import jax -from functools import partial -import vtk -from vtk.util import numpy_support -import openpmd_api as io -import importlib.metadata - -from PyPIC3D.utils import compute_energy - -def plot_rho(rho, t, name, dx, dy, dz): - """ - Plot the density field. - - Args: - rho (ndarray): The density field. - t (int): The time step. - name (str): The name of the plot. - dx (float): The spacing in the x-direction. - dy (float): The spacing in the y-direction. - dz (float): The spacing in the z-direction. - - Returns: - None - """ - Nx = rho.shape[0] - Ny = rho.shape[1] - Nz = rho.shape[2] - x = np.linspace(0, Nx, Nx) * dx - y = np.linspace(0, Ny, Ny) * dy - z = np.linspace(0, Nz, Nz) * dz - - - # Create directory if it doesn't exist - directory = "./plots/rho" - if not os.path.exists(directory): - os.makedirs(directory) - - - gridToVTK(f"./plots/rho/{name}_{t:09}", x, y, z, \ - cellData = {f"{name}" : np.asarray(rho)}) -# plot the charge density in the vtk file format - -def plot_fields(fieldx, fieldy, fieldz, t, name, dx, dy, dz): - """ - Plot the fields in a 3D grid. - - Args: - fieldx (ndarray): Array representing the x-component of the field. - fieldy (ndarray): Array representing the y-component of the field. - fieldz (ndarray): Array representing the z-component of the field. - t (float): Time value. - name (str): Name of the field. - dx (float): Spacing between grid points in the x-direction. - dy (float): Spacing between grid points in the y-direction. - dz (float): Spacing between grid points in the z-direction. - - Returns: - None - """ - Nx = fieldx.shape[0] - Ny = fieldx.shape[1] - Nz = fieldx.shape[2] - x = np.linspace(0, Nx, Nx) * dx - y = np.linspace(0, Ny, Ny) * dy - z = np.linspace(0, Nz, Nz) * dz - - # Create directory if it doesn't exist - directory = "./plots/fields" - if not os.path.exists(directory): - os.makedirs(directory) - - gridToVTK(f"./plots/fields/{name}_{t:09}", x, y, z, \ - cellData = {f"{name}_x" : np.asarray(fieldx), \ - f"{name}_y" : np.asarray(fieldy), f"{name}_z" : np.asarray(fieldz)}) -# plot the electric fields in the vtk file format - -def plot_1dposition(x, name, particle): - """ - Plot the 1D position of a particle. - - Args: - x (ndarray): The x-coordinates of the particle. - name (str): The name of the plot. - - Returns: - None - """ - plt.plot(x) - plt.title(f"{name} Position") - plt.xlabel("Time") - plt.ylabel("Position") - - if not os.path.exists(f"plots/{name}"): - os.makedirs(f"plots/{name}") - - plt.savefig(f"plots/{name}/{particle}_position.png", dpi=300) - plt.close() - - -def plot_positions(particles, t, x_wind, y_wind, z_wind, path): - """ - Makes an interactive 3D plot of the positions of the particles using Plotly. - - Args: - particles (list): A list of ParticleSpecies objects containing positions. - t (float): The time value. - x_wind (float): The x-axis wind limit. - y_wind (float): The y-axis wind limit. - z_wind (float): The z-axis wind limit. - - Returns: - None - """ - fig = go.Figure() - - colors = ['red', 'blue', 'green', 'purple', 'orange', 'yellow', 'cyan'] - for idx, species in enumerate(particles): - x, y, z = species.get_position() - fig.add_trace(go.Scatter3d( - x=x, y=y, z=z, mode='markers', - marker=dict(size=2, color=colors[idx % len(colors)]), - name=f'Species {idx + 1}' - )) - - fig.update_layout( - scene=dict( - xaxis=dict(range=[-(2/3)*x_wind, (2/3)*x_wind]), - yaxis=dict(range=[-(2/3)*y_wind, (2/3)*y_wind]), - zaxis=dict(range=[-(2/3)*z_wind, (2/3)*z_wind]), - xaxis_title='X (m)', - yaxis_title='Y (m)', - zaxis_title='Z (m)' - ), - title="Particle Positions" - ) - - if not os.path.exists(f"{path}/data/positions"): - os.makedirs(f"{path}/data/positions") - - fig.write_html(f"{path}/data/positions/particles.{t:09}.html") - -def write_particles_phase_space(particles, t, path): - """ - Write the phase space of the particles to a file. - - Args: - particles (Particles): The particles to be written. - t (ndarray): The time values. - name (str): The name of the plot. - - Returns: - None - """ - if not os.path.exists(f"{path}/data/phase_space/x"): - os.makedirs(f"{path}/data/phase_space/x") - if not os.path.exists(f"{path}/data/phase_space/y"): - os.makedirs(f"{path}/data/phase_space/y") - if not os.path.exists(f"{path}/data/phase_space/z"): - os.makedirs(f"{path}/data/phase_space/z") - # Create directory if it doesn't exist - - for species in particles: - x, y, z = species.get_position() - vx, vy, vz = species.get_velocity() - name = species.get_name().replace(" ", "") - - x_phase_space = jnp.stack((x, vx), axis=-1) - y_phase_space = jnp.stack((y, vy), axis=-1) - z_phase_space = jnp.stack((z, vz), axis=-1) - - jnp.save(f"{path}/data/phase_space/x/{name}_phase_space.{t:09}.npy", x_phase_space) - jnp.save(f"{path}/data/phase_space/y/{name}_phase_space.{t:09}.npy", y_phase_space) - jnp.save(f"{path}/data/phase_space/z/{name}_phase_space.{t:09}.npy", z_phase_space) - # write the phase space of the particles to a file - -def particles_phase_space(particles, world, t, name, path): - """ - Plot the phase space of the particles. - - Args: - particles (Particles): The particles to be plotted. - t (ndarray): The time values. - name (str): The name of the plot. - - Returns: - None - """ - - x_wind = world['x_wind'] - y_wind = world['y_wind'] - z_wind = world['z_wind'] - - colors = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] - idx = 0 - order = 10 - for species in particles: - x, y, z = species.get_position() - vx, vy, vz = species.get_velocity() - plt.scatter(x, vx, c=colors[idx], zorder=order, s=1) - idx += 1 - order -= 1 - plt.xlabel("Position") - plt.ylabel("Velocity") - #plt.ylim(-1e10, 1e10) - plt.xlim(-(2/3)*x_wind, (2/3)*x_wind) - plt.title(f"{name} Phase Space") - plt.savefig(f"{path}/data/phase_space/x/{name}_phase_space.{t:09}.png", dpi=300) - plt.close() - - idx = 0 - for species in particles: - x, y, z = species.get_position() - vx, vy, vz = species.get_velocity() - plt.scatter(y, vy, c=colors[idx]) - idx += 1 - plt.xlabel("Position") - plt.ylabel("Velocity") - plt.xlim(-(2/3)*y_wind, (2/3)*y_wind) - plt.title(f"{name} Phase Space") - plt.savefig(f"{path}/data/phase_space/y/{name}_phase_space.{t:09}.png", dpi=150) - plt.close() - - idx = 0 - for species in particles: - x, y, z = species.get_position() - vx, vy, vz = species.get_velocity() - plt.scatter(z, vz, c=colors[idx]) - idx += 1 - plt.xlabel("Position") - plt.ylabel("Velocity") - plt.xlim(-(2/3)*z_wind, (2/3)*z_wind) - plt.title(f"{name} Phase Space") - plt.savefig(f"{path}/data/phase_space/z/{name}_phase_space.{t:09}.png", dpi=150) - plt.close() - - -def plot_initial_histograms(particle_species, world, name, path): - """ - Generates and saves histograms for the initial positions and velocities - of particles in a simulation. - - Parameters: - particle_species (object): An object representing the particle species, - which provides methods `get_position()` and - `get_velocity()` to retrieve particle positions - (x, y, z) and velocities (vx, vy, vz). - world (dict): A dictionary containing the simulation world parameters, - specifically the wind dimensions with keys 'x_wind', - 'y_wind', and 'z_wind'. - name (str): A string representing the name of the particle species or - simulation, used in the titles of the histograms and filenames. - path (str): The directory path where the histogram images will be saved. - - Saves: - Six histogram images: - - Initial X, Y, and Z position histograms. - - Initial X, Y, and Z velocity histograms. - The images are saved in the specified `path` directory with filenames - formatted as "{name}_initial__histogram.png". - """ - - x, y, z = particle_species.get_position() - vx, vy, vz = particle_species.get_velocity() - - x_wind = world['x_wind'] - y_wind = world['y_wind'] - z_wind = world['z_wind'] - - - plt.hist(x, bins=50) - plt.xlabel("X") - plt.ylabel("Number of Particles") - plt.xlim(-(2/3)*x_wind, (2/3)*x_wind) - plt.title(f"{name} Initial X Position Histogram") - plt.savefig(f"{path}/{name}_initial_x_histogram.png", dpi=150) - plt.close() - - plt.hist(y, bins=50) - plt.xlabel("Y") - plt.ylabel("Number of Particles") - plt.xlim(-(2/3)*y_wind, (2/3)*y_wind) - plt.title(f"{name} Initial Y Position Histogram") - plt.savefig(f"{path}/{name}_initial_y_histogram.png", dpi=150) - plt.close() - - plt.hist(z, bins=50) - plt.xlabel("Z") - plt.ylabel("Number of Particles") - plt.xlim(-(2/3)*z_wind, (2/3)*z_wind) - plt.title(f"{name} Initial Z Position Histogram") - plt.savefig(f"{path}/{name}_initial_z_histogram.png", dpi=150) - plt.close() - - plt.hist(vx, bins=50) - plt.xlabel("X Velocity") - plt.ylabel("Number of Particles") - plt.title(f"{name} Initial X Velocity Histogram") - plt.savefig(f"{path}/{name}_initial_x_velocity_histogram.png", dpi=150) - plt.close() - - plt.hist(vy, bins=50) - plt.xlabel("Y Velocity") - plt.ylabel("Number of Particles") - plt.title(f"{name} Initial Y Velocity Histogram") - plt.savefig(f"{path}/{name}_initial_y_velocity_histogram.png", dpi=150) - plt.close() - - plt.hist(vz, bins=50) - plt.xlabel("Z Velocity") - plt.ylabel("Number of Particles") - plt.title(f"{name} Initial Z Velocity Histogram") - plt.savefig(f"{path}/{name}_initial_z_velocity_histogram.png", dpi=150) - plt.close() - - -def plot_initial_KE(particles, path): - """ - Plots the initial kinetic energy of the particles. - - Args: - particles (Particles): The particles to be plotted. - t (ndarray): The time values. - name (str): The name of the plot. - - Returns: - None - """ - for particle in particles: - particle_name = particle.get_name().replace(" ", "") - vx, vy, vz = particle.get_velocity() - vmag = jnp.sqrt(vx**2 + vy**2 + vz**2) - KE = 0.5 * particle.get_mass() * vmag**2 - plt.hist(KE, bins=50) - plt.xlabel("Kinetic Energy") - plt.ylabel("Number of Particles") - plt.title("Initial Kinetic Energy") - plt.savefig(f"{path}/data/{particle_name}_initialKE.png", dpi=300) - plt.close() - - -@partial(jax.jit, static_argnums=(2, 3)) -def plot_slice(field_slice, t, name, path, world, dt): - """ - Plots a 2D slice of a field and saves the plot as a PNG file using JAX's debug callback. - - Args: - field_slice (2D array): The 2D array representing the field slice to be plotted. - t (int): The time step index. - name (str): The name of the field being plotted. - path (str): The directory path where the plot will be saved. - world (dict): A dictionary containing the dimensions of the world with keys 'x_wind' and 'y_wind'. - dt (float): The time step duration. - - Returns: - None - """ - - def plot_and_save(field_slice, t, name, path, world, dt): - if not os.path.exists(f"{path}/data/{name}_slice"): - os.makedirs(f"{path}/data/{name}_slice") - # Create directory if it doesn't exist - - plt.title(f'{name} at t={t*dt:.2e}s') - plt.imshow(jnp.swapaxes(field_slice, 0, 1), origin='lower', - extent=[-world['x_wind']/2, world['x_wind']/2, - -world['y_wind']/2, world['y_wind']/2]) - plt.colorbar(label=name) - plt.tight_layout() - plt.savefig(f'{path}/data/{name}_slice/{name}_slice_{t:09}.png', dpi=300) - plt.clf() # Clear the current figure - plt.close('all') # Close all figures to free up memory - - # Use JAX debug callback to execute the plotting function - return jax.debug.callback(plot_and_save, field_slice, t, name, path, world, dt, ordered=True) - - -def write_slice(field_slice, x1, x2, t, name, path, dt): - """ - Plots a slice of a field and saves it in VTK format. - Parameters: - field_slice (numpy.ndarray): The 2D slice of the field to be plotted. - x1 (numpy.ndarray): The x-coordinates of the field slice. - x2 (numpy.ndarray): The y-coordinates of the field slice. - t (int): The time step or index for the slice. - name (str): The name of the field or slice. - path (str): The directory path where the VTK file will be saved. - dt (float): The time step size (not used in the function but included in parameters). - Returns: - None - """ - - x3 = np.zeros(1) - - field_slice = jnp.expand_dims(field_slice, axis=-1) - - gridToVTK(f"{path}/data/{name}_slice/{name}_slice_{t:09}", x1, x2, x3, \ - cellData = {f"{name}" : field_slice}) - # plot the slice of the field in the vtk file format - - -@partial(jax.jit, static_argnums=(0)) -def write_data(filename, time, data): - """ - Write the given time and data to a file using JAX's callback mechanism. - This function is designed to be used with JAX's just-in-time compilation (jit) to optimize performance. - - Args: - filename (str): The name of the file to write to. - time (float): The time value to write. - data (any): The data to write. - - Returns: - None - """ - - def write_to_file(filename, time, data): - with open(filename, "a") as f: - f.write(f"{time}, {data}\n") - - return jax.debug.callback(write_to_file, filename, time, data, ordered=True) - - -def plot_vtk_particles(particles, t, path): - """ - Plot the particles in VTK format. - - Args: - particles (Particles): The particles to be plotted. - t (ndarray): The time values. - path (str): The path to save the plot. - - Returns: - None - """ - if not os.path.exists(f"{path}/data/particles"): - os.makedirs(f"{path}/data/particles") - - particle_names = [species.get_name().replace(" ", "") for species in particles] - for species in particles: - name = species.get_name().replace(" ", "") - x, y, z = map(np.asarray, species.get_position()) - vx, vy, vz = map(np.asarray, species.get_velocity()) - # Get the position and velocity of the particles - q = np.asarray( species.get_charge() * np.ones_like(vx) ) - # Get the charge of the particles - - pointsToVTK(f"{path}/data/particles/{name}.{t:09}", x, y, z, \ - {"v": (vx, vy, vz), "q": q}) - # save the particles in the vtk file format - - -def plot_field_slice_vtk(field_slices, field_names, slice, grid, t, name, path, world): - """ - Plot a slice of a field in VTK format using Python VTK library. - - Args: - field_slices (list): List of 2D field slices to be plotted. - field_names (list): List of field names corresponding to the slices. - slice (int): Slice direction (0=x-slice, 1=y-slice, 2=z-slice, 3=full 3D). - grid (tuple): The grid dimensions (x, y, z). - t (int): The time step. - name (str): The name of the field. - path (str): The path to save the plot. - world (dict): World parameters containing grid information. - - Returns: - None - """ - - x, y, z = grid - nx, ny, nz = world['Nx'], world['Ny'], world['Nz'] - dx, dy, dz = world['dx'], world['dy'], world['dz'] - - if not os.path.exists(f"{path}/data/{name}_slice"): - os.makedirs(f"{path}/data/{name}_slice") - # Create directory if it doesn't exist - - # Create VTK structured grid - structured_grid = vtk.vtkStructuredGrid() - - if slice == 0: - x = np.asarray([x[nx//2]]) - elif slice == 1: - y = np.asarray([y[ny//2]]) - elif slice == 2: - z = np.asarray([z[nz//2]]) - - structured_grid.SetDimensions(x.shape[0], y.shape[0], z.shape[0]) - # Set the dimensions of the structured grid based on the slice type - - # Efficiently create all grid points using numpy meshgrid and bulk insert - X, Y, Z = np.meshgrid(x, y, z, indexing='ij') - coords = np.column_stack((Z.ravel(), Y.ravel(), X.ravel())) - points = vtk.vtkPoints() - points.SetData(numpy_support.numpy_to_vtk(coords, deep=True)) - structured_grid.SetPoints(points) - # Create points for the structured grid based on the slice type - - for idx, (field_slice, field_name) in enumerate(zip(field_slices, field_names)): - # Ensure field_slice is 2D and handle VTK ordering - field_data = np.asarray(field_slice) - if field_data.ndim == 2: - # VTK expects data in (k,j,i) order, but our slice is typically (j,k) - field_data = field_data.T.flatten() - else: - field_data = field_data.flatten() - - vtk_array = numpy_support.numpy_to_vtk(field_data) - vtk_array.SetName(field_name) - structured_grid.GetPointData().AddArray(vtk_array) - - # Write the VTK file - writer = vtk.vtkStructuredGridWriter() - writer.SetFileName(f"{path}/data/{name}_slice/{name}_slice_{t:09}.vtk") - writer.SetInputData(structured_grid) - writer.Write() - - -def plot_vectorfield_slice_vtk(field_slices, field_names, slice, grid, t, name, path, world): - """ - Plot a slice of a field in VTK format as vector data using Python VTK library. - - Args: - field_slices (list): List of 2D field slices to be plotted. Should be [Fx, Fy, Fz] for vector fields. - field_names (list): List of field names corresponding to the slices (e.g., ['Ex', 'Ey', 'Ez']). - slice (int): Slice direction (0=x-slice, 1=y-slice, 2=z-slice, 3=full 3D). - grid (tuple): The grid dimensions (x, y, z). - t (int): The time step. - name (str): The name of the field. - path (str): The path to save the plot. - world (dict): World parameters containing grid information. - - Returns: - None - """ - - x, y, z = grid - nx, ny, nz = world['Nx'], world['Ny'], world['Nz'] - dx, dy, dz = world['dx'], world['dy'], world['dz'] - - if not os.path.exists(f"{path}/data/{name}_slice"): - os.makedirs(f"{path}/data/{name}_slice") - - # Handle slicing - if slice == 0: - x = np.asarray([x[nx//2]]) - elif slice == 1: - y = np.asarray([y[ny//2]]) - elif slice == 2: - z = np.asarray([z[nz//2]]) - - # Create VTK structured grid - structured_grid = vtk.vtkStructuredGrid() - structured_grid.SetDimensions(x.shape[0], y.shape[0], z.shape[0]) - - # Create grid points - X, Y, Z = np.meshgrid(x, y, z, indexing='ij') - coords = np.column_stack((Z.ravel(), Y.ravel(), X.ravel())) - points = vtk.vtkPoints() - points.SetData(numpy_support.numpy_to_vtk(coords, deep=True)) - structured_grid.SetPoints(points) - - # Stack field slices as vector data - # Each field_slice should be 2D, shape (len(x), len(y)) or similar - # We flatten and stack them as (N, 3) for VTK vector - for field_slice, field_name in zip(field_slices, field_names): - field_arrays = [] - for comp in field_slice: - field_data = np.asarray(comp) - # convert to np array - if field_data.ndim == 2: - field_data = field_data.T.flatten() # VTK expects Fortran order - else: - field_data = field_data.flatten() - field_arrays.append(field_data) - # Stack as (N, 3) - vector_data = np.stack(field_arrays, axis=-1) - # If shape is (N, 3), convert to VTK - vtk_vector_array = numpy_support.numpy_to_vtk(vector_data, deep=True) - vtk_vector_array.SetName(f"{field_name}_vector") - structured_grid.GetPointData().AddArray(vtk_vector_array) - - # Write the VTK file - writer = vtk.vtkStructuredGridWriter() - writer.SetFileName(f"{path}/data/{name}_slice/{name}_slice_{t:09}.vtk") - writer.SetInputData(structured_grid) - writer.Write() - - -def write_openpmd_initial_particles(particles, world, constants, output_dir, filename="initial_particles.h5"): - """ - Write the initial particle states to separate openPMD files, one per species. - - Args: - particles (list): List of particle species. - world (dict): Dictionary containing the simulation world parameters. - constants (dict): Dictionary of physical constants (must include key 'C' for the speed of light). - output_dir (str): Base output directory for the simulation. - filename (str): Base name of the openPMD output file (species name is prepended). - """ - if not particles: - return - - C = constants['C'] - # speed of light - - output_path = os.path.join(output_dir, "data", "openpmd") - os.makedirs(output_path, exist_ok=True) - - def make_array_writable(arr): - arr = np.array(arr, dtype=np.float64, copy=True, order="C") - arr.setflags(write=True) - return arr - - for species in particles: - species_name = species.get_name().replace(" ", "_") - series_filename = f"{species_name}_{filename}" - series_path = os.path.join(output_path, series_filename) - - series = io.Series(series_path, io.Access.create) - series.set_attribute("software", "PyPIC3D") - series.set_attribute("softwareVersion", importlib.metadata.version("PyPIC3D")) - - iteration = series.iterations[0] - iteration.time = 0.0 - iteration.dt = float(world["dt"]) - iteration.time_unit_SI = 1.0 - - species_group = iteration.particles[species_name] - - x, y, z = species.get_forward_position() - vx, vy, vz = species.get_velocity() - gamma = 1 / jnp.sqrt(1.0 - (vx**2 + vy**2 + vz**2) / C**2) - # compute the Lorentz factor - - x = make_array_writable(x) - y = make_array_writable(y) - z = make_array_writable(z) - vx = make_array_writable(vx) - vy = make_array_writable(vy) - vz = make_array_writable(vz) - gamma = make_array_writable(gamma) - - num_particles = x.shape[0] - particle_mass = float(species.mass) - particle_charge = float(species.charge) - - position = species_group["position"] - for component, data in zip(("x", "y", "z"), (x, y, z)): - record_component = position[component] - record_component.reset_dataset(io.Dataset(data.dtype, [num_particles])) - record_component.store_chunk(data, [0], [num_particles]) - record_component.unit_SI = 1.0 - - # positionOffset: required by openPMD consumers (WarpX expects it) - pos_off = species_group["positionOffset"] - zeros = np.zeros(num_particles, dtype=np.float64) - for comp in ("x", "y", "z"): - rc = pos_off[comp] - rc.reset_dataset(io.Dataset(zeros.dtype, [num_particles])) - rc.store_chunk(zeros, [0], [num_particles]) - rc.unit_SI = 1.0 - - momentum = species_group["momentum"] - for component, data in zip(("x", "y", "z"), (vx, vy, vz)): - record_component = momentum[component] - record_component.reset_dataset(io.Dataset(data.dtype, [num_particles])) - record_component.store_chunk(data * particle_mass * gamma , [0], [num_particles]) - record_component.unit_SI = 1.0 - - weighting = species_group["weighting"] - weights = np.full(num_particles, float(species.weight), dtype=np.float64) - weighting.reset_dataset(io.Dataset(weights.dtype, [num_particles])) - weighting.store_chunk(weights, [0], [num_particles]) - weighting.unit_SI = 1.0 - - charge = species_group["charge"] - charges = np.full(num_particles, particle_charge, dtype=np.float64) - charge.reset_dataset(io.Dataset(charges.dtype, [num_particles])) - charge.store_chunk(charges, [0], [num_particles]) - charge.unit_SI = 1.0 - - mass = species_group["mass"] - masses = np.full(num_particles, particle_mass, dtype=np.float64) - mass.reset_dataset(io.Dataset(masses.dtype, [num_particles])) - mass.store_chunk(masses, [0], [num_particles]) - mass.unit_SI = 1.0 - - series.flush() - series.close() \ No newline at end of file diff --git a/PyPIC3D/rho.py b/PyPIC3D/rho.py index 94c0d55..57ec029 100644 --- a/PyPIC3D/rho.py +++ b/PyPIC3D/rho.py @@ -5,6 +5,7 @@ # import external libraries from PyPIC3D.utils import digital_filter, wrap_around +from PyPIC3D.shapes import get_first_order_weights, get_second_order_weights # import internal libraries @jit @@ -116,308 +117,4 @@ def compute_rho(particles, rho, world, constants): rho = digital_filter(rho, alpha) # apply a digital filter to the charge density array - return rho - -@jit -def get_second_order_weights(deltax, deltay, deltaz, dx, dy, dz): - """ - Calculate the second-order weights for particle current distribution. - - Args: - deltax, deltay, deltaz (float): Particle position offsets from grid points. - dx, dy, dz (float): Grid spacings in x, y, and z directions. - - Returns: - tuple: Weights for x, y, and z directions. - """ - Sx0 = (3/4) - (deltax/dx)**2 - Sy0 = (3/4) - (deltay/dy)**2 - Sz0 = (3/4) - (deltaz/dz)**2 - - Sx1 = (1/2) * ((1/2) - (deltax/dx))**2 - Sy1 = (1/2) * ((1/2) - (deltay/dy))**2 - Sz1 = (1/2) * ((1/2) - (deltaz/dz))**2 - - Sx_minus1 = (1/2) * ((1/2) + (deltax/dx))**2 - Sy_minus1 = (1/2) * ((1/2) + (deltay/dy))**2 - Sz_minus1 = (1/2) * ((1/2) + (deltaz/dz))**2 - # second order weights - - x_weights = [Sx_minus1, Sx0, Sx1] - y_weights = [Sy_minus1, Sy0, Sy1] - z_weights = [Sz_minus1, Sz0, Sz1] - - return x_weights, y_weights, z_weights - -@jit -def get_first_order_weights(deltax, deltay, deltaz, dx, dy, dz): - """ - Calculate the first-order weights for particle current distribution. - - Args: - deltax, deltay, deltaz (float): Particle position offsets from grid points. - dx, dy, dz (float): Grid spacings in x, y, and z directions. - - Returns: - tuple: Weights for x, y, and z directions. - """ - Sx0 = jnp.asarray(1 - deltax / dx) - Sy0 = jnp.asarray(1 - deltay / dy) - Sz0 = jnp.asarray(1 - deltaz / dz) - - Sx1 = jnp.asarray(deltax / dx) - Sy1 = jnp.asarray(deltay / dy) - Sz1 = jnp.asarray(deltaz / dz) - - Sx_minus1 = jnp.zeros_like(Sx0) - Sy_minus1 = jnp.zeros_like(Sy0) - Sz_minus1 = jnp.zeros_like(Sz0) - # No second-order weights for first-order weighting - - x_weights = [Sx_minus1, Sx0, Sx1] - y_weights = [Sy_minus1, Sy0, Sy1] - z_weights = [Sz_minus1, Sz0, Sz1] - - return x_weights, y_weights, z_weights - - -@jit -def compute_mass_density(particles, rho, world): - """ - Compute the mass density (rho) for a given set of particles in a simulation world. - Parameters: - particles (list): A list of particle species, each containing methods to get the number of particles, - their positions, and their mass. - rho (ndarray): The initial mass density array to be updated. - world (dict): A dictionary containing the simulation world parameters, including: - - 'dx': Grid spacing in the x-direction. - - 'dy': Grid spacing in the y-direction. - - 'dz': Grid spacing in the z-direction. - - 'x_wind': Window size in the x-direction. - - 'y_wind': Window size in the y-direction. - - 'z_wind': Window size in the z-direction. - Returns: - ndarray: The updated charge density array. - """ - dx = world['dx'] - dy = world['dy'] - dz = world['dz'] - x_wind = world['x_wind'] - y_wind = world['y_wind'] - z_wind = world['z_wind'] - Nx, Ny, Nz = rho.shape - # get the shape of the charge density array - - rho = jnp.zeros_like(rho) - # reset rho to zero - - for species in particles: - shape_factor = species.get_shape() - # get the shape factor of the species, which determines the weighting function - N_particles = species.get_number_of_particles() - mass = species.get_mass() - # get the number of particles and their mass - dm = mass / dx / dy / dz - # calculate the mass per unit volume - x, y, z = species.get_position() - # get the position of the particles in the species - - x0 = jnp.floor((x + x_wind / 2) / dx).astype(int) - y0 = jnp.floor((y + y_wind / 2) / dy).astype(int) - z0 = jnp.floor((z + z_wind / 2) / dz).astype(int) - # Calculate the nearest grid points - - deltax = x - jnp.floor(x / dx) * dx - deltay = y - jnp.floor(y / dy) * dy - deltaz = z - jnp.floor(z / dz) * dz - # Calculate the difference between the particle position and the nearest grid point - - x1 = wrap_around(x0 + 1, Nx) - y1 = wrap_around(y0 + 1, Ny) - z1 = wrap_around(z0 + 1, Nz) - # Calculate the index of the next grid point - - x_minus1 = x0 - 1 - y_minus1 = y0 - 1 - z_minus1 = z0 - 1 - # Calculate the index of the previous grid point - - xpts = [x_minus1, x0, x1] - ypts = [y_minus1, y0, y1] - zpts = [z_minus1, z0, z1] - # place all the points in a list - - x_weights, y_weights, z_weights = jax.lax.cond( - shape_factor == 1, - lambda _: get_first_order_weights(deltax, deltay, deltaz, dx, dy, dz), - lambda _: get_second_order_weights(deltax, deltay, deltaz, dx, dy, dz), - operand=None - ) - # get the weighting factors based on the shape factor - - - for i in range(3): - for j in range(3): - for k in range(3): - rho = rho.at[xpts[i], ypts[j], zpts[k]].add( dm * x_weights[i] * y_weights[j] * z_weights[k], mode='drop') - # distribute the mass of the particles to the grid points using the weighting factors - - return rho - -@jit -def compute_velocity_field(particles, field, direction, world): - """ - Compute the velocity field (v) for a given set of particles in a simulation world. - Parameters: - particles (list): A list of particle species, each containing methods to get the number of particles, - their positions, and their mass. - field (ndarray): The initial velocity field array to be updated. - direction (int): The direction along which to compute the velocity field (0: x, 1: y, 2: z). - world (dict): A dictionary containing the simulation world parameters, including: - - 'dx': Grid spacing in the x-direction. - - 'dy': Grid spacing in the y-direction. - - 'dz': Grid spacing in the z-direction. - - 'x_wind': Window size in the x-direction. - - 'y_wind': Window size in the y-direction. - - 'z_wind': Window size in the z-direction. - Returns: - ndarray: The updated velocity field array. - """ - dx = world['dx'] - dy = world['dy'] - dz = world['dz'] - x_wind = world['x_wind'] - y_wind = world['y_wind'] - z_wind = world['z_wind'] - Nx, Ny, Nz = field.shape - # get the shape of the velocity field array - - field = jnp.zeros_like(field) - # reset field to zero - - for species in particles: - shape_factor = species.get_shape() - # get the shape factor of the species, which determines the weighting function - N_particles = species.get_number_of_particles() - # get the number of particles - x, y, z = species.get_position() - # get the position of the particles in the species - vx, vy, vz = species.get_velocity() - # get the velocity of the particles in the species - - dv = jnp.array([vx, vy, vz])[direction] / N_particles - # select the velocity component based on the direction - - x0 = jnp.floor((x + x_wind / 2) / dx).astype(int) - y0 = jnp.floor((y + y_wind / 2) / dy).astype(int) - z0 = jnp.floor((z + z_wind / 2) / dz).astype(int) - # Calculate the nearest grid points - - deltax = x - jnp.floor(x / dx) * dx - deltay = y - jnp.floor(y / dy) * dy - deltaz = z - jnp.floor(z / dz) * dz - # Calculate the difference between the particle position and the nearest grid point - - x1 = wrap_around(x0 + 1, Nx) - y1 = wrap_around(y0 + 1, Ny) - z1 = wrap_around(z0 + 1, Nz) - # Calculate the index of the next grid point - - x_minus1 = x0 - 1 - y_minus1 = y0 - 1 - z_minus1 = z0 - 1 - # Calculate the index of the previous grid point - - xpts = [x_minus1, x0, x1] - ypts = [y_minus1, y0, y1] - zpts = [z_minus1, z0, z1] - # place all the points in a list - - x_weights, y_weights, z_weights = jax.lax.cond( - shape_factor == 1, - lambda _: get_first_order_weights(deltax, deltay, deltaz, dx, dy, dz), - lambda _: get_second_order_weights(deltax, deltay, deltaz, dx, dy, dz), - operand=None - ) - # get the weighting factors based on the shape factor - - for i in range(3): - for j in range(3): - for k in range(3): - field = field.at[xpts[i], ypts[j], zpts[k]].add( dv * x_weights[i] * y_weights[j] * z_weights[k], mode='drop') - # distribute the velocity of the particles to the grid points using the weighting factors - - return field - - - - -@jit -def compute_pressure_field(particles, field, velocity_field, direction, world): - - dx = world['dx'] - dy = world['dy'] - dz = world['dz'] - x_wind = world['x_wind'] - y_wind = world['y_wind'] - z_wind = world['z_wind'] - Nx, Ny, Nz = field.shape - # get the shape of the velocity field array - - field = jnp.zeros_like(field) - # reset field to zero - - for species in particles: - shape_factor = species.get_shape() - # get the shape factor of the species, which determines the weighting function - x, y, z = species.get_position() - # get the position of the particles in the species - vx, vy, vz = species.get_velocity() - # get the velocity of the particles in the species - - - v = jnp.array([vx, vy, vz])[direction] - # select the velocity component based on the direction - - x0 = jnp.floor((x + x_wind / 2) / dx).astype(int) - y0 = jnp.floor((y + y_wind / 2) / dy).astype(int) - z0 = jnp.floor((z + z_wind / 2) / dz).astype(int) - # Calculate the nearest grid points - - deltax = x - jnp.floor(x / dx) * dx - deltay = y - jnp.floor(y / dy) * dy - deltaz = z - jnp.floor(z / dz) * dz - # Calculate the difference between the particle position and the nearest grid point - - x1 = wrap_around(x0 + 1, Nx) - y1 = wrap_around(y0 + 1, Ny) - z1 = wrap_around(z0 + 1, Nz) - # Calculate the index of the next grid point - - x_minus1 = x0 - 1 - y_minus1 = y0 - 1 - z_minus1 = z0 - 1 - # Calculate the index of the previous grid point - - xpts = [x_minus1, x0, x1] - ypts = [y_minus1, y0, y1] - zpts = [z_minus1, z0, z1] - # place all the points in a list - - x_weights, y_weights, z_weights = jax.lax.cond( - shape_factor == 1, - lambda _: get_first_order_weights(deltax, deltay, deltaz, dx, dy, dz), - lambda _: get_second_order_weights(deltax, deltay, deltaz, dx, dy, dz), - operand=None - ) - # get the weighting factors based on the shape factor - - for i in range(3): - for j in range(3): - for k in range(3): - vbar = v - velocity_field.at[xpts[i], ypts[j], zpts[k]].get() - - field = field.at[xpts[i], ypts[j], zpts[k]].add( vbar**2 * x_weights[i] * y_weights[j] * z_weights[k], mode='drop') - # distribute the pressure moment of the particles to the grid points using the weighting factors - - return field \ No newline at end of file + return rho \ No newline at end of file diff --git a/PyPIC3D/shapes.py b/PyPIC3D/shapes.py new file mode 100644 index 0000000..2738aef --- /dev/null +++ b/PyPIC3D/shapes.py @@ -0,0 +1,65 @@ +from jax import jit +import jax.numpy as jnp + + +@jit +def get_second_order_weights(deltax, deltay, deltaz, dx, dy, dz): + """ + Calculate the second-order weights for particle current distribution. + + Args: + deltax, deltay, deltaz (float): Particle position offsets from grid points. + dx, dy, dz (float): Grid spacings in x, y, and z directions. + + Returns: + tuple: Weights for x, y, and z directions. + """ + Sx0 = (3/4) - (deltax/dx)**2 + Sy0 = (3/4) - (deltay/dy)**2 + Sz0 = (3/4) - (deltaz/dz)**2 + + Sx1 = (1/2) * ((1/2) + (deltax/dx))**2 + Sy1 = (1/2) * ((1/2) + (deltay/dy))**2 + Sz1 = (1/2) * ((1/2) + (deltaz/dz))**2 + + Sx_minus1 = (1/2) * ((1/2) - (deltax/dx))**2 + Sy_minus1 = (1/2) * ((1/2) - (deltay/dy))**2 + Sz_minus1 = (1/2) * ((1/2) - (deltaz/dz))**2 + # second order weights + + x_weights = [Sx_minus1, Sx0, Sx1] + y_weights = [Sy_minus1, Sy0, Sy1] + z_weights = [Sz_minus1, Sz0, Sz1] + + return x_weights, y_weights, z_weights + +@jit +def get_first_order_weights(deltax, deltay, deltaz, dx, dy, dz): + """ + Calculate the first-order weights for particle current distribution. + + Args: + deltax, deltay, deltaz (float): Particle position offsets from grid points. + dx, dy, dz (float): Grid spacings in x, y, and z directions. + + Returns: + tuple: Weights for x, y, and z directions. + """ + Sx0 = jnp.asarray(1 - deltax / dx) + Sy0 = jnp.asarray(1 - deltay / dy) + Sz0 = jnp.asarray(1 - deltaz / dz) + + Sx1 = jnp.asarray(deltax / dx) + Sy1 = jnp.asarray(deltay / dy) + Sz1 = jnp.asarray(deltaz / dz) + + Sx_minus1 = jnp.zeros_like(Sx0) + Sy_minus1 = jnp.zeros_like(Sy0) + Sz_minus1 = jnp.zeros_like(Sz0) + # No second-order weights for first-order weighting + + x_weights = [Sx_minus1, Sx0, Sx1] + y_weights = [Sy_minus1, Sy0, Sy1] + z_weights = [Sz_minus1, Sz0, Sz1] + + return x_weights, y_weights, z_weights \ No newline at end of file diff --git a/PyPIC3D/utils.py b/PyPIC3D/utils.py index e329dd2..6f2c02a 100644 --- a/PyPIC3D/utils.py +++ b/PyPIC3D/utils.py @@ -15,6 +15,22 @@ from scipy import stats # import external libraries +def setup_pmd_files(file_path, name, extension=".bp"): + """ + Set up the openPMD file structure for storing simulation data. + + Args: + file_path (str): The path where the openPMD files will be stored. + name (str): The base name for the openPMD files. + Returns: + None + """ + + file = os.path.join(file_path, name + ".pmd") + with open(file, 'w') as f: + f.write(f"{name}{extension}\n") + # create the openPMD file structure + @jit def wrap_around(ix, size): """Wrap around index (scalar or 1D array) to ensure it is within bounds.""" diff --git a/docs/images/PyPICLogo.png b/docs/images/PyPICLogo.png index 5328948..755785b 100644 Binary files a/docs/images/PyPICLogo.png and b/docs/images/PyPICLogo.png differ