diff --git a/manimlib/mobject/vector_field.py b/manimlib/mobject/vector_field.py index 38ca9dc5..df784190 100644 --- a/manimlib/mobject/vector_field.py +++ b/manimlib/mobject/vector_field.py @@ -5,7 +5,8 @@ import itertools as it import numpy as np from manimlib.constants import FRAME_HEIGHT, FRAME_WIDTH -from manimlib.constants import WHITE +from manimlib.constants import BLUE, WHITE +from manimlib.constants import ORIGIN from manimlib.animation.indication import VShowPassingFlash from manimlib.mobject.geometry import Arrow from manimlib.mobject.types.vectorized_mobject import VGroup @@ -15,6 +16,7 @@ from manimlib.utils.bezier import inverse_interpolate from manimlib.utils.color import get_colormap_list from manimlib.utils.color import rgb_to_color from manimlib.utils.dict_ops import merge_dicts_recursively +from manimlib.utils.iterables import cartesian_product from manimlib.utils.rate_functions import linear from manimlib.utils.simple_functions import sigmoid from manimlib.utils.space_ops import get_norm @@ -118,7 +120,184 @@ def get_sample_points_from_coordinate_system( # Mobjects -class VectorField(VGroup): + +class VectorField(VMobject): + def __init__( + self, + func, + stroke_color: ManimColor = BLUE, + stroke_opacity: float = 1.0, + center: Vect3 = ORIGIN, + sample_points: Optional[Vect3Array] = None, + x_density: float = 2.0, + y_density: float = 2.0, + z_density: float = 2.0, + width: float = 14.0, + height: float = 8.0, + depth: float = 0.0, + stroke_width: float = 2, + tip_width_ratio: float = 4, + tip_len_to_width: float = 0.01, + max_vect_len: float | None = None, + min_drawn_norm: float = 1e-2, + flat_stroke: bool = False, + norm_to_opacity_func=None, + norm_to_rgb_func=None, + **kwargs + ): + self.func = func + self.stroke_width = stroke_width + self.tip_width_ratio = tip_width_ratio + self.tip_len_to_width = tip_len_to_width + self.min_drawn_norm = min_drawn_norm + self.norm_to_opacity_func = norm_to_opacity_func + self.norm_to_rgb_func = norm_to_rgb_func + + if max_vect_len is not None: + self.max_vect_len = max_vect_len + else: + densities = np.array([x_density, y_density, z_density]) + dims = np.array([width, height, depth]) + self.max_vect_len = 1.0 / densities[dims > 0].mean() + + if sample_points is None: + self.sample_points = self.get_sample_points( + center, width, height, depth, + x_density, y_density, z_density + ) + else: + self.sample_points = sample_points + + self.init_base_stroke_width_array(len(self.sample_points)) + + super().__init__( + stroke_color=stroke_color, + stroke_opacity=stroke_opacity, + flat_stroke=flat_stroke, + **kwargs + ) + + n_samples = len(self.sample_points) + self.set_points(np.zeros((8 * n_samples - 1, 3))) + self.set_stroke(width=stroke_width) + self.set_joint_type('no_joint') + self.update_vectors() + + def get_sample_points( + self, + center: np.ndarray, + width: float, + height: float, + depth: float, + x_density: float, + y_density: float, + z_density: float + ) -> np.ndarray: + to_corner = np.array([width / 2, height / 2, depth / 2]) + spacings = 1.0 / np.array([x_density, y_density, z_density]) + to_corner = spacings * (to_corner / spacings).astype(int) + lower_corner = center - to_corner + upper_corner = center + to_corner + spacings + return cartesian_product(*( + np.arange(low, high, space) + for low, high, space in zip(lower_corner, upper_corner, spacings) + )) + + def init_base_stroke_width_array(self, n_sample_points): + arr = np.ones(8 * n_sample_points - 1) + arr[4::8] = self.tip_width_ratio + arr[5::8] = self.tip_width_ratio * 0.5 + arr[6::8] = 0 + arr[7::8] = 0 + self.base_stroke_width_array = arr + + def set_sample_points(self, sample_points: Vect3Array): + self.sample_points = sample_points + return self + + def set_stroke(self, color=None, width=None, opacity=None, behind=None, flat=None, recurse=True): + super().set_stroke(color, None, opacity, behind, flat, recurse) + if width is not None: + self.set_stroke_width(float(width)) + return self + + def set_stroke_width(self, width: float): + if self.get_num_points() > 0: + self.get_stroke_widths()[:] = width * self.base_stroke_width_array + self.stroke_width = width + return self + + def update_vectors(self): + tip_width = self.tip_width_ratio * self.stroke_width + tip_len = self.tip_len_to_width * tip_width + samples = self.sample_points + + # Get raw outputs and lengths + outputs = self.func(samples) + norms = np.linalg.norm(outputs, axis=1)[:, np.newaxis] + + # How long should the arrows be drawn? + max_len = self.max_vect_len + if max_len < np.inf: + drawn_norms = max_len * np.tanh(norms / max_len) + else: + drawn_norms = norms + + # What's the distance from the base of an arrow to + # the base of its head? + dist_to_head_base = np.clip(drawn_norms - tip_len, 0, np.inf) + + # Set all points + unit_outputs = np.zeros_like(outputs) + np.true_divide(outputs, norms, out=unit_outputs, where=(norms > self.min_drawn_norm)) + + points = self.get_points() + points[0::8] = samples + points[2::8] = samples + dist_to_head_base * unit_outputs + points[4::8] = points[2::8] + points[6::8] = samples + drawn_norms * unit_outputs + for i in (1, 3, 5): + points[i::8] = 0.5 * (points[i - 1::8] + points[i + 1::8]) + points[7::8] = points[6:-1:8] + + # Adjust stroke widths + width_arr = self.stroke_width * self.base_stroke_width_array + width_scalars = np.clip(drawn_norms / tip_len, 0, 1) + width_scalars = np.repeat(width_scalars, 8)[:-1] + self.get_stroke_widths()[:] = width_scalars * width_arr + + # Potentially adjust opacity and color + if self.norm_to_opacity_func is not None: + self.get_stroke_opacities()[:] = self.norm_to_opacity_func( + np.repeat(norms, 8)[:-1] + ) + if self.norm_to_rgb_func is not None: + self.get_stroke_colors() + self.data['stroke_rgba'][:, :3] = self.norm_to_rgb_func( + np.repeat(norms, 8)[:-1] + ) + + self.note_changed_data() + return self + + +class TimeVaryingVectorField(VectorField): + def __init__( + self, + # Takes in an array of points and a float for time + time_func, + **kwargs + ): + self.time = 0 + super().__init__(func=lambda p: time_func(p, self.time), **kwargs) + self.add_updater(lambda m, dt: m.increment_time(dt)) + always(self.update_vectors) + + def increment_time(self, dt): + self.time += dt + + +class OldVectorField(VGroup): def __init__( self, func: Callable[[float, float], Sequence[float]],