From 0ad5a0e76e8c075380b94be753efd6f1eb47bf65 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Fri, 15 Nov 2024 09:07:46 -0800 Subject: [PATCH] Further development on VectorField --- manimlib/mobject/vector_field.py | 332 ++++++++++++++++--------------- 1 file changed, 175 insertions(+), 157 deletions(-) diff --git a/manimlib/mobject/vector_field.py b/manimlib/mobject/vector_field.py index 0a1eb94a..1c882ba6 100644 --- a/manimlib/mobject/vector_field.py +++ b/manimlib/mobject/vector_field.py @@ -3,6 +3,7 @@ from __future__ import annotations import itertools as it import numpy as np +from scipy.integrate import solve_ivp from manimlib.constants import FRAME_HEIGHT, FRAME_WIDTH from manimlib.constants import BLUE, WHITE @@ -26,7 +27,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Callable, Iterable, Sequence, TypeVar, Tuple - from manimlib.typing import ManimColor, Vect3, VectN, Vect3Array, Vect4Array + from manimlib.typing import ManimColor, Vect3, VectN, Vect2Array, Vect3Array, Vect4Array from manimlib.mobject.coordinate_systems import CoordinateSystem from manimlib.mobject.mobject import Mobject @@ -68,6 +69,16 @@ def get_rgb_gradient_function( #### +def ode_solution_points(function, state0, time, dt=0.01): + solution = solve_ivp( + lambda t, state: function(state), + t_span=(0, time), + y0=state0, + t_eval=np.arange(0, time, dt) + ) + return solution.y.T + + def move_along_vector_field( mobject: Mobject, func: Callable[[Vect3], Vect3] @@ -128,12 +139,12 @@ def get_sample_coords( class VectorField(VMobject): def __init__( self, - func: Callable[Sequence[float], Sequence[float]], + func: Callable[[VectArray], VectArray], coordinate_system: CoordinateSystem, step_multiple: float = 0.5, magnitude_range: Optional[Tuple[float, float]] = None, color_map_name: Optional[str] = "3b1b_colormap", - color_map: Optional[Callable[Sequence[float]], Vect4Array] = None, + color_map: Optional[Callable[[Sequence[float]], Vect4Array]] = None, stroke_color: ManimColor = BLUE, stroke_opacity: float = 1.0, stroke_width: float = 2, @@ -281,9 +292,9 @@ class VectorField(VMobject): if self.color_map is not None: self.get_stroke_colors() # Ensures the array is updated to appropriate length low, high = self.magnitude_range - self.data['stroke_rgba'][:] = self.color_map( + self.data['stroke_rgba'][:, :3] = self.color_map( inverse_interpolate(low, high, np.repeat(output_norms, 8)[:-1]) - ) + )[:, :3] if self.norm_to_opacity_func is not None: self.get_stroke_opacities()[:] = self.norm_to_opacity_func( @@ -310,6 +321,165 @@ class TimeVaryingVectorField(VectorField): self.time += dt +class StreamLines(VGroup): + def __init__( + self, + func: Callable[[VectArray], VectArray], + coordinate_system: CoordinateSystem, + step_multiple: float = 0.5, + n_repeats: int = 1, + noise_factor: float | None = None, + # Config for drawing lines + solution_time: float = 3, + dt: float = 0.05, + arc_len: float = 3, + max_time_steps: int = 200, + n_samples_per_line: int = 10, + cutoff_norm: float = 15, + # Style info + stroke_width: float = 1.0, + stroke_color: ManimColor = WHITE, + stroke_opacity: float = 1, + color_by_magnitude: bool = True, + magnitude_range: Tuple[float, float] = (0, 2.0), + taper_stroke_width: bool = False, + color_map: str = "3b1b_colormap", + **kwargs + ): + super().__init__(**kwargs) + self.func = func + self.coordinate_system = coordinate_system + self.step_multiple = step_multiple + self.n_repeats = n_repeats + self.noise_factor = noise_factor + self.solution_time = solution_time + self.dt = dt + self.arc_len = arc_len + self.max_time_steps = max_time_steps + self.n_samples_per_line = n_samples_per_line + self.cutoff_norm = cutoff_norm + self.stroke_width = stroke_width + self.stroke_color = stroke_color + self.stroke_opacity = stroke_opacity + self.color_by_magnitude = color_by_magnitude + self.magnitude_range = magnitude_range + self.taper_stroke_width = taper_stroke_width + self.color_map = color_map + + self.draw_lines() + self.init_style() + + def point_func(self, points: Vect3Array) -> Vect3: + in_coords = np.array(self.coordinate_system.p2c(points)).T + out_coords = self.func(in_coords) + origin = self.coordinate_system.get_origin() + return self.coordinate_system.c2p(*out_coords.T) - origin + + def draw_lines(self) -> None: + lines = [] + origin = self.coordinate_system.get_origin() + + # Todo, it feels like coordinate system should just have + # the ODE solver built into it, no? + lines = [] + for coords in self.get_sample_coords(): + solution_coords = ode_solution_points(self.func, coords, self.solution_time, self.dt) + line = VMobject() + line.set_points_smoothly(self.coordinate_system.c2p(*solution_coords.T)) + # TODO, account for arc length somehow? + line.virtual_time = self.solution_time + lines.append(line) + self.set_submobjects(lines) + + def get_sample_coords(self): + cs = self.coordinate_system + sample_coords = get_sample_coords(cs, self.step_multiple) + + noise_factor = self.noise_factor + if noise_factor is None: + noise_factor = cs.get_x_unit_size() * self.step_multiple * 0.5 + + return np.array([ + coords + noise_factor * np.random.random(coords.shape) + for n in range(self.n_repeats) + for coords in sample_coords + ]) + + def get_start_points(self) -> Vect3Array: + cs = self.coordinate_system + sample_coords = get_sample_coords(cs, self.step_multiple) + + noise_factor = self.noise_factor + if noise_factor is None: + noise_factor = cs.get_x_unit_size() * self.step_multiple * 0.5 + + return np.array([ + cs.c2p(*coords) + noise_factor * np.random.random(3) + for n in range(self.n_repeats) + for coords in sample_coords + ]) + + def init_style(self) -> None: + if self.color_by_magnitude: + values_to_rgbs = get_vectorized_rgb_gradient_function( + *self.magnitude_range, self.color_map, + ) + cs = self.coordinate_system + for line in self.submobjects: + norms = [ + get_norm(self.func(*cs.p2c(point))) + for point in line.get_points() + ] + rgbs = values_to_rgbs(norms) + rgbas = np.zeros((len(rgbs), 4)) + rgbas[:, :3] = rgbs + rgbas[:, 3] = self.stroke_opacity + line.set_rgba_array(rgbas, "stroke_rgba") + else: + self.set_stroke(self.stroke_color, opacity=self.stroke_opacity) + + if self.taper_stroke_width: + width = [0, self.stroke_width, 0] + else: + width = self.stroke_width + self.set_stroke(width=width) + + +class AnimatedStreamLines(VGroup): + def __init__( + self, + stream_lines: StreamLines, + lag_range: float = 4, + rate_multiple: float = 1.0, + line_anim_config: dict = dict( + rate_func=linear, + time_width=1.0, + ), + **kwargs + ): + super().__init__(**kwargs) + self.stream_lines = stream_lines + + for line in stream_lines: + line.anim = VShowPassingFlash( + line, + run_time=line.virtual_time / rate_multiple, + **line_anim_config, + ) + line.anim.begin() + line.time = -lag_range * np.random.random() + self.add(line.anim.mobject) + + self.add_updater(lambda m, dt: m.update(dt)) + + def update(self, dt: float) -> None: + stream_lines = self.stream_lines + for line in stream_lines: + line.time += dt + adjusted_time = max(line.time, 0) % line.anim.run_time + line.anim.update(adjusted_time / line.anim.run_time) + + class OldVectorField(VGroup): def __init__( self, @@ -371,155 +541,3 @@ class OldVectorField(VGroup): opacity=self.opacity, ) return vect - - -class StreamLines(VGroup): - def __init__( - self, - func: Callable[[float, float], Sequence[float]], - coordinate_system: CoordinateSystem, - step_multiple: float = 0.5, - n_repeats: int = 1, - noise_factor: float | None = None, - # Config for drawing lines - dt: float = 0.05, - arc_len: float = 3, - max_time_steps: int = 200, - n_samples_per_line: int = 10, - cutoff_norm: float = 15, - # Style info - stroke_width: float = 1.0, - stroke_color: ManimColor = WHITE, - stroke_opacity: float = 1, - color_by_magnitude: bool = True, - magnitude_range: Tuple[float, float] = (0, 2.0), - taper_stroke_width: bool = False, - color_map: str = "3b1b_colormap", - **kwargs - ): - super().__init__(**kwargs) - self.func = func - self.coordinate_system = coordinate_system - self.step_multiple = step_multiple - self.n_repeats = n_repeats - self.noise_factor = noise_factor - self.dt = dt - self.arc_len = arc_len - self.max_time_steps = max_time_steps - self.n_samples_per_line = n_samples_per_line - self.cutoff_norm = cutoff_norm - self.stroke_width = stroke_width - self.stroke_color = stroke_color - self.stroke_opacity = stroke_opacity - self.color_by_magnitude = color_by_magnitude - self.magnitude_range = magnitude_range - self.taper_stroke_width = taper_stroke_width - self.color_map = color_map - - self.draw_lines() - self.init_style() - - def point_func(self, point: Vect3) -> Vect3: - in_coords = self.coordinate_system.p2c(point) - out_coords = self.func(*in_coords) - return self.coordinate_system.c2p(*out_coords) - - def draw_lines(self) -> None: - lines = [] - origin = self.coordinate_system.get_origin() - for point in self.get_start_points(): - points = [point] - total_arc_len = 0 - time = 0 - for x in range(self.max_time_steps): - time += self.dt - last_point = points[-1] - new_point = last_point + self.dt * (self.point_func(last_point) - origin) - points.append(new_point) - total_arc_len += get_norm(new_point - last_point) - if get_norm(last_point) > self.cutoff_norm: - break - if total_arc_len > self.arc_len: - break - line = VMobject() - line.virtual_time = time - step = max(1, int(len(points) / self.n_samples_per_line)) - line.set_points_as_corners(points[::step]) - line.make_smooth(approx=True) - lines.append(line) - self.set_submobjects(lines) - - def get_start_points(self) -> Vect3Array: - cs = self.coordinate_system - sample_coords = get_sample_coords( - cs, self.step_multiple, - ) - - noise_factor = self.noise_factor - if noise_factor is None: - noise_factor = cs.x_range[2] * self.step_multiple * 0.5 - - return np.array([ - cs.c2p(*coords) + noise_factor * np.random.random(3) - for n in range(self.n_repeats) - for coords in sample_coords - ]) - - def init_style(self) -> None: - if self.color_by_magnitude: - values_to_rgbs = get_vectorized_rgb_gradient_function( - *self.magnitude_range, self.color_map, - ) - cs = self.coordinate_system - for line in self.submobjects: - norms = [ - get_norm(self.func(*cs.p2c(point))) - for point in line.get_points() - ] - rgbs = values_to_rgbs(norms) - rgbas = np.zeros((len(rgbs), 4)) - rgbas[:, :3] = rgbs - rgbas[:, 3] = self.stroke_opacity - line.set_rgba_array(rgbas, "stroke_rgba") - else: - self.set_stroke(self.stroke_color, opacity=self.stroke_opacity) - - if self.taper_stroke_width: - width = [0, self.stroke_width, 0] - else: - width = self.stroke_width - self.set_stroke(width=width) - - -class AnimatedStreamLines(VGroup): - def __init__( - self, - stream_lines: StreamLines, - lag_range: float = 4, - line_anim_config: dict = dict( - rate_func=linear, - time_width=1.0, - ), - **kwargs - ): - super().__init__(**kwargs) - self.stream_lines = stream_lines - - for line in stream_lines: - line.anim = VShowPassingFlash( - line, - run_time=line.virtual_time, - **line_anim_config, - ) - line.anim.begin() - line.time = -lag_range * np.random.random() - self.add(line.anim.mobject) - - self.add_updater(lambda m, dt: m.update(dt)) - - def update(self, dt: float) -> None: - stream_lines = self.stream_lines - for line in stream_lines: - line.time += dt - adjusted_time = max(line.time, 0) % line.anim.run_time - line.anim.update(adjusted_time / line.anim.run_time)