Source code for wake_t.fields.analytical_field

"""Contains the class used to define analytic fields."""

from typing import Callable, Optional, List
import numpy as np

from .base import Field
from wake_t.utilities.numba import njit_parallel


# Define type alias.
FieldFunction = Callable[
    [np.ndarray, np.ndarray, np.ndarray, float, np.ndarray, List], np.ndarray
]


[docs] class AnalyticalField(Field): """Class used to define fields with analytical components. The given components (Ex, Ey, Ez, Bx, By, Bz) must be functions taking 5 arguments (3 arrays containing the x, y, z positions where to calculate the field; 1 array with the same size where the field values will be stored; and a list of constants). The given functions must be written in a way which allows them to be compiled with ``numba``. Not all components need to be given. Those which are not specified will simply return a zero array when gathered. In addition to the field components, a list of constants can also be given. This list of constants is always passed to the field functions and can be used to compute the field. Parameters ---------- e_x : callable, optional Function defining the Ex component. e_y : callable, optional Function defining the Ey component. e_z : callable, optional Function defining the Ez component. b_x : callable, optional Function defining the Bx component. b_y : callable, optional Function defining the By component. b_z : callable, optional Function defining the Bz component. constants : list, optional List of constants to be passed to each component. Examples -------- >>> from numba import prange >>> def linear_ex(x, y, z, t, ex, constants): ... ex_slope = constants[0] ... for i in prange(x.shape[0]): ... ex[i] = ex_slope * x[i] ... >>> ex = AnalyticField(e_x=linear_ex, constants=[1e6]) """ def __init__( self, e_x: Optional[FieldFunction] = None, e_y: Optional[FieldFunction] = None, e_z: Optional[FieldFunction] = None, b_x: Optional[FieldFunction] = None, b_y: Optional[FieldFunction] = None, b_z: Optional[FieldFunction] = None, constants: Optional[List] = None, ) -> None: super().__init__() constants = [] if constants is None else constants def no_field(x, y, z, t, fld, k): """Default field component.""" pass self.__e_x = njit_parallel(e_x) if e_x is not None else no_field self.__e_y = njit_parallel(e_y) if e_y is not None else no_field self.__e_z = njit_parallel(e_z) if e_z is not None else no_field self.__b_x = njit_parallel(b_x) if b_x is not None else no_field self.__b_y = njit_parallel(b_y) if b_y is not None else no_field self.__b_z = njit_parallel(b_z) if b_z is not None else no_field self.constants = np.array(constants) def _pre_gather(self, x, y, z, t): """Function that is automatically called just before gathering. This method can be overwritten by derived classes and used to, for example, pre-compute any useful quantities. This method is not compiled by numba. """ pass def _gather(self, x, y, z, t, ex, ey, ez, bx, by, bz, bunch_name): self._pre_gather(x, y, z, t) self.__e_x(x, y, z, t, ex, self.constants) self.__e_y(x, y, z, t, ey, self.constants) self.__e_z(x, y, z, t, ez, self.constants) self.__b_x(x, y, z, t, bx, self.constants) self.__b_y(x, y, z, t, by, self.constants) self.__b_z(x, y, z, t, bz, self.constants)