diff --git a/.gitignore b/.gitignore index c8350ec3..2822abd4 100644 --- a/.gitignore +++ b/.gitignore @@ -152,3 +152,4 @@ dmypy.json /videos /custom_config.yml test.py +CLAUDE.md diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index 7e30d2ad..72a79e48 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -88,7 +88,7 @@ class TransformMatchingParts(AnimationGroup): if not source_is_new or not target_is_new: return - transform_type = self.mismatch_animation + transform_type = self.mismatch_animation if source.has_same_shape_as(target): transform_type = self.match_animation @@ -154,16 +154,16 @@ class TransformMatchingStrings(TransformMatchingParts): counts2 = list(map(target.substr_to_path_count, syms2)) # Start with user specified matches - blocks = [(source[key], target[key]) for key in matched_keys] - blocks += [(source[key1], target[key2]) for key1, key2 in key_map.items()] + blocks = [(source[key1], target[key2]) for key1, key2 in key_map.items()] + blocks += [(source[key], target[key]) for key in matched_keys] # Nullify any intersections with those matches in the two symbol lists for sub_source, sub_target in blocks: for i in range(len(syms1)): - if source[i] in sub_source.family_members_with_points(): + if i < len(source) and source[i] in sub_source.family_members_with_points(): syms1[i] = "Null1" for j in range(len(syms2)): - if target[j] in sub_target.family_members_with_points(): + if j < len(target) and target[j] in sub_target.family_members_with_points(): syms2[j] = "Null2" # Group together longest matching substrings diff --git a/manimlib/animation/update.py b/manimlib/animation/update.py index a1d69337..ed332447 100644 --- a/manimlib/animation/update.py +++ b/manimlib/animation/update.py @@ -46,7 +46,8 @@ class UpdateFromAlphaFunc(Animation): super().__init__(mobject, suspend_mobject_updating=suspend_mobject_updating, **kwargs) def interpolate_mobject(self, alpha: float) -> None: - self.update_function(self.mobject, alpha) + true_alpha = self.rate_func(self.time_spanned_alpha(alpha)) + self.update_function(self.mobject, true_alpha) class MaintainPositionRelativeTo(Animation): diff --git a/manimlib/mobject/changing.py b/manimlib/mobject/changing.py index 15c0b504..fd399308 100644 --- a/manimlib/mobject/changing.py +++ b/manimlib/mobject/changing.py @@ -101,10 +101,17 @@ class TracedPath(VMobject): traced_point_func: Callable[[], Vect3], time_traced: float = np.inf, time_per_anchor: float = 1.0 / 15, - stroke_width: float | Iterable[float] = 2.0, stroke_color: ManimColor = DEFAULT_MOBJECT_COLOR, + stroke_width: float | Iterable[float] = 2.0, + stroke_opacity: float = 1.0, **kwargs ): + self.stroke_config = dict( + color=stroke_color, + width=stroke_width, + opacity=stroke_opacity, + ) + super().__init__(**kwargs) self.traced_point_func = traced_point_func self.time_traced = time_traced @@ -112,7 +119,6 @@ class TracedPath(VMobject): self.time: float = 0 self.traced_points: list[np.ndarray] = [] self.add_updater(lambda m, dt: m.update_path(dt)) - self.always.set_stroke(stroke_color, stroke_width) def update_path(self, dt: float) -> Self: if dt == 0: @@ -136,6 +142,8 @@ class TracedPath(VMobject): if points: self.set_points_smoothly(points) + self.set_stroke(**self.stroke_config) + self.time += dt return self @@ -145,21 +153,24 @@ class TracingTail(TracedPath): self, mobject_or_func: Mobject | Callable[[], np.ndarray], time_traced: float = 1.0, + stroke_color: ManimColor = DEFAULT_MOBJECT_COLOR, stroke_width: float | Iterable[float] = (0, 3), stroke_opacity: float | Iterable[float] = (0, 1), - stroke_color: ManimColor = DEFAULT_MOBJECT_COLOR, **kwargs ): if isinstance(mobject_or_func, Mobject): func = mobject_or_func.get_center else: func = mobject_or_func + super().__init__( func, time_traced=time_traced, + stroke_color=stroke_color, stroke_width=stroke_width, stroke_opacity=stroke_opacity, - stroke_color=stroke_color, **kwargs ) - self.add_updater(lambda m: m.set_stroke(width=stroke_width, opacity=stroke_opacity)) + curr_point = self.traced_point_func() + n_points = int(self.time_traced / self.time_per_anchor) + self.traced_points: list[np.ndarray] = n_points * [curr_point] diff --git a/manimlib/mobject/coordinate_systems.py b/manimlib/mobject/coordinate_systems.py index bb8b5aa1..11dcb8d7 100644 --- a/manimlib/mobject/coordinate_systems.py +++ b/manimlib/mobject/coordinate_systems.py @@ -412,15 +412,19 @@ class CoordinateSystem(ABC): rect.set_fill(negative_color) return result - def get_area_under_graph(self, graph, x_range, fill_color=BLUE, fill_opacity=0.5): - if not hasattr(graph, "x_range"): - raise Exception("Argument `graph` must have attribute `x_range`") + def get_area_under_graph(self, graph, x_range=None, fill_color=BLUE, fill_opacity=0.5): + if x_range is None: + x_range = [ + self.x_axis.p2n(graph.get_start()), + self.x_axis.p2n(graph.get_end()), + ] alpha_bounds = [ - inverse_interpolate(*graph.x_range, x) + inverse_interpolate(*graph.x_range[:2], x) for x in x_range ] sub_graph = graph.copy() + sub_graph.clear_updaters() sub_graph.pointwise_become_partial(graph, *alpha_bounds) sub_graph.add_line_to(self.c2p(x_range[1], 0)) sub_graph.add_line_to(self.c2p(x_range[0], 0)) @@ -638,7 +642,10 @@ class NumberPlane(Axes): stroke_opacity=1, ), # Defaults to a faded version of line_config - faded_line_style: dict = dict(), + faded_line_style: dict = dict( + stroke_width=1, + stroke_opacity=0.25, + ), faded_line_ratio: int = 4, make_smooth_after_applying_functions: bool = True, **kwargs @@ -651,14 +658,8 @@ class NumberPlane(Axes): self.init_background_lines() def init_background_lines(self) -> None: - if not self.faded_line_style: - style = dict(self.background_line_style) - # For anything numerical, like stroke_width - # and stroke_opacity, chop it in half - for key in style: - if isinstance(style[key], numbers.Number): - style[key] *= 0.5 - self.faded_line_style = style + if "stroke_color" not in self.faded_line_style: + self.faded_line_style["stroke_color"] = self.background_line_style["stroke_color"] self.background_lines, self.faded_lines = self.get_lines() self.background_lines.set_style(**self.background_line_style) @@ -726,11 +727,10 @@ class NumberPlane(Axes): class ComplexPlane(NumberPlane): - def number_to_point(self, number: complex | float) -> Vect3: - number = complex(number) - return self.coords_to_point(number.real, number.imag) + def number_to_point(self, number: complex | float | np.array) -> Vect3: + return self.coords_to_point(np.real(number), np.imag(number)) - def n2p(self, number: complex | float) -> Vect3: + def n2p(self, number: complex | float | np.array) -> Vect3: return self.number_to_point(number) def point_to_number(self, point: Vect3) -> complex: diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 6fc99055..28172cd5 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -1232,8 +1232,9 @@ class Mobject(object): def set_z(self, z: float, direction: Vect3 = ORIGIN) -> Self: return self.set_coord(z, 2, direction) - def set_z_index(self, z_index: int) -> Self: - self.z_index = z_index + def set_z_index(self, z_index: int, recurse=True) -> Self: + for mob in self.get_family(recurse): + mob.z_index = z_index return self def space_out_submobjects(self, factor: float = 1.5, **kwargs) -> Self: @@ -1284,6 +1285,14 @@ class Mobject(object): self.scale((length + buff) / length) return self + def put_start_on(self, point: Vect3) -> Self: + self.shift(point - self.get_start()) + return self + + def put_end_on(self, point: Vect3) -> Self: + self.shift(point - self.get_end()) + return self + def put_start_and_end_on(self, start: Vect3, end: Vect3) -> Self: curr_start, curr_end = self.get_start_and_end() curr_vect = curr_end - curr_start diff --git a/manimlib/mobject/number_line.py b/manimlib/mobject/number_line.py index eedbe7c3..07ad4b23 100644 --- a/manimlib/mobject/number_line.py +++ b/manimlib/mobject/number_line.py @@ -4,19 +4,23 @@ import numpy as np from manimlib.constants import DOWN, LEFT, RIGHT, UP from manimlib.constants import DEFAULT_LIGHT_COLOR -from manimlib.constants import MED_SMALL_BUFF -from manimlib.mobject.geometry import Line +from manimlib.constants import MED_SMALL_BUFF, SMALL_BUFF +from manimlib.constants import YELLOW, DEG +from manimlib.mobject.geometry import Line, ArrowTip from manimlib.mobject.numbers import DecimalNumber +from manimlib.mobject.svg.tex_mobject import Tex from manimlib.mobject.types.vectorized_mobject import VGroup +from manimlib.mobject.value_tracker import ValueTracker from manimlib.utils.bezier import interpolate from manimlib.utils.bezier import outer_interpolate from manimlib.utils.dict_ops import merge_dicts_recursively from manimlib.utils.simple_functions import fdiv +from manimlib.utils.space_ops import rotate_vector, angle_of_vector from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Iterable, Optional + from typing import Iterable, Optional, Tuple, Dict, Any from manimlib.typing import ManimColor, Vect3, Vect3Array, VectN, RangeSpecifier @@ -235,3 +239,67 @@ class UnitInterval(NumberLine): decimal_number_config=decimal_number_config, **kwargs ) + + +class Slider(VGroup): + def __init__( + self, + value_tracker: ValueTracker, + x_range: Tuple[float, float] = (-5, 5), + var_name: Optional[str] = None, + width: float = 3, + unit_size: float = 1, + arrow_width: float = 0.15, + arrow_length: float = 0.15, + arrow_color: ManimColor = YELLOW, + font_size: int = 24, + label_buff: float = SMALL_BUFF, + num_decimal_places: int = 2, + tick_size: float = 0.05, + number_line_config: Dict[str, Any] = dict(), + arrow_tip_config: Dict[str, Any] = dict(), + decimal_config: Dict[str, Any] = dict(), + angle: float = 0, + label_direction: Optional[np.ndarray] = None, + add_tick_labels: bool = True, + tick_label_font_size: int = 16, + ): + get_value = value_tracker.get_value + if label_direction is None: + label_direction = np.round(rotate_vector(UP, angle), 2) + + # Initialize number line + number_line_kw = dict(x_range=x_range, width=width, tick_size=tick_size) + number_line_kw.update(number_line_config) + number_line = NumberLine(**number_line_kw) + number_line.rotate(angle) + if add_tick_labels: + number_line.add_numbers( + font_size=tick_label_font_size, + buff=2 * tick_size, + direction=-label_direction + ) + + # Initialize arrow tip + arrow_tip_kw = dict( + width=arrow_width, + length=arrow_length, + fill_color=arrow_color, + angle=-180 * DEG + angle_of_vector(label_direction), + ) + arrow_tip_kw.update(arrow_tip_config) + tip = ArrowTip(**arrow_tip_kw) + tip.add_updater(lambda m: m.move_to(number_line.n2p(get_value()), -label_direction)) + + # Initialize label + dec_string = f"{{:.{num_decimal_places}f}}".format(0) + lhs = f"{var_name} = " if var_name is not None else "" + label = Tex(lhs + dec_string, font_size=font_size) + label[var_name].set_fill(arrow_color) + decimal = label.make_number_changeable(dec_string) + decimal.add_updater(lambda m: m.set_value(get_value())) + label.add_updater(lambda m: m.next_to(tip, label_direction, label_buff)) + + # Assemble group + super().__init__(number_line, tip, label) + self.set_stroke(behind=True) diff --git a/manimlib/mobject/numbers.py b/manimlib/mobject/numbers.py index 7fd8fe4d..ba6d73f5 100644 --- a/manimlib/mobject/numbers.py +++ b/manimlib/mobject/numbers.py @@ -47,6 +47,7 @@ class DecimalNumber(VMobject): show_ellipsis: bool = False, unit: str | None = None, # Aligned to bottom unless it starts with "^" include_background_rectangle: bool = False, + hide_zero_components_on_complex: bool = True, edge_to_fix: Vect3 = LEFT, font_size: float = 48, text_config: dict = dict(), # Do not pass in font_size here @@ -60,6 +61,7 @@ class DecimalNumber(VMobject): self.show_ellipsis = show_ellipsis self.unit = unit self.include_background_rectangle = include_background_rectangle + self.hide_zero_components_on_complex = hide_zero_components_on_complex self.edge_to_fix = edge_to_fix self.font_size = font_size self.text_config = dict(text_config) @@ -120,7 +122,14 @@ class DecimalNumber(VMobject): def get_num_string(self, number: float | complex) -> str: if isinstance(number, complex): - formatter = self.get_complex_formatter() + if self.hide_zero_components_on_complex and number.imag == 0: + number = number.real + formatter = self.get_formatter() + elif self.hide_zero_components_on_complex and number.real == 0: + number = number.imag + formatter = self.get_formatter() + "i" + else: + formatter = self.get_complex_formatter() else: formatter = self.get_formatter() if self.num_decimal_places == 0 and isinstance(number, float): diff --git a/manimlib/mobject/svg/svg_mobject.py b/manimlib/mobject/svg/svg_mobject.py index 739644ca..7bd408e0 100644 --- a/manimlib/mobject/svg/svg_mobject.py +++ b/manimlib/mobject/svg/svg_mobject.py @@ -73,7 +73,7 @@ class SVGMobject(VMobject): elif file_name != "": self.svg_string = self.file_name_to_svg_string(file_name) elif self.file_name != "": - self.file_name_to_svg_string(self.file_name) + self.svg_string = self.file_name_to_svg_string(self.file_name) else: raise Exception("Must specify either a file_name or svg_string SVGMobject") diff --git a/manimlib/mobject/svg/tex_mobject.py b/manimlib/mobject/svg/tex_mobject.py index 71767f19..18aeefa3 100644 --- a/manimlib/mobject/svg/tex_mobject.py +++ b/manimlib/mobject/svg/tex_mobject.py @@ -242,15 +242,10 @@ class Tex(StringMobject): decimal_mobs = [] for part in parts: - if "." in substr: - num_decimal_places = len(substr.split(".")[1]) - else: - num_decimal_places = 0 - decimal_mob = DecimalNumber( - float(value), - num_decimal_places=num_decimal_places, - **config, - ) + if "num_decimal_places" not in config: + ndp = len(substr.split(".")[1]) if "." in substr else 0 + config["num_decimal_places"] = ndp + decimal_mob = DecimalNumber(float(value), **config) decimal_mob.replace(part) decimal_mob.match_style(part) if len(part) > 1: diff --git a/manimlib/mobject/types/surface.py b/manimlib/mobject/types/surface.py index 914677aa..ed2cc1da 100644 --- a/manimlib/mobject/types/surface.py +++ b/manimlib/mobject/types/surface.py @@ -75,16 +75,12 @@ class Surface(Mobject): @Mobject.affects_data def init_points(self): - dim = self.dim - nu, nv = self.resolution - u_range = np.linspace(*self.u_range, nu) - v_range = np.linspace(*self.v_range, nv) - # Get three lists: # - Points generated by pure uv values # - Those generated by values nudged by du # - Those generated by values nudged by dv - uv_grid = np.array([[[u, v] for v in v_range] for u in u_range]) + nu, nv = self.resolution + uv_grid = self.get_uv_grid() uv_plus_du = uv_grid.copy() uv_plus_du[:, :, 0] += self.epsilon uv_plus_dv = uv_grid.copy() @@ -93,7 +89,7 @@ class Surface(Mobject): points, du_points, dv_points = [ np.apply_along_axis( lambda p: self.uv_func(*p), 2, grid - ).reshape((nu * nv, dim)) + ).reshape((nu * nv, self.dim)) for grid in (uv_grid, uv_plus_du, uv_plus_dv) ] crosses = cross(du_points - points, dv_points - points) @@ -102,9 +98,20 @@ class Surface(Mobject): self.set_points(points) self.data['d_normal_point'] = points + self.normal_nudge * normals + def get_uv_grid(self) -> np.array: + """ + Returns an (nu, nv, 2) array of all pairs of u, v values, where + (nu, nv) is the resolution + """ + nu, nv = self.resolution + u_range = np.linspace(*self.u_range, nu) + v_range = np.linspace(*self.v_range, nv) + U, V = np.meshgrid(u_range, v_range, indexing='ij') + return np.stack([U, V], axis=-1) + def uv_to_point(self, u, v): nu, nv = self.resolution - uv_grid = np.reshape(self.get_points(), (nu, nv, self.dim)) + verts_by_uv = np.reshape(self.get_points(), (nu, nv, self.dim)) alpha1 = clip(inverse_interpolate(*self.u_range[:2], u), 0, 1) alpha2 = clip(inverse_interpolate(*self.v_range[:2], v), 0, 1) @@ -115,10 +122,10 @@ class Surface(Mobject): u_int_plus = min(u_int + 1, nu - 1) v_int_plus = min(v_int + 1, nv - 1) - a = uv_grid[u_int, v_int, :] - b = uv_grid[u_int, v_int_plus, :] - c = uv_grid[u_int_plus, v_int, :] - d = uv_grid[u_int_plus, v_int_plus, :] + a = verts_by_uv[u_int, v_int, :] + b = verts_by_uv[u_int, v_int_plus, :] + c = verts_by_uv[u_int_plus, v_int, :] + d = verts_by_uv[u_int_plus, v_int_plus, :] u_res = scaled_u % 1 v_res = scaled_v % 1 @@ -240,6 +247,14 @@ class Surface(Mobject): self.add_updater(updater) return self + def color_by_uv_function(self, uv_to_color: Callable[[Vect2], Color]): + uv_grid = self.get_uv_grid() + self.set_rgba_array_by_color([ + uv_to_color(u, v) + for u, v in uv_grid.reshape(-1, 2) + ]) + return self + def get_shader_vert_indices(self) -> np.ndarray: return self.get_triangle_indices() diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 912afe0c..20005533 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -303,6 +303,11 @@ class VMobject(Mobject): self.set_stroke(opacity=opacity, recurse=recurse) return self + def set_color_by_proportion(self, prop_to_color: Callable[[float], Color]) -> Self: + colors = list(map(prop_to_color, np.linspace(0, 1, self.get_num_points()))) + self.set_stroke(color=colors) + return self + def set_anti_alias_width(self, anti_alias_width: float, recurse: bool = True) -> Self: self.set_uniform(recurse, anti_alias_width=anti_alias_width) return self diff --git a/manimlib/scene/scene_embed.py b/manimlib/scene/scene_embed.py index 429693dd..6ae0e560 100644 --- a/manimlib/scene/scene_embed.py +++ b/manimlib/scene/scene_embed.py @@ -109,6 +109,31 @@ class InteractiveSceneEmbed: self.shell.set_custom_exc((Exception,), custom_exc) + def validate_syntax(self, file_path: str) -> bool: + """ + Validates the syntax of a Python file without executing it. + Returns True if syntax is valid, False otherwise. + Prints syntax errors to the console if found. + """ + try: + with open(file_path, 'r', encoding='utf-8') as f: + source_code = f.read() + + # Use compile() to check for syntax errors without executing + compile(source_code, file_path, 'exec') + return True + + except SyntaxError as e: + print(f"\nSyntax Error in {file_path}:") + print(f" Line {e.lineno}: {e.text.strip() if e.text else ''}") + print(f" {' ' * (e.offset - 1 if e.offset else 0)}^") + print(f" {e.msg}") + return False + + except Exception as e: + print(f"\nError reading {file_path}: {e}") + return False + def reload_scene(self, embed_line: int | None = None) -> None: """ Reloads the scene just like the `manimgl` command would do with the @@ -132,6 +157,14 @@ class InteractiveSceneEmbed: `set_custom_exc` method, we cannot break out of the IPython shell by this means. """ + # Get the current file path for syntax validation + current_file = self.shell.user_module.__file__ + + # Validate syntax before attempting reload + if not self.validate_syntax(current_file): + print("[ERROR] Reload cancelled due to syntax errors. Fix the errors and try again.") + return + # Update the global run configuration. run_config = manim_config.run run_config.is_reload = True diff --git a/manimlib/tex_templates.yml b/manimlib/tex_templates.yml index 5627a6f3..301ad7d4 100644 --- a/manimlib/tex_templates.yml +++ b/manimlib/tex_templates.yml @@ -24,7 +24,8 @@ default: \usepackage{pifont} \DisableLigatures{encoding = *, family = * } \linespread{1} - + %% Borrowed from https://tex.stackexchange.com/questions/6058/making-a-shorter-minus + \DeclareMathSymbol{\minus}{\mathbin}{AMSa}{"39} ctex: description: "" compiler: xelatex diff --git a/manimlib/utils/color.py b/manimlib/utils/color.py index 25118d4f..6106e4ff 100644 --- a/manimlib/utils/color.py +++ b/manimlib/utils/color.py @@ -78,19 +78,25 @@ def int_to_hex(rgb_int: int) -> str: def color_gradient( reference_colors: Iterable[ManimColor], - length_of_output: int + length_of_output: int, + interp_by_hsl: bool = False, ) -> list[Color]: if length_of_output == 0: return [] - rgbs = list(map(color_to_rgb, reference_colors)) - alphas = np.linspace(0, (len(rgbs) - 1), length_of_output) + n_ref_colors = len(reference_colors) + alphas = np.linspace(0, (n_ref_colors - 1), length_of_output) floors = alphas.astype('int') alphas_mod1 = alphas % 1 # End edge case alphas_mod1[-1] = 1 - floors[-1] = len(rgbs) - 2 + floors[-1] = n_ref_colors - 2 return [ - rgb_to_color(np.sqrt(interpolate(rgbs[i]**2, rgbs[i + 1]**2, alpha))) + interpolate_color( + reference_colors[i], + reference_colors[i + 1], + alpha, + interp_by_hsl=interp_by_hsl, + ) for i, alpha in zip(floors, alphas_mod1) ] @@ -98,10 +104,16 @@ def color_gradient( def interpolate_color( color1: ManimColor, color2: ManimColor, - alpha: float + alpha: float, + interp_by_hsl: bool = False, ) -> Color: - rgb = np.sqrt(interpolate(color_to_rgb(color1)**2, color_to_rgb(color2)**2, alpha)) - return rgb_to_color(rgb) + if interp_by_hsl: + hsl1 = np.array(Color(color1).get_hsl()) + hsl2 = np.array(Color(color2).get_hsl()) + return Color(hsl=interpolate(hsl1, hsl2, alpha)) + else: + rgb = np.sqrt(interpolate(color_to_rgb(color1)**2, color_to_rgb(color2)**2, alpha)) + return rgb_to_color(rgb) def interpolate_color_by_hsl( @@ -109,9 +121,7 @@ def interpolate_color_by_hsl( color2: ManimColor, alpha: float ) -> Color: - hsl1 = np.array(Color(color1).get_hsl()) - hsl2 = np.array(Color(color2).get_hsl()) - return Color(hsl=interpolate(hsl1, hsl2, alpha)) + return interpolate_color(color1, color2, alpha, interp_by_hsl=True) def average_color(*colors: ManimColor) -> Color: diff --git a/manimlib/utils/tex_to_symbol_count.py b/manimlib/utils/tex_to_symbol_count.py index 049e9fde..e0c27d89 100644 --- a/manimlib/utils/tex_to_symbol_count.py +++ b/manimlib/utils/tex_to_symbol_count.py @@ -104,6 +104,7 @@ TEX_TO_SYMBOL_COUNT = { R"\mapsto": 2, R"\markright": 0, R"\mathds": 0, + R"\mathcal": 0, R"\max": 3, R"\mbox": 0, R"\medskip": 0,