mirror of
https://github.com/3b1b/manim.git
synced 2025-08-05 16:49:03 +00:00
Update the Vector Field interface
This commit is contained in:
parent
b84376d6fd
commit
64ae1364ca
3 changed files with 108 additions and 62 deletions
|
@ -45,6 +45,12 @@ DEFAULT_X_RANGE = (-8.0, 8.0, 1.0)
|
|||
DEFAULT_Y_RANGE = (-4.0, 4.0, 1.0)
|
||||
|
||||
|
||||
def full_range_specifier(range_args):
|
||||
if len(range_args) == 2:
|
||||
return (*range_args, 1)
|
||||
return range_args
|
||||
|
||||
|
||||
class CoordinateSystem(ABC):
|
||||
"""
|
||||
Abstract class for Axes and NumberPlane
|
||||
|
@ -57,8 +63,8 @@ class CoordinateSystem(ABC):
|
|||
y_range: RangeSpecifier = DEFAULT_Y_RANGE,
|
||||
num_sampled_graph_points_per_tick: int = 5,
|
||||
):
|
||||
self.x_range = x_range
|
||||
self.y_range = y_range
|
||||
self.x_range = full_range_specifier(x_range)
|
||||
self.y_range = full_range_specifier(y_range)
|
||||
self.num_sampled_graph_points_per_tick = num_sampled_graph_points_per_tick
|
||||
|
||||
@abstractmethod
|
||||
|
@ -536,7 +542,7 @@ class ThreeDAxes(Axes):
|
|||
):
|
||||
Axes.__init__(self, x_range, y_range, **kwargs)
|
||||
|
||||
self.z_range = z_range
|
||||
self.z_range = full_range_specifier(z_range)
|
||||
self.z_axis = self.create_axis(
|
||||
self.z_range,
|
||||
axis_config=merge_dicts_recursively(
|
||||
|
|
|
@ -15,6 +15,7 @@ from manimlib.utils.bezier import interpolate
|
|||
from manimlib.utils.bezier import inverse_interpolate
|
||||
from manimlib.utils.color import get_colormap_list
|
||||
from manimlib.utils.color import rgb_to_color
|
||||
from manimlib.utils.color import get_color_map
|
||||
from manimlib.utils.dict_ops import merge_dicts_recursively
|
||||
from manimlib.utils.iterables import cartesian_product
|
||||
from manimlib.utils.rate_functions import linear
|
||||
|
@ -25,7 +26,7 @@ from typing import TYPE_CHECKING
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Callable, Iterable, Sequence, TypeVar, Tuple
|
||||
from manimlib.typing import ManimColor, Vect3, VectN, Vect3Array
|
||||
from manimlib.typing import ManimColor, Vect3, VectN, Vect3Array, Vect4Array
|
||||
|
||||
from manimlib.mobject.coordinate_systems import CoordinateSystem
|
||||
from manimlib.mobject.mobject import Mobject
|
||||
|
@ -33,6 +34,7 @@ if TYPE_CHECKING:
|
|||
T = TypeVar("T")
|
||||
|
||||
|
||||
#### Delete these two ###
|
||||
def get_vectorized_rgb_gradient_function(
|
||||
min_value: T,
|
||||
max_value: T,
|
||||
|
@ -52,6 +54,7 @@ def get_vectorized_rgb_gradient_function(
|
|||
inter_alphas = inter_alphas.repeat(3).reshape((len(indices), 3))
|
||||
result = interpolate(rgbs[indices], rgbs[next_indices], inter_alphas)
|
||||
return result
|
||||
|
||||
return func
|
||||
|
||||
|
||||
|
@ -62,6 +65,7 @@ def get_rgb_gradient_function(
|
|||
) -> Callable[[float], Vect3]:
|
||||
vectorized_func = get_vectorized_rgb_gradient_function(min_value, max_value, color_map)
|
||||
return lambda value: vectorized_func(np.array([value]))[0]
|
||||
####
|
||||
|
||||
|
||||
def move_along_vector_field(
|
||||
|
@ -106,16 +110,16 @@ def move_points_along_vector_field(
|
|||
return mobject
|
||||
|
||||
|
||||
def get_sample_points_from_coordinate_system(
|
||||
def get_sample_coords(
|
||||
coordinate_system: CoordinateSystem,
|
||||
step_multiple: float
|
||||
step_multiple: float = 1.0
|
||||
) -> it.product[tuple[Vect3, ...]]:
|
||||
ranges = []
|
||||
for range_args in coordinate_system.get_all_ranges():
|
||||
_min, _max, step = range_args
|
||||
step *= step_multiple
|
||||
ranges.append(np.arange(_min, _max + step, step))
|
||||
return it.product(*ranges)
|
||||
return np.array(list(it.product(*ranges)))
|
||||
|
||||
|
||||
# Mobjects
|
||||
|
@ -124,51 +128,53 @@ def get_sample_points_from_coordinate_system(
|
|||
class VectorField(VMobject):
|
||||
def __init__(
|
||||
self,
|
||||
func,
|
||||
func: Callable[Sequence[float], Sequence[float]],
|
||||
coordinate_system: CoordinateSystem,
|
||||
step_multiple: float = 0.5,
|
||||
magnitude_range: Optional[Tuple[float, float]] = None,
|
||||
color_map_name: Optional[str] = "3b1b_colormap",
|
||||
color_map: Optional[Callable[Sequence[float]], Vect4Array] = None,
|
||||
stroke_color: ManimColor = BLUE,
|
||||
stroke_opacity: float = 1.0,
|
||||
center: Vect3 = ORIGIN,
|
||||
sample_points: Optional[Vect3Array] = None,
|
||||
x_density: float = 2.0,
|
||||
y_density: float = 2.0,
|
||||
z_density: float = 2.0,
|
||||
width: float = 14.0,
|
||||
height: float = 8.0,
|
||||
depth: float = 0.0,
|
||||
stroke_width: float = 2,
|
||||
tip_width_ratio: float = 4,
|
||||
tip_len_to_width: float = 0.01,
|
||||
max_vect_len: float | None = None,
|
||||
min_drawn_norm: float = 1e-2,
|
||||
flat_stroke: bool = False,
|
||||
norm_to_opacity_func=None,
|
||||
norm_to_rgb_func=None,
|
||||
norm_to_opacity_func=None, # TODO, check on this
|
||||
**kwargs
|
||||
):
|
||||
self.func = func
|
||||
self.coordinate_system = coordinate_system
|
||||
self.stroke_width = stroke_width
|
||||
self.tip_width_ratio = tip_width_ratio
|
||||
self.tip_len_to_width = tip_len_to_width
|
||||
self.min_drawn_norm = min_drawn_norm
|
||||
self.norm_to_opacity_func = norm_to_opacity_func
|
||||
self.norm_to_rgb_func = norm_to_rgb_func
|
||||
|
||||
if max_vect_len is not None:
|
||||
self.max_vect_len = max_vect_len
|
||||
# Search for sample_points
|
||||
self.sample_coords = get_sample_coords(coordinate_system, step_multiple)
|
||||
self.update_sample_points()
|
||||
|
||||
if max_vect_len is None:
|
||||
self.max_displayed_vect_len = get_norm(self.sample_points[1] - self.sample_points[0])
|
||||
else:
|
||||
densities = np.array([x_density, y_density, z_density])
|
||||
dims = np.array([width, height, depth])
|
||||
self.max_vect_len = 1.0 / densities[dims > 0].mean()
|
||||
self.max_displayed_vect_len = max_vect_len * coordinate_system.get_x_unit_size()
|
||||
|
||||
if sample_points is None:
|
||||
self.sample_points = self.get_sample_points(
|
||||
center, width, height, depth,
|
||||
x_density, y_density, z_density
|
||||
)
|
||||
# Prepare the color map
|
||||
if magnitude_range is None:
|
||||
max_value = max(map(get_norm, func(self.sample_coords)))
|
||||
magnitude_range = (0, max_value)
|
||||
|
||||
self.magnitude_range = magnitude_range
|
||||
|
||||
if color_map is not None:
|
||||
self.color_map = color_map
|
||||
elif color_map_name is not None:
|
||||
self.color_map = get_color_map(color_map_name)
|
||||
else:
|
||||
self.sample_points = sample_points
|
||||
self.color_map = None
|
||||
|
||||
self.init_base_stroke_width_array(len(self.sample_points))
|
||||
self.init_base_stroke_width_array(len(self.sample_coords))
|
||||
|
||||
super().__init__(
|
||||
stroke_color=stroke_color,
|
||||
|
@ -177,7 +183,7 @@ class VectorField(VMobject):
|
|||
**kwargs
|
||||
)
|
||||
|
||||
n_samples = len(self.sample_points)
|
||||
n_samples = len(self.sample_coords)
|
||||
self.set_points(np.zeros((8 * n_samples - 1, 3)))
|
||||
self.set_stroke(width=stroke_width)
|
||||
self.set_joint_type('no_joint')
|
||||
|
@ -211,8 +217,8 @@ class VectorField(VMobject):
|
|||
arr[7::8] = 0
|
||||
self.base_stroke_width_array = arr
|
||||
|
||||
def set_sample_points(self, sample_points: Vect3Array):
|
||||
self.sample_points = sample_points
|
||||
def set_sample_coords(self, sample_points: VectArray):
|
||||
self.sample_coords = sample_coords
|
||||
return self
|
||||
|
||||
def set_stroke(self, color=None, width=None, opacity=None, behind=None, flat=None, recurse=True):
|
||||
|
@ -227,35 +233,40 @@ class VectorField(VMobject):
|
|||
self.stroke_width = width
|
||||
return self
|
||||
|
||||
def update_sample_points(self):
|
||||
self.sample_points = self.coordinate_system.c2p(*self.sample_coords.T)
|
||||
|
||||
def update_vectors(self):
|
||||
tip_width = self.tip_width_ratio * self.stroke_width
|
||||
tip_len = self.tip_len_to_width * tip_width
|
||||
samples = self.sample_points
|
||||
|
||||
# Get raw outputs and lengths
|
||||
outputs = self.func(samples)
|
||||
norms = np.linalg.norm(outputs, axis=1)[:, np.newaxis]
|
||||
# Outputs in the coordinate system
|
||||
outputs = self.func(self.sample_coords)
|
||||
output_norms = np.linalg.norm(outputs, axis=1)[:, np.newaxis]
|
||||
|
||||
# How long should the arrows be drawn?
|
||||
max_len = self.max_vect_len
|
||||
# Corresponding vector values in global coordinates
|
||||
out_vects = self.coordinate_system.c2p(*outputs.T) - self.coordinate_system.get_origin()
|
||||
out_vect_norms = np.linalg.norm(out_vects, axis=1)[:, np.newaxis]
|
||||
unit_outputs = np.zeros_like(out_vects)
|
||||
np.true_divide(out_vects, out_vect_norms, out=unit_outputs, where=(out_vect_norms > 0))
|
||||
|
||||
# How long should the arrows be drawn, in global coordinates
|
||||
max_len = self.max_displayed_vect_len
|
||||
if max_len < np.inf:
|
||||
drawn_norms = max_len * np.tanh(norms / max_len)
|
||||
drawn_norms = max_len * np.tanh(out_vect_norms / max_len)
|
||||
else:
|
||||
drawn_norms = norms
|
||||
drawn_norms = out_vect_norms
|
||||
|
||||
# What's the distance from the base of an arrow to
|
||||
# the base of its head?
|
||||
dist_to_head_base = np.clip(drawn_norms - tip_len, 0, np.inf)
|
||||
dist_to_head_base = np.clip(drawn_norms - tip_len, 0, np.inf) # Mixing units!
|
||||
|
||||
# Set all points
|
||||
unit_outputs = np.zeros_like(outputs)
|
||||
np.true_divide(outputs, norms, out=unit_outputs, where=(norms > self.min_drawn_norm))
|
||||
|
||||
points = self.get_points()
|
||||
points[0::8] = samples
|
||||
points[2::8] = samples + dist_to_head_base * unit_outputs
|
||||
points[0::8] = self.sample_points
|
||||
points[2::8] = self.sample_points + dist_to_head_base * unit_outputs
|
||||
points[4::8] = points[2::8]
|
||||
points[6::8] = samples + drawn_norms * unit_outputs
|
||||
points[6::8] = self.sample_points + drawn_norms * unit_outputs
|
||||
for i in (1, 3, 5):
|
||||
points[i::8] = 0.5 * (points[i - 1::8] + points[i + 1::8])
|
||||
points[7::8] = points[6:-1:8]
|
||||
|
@ -267,14 +278,16 @@ class VectorField(VMobject):
|
|||
self.get_stroke_widths()[:] = width_scalars * width_arr
|
||||
|
||||
# Potentially adjust opacity and color
|
||||
if self.color_map is not None:
|
||||
self.get_stroke_colors() # Ensures the array is updated to appropriate length
|
||||
low, high = self.magnitude_range
|
||||
self.data['stroke_rgba'][:] = self.color_map(
|
||||
inverse_interpolate(low, high, np.repeat(output_norms, 8)[:-1])
|
||||
)
|
||||
|
||||
if self.norm_to_opacity_func is not None:
|
||||
self.get_stroke_opacities()[:] = self.norm_to_opacity_func(
|
||||
np.repeat(norms, 8)[:-1]
|
||||
)
|
||||
if self.norm_to_rgb_func is not None:
|
||||
self.get_stroke_colors()
|
||||
self.data['stroke_rgba'][:, :3] = self.norm_to_rgb_func(
|
||||
np.repeat(norms, 8)[:-1]
|
||||
np.repeat(output_norms, 8)[:-1]
|
||||
)
|
||||
|
||||
self.note_changed_data()
|
||||
|
@ -321,11 +334,11 @@ class OldVectorField(VGroup):
|
|||
self.opacity = opacity
|
||||
self.vector_config = dict(vector_config)
|
||||
|
||||
self.value_to_rgb = get_rgb_gradient_function(
|
||||
self.value_to_rgb = get_vectorized_rgb_gradient_function(
|
||||
*self.magnitude_range, self.color_map,
|
||||
)
|
||||
|
||||
samples = get_sample_points_from_coordinate_system(
|
||||
samples = get_sample_coords(
|
||||
coordinate_system, self.step_multiple
|
||||
)
|
||||
self.add(*(
|
||||
|
@ -438,7 +451,7 @@ class StreamLines(VGroup):
|
|||
|
||||
def get_start_points(self) -> Vect3Array:
|
||||
cs = self.coordinate_system
|
||||
sample_coords = get_sample_points_from_coordinate_system(
|
||||
sample_coords = get_sample_coords(
|
||||
cs, self.step_multiple,
|
||||
)
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@ from manimlib.utils.iterables import resize_with_interpolation
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Iterable, Sequence
|
||||
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array
|
||||
from typing import Iterable, Sequence, Callable
|
||||
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array, NDArray
|
||||
|
||||
|
||||
def color_to_rgb(color: ManimColor) -> Vect3:
|
||||
|
@ -134,6 +134,33 @@ def random_bright_color(
|
|||
))
|
||||
|
||||
|
||||
def get_colormap_from_colors(colors: Iterable[ManimColor]) -> Callable[[Sequence[float]], Vect4Array]:
|
||||
"""
|
||||
Returns a funciton which takes in values between 0 and 1, and returns
|
||||
a corresponding list of rgba values
|
||||
"""
|
||||
rgbas = np.array([color_to_rgba(color) for color in colors])
|
||||
|
||||
def func(values):
|
||||
alphas = np.clip(values, 0, 1)
|
||||
scaled_alphas = alphas * (len(rgbas) - 1)
|
||||
indices = scaled_alphas.astype(int)
|
||||
next_indices = np.clip(indices + 1, 0, len(rgbas) - 1)
|
||||
inter_alphas = scaled_alphas % 1
|
||||
inter_alphas = inter_alphas.repeat(4).reshape((len(indices), 4))
|
||||
result = interpolate(rgbas[indices], rgbas[next_indices], inter_alphas)
|
||||
return result
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def get_color_map(map_name: str) -> Callable[[Sequence[float]], Vect4Array]:
|
||||
if map_name == "3b1b_colormap":
|
||||
return get_colormap_from_colors(COLORMAP_3B1B)
|
||||
return plt.get_cmap(map_name)
|
||||
|
||||
|
||||
# Delete this?
|
||||
def get_colormap_list(
|
||||
map_name: str = "viridis",
|
||||
n_colors: int = 9
|
||||
|
|
Loading…
Add table
Reference in a new issue