Update the Vector Field interface

This commit is contained in:
Grant Sanderson 2024-11-12 11:21:19 -08:00
parent b84376d6fd
commit 64ae1364ca
3 changed files with 108 additions and 62 deletions

View file

@ -45,6 +45,12 @@ DEFAULT_X_RANGE = (-8.0, 8.0, 1.0)
DEFAULT_Y_RANGE = (-4.0, 4.0, 1.0) DEFAULT_Y_RANGE = (-4.0, 4.0, 1.0)
def full_range_specifier(range_args):
if len(range_args) == 2:
return (*range_args, 1)
return range_args
class CoordinateSystem(ABC): class CoordinateSystem(ABC):
""" """
Abstract class for Axes and NumberPlane Abstract class for Axes and NumberPlane
@ -57,8 +63,8 @@ class CoordinateSystem(ABC):
y_range: RangeSpecifier = DEFAULT_Y_RANGE, y_range: RangeSpecifier = DEFAULT_Y_RANGE,
num_sampled_graph_points_per_tick: int = 5, num_sampled_graph_points_per_tick: int = 5,
): ):
self.x_range = x_range self.x_range = full_range_specifier(x_range)
self.y_range = y_range self.y_range = full_range_specifier(y_range)
self.num_sampled_graph_points_per_tick = num_sampled_graph_points_per_tick self.num_sampled_graph_points_per_tick = num_sampled_graph_points_per_tick
@abstractmethod @abstractmethod
@ -536,7 +542,7 @@ class ThreeDAxes(Axes):
): ):
Axes.__init__(self, x_range, y_range, **kwargs) Axes.__init__(self, x_range, y_range, **kwargs)
self.z_range = z_range self.z_range = full_range_specifier(z_range)
self.z_axis = self.create_axis( self.z_axis = self.create_axis(
self.z_range, self.z_range,
axis_config=merge_dicts_recursively( axis_config=merge_dicts_recursively(

View file

@ -15,6 +15,7 @@ from manimlib.utils.bezier import interpolate
from manimlib.utils.bezier import inverse_interpolate from manimlib.utils.bezier import inverse_interpolate
from manimlib.utils.color import get_colormap_list from manimlib.utils.color import get_colormap_list
from manimlib.utils.color import rgb_to_color from manimlib.utils.color import rgb_to_color
from manimlib.utils.color import get_color_map
from manimlib.utils.dict_ops import merge_dicts_recursively from manimlib.utils.dict_ops import merge_dicts_recursively
from manimlib.utils.iterables import cartesian_product from manimlib.utils.iterables import cartesian_product
from manimlib.utils.rate_functions import linear from manimlib.utils.rate_functions import linear
@ -25,7 +26,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence, TypeVar, Tuple from typing import Callable, Iterable, Sequence, TypeVar, Tuple
from manimlib.typing import ManimColor, Vect3, VectN, Vect3Array from manimlib.typing import ManimColor, Vect3, VectN, Vect3Array, Vect4Array
from manimlib.mobject.coordinate_systems import CoordinateSystem from manimlib.mobject.coordinate_systems import CoordinateSystem
from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Mobject
@ -33,6 +34,7 @@ if TYPE_CHECKING:
T = TypeVar("T") T = TypeVar("T")
#### Delete these two ###
def get_vectorized_rgb_gradient_function( def get_vectorized_rgb_gradient_function(
min_value: T, min_value: T,
max_value: T, max_value: T,
@ -52,6 +54,7 @@ def get_vectorized_rgb_gradient_function(
inter_alphas = inter_alphas.repeat(3).reshape((len(indices), 3)) inter_alphas = inter_alphas.repeat(3).reshape((len(indices), 3))
result = interpolate(rgbs[indices], rgbs[next_indices], inter_alphas) result = interpolate(rgbs[indices], rgbs[next_indices], inter_alphas)
return result return result
return func return func
@ -62,6 +65,7 @@ def get_rgb_gradient_function(
) -> Callable[[float], Vect3]: ) -> Callable[[float], Vect3]:
vectorized_func = get_vectorized_rgb_gradient_function(min_value, max_value, color_map) vectorized_func = get_vectorized_rgb_gradient_function(min_value, max_value, color_map)
return lambda value: vectorized_func(np.array([value]))[0] return lambda value: vectorized_func(np.array([value]))[0]
####
def move_along_vector_field( def move_along_vector_field(
@ -106,16 +110,16 @@ def move_points_along_vector_field(
return mobject return mobject
def get_sample_points_from_coordinate_system( def get_sample_coords(
coordinate_system: CoordinateSystem, coordinate_system: CoordinateSystem,
step_multiple: float step_multiple: float = 1.0
) -> it.product[tuple[Vect3, ...]]: ) -> it.product[tuple[Vect3, ...]]:
ranges = [] ranges = []
for range_args in coordinate_system.get_all_ranges(): for range_args in coordinate_system.get_all_ranges():
_min, _max, step = range_args _min, _max, step = range_args
step *= step_multiple step *= step_multiple
ranges.append(np.arange(_min, _max + step, step)) ranges.append(np.arange(_min, _max + step, step))
return it.product(*ranges) return np.array(list(it.product(*ranges)))
# Mobjects # Mobjects
@ -124,51 +128,53 @@ def get_sample_points_from_coordinate_system(
class VectorField(VMobject): class VectorField(VMobject):
def __init__( def __init__(
self, self,
func, func: Callable[Sequence[float], Sequence[float]],
coordinate_system: CoordinateSystem,
step_multiple: float = 0.5,
magnitude_range: Optional[Tuple[float, float]] = None,
color_map_name: Optional[str] = "3b1b_colormap",
color_map: Optional[Callable[Sequence[float]], Vect4Array] = None,
stroke_color: ManimColor = BLUE, stroke_color: ManimColor = BLUE,
stroke_opacity: float = 1.0, stroke_opacity: float = 1.0,
center: Vect3 = ORIGIN,
sample_points: Optional[Vect3Array] = None,
x_density: float = 2.0,
y_density: float = 2.0,
z_density: float = 2.0,
width: float = 14.0,
height: float = 8.0,
depth: float = 0.0,
stroke_width: float = 2, stroke_width: float = 2,
tip_width_ratio: float = 4, tip_width_ratio: float = 4,
tip_len_to_width: float = 0.01, tip_len_to_width: float = 0.01,
max_vect_len: float | None = None, max_vect_len: float | None = None,
min_drawn_norm: float = 1e-2,
flat_stroke: bool = False, flat_stroke: bool = False,
norm_to_opacity_func=None, norm_to_opacity_func=None, # TODO, check on this
norm_to_rgb_func=None,
**kwargs **kwargs
): ):
self.func = func self.func = func
self.coordinate_system = coordinate_system
self.stroke_width = stroke_width self.stroke_width = stroke_width
self.tip_width_ratio = tip_width_ratio self.tip_width_ratio = tip_width_ratio
self.tip_len_to_width = tip_len_to_width self.tip_len_to_width = tip_len_to_width
self.min_drawn_norm = min_drawn_norm
self.norm_to_opacity_func = norm_to_opacity_func self.norm_to_opacity_func = norm_to_opacity_func
self.norm_to_rgb_func = norm_to_rgb_func
if max_vect_len is not None: # Search for sample_points
self.max_vect_len = max_vect_len self.sample_coords = get_sample_coords(coordinate_system, step_multiple)
self.update_sample_points()
if max_vect_len is None:
self.max_displayed_vect_len = get_norm(self.sample_points[1] - self.sample_points[0])
else: else:
densities = np.array([x_density, y_density, z_density]) self.max_displayed_vect_len = max_vect_len * coordinate_system.get_x_unit_size()
dims = np.array([width, height, depth])
self.max_vect_len = 1.0 / densities[dims > 0].mean()
if sample_points is None: # Prepare the color map
self.sample_points = self.get_sample_points( if magnitude_range is None:
center, width, height, depth, max_value = max(map(get_norm, func(self.sample_coords)))
x_density, y_density, z_density magnitude_range = (0, max_value)
)
self.magnitude_range = magnitude_range
if color_map is not None:
self.color_map = color_map
elif color_map_name is not None:
self.color_map = get_color_map(color_map_name)
else: else:
self.sample_points = sample_points self.color_map = None
self.init_base_stroke_width_array(len(self.sample_points)) self.init_base_stroke_width_array(len(self.sample_coords))
super().__init__( super().__init__(
stroke_color=stroke_color, stroke_color=stroke_color,
@ -177,7 +183,7 @@ class VectorField(VMobject):
**kwargs **kwargs
) )
n_samples = len(self.sample_points) n_samples = len(self.sample_coords)
self.set_points(np.zeros((8 * n_samples - 1, 3))) self.set_points(np.zeros((8 * n_samples - 1, 3)))
self.set_stroke(width=stroke_width) self.set_stroke(width=stroke_width)
self.set_joint_type('no_joint') self.set_joint_type('no_joint')
@ -211,8 +217,8 @@ class VectorField(VMobject):
arr[7::8] = 0 arr[7::8] = 0
self.base_stroke_width_array = arr self.base_stroke_width_array = arr
def set_sample_points(self, sample_points: Vect3Array): def set_sample_coords(self, sample_points: VectArray):
self.sample_points = sample_points self.sample_coords = sample_coords
return self return self
def set_stroke(self, color=None, width=None, opacity=None, behind=None, flat=None, recurse=True): def set_stroke(self, color=None, width=None, opacity=None, behind=None, flat=None, recurse=True):
@ -227,35 +233,40 @@ class VectorField(VMobject):
self.stroke_width = width self.stroke_width = width
return self return self
def update_sample_points(self):
self.sample_points = self.coordinate_system.c2p(*self.sample_coords.T)
def update_vectors(self): def update_vectors(self):
tip_width = self.tip_width_ratio * self.stroke_width tip_width = self.tip_width_ratio * self.stroke_width
tip_len = self.tip_len_to_width * tip_width tip_len = self.tip_len_to_width * tip_width
samples = self.sample_points
# Get raw outputs and lengths # Outputs in the coordinate system
outputs = self.func(samples) outputs = self.func(self.sample_coords)
norms = np.linalg.norm(outputs, axis=1)[:, np.newaxis] output_norms = np.linalg.norm(outputs, axis=1)[:, np.newaxis]
# How long should the arrows be drawn? # Corresponding vector values in global coordinates
max_len = self.max_vect_len out_vects = self.coordinate_system.c2p(*outputs.T) - self.coordinate_system.get_origin()
out_vect_norms = np.linalg.norm(out_vects, axis=1)[:, np.newaxis]
unit_outputs = np.zeros_like(out_vects)
np.true_divide(out_vects, out_vect_norms, out=unit_outputs, where=(out_vect_norms > 0))
# How long should the arrows be drawn, in global coordinates
max_len = self.max_displayed_vect_len
if max_len < np.inf: if max_len < np.inf:
drawn_norms = max_len * np.tanh(norms / max_len) drawn_norms = max_len * np.tanh(out_vect_norms / max_len)
else: else:
drawn_norms = norms drawn_norms = out_vect_norms
# What's the distance from the base of an arrow to # What's the distance from the base of an arrow to
# the base of its head? # the base of its head?
dist_to_head_base = np.clip(drawn_norms - tip_len, 0, np.inf) dist_to_head_base = np.clip(drawn_norms - tip_len, 0, np.inf) # Mixing units!
# Set all points # Set all points
unit_outputs = np.zeros_like(outputs)
np.true_divide(outputs, norms, out=unit_outputs, where=(norms > self.min_drawn_norm))
points = self.get_points() points = self.get_points()
points[0::8] = samples points[0::8] = self.sample_points
points[2::8] = samples + dist_to_head_base * unit_outputs points[2::8] = self.sample_points + dist_to_head_base * unit_outputs
points[4::8] = points[2::8] points[4::8] = points[2::8]
points[6::8] = samples + drawn_norms * unit_outputs points[6::8] = self.sample_points + drawn_norms * unit_outputs
for i in (1, 3, 5): for i in (1, 3, 5):
points[i::8] = 0.5 * (points[i - 1::8] + points[i + 1::8]) points[i::8] = 0.5 * (points[i - 1::8] + points[i + 1::8])
points[7::8] = points[6:-1:8] points[7::8] = points[6:-1:8]
@ -267,14 +278,16 @@ class VectorField(VMobject):
self.get_stroke_widths()[:] = width_scalars * width_arr self.get_stroke_widths()[:] = width_scalars * width_arr
# Potentially adjust opacity and color # Potentially adjust opacity and color
if self.color_map is not None:
self.get_stroke_colors() # Ensures the array is updated to appropriate length
low, high = self.magnitude_range
self.data['stroke_rgba'][:] = self.color_map(
inverse_interpolate(low, high, np.repeat(output_norms, 8)[:-1])
)
if self.norm_to_opacity_func is not None: if self.norm_to_opacity_func is not None:
self.get_stroke_opacities()[:] = self.norm_to_opacity_func( self.get_stroke_opacities()[:] = self.norm_to_opacity_func(
np.repeat(norms, 8)[:-1] np.repeat(output_norms, 8)[:-1]
)
if self.norm_to_rgb_func is not None:
self.get_stroke_colors()
self.data['stroke_rgba'][:, :3] = self.norm_to_rgb_func(
np.repeat(norms, 8)[:-1]
) )
self.note_changed_data() self.note_changed_data()
@ -321,11 +334,11 @@ class OldVectorField(VGroup):
self.opacity = opacity self.opacity = opacity
self.vector_config = dict(vector_config) self.vector_config = dict(vector_config)
self.value_to_rgb = get_rgb_gradient_function( self.value_to_rgb = get_vectorized_rgb_gradient_function(
*self.magnitude_range, self.color_map, *self.magnitude_range, self.color_map,
) )
samples = get_sample_points_from_coordinate_system( samples = get_sample_coords(
coordinate_system, self.step_multiple coordinate_system, self.step_multiple
) )
self.add(*( self.add(*(
@ -438,7 +451,7 @@ class StreamLines(VGroup):
def get_start_points(self) -> Vect3Array: def get_start_points(self) -> Vect3Array:
cs = self.coordinate_system cs = self.coordinate_system
sample_coords = get_sample_points_from_coordinate_system( sample_coords = get_sample_coords(
cs, self.step_multiple, cs, self.step_multiple,
) )

View file

@ -14,8 +14,8 @@ from manimlib.utils.iterables import resize_with_interpolation
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Iterable, Sequence from typing import Iterable, Sequence, Callable
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array, NDArray
def color_to_rgb(color: ManimColor) -> Vect3: def color_to_rgb(color: ManimColor) -> Vect3:
@ -134,6 +134,33 @@ def random_bright_color(
)) ))
def get_colormap_from_colors(colors: Iterable[ManimColor]) -> Callable[[Sequence[float]], Vect4Array]:
"""
Returns a funciton which takes in values between 0 and 1, and returns
a corresponding list of rgba values
"""
rgbas = np.array([color_to_rgba(color) for color in colors])
def func(values):
alphas = np.clip(values, 0, 1)
scaled_alphas = alphas * (len(rgbas) - 1)
indices = scaled_alphas.astype(int)
next_indices = np.clip(indices + 1, 0, len(rgbas) - 1)
inter_alphas = scaled_alphas % 1
inter_alphas = inter_alphas.repeat(4).reshape((len(indices), 4))
result = interpolate(rgbas[indices], rgbas[next_indices], inter_alphas)
return result
return func
def get_color_map(map_name: str) -> Callable[[Sequence[float]], Vect4Array]:
if map_name == "3b1b_colormap":
return get_colormap_from_colors(COLORMAP_3B1B)
return plt.get_cmap(map_name)
# Delete this?
def get_colormap_list( def get_colormap_list(
map_name: str = "viridis", map_name: str = "viridis",
n_colors: int = 9 n_colors: int = 9