mirror of
https://github.com/3b1b/manim.git
synced 2025-08-21 05:44:04 +00:00
Further development on VectorField
This commit is contained in:
parent
64ae1364ca
commit
0ad5a0e76e
1 changed files with 175 additions and 157 deletions
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||||
import itertools as it
|
import itertools as it
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from scipy.integrate import solve_ivp
|
||||||
|
|
||||||
from manimlib.constants import FRAME_HEIGHT, FRAME_WIDTH
|
from manimlib.constants import FRAME_HEIGHT, FRAME_WIDTH
|
||||||
from manimlib.constants import BLUE, WHITE
|
from manimlib.constants import BLUE, WHITE
|
||||||
|
@ -26,7 +27,7 @@ from typing import TYPE_CHECKING
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import Callable, Iterable, Sequence, TypeVar, Tuple
|
from typing import Callable, Iterable, Sequence, TypeVar, Tuple
|
||||||
from manimlib.typing import ManimColor, Vect3, VectN, Vect3Array, Vect4Array
|
from manimlib.typing import ManimColor, Vect3, VectN, Vect2Array, Vect3Array, Vect4Array
|
||||||
|
|
||||||
from manimlib.mobject.coordinate_systems import CoordinateSystem
|
from manimlib.mobject.coordinate_systems import CoordinateSystem
|
||||||
from manimlib.mobject.mobject import Mobject
|
from manimlib.mobject.mobject import Mobject
|
||||||
|
@ -68,6 +69,16 @@ def get_rgb_gradient_function(
|
||||||
####
|
####
|
||||||
|
|
||||||
|
|
||||||
|
def ode_solution_points(function, state0, time, dt=0.01):
|
||||||
|
solution = solve_ivp(
|
||||||
|
lambda t, state: function(state),
|
||||||
|
t_span=(0, time),
|
||||||
|
y0=state0,
|
||||||
|
t_eval=np.arange(0, time, dt)
|
||||||
|
)
|
||||||
|
return solution.y.T
|
||||||
|
|
||||||
|
|
||||||
def move_along_vector_field(
|
def move_along_vector_field(
|
||||||
mobject: Mobject,
|
mobject: Mobject,
|
||||||
func: Callable[[Vect3], Vect3]
|
func: Callable[[Vect3], Vect3]
|
||||||
|
@ -128,12 +139,12 @@ def get_sample_coords(
|
||||||
class VectorField(VMobject):
|
class VectorField(VMobject):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
func: Callable[Sequence[float], Sequence[float]],
|
func: Callable[[VectArray], VectArray],
|
||||||
coordinate_system: CoordinateSystem,
|
coordinate_system: CoordinateSystem,
|
||||||
step_multiple: float = 0.5,
|
step_multiple: float = 0.5,
|
||||||
magnitude_range: Optional[Tuple[float, float]] = None,
|
magnitude_range: Optional[Tuple[float, float]] = None,
|
||||||
color_map_name: Optional[str] = "3b1b_colormap",
|
color_map_name: Optional[str] = "3b1b_colormap",
|
||||||
color_map: Optional[Callable[Sequence[float]], Vect4Array] = None,
|
color_map: Optional[Callable[[Sequence[float]], Vect4Array]] = None,
|
||||||
stroke_color: ManimColor = BLUE,
|
stroke_color: ManimColor = BLUE,
|
||||||
stroke_opacity: float = 1.0,
|
stroke_opacity: float = 1.0,
|
||||||
stroke_width: float = 2,
|
stroke_width: float = 2,
|
||||||
|
@ -281,9 +292,9 @@ class VectorField(VMobject):
|
||||||
if self.color_map is not None:
|
if self.color_map is not None:
|
||||||
self.get_stroke_colors() # Ensures the array is updated to appropriate length
|
self.get_stroke_colors() # Ensures the array is updated to appropriate length
|
||||||
low, high = self.magnitude_range
|
low, high = self.magnitude_range
|
||||||
self.data['stroke_rgba'][:] = self.color_map(
|
self.data['stroke_rgba'][:, :3] = self.color_map(
|
||||||
inverse_interpolate(low, high, np.repeat(output_norms, 8)[:-1])
|
inverse_interpolate(low, high, np.repeat(output_norms, 8)[:-1])
|
||||||
)
|
)[:, :3]
|
||||||
|
|
||||||
if self.norm_to_opacity_func is not None:
|
if self.norm_to_opacity_func is not None:
|
||||||
self.get_stroke_opacities()[:] = self.norm_to_opacity_func(
|
self.get_stroke_opacities()[:] = self.norm_to_opacity_func(
|
||||||
|
@ -310,6 +321,165 @@ class TimeVaryingVectorField(VectorField):
|
||||||
self.time += dt
|
self.time += dt
|
||||||
|
|
||||||
|
|
||||||
|
class StreamLines(VGroup):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
func: Callable[[VectArray], VectArray],
|
||||||
|
coordinate_system: CoordinateSystem,
|
||||||
|
step_multiple: float = 0.5,
|
||||||
|
n_repeats: int = 1,
|
||||||
|
noise_factor: float | None = None,
|
||||||
|
# Config for drawing lines
|
||||||
|
solution_time: float = 3,
|
||||||
|
dt: float = 0.05,
|
||||||
|
arc_len: float = 3,
|
||||||
|
max_time_steps: int = 200,
|
||||||
|
n_samples_per_line: int = 10,
|
||||||
|
cutoff_norm: float = 15,
|
||||||
|
# Style info
|
||||||
|
stroke_width: float = 1.0,
|
||||||
|
stroke_color: ManimColor = WHITE,
|
||||||
|
stroke_opacity: float = 1,
|
||||||
|
color_by_magnitude: bool = True,
|
||||||
|
magnitude_range: Tuple[float, float] = (0, 2.0),
|
||||||
|
taper_stroke_width: bool = False,
|
||||||
|
color_map: str = "3b1b_colormap",
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.func = func
|
||||||
|
self.coordinate_system = coordinate_system
|
||||||
|
self.step_multiple = step_multiple
|
||||||
|
self.n_repeats = n_repeats
|
||||||
|
self.noise_factor = noise_factor
|
||||||
|
self.solution_time = solution_time
|
||||||
|
self.dt = dt
|
||||||
|
self.arc_len = arc_len
|
||||||
|
self.max_time_steps = max_time_steps
|
||||||
|
self.n_samples_per_line = n_samples_per_line
|
||||||
|
self.cutoff_norm = cutoff_norm
|
||||||
|
self.stroke_width = stroke_width
|
||||||
|
self.stroke_color = stroke_color
|
||||||
|
self.stroke_opacity = stroke_opacity
|
||||||
|
self.color_by_magnitude = color_by_magnitude
|
||||||
|
self.magnitude_range = magnitude_range
|
||||||
|
self.taper_stroke_width = taper_stroke_width
|
||||||
|
self.color_map = color_map
|
||||||
|
|
||||||
|
self.draw_lines()
|
||||||
|
self.init_style()
|
||||||
|
|
||||||
|
def point_func(self, points: Vect3Array) -> Vect3:
|
||||||
|
in_coords = np.array(self.coordinate_system.p2c(points)).T
|
||||||
|
out_coords = self.func(in_coords)
|
||||||
|
origin = self.coordinate_system.get_origin()
|
||||||
|
return self.coordinate_system.c2p(*out_coords.T) - origin
|
||||||
|
|
||||||
|
def draw_lines(self) -> None:
|
||||||
|
lines = []
|
||||||
|
origin = self.coordinate_system.get_origin()
|
||||||
|
|
||||||
|
# Todo, it feels like coordinate system should just have
|
||||||
|
# the ODE solver built into it, no?
|
||||||
|
lines = []
|
||||||
|
for coords in self.get_sample_coords():
|
||||||
|
solution_coords = ode_solution_points(self.func, coords, self.solution_time, self.dt)
|
||||||
|
line = VMobject()
|
||||||
|
line.set_points_smoothly(self.coordinate_system.c2p(*solution_coords.T))
|
||||||
|
# TODO, account for arc length somehow?
|
||||||
|
line.virtual_time = self.solution_time
|
||||||
|
lines.append(line)
|
||||||
|
self.set_submobjects(lines)
|
||||||
|
|
||||||
|
def get_sample_coords(self):
|
||||||
|
cs = self.coordinate_system
|
||||||
|
sample_coords = get_sample_coords(cs, self.step_multiple)
|
||||||
|
|
||||||
|
noise_factor = self.noise_factor
|
||||||
|
if noise_factor is None:
|
||||||
|
noise_factor = cs.get_x_unit_size() * self.step_multiple * 0.5
|
||||||
|
|
||||||
|
return np.array([
|
||||||
|
coords + noise_factor * np.random.random(coords.shape)
|
||||||
|
for n in range(self.n_repeats)
|
||||||
|
for coords in sample_coords
|
||||||
|
])
|
||||||
|
|
||||||
|
def get_start_points(self) -> Vect3Array:
|
||||||
|
cs = self.coordinate_system
|
||||||
|
sample_coords = get_sample_coords(cs, self.step_multiple)
|
||||||
|
|
||||||
|
noise_factor = self.noise_factor
|
||||||
|
if noise_factor is None:
|
||||||
|
noise_factor = cs.get_x_unit_size() * self.step_multiple * 0.5
|
||||||
|
|
||||||
|
return np.array([
|
||||||
|
cs.c2p(*coords) + noise_factor * np.random.random(3)
|
||||||
|
for n in range(self.n_repeats)
|
||||||
|
for coords in sample_coords
|
||||||
|
])
|
||||||
|
|
||||||
|
def init_style(self) -> None:
|
||||||
|
if self.color_by_magnitude:
|
||||||
|
values_to_rgbs = get_vectorized_rgb_gradient_function(
|
||||||
|
*self.magnitude_range, self.color_map,
|
||||||
|
)
|
||||||
|
cs = self.coordinate_system
|
||||||
|
for line in self.submobjects:
|
||||||
|
norms = [
|
||||||
|
get_norm(self.func(*cs.p2c(point)))
|
||||||
|
for point in line.get_points()
|
||||||
|
]
|
||||||
|
rgbs = values_to_rgbs(norms)
|
||||||
|
rgbas = np.zeros((len(rgbs), 4))
|
||||||
|
rgbas[:, :3] = rgbs
|
||||||
|
rgbas[:, 3] = self.stroke_opacity
|
||||||
|
line.set_rgba_array(rgbas, "stroke_rgba")
|
||||||
|
else:
|
||||||
|
self.set_stroke(self.stroke_color, opacity=self.stroke_opacity)
|
||||||
|
|
||||||
|
if self.taper_stroke_width:
|
||||||
|
width = [0, self.stroke_width, 0]
|
||||||
|
else:
|
||||||
|
width = self.stroke_width
|
||||||
|
self.set_stroke(width=width)
|
||||||
|
|
||||||
|
|
||||||
|
class AnimatedStreamLines(VGroup):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stream_lines: StreamLines,
|
||||||
|
lag_range: float = 4,
|
||||||
|
rate_multiple: float = 1.0,
|
||||||
|
line_anim_config: dict = dict(
|
||||||
|
rate_func=linear,
|
||||||
|
time_width=1.0,
|
||||||
|
),
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.stream_lines = stream_lines
|
||||||
|
|
||||||
|
for line in stream_lines:
|
||||||
|
line.anim = VShowPassingFlash(
|
||||||
|
line,
|
||||||
|
run_time=line.virtual_time / rate_multiple,
|
||||||
|
**line_anim_config,
|
||||||
|
)
|
||||||
|
line.anim.begin()
|
||||||
|
line.time = -lag_range * np.random.random()
|
||||||
|
self.add(line.anim.mobject)
|
||||||
|
|
||||||
|
self.add_updater(lambda m, dt: m.update(dt))
|
||||||
|
|
||||||
|
def update(self, dt: float) -> None:
|
||||||
|
stream_lines = self.stream_lines
|
||||||
|
for line in stream_lines:
|
||||||
|
line.time += dt
|
||||||
|
adjusted_time = max(line.time, 0) % line.anim.run_time
|
||||||
|
line.anim.update(adjusted_time / line.anim.run_time)
|
||||||
|
|
||||||
|
|
||||||
class OldVectorField(VGroup):
|
class OldVectorField(VGroup):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -371,155 +541,3 @@ class OldVectorField(VGroup):
|
||||||
opacity=self.opacity,
|
opacity=self.opacity,
|
||||||
)
|
)
|
||||||
return vect
|
return vect
|
||||||
|
|
||||||
|
|
||||||
class StreamLines(VGroup):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
func: Callable[[float, float], Sequence[float]],
|
|
||||||
coordinate_system: CoordinateSystem,
|
|
||||||
step_multiple: float = 0.5,
|
|
||||||
n_repeats: int = 1,
|
|
||||||
noise_factor: float | None = None,
|
|
||||||
# Config for drawing lines
|
|
||||||
dt: float = 0.05,
|
|
||||||
arc_len: float = 3,
|
|
||||||
max_time_steps: int = 200,
|
|
||||||
n_samples_per_line: int = 10,
|
|
||||||
cutoff_norm: float = 15,
|
|
||||||
# Style info
|
|
||||||
stroke_width: float = 1.0,
|
|
||||||
stroke_color: ManimColor = WHITE,
|
|
||||||
stroke_opacity: float = 1,
|
|
||||||
color_by_magnitude: bool = True,
|
|
||||||
magnitude_range: Tuple[float, float] = (0, 2.0),
|
|
||||||
taper_stroke_width: bool = False,
|
|
||||||
color_map: str = "3b1b_colormap",
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.func = func
|
|
||||||
self.coordinate_system = coordinate_system
|
|
||||||
self.step_multiple = step_multiple
|
|
||||||
self.n_repeats = n_repeats
|
|
||||||
self.noise_factor = noise_factor
|
|
||||||
self.dt = dt
|
|
||||||
self.arc_len = arc_len
|
|
||||||
self.max_time_steps = max_time_steps
|
|
||||||
self.n_samples_per_line = n_samples_per_line
|
|
||||||
self.cutoff_norm = cutoff_norm
|
|
||||||
self.stroke_width = stroke_width
|
|
||||||
self.stroke_color = stroke_color
|
|
||||||
self.stroke_opacity = stroke_opacity
|
|
||||||
self.color_by_magnitude = color_by_magnitude
|
|
||||||
self.magnitude_range = magnitude_range
|
|
||||||
self.taper_stroke_width = taper_stroke_width
|
|
||||||
self.color_map = color_map
|
|
||||||
|
|
||||||
self.draw_lines()
|
|
||||||
self.init_style()
|
|
||||||
|
|
||||||
def point_func(self, point: Vect3) -> Vect3:
|
|
||||||
in_coords = self.coordinate_system.p2c(point)
|
|
||||||
out_coords = self.func(*in_coords)
|
|
||||||
return self.coordinate_system.c2p(*out_coords)
|
|
||||||
|
|
||||||
def draw_lines(self) -> None:
|
|
||||||
lines = []
|
|
||||||
origin = self.coordinate_system.get_origin()
|
|
||||||
for point in self.get_start_points():
|
|
||||||
points = [point]
|
|
||||||
total_arc_len = 0
|
|
||||||
time = 0
|
|
||||||
for x in range(self.max_time_steps):
|
|
||||||
time += self.dt
|
|
||||||
last_point = points[-1]
|
|
||||||
new_point = last_point + self.dt * (self.point_func(last_point) - origin)
|
|
||||||
points.append(new_point)
|
|
||||||
total_arc_len += get_norm(new_point - last_point)
|
|
||||||
if get_norm(last_point) > self.cutoff_norm:
|
|
||||||
break
|
|
||||||
if total_arc_len > self.arc_len:
|
|
||||||
break
|
|
||||||
line = VMobject()
|
|
||||||
line.virtual_time = time
|
|
||||||
step = max(1, int(len(points) / self.n_samples_per_line))
|
|
||||||
line.set_points_as_corners(points[::step])
|
|
||||||
line.make_smooth(approx=True)
|
|
||||||
lines.append(line)
|
|
||||||
self.set_submobjects(lines)
|
|
||||||
|
|
||||||
def get_start_points(self) -> Vect3Array:
|
|
||||||
cs = self.coordinate_system
|
|
||||||
sample_coords = get_sample_coords(
|
|
||||||
cs, self.step_multiple,
|
|
||||||
)
|
|
||||||
|
|
||||||
noise_factor = self.noise_factor
|
|
||||||
if noise_factor is None:
|
|
||||||
noise_factor = cs.x_range[2] * self.step_multiple * 0.5
|
|
||||||
|
|
||||||
return np.array([
|
|
||||||
cs.c2p(*coords) + noise_factor * np.random.random(3)
|
|
||||||
for n in range(self.n_repeats)
|
|
||||||
for coords in sample_coords
|
|
||||||
])
|
|
||||||
|
|
||||||
def init_style(self) -> None:
|
|
||||||
if self.color_by_magnitude:
|
|
||||||
values_to_rgbs = get_vectorized_rgb_gradient_function(
|
|
||||||
*self.magnitude_range, self.color_map,
|
|
||||||
)
|
|
||||||
cs = self.coordinate_system
|
|
||||||
for line in self.submobjects:
|
|
||||||
norms = [
|
|
||||||
get_norm(self.func(*cs.p2c(point)))
|
|
||||||
for point in line.get_points()
|
|
||||||
]
|
|
||||||
rgbs = values_to_rgbs(norms)
|
|
||||||
rgbas = np.zeros((len(rgbs), 4))
|
|
||||||
rgbas[:, :3] = rgbs
|
|
||||||
rgbas[:, 3] = self.stroke_opacity
|
|
||||||
line.set_rgba_array(rgbas, "stroke_rgba")
|
|
||||||
else:
|
|
||||||
self.set_stroke(self.stroke_color, opacity=self.stroke_opacity)
|
|
||||||
|
|
||||||
if self.taper_stroke_width:
|
|
||||||
width = [0, self.stroke_width, 0]
|
|
||||||
else:
|
|
||||||
width = self.stroke_width
|
|
||||||
self.set_stroke(width=width)
|
|
||||||
|
|
||||||
|
|
||||||
class AnimatedStreamLines(VGroup):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
stream_lines: StreamLines,
|
|
||||||
lag_range: float = 4,
|
|
||||||
line_anim_config: dict = dict(
|
|
||||||
rate_func=linear,
|
|
||||||
time_width=1.0,
|
|
||||||
),
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.stream_lines = stream_lines
|
|
||||||
|
|
||||||
for line in stream_lines:
|
|
||||||
line.anim = VShowPassingFlash(
|
|
||||||
line,
|
|
||||||
run_time=line.virtual_time,
|
|
||||||
**line_anim_config,
|
|
||||||
)
|
|
||||||
line.anim.begin()
|
|
||||||
line.time = -lag_range * np.random.random()
|
|
||||||
self.add(line.anim.mobject)
|
|
||||||
|
|
||||||
self.add_updater(lambda m, dt: m.update(dt))
|
|
||||||
|
|
||||||
def update(self, dt: float) -> None:
|
|
||||||
stream_lines = self.stream_lines
|
|
||||||
for line in stream_lines:
|
|
||||||
line.time += dt
|
|
||||||
adjusted_time = max(line.time, 0) % line.anim.run_time
|
|
||||||
line.anim.update(adjusted_time / line.anim.run_time)
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue