Video work (#2402)

* Bug fix for TransformMatchingStrings with incompatible lengths

* Change faded line in NumberPlane initialization to be more explicit, and lower opacity

* Add option hide_zero_components_on_complex to DecimalNumber

* Validate syntax before reloading

* Add remembered stroke_config to TracedPath

* Add CLAUDE.md to gitignore

* Move pre-calculated traced points to TracingTail

* Fix interplay between time_span and alpha in Animation

* Clearer init for points in TracingTail

* Fix CoordinateSystem.get_area_under_graph

* Allow ComplexPlane.n2p to take in array of complex numbers

* Add put_start_on and put_end_on

* Add Slider

* Add \minus option for Tex to give shorter negative sign

* Put interp_by_hsl option in various color interpretation functions

* Swap priority of matched_keys vs key_map is TransformMatchingStrings

* Have z-index apply recursively

* Set self.svg_string property for SVGMobject

* Fix num_decimal_places config in Tex.make_number_changeable

* Add Surface. color_by_uv_function

* Add VMobject. set_color_by_proportion

* Add \mathcal to tex_to_symbol_count
This commit is contained in:
Grant Sanderson 2025-10-14 09:15:39 -05:00 committed by GitHub
parent 41613db7ec
commit e5298385ed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 227 additions and 68 deletions

1
.gitignore vendored
View file

@ -152,3 +152,4 @@ dmypy.json
/videos
/custom_config.yml
test.py
CLAUDE.md

View file

@ -88,7 +88,7 @@ class TransformMatchingParts(AnimationGroup):
if not source_is_new or not target_is_new:
return
transform_type = self.mismatch_animation
transform_type = self.mismatch_animation
if source.has_same_shape_as(target):
transform_type = self.match_animation
@ -154,16 +154,16 @@ class TransformMatchingStrings(TransformMatchingParts):
counts2 = list(map(target.substr_to_path_count, syms2))
# Start with user specified matches
blocks = [(source[key], target[key]) for key in matched_keys]
blocks += [(source[key1], target[key2]) for key1, key2 in key_map.items()]
blocks = [(source[key1], target[key2]) for key1, key2 in key_map.items()]
blocks += [(source[key], target[key]) for key in matched_keys]
# Nullify any intersections with those matches in the two symbol lists
for sub_source, sub_target in blocks:
for i in range(len(syms1)):
if source[i] in sub_source.family_members_with_points():
if i < len(source) and source[i] in sub_source.family_members_with_points():
syms1[i] = "Null1"
for j in range(len(syms2)):
if target[j] in sub_target.family_members_with_points():
if j < len(target) and target[j] in sub_target.family_members_with_points():
syms2[j] = "Null2"
# Group together longest matching substrings

View file

@ -46,7 +46,8 @@ class UpdateFromAlphaFunc(Animation):
super().__init__(mobject, suspend_mobject_updating=suspend_mobject_updating, **kwargs)
def interpolate_mobject(self, alpha: float) -> None:
self.update_function(self.mobject, alpha)
true_alpha = self.rate_func(self.time_spanned_alpha(alpha))
self.update_function(self.mobject, true_alpha)
class MaintainPositionRelativeTo(Animation):

View file

@ -101,10 +101,17 @@ class TracedPath(VMobject):
traced_point_func: Callable[[], Vect3],
time_traced: float = np.inf,
time_per_anchor: float = 1.0 / 15,
stroke_width: float | Iterable[float] = 2.0,
stroke_color: ManimColor = DEFAULT_MOBJECT_COLOR,
stroke_width: float | Iterable[float] = 2.0,
stroke_opacity: float = 1.0,
**kwargs
):
self.stroke_config = dict(
color=stroke_color,
width=stroke_width,
opacity=stroke_opacity,
)
super().__init__(**kwargs)
self.traced_point_func = traced_point_func
self.time_traced = time_traced
@ -112,7 +119,6 @@ class TracedPath(VMobject):
self.time: float = 0
self.traced_points: list[np.ndarray] = []
self.add_updater(lambda m, dt: m.update_path(dt))
self.always.set_stroke(stroke_color, stroke_width)
def update_path(self, dt: float) -> Self:
if dt == 0:
@ -136,6 +142,8 @@ class TracedPath(VMobject):
if points:
self.set_points_smoothly(points)
self.set_stroke(**self.stroke_config)
self.time += dt
return self
@ -145,21 +153,24 @@ class TracingTail(TracedPath):
self,
mobject_or_func: Mobject | Callable[[], np.ndarray],
time_traced: float = 1.0,
stroke_color: ManimColor = DEFAULT_MOBJECT_COLOR,
stroke_width: float | Iterable[float] = (0, 3),
stroke_opacity: float | Iterable[float] = (0, 1),
stroke_color: ManimColor = DEFAULT_MOBJECT_COLOR,
**kwargs
):
if isinstance(mobject_or_func, Mobject):
func = mobject_or_func.get_center
else:
func = mobject_or_func
super().__init__(
func,
time_traced=time_traced,
stroke_color=stroke_color,
stroke_width=stroke_width,
stroke_opacity=stroke_opacity,
stroke_color=stroke_color,
**kwargs
)
self.add_updater(lambda m: m.set_stroke(width=stroke_width, opacity=stroke_opacity))
curr_point = self.traced_point_func()
n_points = int(self.time_traced / self.time_per_anchor)
self.traced_points: list[np.ndarray] = n_points * [curr_point]

View file

@ -412,15 +412,19 @@ class CoordinateSystem(ABC):
rect.set_fill(negative_color)
return result
def get_area_under_graph(self, graph, x_range, fill_color=BLUE, fill_opacity=0.5):
if not hasattr(graph, "x_range"):
raise Exception("Argument `graph` must have attribute `x_range`")
def get_area_under_graph(self, graph, x_range=None, fill_color=BLUE, fill_opacity=0.5):
if x_range is None:
x_range = [
self.x_axis.p2n(graph.get_start()),
self.x_axis.p2n(graph.get_end()),
]
alpha_bounds = [
inverse_interpolate(*graph.x_range, x)
inverse_interpolate(*graph.x_range[:2], x)
for x in x_range
]
sub_graph = graph.copy()
sub_graph.clear_updaters()
sub_graph.pointwise_become_partial(graph, *alpha_bounds)
sub_graph.add_line_to(self.c2p(x_range[1], 0))
sub_graph.add_line_to(self.c2p(x_range[0], 0))
@ -638,7 +642,10 @@ class NumberPlane(Axes):
stroke_opacity=1,
),
# Defaults to a faded version of line_config
faded_line_style: dict = dict(),
faded_line_style: dict = dict(
stroke_width=1,
stroke_opacity=0.25,
),
faded_line_ratio: int = 4,
make_smooth_after_applying_functions: bool = True,
**kwargs
@ -651,14 +658,8 @@ class NumberPlane(Axes):
self.init_background_lines()
def init_background_lines(self) -> None:
if not self.faded_line_style:
style = dict(self.background_line_style)
# For anything numerical, like stroke_width
# and stroke_opacity, chop it in half
for key in style:
if isinstance(style[key], numbers.Number):
style[key] *= 0.5
self.faded_line_style = style
if "stroke_color" not in self.faded_line_style:
self.faded_line_style["stroke_color"] = self.background_line_style["stroke_color"]
self.background_lines, self.faded_lines = self.get_lines()
self.background_lines.set_style(**self.background_line_style)
@ -726,11 +727,10 @@ class NumberPlane(Axes):
class ComplexPlane(NumberPlane):
def number_to_point(self, number: complex | float) -> Vect3:
number = complex(number)
return self.coords_to_point(number.real, number.imag)
def number_to_point(self, number: complex | float | np.array) -> Vect3:
return self.coords_to_point(np.real(number), np.imag(number))
def n2p(self, number: complex | float) -> Vect3:
def n2p(self, number: complex | float | np.array) -> Vect3:
return self.number_to_point(number)
def point_to_number(self, point: Vect3) -> complex:

View file

@ -1232,8 +1232,9 @@ class Mobject(object):
def set_z(self, z: float, direction: Vect3 = ORIGIN) -> Self:
return self.set_coord(z, 2, direction)
def set_z_index(self, z_index: int) -> Self:
self.z_index = z_index
def set_z_index(self, z_index: int, recurse=True) -> Self:
for mob in self.get_family(recurse):
mob.z_index = z_index
return self
def space_out_submobjects(self, factor: float = 1.5, **kwargs) -> Self:
@ -1284,6 +1285,14 @@ class Mobject(object):
self.scale((length + buff) / length)
return self
def put_start_on(self, point: Vect3) -> Self:
self.shift(point - self.get_start())
return self
def put_end_on(self, point: Vect3) -> Self:
self.shift(point - self.get_end())
return self
def put_start_and_end_on(self, start: Vect3, end: Vect3) -> Self:
curr_start, curr_end = self.get_start_and_end()
curr_vect = curr_end - curr_start

View file

@ -4,19 +4,23 @@ import numpy as np
from manimlib.constants import DOWN, LEFT, RIGHT, UP
from manimlib.constants import DEFAULT_LIGHT_COLOR
from manimlib.constants import MED_SMALL_BUFF
from manimlib.mobject.geometry import Line
from manimlib.constants import MED_SMALL_BUFF, SMALL_BUFF
from manimlib.constants import YELLOW, DEG
from manimlib.mobject.geometry import Line, ArrowTip
from manimlib.mobject.numbers import DecimalNumber
from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.value_tracker import ValueTracker
from manimlib.utils.bezier import interpolate
from manimlib.utils.bezier import outer_interpolate
from manimlib.utils.dict_ops import merge_dicts_recursively
from manimlib.utils.simple_functions import fdiv
from manimlib.utils.space_ops import rotate_vector, angle_of_vector
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Iterable, Optional
from typing import Iterable, Optional, Tuple, Dict, Any
from manimlib.typing import ManimColor, Vect3, Vect3Array, VectN, RangeSpecifier
@ -235,3 +239,67 @@ class UnitInterval(NumberLine):
decimal_number_config=decimal_number_config,
**kwargs
)
class Slider(VGroup):
def __init__(
self,
value_tracker: ValueTracker,
x_range: Tuple[float, float] = (-5, 5),
var_name: Optional[str] = None,
width: float = 3,
unit_size: float = 1,
arrow_width: float = 0.15,
arrow_length: float = 0.15,
arrow_color: ManimColor = YELLOW,
font_size: int = 24,
label_buff: float = SMALL_BUFF,
num_decimal_places: int = 2,
tick_size: float = 0.05,
number_line_config: Dict[str, Any] = dict(),
arrow_tip_config: Dict[str, Any] = dict(),
decimal_config: Dict[str, Any] = dict(),
angle: float = 0,
label_direction: Optional[np.ndarray] = None,
add_tick_labels: bool = True,
tick_label_font_size: int = 16,
):
get_value = value_tracker.get_value
if label_direction is None:
label_direction = np.round(rotate_vector(UP, angle), 2)
# Initialize number line
number_line_kw = dict(x_range=x_range, width=width, tick_size=tick_size)
number_line_kw.update(number_line_config)
number_line = NumberLine(**number_line_kw)
number_line.rotate(angle)
if add_tick_labels:
number_line.add_numbers(
font_size=tick_label_font_size,
buff=2 * tick_size,
direction=-label_direction
)
# Initialize arrow tip
arrow_tip_kw = dict(
width=arrow_width,
length=arrow_length,
fill_color=arrow_color,
angle=-180 * DEG + angle_of_vector(label_direction),
)
arrow_tip_kw.update(arrow_tip_config)
tip = ArrowTip(**arrow_tip_kw)
tip.add_updater(lambda m: m.move_to(number_line.n2p(get_value()), -label_direction))
# Initialize label
dec_string = f"{{:.{num_decimal_places}f}}".format(0)
lhs = f"{var_name} = " if var_name is not None else ""
label = Tex(lhs + dec_string, font_size=font_size)
label[var_name].set_fill(arrow_color)
decimal = label.make_number_changeable(dec_string)
decimal.add_updater(lambda m: m.set_value(get_value()))
label.add_updater(lambda m: m.next_to(tip, label_direction, label_buff))
# Assemble group
super().__init__(number_line, tip, label)
self.set_stroke(behind=True)

View file

@ -47,6 +47,7 @@ class DecimalNumber(VMobject):
show_ellipsis: bool = False,
unit: str | None = None, # Aligned to bottom unless it starts with "^"
include_background_rectangle: bool = False,
hide_zero_components_on_complex: bool = True,
edge_to_fix: Vect3 = LEFT,
font_size: float = 48,
text_config: dict = dict(), # Do not pass in font_size here
@ -60,6 +61,7 @@ class DecimalNumber(VMobject):
self.show_ellipsis = show_ellipsis
self.unit = unit
self.include_background_rectangle = include_background_rectangle
self.hide_zero_components_on_complex = hide_zero_components_on_complex
self.edge_to_fix = edge_to_fix
self.font_size = font_size
self.text_config = dict(text_config)
@ -120,7 +122,14 @@ class DecimalNumber(VMobject):
def get_num_string(self, number: float | complex) -> str:
if isinstance(number, complex):
formatter = self.get_complex_formatter()
if self.hide_zero_components_on_complex and number.imag == 0:
number = number.real
formatter = self.get_formatter()
elif self.hide_zero_components_on_complex and number.real == 0:
number = number.imag
formatter = self.get_formatter() + "i"
else:
formatter = self.get_complex_formatter()
else:
formatter = self.get_formatter()
if self.num_decimal_places == 0 and isinstance(number, float):

View file

@ -73,7 +73,7 @@ class SVGMobject(VMobject):
elif file_name != "":
self.svg_string = self.file_name_to_svg_string(file_name)
elif self.file_name != "":
self.file_name_to_svg_string(self.file_name)
self.svg_string = self.file_name_to_svg_string(self.file_name)
else:
raise Exception("Must specify either a file_name or svg_string SVGMobject")

View file

@ -242,15 +242,10 @@ class Tex(StringMobject):
decimal_mobs = []
for part in parts:
if "." in substr:
num_decimal_places = len(substr.split(".")[1])
else:
num_decimal_places = 0
decimal_mob = DecimalNumber(
float(value),
num_decimal_places=num_decimal_places,
**config,
)
if "num_decimal_places" not in config:
ndp = len(substr.split(".")[1]) if "." in substr else 0
config["num_decimal_places"] = ndp
decimal_mob = DecimalNumber(float(value), **config)
decimal_mob.replace(part)
decimal_mob.match_style(part)
if len(part) > 1:

View file

@ -75,16 +75,12 @@ class Surface(Mobject):
@Mobject.affects_data
def init_points(self):
dim = self.dim
nu, nv = self.resolution
u_range = np.linspace(*self.u_range, nu)
v_range = np.linspace(*self.v_range, nv)
# Get three lists:
# - Points generated by pure uv values
# - Those generated by values nudged by du
# - Those generated by values nudged by dv
uv_grid = np.array([[[u, v] for v in v_range] for u in u_range])
nu, nv = self.resolution
uv_grid = self.get_uv_grid()
uv_plus_du = uv_grid.copy()
uv_plus_du[:, :, 0] += self.epsilon
uv_plus_dv = uv_grid.copy()
@ -93,7 +89,7 @@ class Surface(Mobject):
points, du_points, dv_points = [
np.apply_along_axis(
lambda p: self.uv_func(*p), 2, grid
).reshape((nu * nv, dim))
).reshape((nu * nv, self.dim))
for grid in (uv_grid, uv_plus_du, uv_plus_dv)
]
crosses = cross(du_points - points, dv_points - points)
@ -102,9 +98,20 @@ class Surface(Mobject):
self.set_points(points)
self.data['d_normal_point'] = points + self.normal_nudge * normals
def get_uv_grid(self) -> np.array:
"""
Returns an (nu, nv, 2) array of all pairs of u, v values, where
(nu, nv) is the resolution
"""
nu, nv = self.resolution
u_range = np.linspace(*self.u_range, nu)
v_range = np.linspace(*self.v_range, nv)
U, V = np.meshgrid(u_range, v_range, indexing='ij')
return np.stack([U, V], axis=-1)
def uv_to_point(self, u, v):
nu, nv = self.resolution
uv_grid = np.reshape(self.get_points(), (nu, nv, self.dim))
verts_by_uv = np.reshape(self.get_points(), (nu, nv, self.dim))
alpha1 = clip(inverse_interpolate(*self.u_range[:2], u), 0, 1)
alpha2 = clip(inverse_interpolate(*self.v_range[:2], v), 0, 1)
@ -115,10 +122,10 @@ class Surface(Mobject):
u_int_plus = min(u_int + 1, nu - 1)
v_int_plus = min(v_int + 1, nv - 1)
a = uv_grid[u_int, v_int, :]
b = uv_grid[u_int, v_int_plus, :]
c = uv_grid[u_int_plus, v_int, :]
d = uv_grid[u_int_plus, v_int_plus, :]
a = verts_by_uv[u_int, v_int, :]
b = verts_by_uv[u_int, v_int_plus, :]
c = verts_by_uv[u_int_plus, v_int, :]
d = verts_by_uv[u_int_plus, v_int_plus, :]
u_res = scaled_u % 1
v_res = scaled_v % 1
@ -240,6 +247,14 @@ class Surface(Mobject):
self.add_updater(updater)
return self
def color_by_uv_function(self, uv_to_color: Callable[[Vect2], Color]):
uv_grid = self.get_uv_grid()
self.set_rgba_array_by_color([
uv_to_color(u, v)
for u, v in uv_grid.reshape(-1, 2)
])
return self
def get_shader_vert_indices(self) -> np.ndarray:
return self.get_triangle_indices()

View file

@ -303,6 +303,11 @@ class VMobject(Mobject):
self.set_stroke(opacity=opacity, recurse=recurse)
return self
def set_color_by_proportion(self, prop_to_color: Callable[[float], Color]) -> Self:
colors = list(map(prop_to_color, np.linspace(0, 1, self.get_num_points())))
self.set_stroke(color=colors)
return self
def set_anti_alias_width(self, anti_alias_width: float, recurse: bool = True) -> Self:
self.set_uniform(recurse, anti_alias_width=anti_alias_width)
return self

View file

@ -109,6 +109,31 @@ class InteractiveSceneEmbed:
self.shell.set_custom_exc((Exception,), custom_exc)
def validate_syntax(self, file_path: str) -> bool:
"""
Validates the syntax of a Python file without executing it.
Returns True if syntax is valid, False otherwise.
Prints syntax errors to the console if found.
"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
source_code = f.read()
# Use compile() to check for syntax errors without executing
compile(source_code, file_path, 'exec')
return True
except SyntaxError as e:
print(f"\nSyntax Error in {file_path}:")
print(f" Line {e.lineno}: {e.text.strip() if e.text else ''}")
print(f" {' ' * (e.offset - 1 if e.offset else 0)}^")
print(f" {e.msg}")
return False
except Exception as e:
print(f"\nError reading {file_path}: {e}")
return False
def reload_scene(self, embed_line: int | None = None) -> None:
"""
Reloads the scene just like the `manimgl` command would do with the
@ -132,6 +157,14 @@ class InteractiveSceneEmbed:
`set_custom_exc` method, we cannot break out of the IPython shell by
this means.
"""
# Get the current file path for syntax validation
current_file = self.shell.user_module.__file__
# Validate syntax before attempting reload
if not self.validate_syntax(current_file):
print("[ERROR] Reload cancelled due to syntax errors. Fix the errors and try again.")
return
# Update the global run configuration.
run_config = manim_config.run
run_config.is_reload = True

View file

@ -24,7 +24,8 @@ default:
\usepackage{pifont}
\DisableLigatures{encoding = *, family = * }
\linespread{1}
%% Borrowed from https://tex.stackexchange.com/questions/6058/making-a-shorter-minus
\DeclareMathSymbol{\minus}{\mathbin}{AMSa}{"39}
ctex:
description: ""
compiler: xelatex

View file

@ -78,19 +78,25 @@ def int_to_hex(rgb_int: int) -> str:
def color_gradient(
reference_colors: Iterable[ManimColor],
length_of_output: int
length_of_output: int,
interp_by_hsl: bool = False,
) -> list[Color]:
if length_of_output == 0:
return []
rgbs = list(map(color_to_rgb, reference_colors))
alphas = np.linspace(0, (len(rgbs) - 1), length_of_output)
n_ref_colors = len(reference_colors)
alphas = np.linspace(0, (n_ref_colors - 1), length_of_output)
floors = alphas.astype('int')
alphas_mod1 = alphas % 1
# End edge case
alphas_mod1[-1] = 1
floors[-1] = len(rgbs) - 2
floors[-1] = n_ref_colors - 2
return [
rgb_to_color(np.sqrt(interpolate(rgbs[i]**2, rgbs[i + 1]**2, alpha)))
interpolate_color(
reference_colors[i],
reference_colors[i + 1],
alpha,
interp_by_hsl=interp_by_hsl,
)
for i, alpha in zip(floors, alphas_mod1)
]
@ -98,10 +104,16 @@ def color_gradient(
def interpolate_color(
color1: ManimColor,
color2: ManimColor,
alpha: float
alpha: float,
interp_by_hsl: bool = False,
) -> Color:
rgb = np.sqrt(interpolate(color_to_rgb(color1)**2, color_to_rgb(color2)**2, alpha))
return rgb_to_color(rgb)
if interp_by_hsl:
hsl1 = np.array(Color(color1).get_hsl())
hsl2 = np.array(Color(color2).get_hsl())
return Color(hsl=interpolate(hsl1, hsl2, alpha))
else:
rgb = np.sqrt(interpolate(color_to_rgb(color1)**2, color_to_rgb(color2)**2, alpha))
return rgb_to_color(rgb)
def interpolate_color_by_hsl(
@ -109,9 +121,7 @@ def interpolate_color_by_hsl(
color2: ManimColor,
alpha: float
) -> Color:
hsl1 = np.array(Color(color1).get_hsl())
hsl2 = np.array(Color(color2).get_hsl())
return Color(hsl=interpolate(hsl1, hsl2, alpha))
return interpolate_color(color1, color2, alpha, interp_by_hsl=True)
def average_color(*colors: ManimColor) -> Color:

View file

@ -104,6 +104,7 @@ TEX_TO_SYMBOL_COUNT = {
R"\mapsto": 2,
R"\markright": 0,
R"\mathds": 0,
R"\mathcal": 0,
R"\max": 3,
R"\mbox": 0,
R"\medskip": 0,