diff --git a/example_scenes.py b/example_scenes.py index 9aa3e108..127e438d 100644 --- a/example_scenes.py +++ b/example_scenes.py @@ -174,16 +174,17 @@ class TexTransformExample(Scene): self.add(lines[0]) # The animation TransformMatchingStrings will line up parts # of the source and target which have matching substring strings. - # Here, giving it a little path_arc makes each part sort of - # rotate into their final positions, which feels appropriate - # for the idea of rearranging an equation + # Here, giving it a little path_arc makes each part rotate into + # their final positions, which feels appropriate for the idea of + # rearranging an equation self.play( TransformMatchingStrings( lines[0].copy(), lines[1], # matched_keys specifies which substring should # line up. If it's not specified, the animation - # will try its best, but may not quite give the - # intended effect + # will align the longest matching substrings. + # In this case, the substring "^2 = C^2" would + # trip it up matched_keys=["A^2", "B^2", "C^2"], # When you want a substring from the source # to go to a non-equal substring from the target, @@ -206,25 +207,57 @@ class TexTransformExample(Scene): ), ) self.wait(2) - - # You can also index into Tex mobject (or other StringMobjects) - # by substrings and regular expressions - top_equation = lines[0] - low_equation = lines[3] - - self.play(LaggedStartMap(FlashAround, low_equation["C"], lag_ratio=0.5)) - self.play(LaggedStartMap(FlashAround, low_equation["B"], lag_ratio=0.5)) - self.play(LaggedStartMap(FlashAround, top_equation[re.compile(r"\w\^2")])) - self.play(Indicate(low_equation[R"\sqrt"])) - self.wait() self.play(LaggedStartMap(FadeOut, lines, shift=2 * RIGHT)) + # TransformMatchingShapes will try to line up all pieces of a + # source mobject with those of a target, regardless of the + # what Mobject type they are. + source = Text("the morse code", height=1) + target = Text("here come dots", height=1) + saved_source = source.copy() + + self.play(Write(source)) + self.wait() + kw = dict(run_time=3, path_arc=PI / 2) + self.play(TransformMatchingShapes(source, target, **kw)) + self.wait() + self.play(TransformMatchingShapes(target, saved_source, **kw)) + self.wait() + + +class TexIndexing(Scene): + def construct(self): + # You can index into Tex mobject (or other StringMobjects) by substrings + equation = Tex(R"e^{\pi i} = -1", font_size=144) + + self.add(equation) + self.play(FlashAround(equation["e"])) + self.wait() + self.play(Indicate(equation[R"\pi"])) + self.wait() + self.play(TransformFromCopy( + equation[R"e^{\pi i}"].copy().set_opacity(0.5), + equation["-1"], + path_arc=-PI / 2, + run_time=3 + )) + self.play(FadeOut(equation)) + + # Or regular expressions + equation = Tex("A^2 + B^2 = C^2", font_size=144) + + self.play(Write(equation)) + for part in equation[re.compile(r"\w\^2")]: + self.play(FlashAround(part)) + self.wait() + self.play(FadeOut(equation)) + # Indexing by substrings like this may not work when # the order in which Latex draws symbols does not match # the order in which they show up in the string. # For example, here the infinity is drawn before the sigma # so we don't get the desired behavior. - equation = Tex(R"\sum_{n = 1}^\infty \frac{1}{n^2} = \frac{\pi^2}{6}") + equation = Tex(R"\sum_{n = 1}^\infty \frac{1}{n^2} = \frac{\pi^2}{6}", font_size=72) self.play(FadeIn(equation)) self.play(equation[R"\infty"].animate.set_color(RED)) # Doesn't hit the infinity self.wait() @@ -236,27 +269,14 @@ class TexTransformExample(Scene): equation = Tex( R"\sum_{n = 1}^\infty {1 \over n^2} = {\pi^2 \over 6}", # Explicitly mark "\infty" as a substring you might want to access - isolate=[R"\infty"] + isolate=[R"\infty"], + font_size=72 ) self.play(FadeIn(equation)) self.play(equation[R"\infty"].animate.set_color(RED)) # Got it! self.wait() self.play(FadeOut(equation)) - # TransformMatchingShapes will try to line up all pieces of a - # source mobject with those of a target, regardless of the - # what Mobject type they are. - source = Text("the morse code", height=1) - target = Text("here come dots", height=1) - - self.play(Write(source)) - self.wait() - kw = dict(run_time=3, path_arc=PI / 2) - self.play(TransformMatchingShapes(source, target, **kw)) - self.wait() - self.play(TransformMatchingShapes(target, source, **kw)) - self.wait() - class UpdatersExample(Scene): def construct(self): diff --git a/manimlib/animation/composition.py b/manimlib/animation/composition.py index a1c0ed37..e4ad5656 100644 --- a/manimlib/animation/composition.py +++ b/manimlib/animation/composition.py @@ -165,7 +165,7 @@ class LaggedStart(AnimationGroup): class LaggedStartMap(LaggedStart): def __init__( self, - AnimationClass: type, + anim_func: Callable[[Mobject], Animation], group: Mobject, arg_creator: Callable[[Mobject], tuple] | None = None, run_time: float = 2.0, @@ -175,7 +175,7 @@ class LaggedStartMap(LaggedStart): anim_kwargs = dict(kwargs) anim_kwargs.pop("lag_ratio", None) super().__init__( - *(AnimationClass(submob, **anim_kwargs) for submob in group), + *(anim_func(submob, **anim_kwargs) for submob in group), run_time=run_time, lag_ratio=lag_ratio, ) diff --git a/manimlib/animation/transform.py b/manimlib/animation/transform.py index 8192a36e..fdc262a9 100644 --- a/manimlib/animation/transform.py +++ b/manimlib/animation/transform.py @@ -74,8 +74,6 @@ class Transform(Animation): def finish(self) -> None: super().finish() self.mobject.unlock_data() - if self.target_mobject is not None and self.rate_func(1) == 1: - self.mobject.become(self.target_mobject) def create_target(self) -> Mobject: # Has no meaningful effect here, but may be useful diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index b832cf0e..7e30d2ad 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -1,19 +1,15 @@ from __future__ import annotations import itertools as it - -import numpy as np +from difflib import SequenceMatcher from manimlib.animation.composition import AnimationGroup from manimlib.animation.fading import FadeInFromPoint from manimlib.animation.fading import FadeOutToPoint -from manimlib.animation.fading import FadeTransformPieces from manimlib.animation.transform import Transform from manimlib.mobject.mobject import Mobject -from manimlib.mobject.mobject import Group -from manimlib.mobject.svg.string_mobject import StringMobject -from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject +from manimlib.mobject.svg.string_mobject import StringMobject from typing import TYPE_CHECKING @@ -131,28 +127,64 @@ class TransformMatchingStrings(TransformMatchingParts): target: StringMobject, matched_keys: Iterable[str] = [], key_map: dict[str, str] = dict(), - matched_pairs: Iterable[tuple[Mobject, Mobject]] = [], + matched_pairs: Iterable[tuple[VMobject, VMobject]] = [], **kwargs, ): - matched_pairs = list(matched_pairs) + [ - *[(source[key], target[key]) for key in matched_keys], - *[(source[key1], target[key2]) for key1, key2 in key_map.items()], - *[ - (source[substr], target[substr]) - for substr in [ - *source.get_specified_substrings(), - *target.get_specified_substrings(), - *source.get_symbol_substrings(), - *target.get_symbol_substrings(), - ] - ] + matched_pairs = [ + *matched_pairs, + *self.matching_blocks(source, target, matched_keys, key_map), ] + super().__init__( source, target, matched_pairs=matched_pairs, **kwargs, ) + def matching_blocks( + self, + source: StringMobject, + target: StringMobject, + matched_keys: Iterable[str], + key_map: dict[str, str] + ) -> list[tuple[VMobject, VMobject]]: + syms1 = source.get_symbol_substrings() + syms2 = target.get_symbol_substrings() + counts1 = list(map(source.substr_to_path_count, syms1)) + counts2 = list(map(target.substr_to_path_count, syms2)) + + # Start with user specified matches + blocks = [(source[key], target[key]) for key in matched_keys] + blocks += [(source[key1], target[key2]) for key1, key2 in key_map.items()] + + # Nullify any intersections with those matches in the two symbol lists + for sub_source, sub_target in blocks: + for i in range(len(syms1)): + if source[i] in sub_source.family_members_with_points(): + syms1[i] = "Null1" + for j in range(len(syms2)): + if target[j] in sub_target.family_members_with_points(): + syms2[j] = "Null2" + + # Group together longest matching substrings + while True: + matcher = SequenceMatcher(None, syms1, syms2) + match = matcher.find_longest_match(0, len(syms1), 0, len(syms2)) + if match.size == 0: + break + + i1 = sum(counts1[:match.a]) + i2 = sum(counts2[:match.b]) + size = sum(counts1[match.a:match.a + match.size]) + + blocks.append((source[i1:i1 + size], target[i2:i2 + size])) + + for i in range(match.size): + syms1[match.a + i] = "Null1" + syms2[match.b + i] = "Null2" + + return blocks + class TransformMatchingTex(TransformMatchingStrings): """Alias for TransformMatchingStrings""" diff --git a/manimlib/camera/camera_frame.py b/manimlib/camera/camera_frame.py index deafefe3..e8077d7b 100644 --- a/manimlib/camera/camera_frame.py +++ b/manimlib/camera/camera_frame.py @@ -41,11 +41,6 @@ class CameraFrame(Mobject): self.set_height(frame_shape[1], stretch=True) self.move_to(center_point) - def note_changed_data(self, recurse_up: bool = True): - super().note_changed_data(recurse_up) - self.get_view_matrix(refresh=True) - self.get_implied_camera_location(refresh=True) - def set_orientation(self, rotation: Rotation): self.uniforms["orientation"][:] = rotation.as_quat() return self @@ -89,7 +84,7 @@ class CameraFrame(Mobject): Returns a 4x4 for the affine transformation mapping a point into the camera's internal coordinate system """ - if refresh: + if self._data_has_changed: shift = np.identity(4) rotation = np.identity(4) scale_mat = np.identity(4) @@ -169,10 +164,12 @@ class CameraFrame(Mobject): self.rotate(dgamma, self.get_inverse_camera_rotation_matrix()[2]) return self + @Mobject.affects_data def set_focal_distance(self, focal_distance: float): self.uniforms["fovy"] = 2 * math.atan(0.5 * self.get_height() / focal_distance) return self + @Mobject.affects_data def set_field_of_view(self, field_of_view: float): self.uniforms["fovy"] = field_of_view return self @@ -202,8 +199,8 @@ class CameraFrame(Mobject): def get_field_of_view(self) -> float: return self.uniforms["fovy"] - def get_implied_camera_location(self, refresh=False) -> np.ndarray: - if refresh: + def get_implied_camera_location(self) -> np.ndarray: + if self._data_has_changed: to_camera = self.get_inverse_camera_rotation_matrix()[2] dist = self.get_focal_distance() self.camera_location = self.get_center() + dist * to_camera diff --git a/manimlib/config.py b/manimlib/config.py index e3e93420..bb5e57ef 100644 --- a/manimlib/config.py +++ b/manimlib/config.py @@ -93,6 +93,14 @@ def parse_cli(): action="store_true", help="Render to a movie file with an alpha channel", ) + parser.add_argument( + "--vcodec", + help="Video codec to use with ffmpeg", + ) + parser.add_argument( + "--pix_fmt", + help="Pixel format to use for the output of ffmpeg, defaults to `yuv420p`", + ) parser.add_argument( "-q", "--quiet", action="store_true", @@ -160,6 +168,12 @@ def parse_cli(): action="store_true", help="Show progress bar for each animation", ) + parser.add_argument( + "--prerun", + action="store_true", + help="Calculate total framecount, to display in a progress bar, by doing " + \ + "an initial run of the scene which skips animations." + ) parser.add_argument( "--video_dir", help="Directory to write video", @@ -386,7 +400,7 @@ def get_output_directory(args: Namespace, custom_config: dict) -> str: def get_file_writer_config(args: Namespace, custom_config: dict) -> dict: - return { + result = { "write_to_movie": not args.skip_animations and args.write_file, "break_into_partial_movies": custom_config["break_into_partial_movies"], "save_last_frame": args.skip_animations and args.write_file, @@ -402,6 +416,18 @@ def get_file_writer_config(args: Namespace, custom_config: dict) -> dict: "quiet": args.quiet, } + if args.vcodec: + result["video_codec"] = args.vcodec + elif args.transparent: + result["video_codec"] = 'prores_ks' + elif args.gif: + result["video_codec"] = '' + + if args.pix_fmt: + result["pix_fmt"] = args.pix_fmt + + return result + def get_window_config(args: Namespace, custom_config: dict, camera_config: dict) -> dict: # Default to making window half the screen size @@ -489,6 +515,7 @@ def get_configuration(args: Namespace) -> dict: "presenter_mode": args.presenter_mode, "leave_progress_bars": args.leave_progress_bars, "show_animation_progress": args.show_animation_progress, + "prerun": args.prerun, "embed_exception_mode": custom_config["embed_exception_mode"], "embed_error_sound": custom_config["embed_error_sound"], } diff --git a/manimlib/extract_scene.py b/manimlib/extract_scene.py index 66c27a2b..101c5725 100644 --- a/manimlib/extract_scene.py +++ b/manimlib/extract_scene.py @@ -79,40 +79,35 @@ def compute_total_frames(scene_class, scene_config): return int(total_time * scene_config["camera_config"]["fps"]) -def get_scenes_to_render(scene_classes, scene_config, config): - if config["write_all"]: - return [sc(**scene_config) for sc in scene_classes] +def scene_from_class(scene_class, scene_config, config): + fw_config = scene_config["file_writer_config"] + if fw_config["write_to_movie"] and config["prerun"]: + fw_config["total_frames"] = compute_total_frames(scene_class, scene_config) + return scene_class(**scene_config) - result = [] - for scene_name in config["scene_names"]: - found = False - for scene_class in scene_classes: - if scene_class.__name__ == scene_name: - fw_config = scene_config["file_writer_config"] - if fw_config["write_to_movie"]: - fw_config["total_frames"] = compute_total_frames(scene_class, scene_config) - scene = scene_class(**scene_config) - result.append(scene) - found = True - break - if not found and (scene_name != ""): - log.error(f"No scene named {scene_name} found") - if result: - return result - - # another case - result=[] - if len(scene_classes) == 1: - scene_classes = [scene_classes[0]] + +def get_scenes_to_render(all_scene_classes, scene_config, config): + if config["write_all"]: + return [sc(**scene_config) for sc in all_scene_classes] + + names_to_classes = {sc.__name__ : sc for sc in all_scene_classes} + scene_names = config["scene_names"] + + for name in set.difference(set(scene_names), names_to_classes): + log.error(f"No scene named {name} found") + scene_names.remove(name) + + if scene_names: + classes_to_run = [names_to_classes[name] for name in scene_names] + elif len(all_scene_classes) == 1: + classes_to_run = [all_scene_classes[0]] else: - scene_classes = prompt_user_for_choice(scene_classes) - for scene_class in scene_classes: - fw_config = scene_config["file_writer_config"] - if fw_config["write_to_movie"]: - fw_config["total_frames"] = compute_total_frames(scene_class, scene_config) - scene = scene_class(**scene_config) - result.append(scene) - return result + classes_to_run = prompt_user_for_choice(all_scene_classes) + + return [ + scene_from_class(scene_class, scene_config, config) + for scene_class in classes_to_run + ] def get_scene_classes_from_module(module): diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 34052b16..2f974389 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -105,6 +105,7 @@ class Mobject(object): self.bounding_box: Vect3Array = np.zeros((3, 3)) self._shaders_initialized: bool = False self._data_has_changed: bool = True + self.shader_code_replacements: dict[str, str] = dict() self.init_data() self._data_defaults = np.ones(1, dtype=self.data.dtype) @@ -738,7 +739,7 @@ class Mobject(object): ) if len(points1) != len(points2): return False - return bool(np.isclose(points1, points2).all()) + return bool(np.isclose(points1, points2, atol=self.get_width() * 1e-2).all()) # Creating new Mobjects from this one @@ -1895,12 +1896,12 @@ class Mobject(object): # Shader code manipulation + @affects_data def replace_shader_code(self, old: str, new: str) -> Self: - # TODO, will this work with VMobject structure, given - # that it does not simpler return shader_wrappers of - # family? - for wrapper in self.get_shader_wrapper_list(): - wrapper.replace_code(old, new) + self.shader_code_replacements[old] = new + self._shaders_initialized = False + for mob in self.get_ancestors(): + mob._shaders_initialized = False return self def set_color_by_code(self, glsl_code: str) -> Self: @@ -1967,8 +1968,10 @@ class Mobject(object): self.shader_wrapper.vert_data = self.get_shader_data() self.shader_wrapper.vert_indices = self.get_shader_vert_indices() - self.shader_wrapper.update_program_uniforms(self.get_uniforms()) + self.shader_wrapper.bind_to_mobject_uniforms(self.get_uniforms()) self.shader_wrapper.depth_test = self.depth_test + for old, new in self.shader_code_replacements.items(): + self.shader_wrapper.replace_code(old, new) return self.shader_wrapper def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]: @@ -2004,9 +2007,7 @@ class Mobject(object): shader_wrapper.generate_vao() self._data_has_changed = False for shader_wrapper in self.shader_wrappers: - shader_wrapper.depth_test = self.depth_test - shader_wrapper.update_program_uniforms(self.get_uniforms()) - shader_wrapper.update_program_uniforms(camera_uniforms, universal=True) + shader_wrapper.update_program_uniforms(camera_uniforms) shader_wrapper.pre_render() shader_wrapper.render() diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 461d0fe6..8c43b80f 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -166,20 +166,12 @@ class VMobject(Mobject): def set_rgba_array( self, rgba_array: Vect4Array, - name: str | None = None, + name: str = "stroke_rgba", recurse: bool = False ) -> Self: - if name is None: - names = ["fill_rgba", "stroke_rgba"] - else: - names = [name] - - for name in names: - super().set_rgba_array(rgba_array, name, recurse) - if name == "fill_rgba": - self.note_changed_fill() - elif name == "stroke_rgba": - self.note_changed_stroke() + super().set_rgba_array(rgba_array, name, recurse) + self.note_changed_fill() + self.note_changed_stroke() return self def set_fill( @@ -1262,11 +1254,10 @@ class VMobject(Mobject): def set_animating_status(self, is_animating: bool, recurse: bool = True): super().set_animating_status(is_animating, recurse) - if is_animating: - for submob in self.get_family(recurse): - submob.get_joint_products(refresh=True) - if not submob._use_winding_fill: - submob.get_triangulation() + for submob in self.get_family(recurse): + submob.get_joint_products(refresh=True) + if not submob._use_winding_fill: + submob.get_triangulation() return self # For shaders @@ -1284,14 +1275,14 @@ class VMobject(Mobject): self.fill_shader_wrapper = FillShaderWrapper( ctx=ctx, vert_data=fill_data, - uniforms=self.uniforms, + mobject_uniforms=self.uniforms, shader_folder=self.fill_shader_folder, render_primitive=self.fill_render_primitive, ) self.stroke_shader_wrapper = ShaderWrapper( ctx=ctx, vert_data=stroke_data, - uniforms=self.uniforms, + mobject_uniforms=self.uniforms, shader_folder=self.stroke_shader_folder, render_primitive=self.stroke_render_primitive, ) @@ -1301,6 +1292,11 @@ class VMobject(Mobject): self.fill_shader_wrapper, self.stroke_shader_wrapper, ] + for sw in self.shader_wrappers: + family = self.family_members_with_points() + rep = family[0] if family else self + for old, new in rep.shader_code_replacements.items(): + sw.replace_code(old, new) def refresh_shader_wrapper_id(self) -> Self: if not self._shaders_initialized: @@ -1309,11 +1305,6 @@ class VMobject(Mobject): wrapper.refresh_id() return self - def get_uniforms(self): - # TODO, account for submob uniforms separately? - self.uniforms.update(self.family_members_with_points()[0].uniforms) - return self.uniforms - def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]: if not self._shaders_initialized: self.init_shader_data(ctx) @@ -1325,32 +1316,25 @@ class VMobject(Mobject): fill_names = self.fill_data_names stroke_names = self.stroke_data_names - # Build up data lists + fill_family = (sm for sm in family if sm._has_fill) + stroke_family = (sm for sm in family if sm._has_stroke) + + # Build up fill data lists fill_datas = [] fill_indices = [] fill_border_datas = [] - stroke_datas = [] - back_stroke_datas = [] - for submob in family: - submob.get_joint_products() + for submob in fill_family: indices = submob.get_outer_vert_indices() - has_fill = submob._has_fill - has_stroke = submob._has_stroke - back_stroke = has_stroke and submob.stroke_behind - front_stroke = has_stroke and not submob.stroke_behind - if back_stroke: - back_stroke_datas.append(submob.data[stroke_names][indices]) - if front_stroke: - stroke_datas.append(submob.data[stroke_names][indices]) - if has_fill and submob._use_winding_fill: + if submob._use_winding_fill: data = submob.data[fill_names] data["base_point"][:] = data["point"][0] fill_datas.append(data[indices]) - if has_fill and not submob._use_winding_fill: + else: fill_datas.append(submob.data[fill_names]) fill_indices.append(submob.get_triangulation()) - if has_fill and not front_stroke: + if (not submob._has_stroke) or submob.stroke_behind: # Add fill border + submob.get_joint_products() names = list(stroke_names) names[names.index('stroke_rgba')] = 'fill_rgba' names[names.index('stroke_width')] = 'fill_border_width' @@ -1359,11 +1343,26 @@ class VMobject(Mobject): ) fill_border_datas.append(border_stroke_data[indices]) + # Build up stroke data lists + stroke_datas = [] + back_stroke_datas = [] + for submob in stroke_family: + submob.get_joint_products() + indices = submob.get_outer_vert_indices() + if submob.stroke_behind: + back_stroke_datas.append(submob.data[stroke_names][indices]) + else: + stroke_datas.append(submob.data[stroke_names][indices]) + shader_wrappers = [ self.back_stroke_shader_wrapper.read_in([*back_stroke_datas, *fill_border_datas]), self.fill_shader_wrapper.read_in(fill_datas, fill_indices or None), self.stroke_shader_wrapper.read_in(stroke_datas), ] + for sw in shader_wrappers: + rep = family[0] # Representative family member + sw.bind_to_mobject_uniforms(rep.get_uniforms()) + sw.depth_test = rep.depth_test return [sw for sw in shader_wrappers if len(sw.vert_data) > 0] @@ -1371,6 +1370,8 @@ class VGroup(VMobject): def __init__(self, *vmobjects: VMobject, **kwargs): super().__init__(**kwargs) self.add(*vmobjects) + if vmobjects: + self.uniforms.update(vmobjects[0].uniforms) def __add__(self, other: VMobject) -> Self: assert(isinstance(other, VMobject)) diff --git a/manimlib/mobject/vector_field.py b/manimlib/mobject/vector_field.py index 02463de2..38ca9dc5 100644 --- a/manimlib/mobject/vector_field.py +++ b/manimlib/mobject/vector_field.py @@ -13,6 +13,7 @@ from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.bezier import interpolate from manimlib.utils.bezier import inverse_interpolate from manimlib.utils.color import get_colormap_list +from manimlib.utils.color import rgb_to_color from manimlib.utils.dict_ops import merge_dicts_recursively from manimlib.utils.rate_functions import linear from manimlib.utils.simple_functions import sigmoid @@ -173,7 +174,10 @@ class VectorField(VGroup): **vector_config ) vect.shift(_input - origin) - vect.set_rgba_array([[*self.value_to_rgb(norm), self.opacity]]) + vect.set_color( + rgb_to_color(self.value_to_rgb(norm)), + opacity=self.opacity, + ) return vect diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index 0a764737..ed4bd5fc 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -387,7 +387,10 @@ class Scene(object): same type are grouped together, so this function creates Groups of all clusters of adjacent Mobjects in the scene """ - batches = batch_by_property(self.mobjects, lambda m: str(type(m))) + batches = batch_by_property( + self.mobjects, + lambda m: str(type(m)) + str(m.get_uniforms()) + ) for group in self.render_groups: group.clear() @@ -554,6 +557,7 @@ class Scene(object): leave=self.leave_progress_bars, ascii=True if platform.system() == 'Windows' else None, desc=desc, + bar_format="{l_bar} {n_fmt:3}/{total_fmt:3} {rate_fmt}{postfix}", ) else: return times @@ -723,6 +727,7 @@ class Scene(object): def get_state(self) -> SceneState: return SceneState(self) + @affects_mobject_list def restore_state(self, scene_state: SceneState): scene_state.restore_scene(self) diff --git a/manimlib/scene/scene_file_writer.py b/manimlib/scene/scene_file_writer.py index d288e7dd..fdcabb8f 100644 --- a/manimlib/scene/scene_file_writer.py +++ b/manimlib/scene/scene_file_writer.py @@ -47,6 +47,8 @@ class SceneFileWriter(object): quiet: bool = False, total_frames: int = 0, progress_description_len: int = 40, + video_codec: str = "libx264", + pixel_format: str = "yuv420p", ): self.scene: Scene = scene self.write_to_movie = write_to_movie @@ -63,6 +65,8 @@ class SceneFileWriter(object): self.quiet = quiet self.total_frames = total_frames self.progress_description_len = progress_description_len + self.video_codec = video_codec + self.pixel_format = pixel_format # State during file writing self.writing_process: sp.Popen | None = None @@ -262,32 +266,26 @@ class SceneFileWriter(object): '-an', # Tells FFMPEG not to expect any audio '-loglevel', 'error', ] - if self.movie_file_extension == ".mov": - # This is if the background of the exported - # video should be transparent. - command += [ - '-vcodec', 'prores_ks', - ] - elif self.movie_file_extension == ".gif": - command += [] - else: - command += [ - '-vcodec', 'libx264', - '-pix_fmt', 'yuv420p', - ] + if self.video_codec: + command += ['-vcodec', self.video_codec] + if self.pixel_format: + command += ['-pix_fmt', self.pixel_format] command += [self.temp_file_path] self.writing_process = sp.Popen(command, stdin=sp.PIPE) - if self.total_frames > 0 and not self.quiet: + if not self.quiet: self.progress_display = ProgressDisplay( range(self.total_frames), - # bar_format="{l_bar}{bar}|{n_fmt}/{total_fmt}", leave=False, ascii=True if platform.system() == 'Windows' else None, dynamic_ncols=True, ) self.set_progress_display_description() + def use_fast_encoding(self): + self.video_codec = "libx264rgb" + self.pixel_format = "rgb32" + def begin_insert(self): # Begin writing process self.write_to_movie = True diff --git a/manimlib/shader_wrapper.py b/manimlib/shader_wrapper.py index dc5de477..4e811147 100644 --- a/manimlib/shader_wrapper.py +++ b/manimlib/shader_wrapper.py @@ -15,6 +15,7 @@ from manimlib.utils.shaders import image_path_to_texture from manimlib.utils.shaders import get_texture_id from manimlib.utils.shaders import get_fill_canvas from manimlib.utils.shaders import release_texture +from manimlib.utils.shaders import set_program_uniform from typing import TYPE_CHECKING @@ -37,7 +38,7 @@ class ShaderWrapper(object): vert_data: np.ndarray, vert_indices: Optional[np.ndarray] = None, shader_folder: Optional[str] = None, - uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables + mobject_uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures. depth_test: bool = False, render_primitive: int = moderngl.TRIANGLE_STRIP, @@ -47,13 +48,14 @@ class ShaderWrapper(object): self.vert_indices = (vert_indices or np.zeros(0)).astype(int) self.vert_attributes = vert_data.dtype.names self.shader_folder = shader_folder - self.uniforms: UniformDict = dict() self.depth_test = depth_test self.render_primitive = render_primitive + self.program_uniform_mirror: UniformDict = dict() + self.bind_to_mobject_uniforms(mobject_uniforms or dict()) + self.init_program_code() self.init_program() - self.update_program_uniforms(uniforms or dict()) if texture_paths is not None: self.init_textures(texture_paths) self.init_vao() @@ -91,14 +93,17 @@ class ShaderWrapper(object): self.ibo = None self.vao = None + def bind_to_mobject_uniforms(self, mobject_uniforms: UniformDict): + self.mobject_uniforms = mobject_uniforms + def __eq__(self, shader_wrapper: ShaderWrapper): return all(( np.all(self.vert_data == shader_wrapper.vert_data), np.all(self.vert_indices == shader_wrapper.vert_indices), self.shader_folder == shader_wrapper.shader_folder, all( - self.uniforms[key] == shader_wrapper.uniforms[key] - for key in self.uniforms + self.mobject_uniforms[key] == shader_wrapper.mobject_uniforms[key] + for key in self.mobject_uniforms ), self.depth_test == shader_wrapper.depth_test, self.render_primitive == shader_wrapper.render_primitive, @@ -122,31 +127,25 @@ class ShaderWrapper(object): def get_id(self) -> str: return self.id - def get_program_id(self) -> int: - return self.program_id - def create_id(self) -> str: # A unique id for a shader + program_id = hash("".join( + self.program_code[f"{name}_shader"] or "" + for name in ("vertex", "geometry", "fragment") + )) return "|".join(map(str, [ - self.program_id, - self.uniforms, + program_id, + self.mobject_uniforms, self.depth_test, self.render_primitive, ])) def refresh_id(self) -> None: - self.program_id = self.create_program_id() self.id = self.create_id() - def create_program_id(self) -> int: - return hash("".join(( - self.program_code[f"{name}_shader"] or "" - for name in ("vertex", "geometry", "fragment") - ))) - def replace_code(self, old: str, new: str) -> None: code_map = self.program_code - for (name, code) in code_map.items(): + for name in code_map: if code_map[name] is None: continue code_map[name] = re.sub(old, new, code_map[name]) @@ -155,9 +154,9 @@ class ShaderWrapper(object): # Changing context def use_clip_plane(self): - if "clip_plane" not in self.uniforms: + if "clip_plane" not in self.mobject_uniforms: return False - return any(self.uniforms["clip_plane"]) + return any(self.mobject_uniforms["clip_plane"]) def set_ctx_depth_test(self, enable: bool = True) -> None: if enable: @@ -222,18 +221,11 @@ class ShaderWrapper(object): assert(self.vao is not None) self.vao.render() - def update_program_uniforms(self, uniforms: UniformDict, universal: bool = False): + def update_program_uniforms(self, camera_uniforms: UniformDict): if self.program is None: return - for name, value in uniforms.items(): - if name not in self.program: - continue - if isinstance(value, np.ndarray) and value.ndim > 0: - value = tuple(value) - if universal and self.uniforms.get(name, None) == value: - continue - self.program[name].value = value - self.uniforms[name] = value + for name, value in (*self.mobject_uniforms.items(), *camera_uniforms.items()): + set_program_uniform(self.program, name, value) def get_vertex_buffer_object(self, refresh: bool = True): if refresh: diff --git a/manimlib/utils/shaders.py b/manimlib/utils/shaders.py index 2a24fd76..7cc1ba5c 100644 --- a/manimlib/utils/shaders.py +++ b/manimlib/utils/shaders.py @@ -9,7 +9,6 @@ import numpy as np from manimlib.config import parse_cli from manimlib.config import get_configuration -from manimlib.utils.customization import get_customization from manimlib.utils.directories import get_shader_dir from manimlib.utils.file_ops import find_file @@ -17,11 +16,13 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Sequence, Optional, Tuple + from manimlib.typing import UniformDict from moderngl.vertex_array import VertexArray from moderngl.framebuffer import Framebuffer ID_TO_TEXTURE: dict[int, moderngl.Texture] = dict() +PROGRAM_UNIFORM_MIRRORS: dict[int, dict[str, float | tuple]] = dict() @lru_cache() @@ -63,6 +64,38 @@ def get_shader_program( ) +def set_program_uniform( + program: moderngl.Program, + name: str, + value: float | tuple | np.ndarray +) -> bool: + """ + Sets a program uniform, and also keeps track of a dictionary + of previously set uniforms for that program so that it + doesn't needlessly reset it, requiring an exchange with gpu + memory, if it sees the same value again. + + Returns True if changed the program, False if it left it as is. + """ + + pid = id(program) + if pid not in PROGRAM_UNIFORM_MIRRORS: + PROGRAM_UNIFORM_MIRRORS[pid] = dict() + uniform_mirror = PROGRAM_UNIFORM_MIRRORS[pid] + + if type(value) is np.ndarray and value.ndim > 0: + value = tuple(value) + if uniform_mirror.get(name, None) == value: + return False + + try: + program[name].value = value + except KeyError: + return False + uniform_mirror[name] = value + return True + + @lru_cache() def get_shader_code_from_file(filename: str) -> str | None: if not filename: diff --git a/manimlib/window.py b/manimlib/window.py index 9bc5090a..299a4c8d 100644 --- a/manimlib/window.py +++ b/manimlib/window.py @@ -29,6 +29,7 @@ class Window(PygletWindow): size: tuple[int, int] = (1280, 720), samples = 0 ): + scene.window = self super().__init__(size=size, samples=samples) self.default_size = size