diff --git a/manimlib/mobject/vector_field.py b/manimlib/mobject/vector_field.py index a41f2da2..7b13f9c5 100644 --- a/manimlib/mobject/vector_field.py +++ b/manimlib/mobject/vector_field.py @@ -23,9 +23,9 @@ from manimlib.utils.space_ops import get_norm from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Callable, Iterable, Sequence, TypeVar - + from typing import Callable, Iterable, Sequence, TypeVar, Tuple import numpy.typing as npt + from manimlib.constants import ManimColor, np_vector from manimlib.mobject.coordinate_systems import CoordinateSystem from manimlib.mobject.mobject import Mobject @@ -37,7 +37,7 @@ def get_vectorized_rgb_gradient_function( min_value: T, max_value: T, color_map: str -) -> Callable[[npt.ArrayLike], np.ndarray]: +) -> Callable[[npt.ArrayLike], np_vector]: rgbs = np.array(get_colormap_list(color_map)) def func(values): @@ -59,14 +59,14 @@ def get_rgb_gradient_function( min_value: T, max_value: T, color_map: str -) -> Callable[[T], np.ndarray]: +) -> Callable[[T], np_vector]: vectorized_func = get_vectorized_rgb_gradient_function(min_value, max_value, color_map) return lambda value: vectorized_func([value])[0] def move_along_vector_field( mobject: Mobject, - func: Callable[[np.ndarray], np.ndarray] + func: Callable[[np_vector], np_vector] ) -> Mobject: mobject.add_updater( lambda m, dt: m.shift( @@ -78,7 +78,7 @@ def move_along_vector_field( def move_submobjects_along_vector_field( mobject: Mobject, - func: Callable[[np.ndarray], np.ndarray] + func: Callable[[np_vector], np_vector] ) -> Mobject: def apply_nudge(mob, dt): for submob in mob: @@ -109,7 +109,7 @@ def move_points_along_vector_field( def get_sample_points_from_coordinate_system( coordinate_system: CoordinateSystem, step_multiple: float -) -> it.product[tuple[np.ndarray, ...]]: +) -> it.product[tuple[np_vector, ...]]: ranges = [] for range_args in coordinate_system.get_all_ranges(): _min, _max, step = range_args @@ -121,25 +121,29 @@ def get_sample_points_from_coordinate_system( # Mobjects class VectorField(VGroup): - CONFIG = { - "step_multiple": 0.5, - "magnitude_range": (0, 2), - "color_map": "3b1b_colormap", - # Takes in actual norm, spits out displayed norm - "length_func": lambda norm: 0.45 * sigmoid(norm), - "opacity": 1.0, - "vector_config": {}, - } - def __init__( self, func: Callable[[float, float], Sequence[float]], coordinate_system: CoordinateSystem, + step_multiple: float = 0.5, + magnitude_range: Tuple[float, float] = (0, 2), + color_map: str = "3b1b_colormap", + # Takes in actual norm, spits out displayed norm + length_func: Callable[[float], float] = lambda norm: 0.45 * sigmoid(norm), + opacity: float = 1.0, + vector_config: dict = dict(), **kwargs ): super().__init__(**kwargs) self.func = func self.coordinate_system = coordinate_system + self.step_multiple = step_multiple + self.magnitude_range = magnitude_range + self.color_map = color_map + self.length_func = length_func + self.opacity = opacity + self.vector_config = vector_config + self.value_to_rgb = get_rgb_gradient_function( *self.magnitude_range, self.color_map, ) @@ -177,39 +181,52 @@ class VectorField(VGroup): class StreamLines(VGroup): - CONFIG = { - "step_multiple": 0.5, - "n_repeats": 1, - "noise_factor": None, - # Config for drawing lines - "dt": 0.05, - "arc_len": 3, - "max_time_steps": 200, - "n_samples_per_line": 10, - "cutoff_norm": 15, - # Style info - "stroke_width": 1, - "stroke_color": WHITE, - "stroke_opacity": 1, - "color_by_magnitude": True, - "magnitude_range": (0, 2.0), - "taper_stroke_width": False, - "color_map": "3b1b_colormap", - } - def __init__( self, func: Callable[[float, float], Sequence[float]], coordinate_system: CoordinateSystem, + step_multiple: float = 0.5, + n_repeats: int = 1, + noise_factor: float | None = None, + # Config for drawing lines + dt: float = 0.05, + arc_len: float = 3, + max_time_steps: int = 200, + n_samples_per_line: int = 10, + cutoff_norm: float = 15, + # Style info + stroke_width: float = 1.0, + stroke_color: ManimColor = WHITE, + stroke_opacity: float = 1, + color_by_magnitude: bool = True, + magnitude_range: Tuple[float, float] = (0, 2.0), + taper_stroke_width: bool = False, + color_map: str = "3b1b_colormap", **kwargs ): super().__init__(**kwargs) self.func = func self.coordinate_system = coordinate_system + self.step_multiple = step_multiple + self.n_repeats = n_repeats + self.noise_factor = noise_factor + self.dt = dt + self.arc_len = arc_len + self.max_time_steps = max_time_steps + self.n_samples_per_line = n_samples_per_line + self.cutoff_norm = cutoff_norm + self.stroke_width = stroke_width + self.stroke_color = stroke_color + self.stroke_opacity = stroke_opacity + self.color_by_magnitude = color_by_magnitude + self.magnitude_range = magnitude_range + self.taper_stroke_width = taper_stroke_width + self.color_map = color_map + self.draw_lines() self.init_style() - def point_func(self, point: np.ndarray) -> np.ndarray: + def point_func(self, point: np_vector) -> np_vector: in_coords = self.coordinate_system.p2c(point) out_coords = self.func(*in_coords) return self.coordinate_system.c2p(*out_coords) @@ -239,7 +256,7 @@ class StreamLines(VGroup): lines.append(line) self.set_submobjects(lines) - def get_start_points(self) -> np.ndarray: + def get_start_points(self) -> np_vector: cs = self.coordinate_system sample_coords = get_sample_points_from_coordinate_system( cs, self.step_multiple, @@ -282,27 +299,27 @@ class StreamLines(VGroup): class AnimatedStreamLines(VGroup): - CONFIG = { - "lag_range": 4, - "line_anim_class": VShowPassingFlash, - "line_anim_config": { - # "run_time": 4, - "rate_func": linear, - "time_width": 0.5, - }, - } - - def __init__(self, stream_lines: StreamLines, **kwargs): + def __init__( + self, + stream_lines: StreamLines, + lag_range: float = 4, + line_anim_config: dict = dict( + rate_func=linear, + time_width=1.0, + ), + **kwargs + ): super().__init__(**kwargs) self.stream_lines = stream_lines + for line in stream_lines: - line.anim = self.line_anim_class( + line.anim = VShowPassingFlash( line, run_time=line.virtual_time, - **self.line_anim_config, + **line_anim_config, ) line.anim.begin() - line.time = -self.lag_range * np.random.random() + line.time = -lag_range * np.random.random() self.add(line.anim.mobject) self.add_updater(lambda m, dt: m.update(dt))