diff --git a/example_scenes.py b/example_scenes.py index 3134d484..9aa3e108 100644 --- a/example_scenes.py +++ b/example_scenes.py @@ -26,7 +26,7 @@ class OpeningManimExample(Scene): matrix = [[1, 1], [0, 1]] linear_transform_words = VGroup( Text("This is what the matrix"), - IntegerMatrix(matrix, include_background_rectangle=True, h_buff=1.0), + IntegerMatrix(matrix, include_background_rectangle=True), Text("looks like") ) linear_transform_words.arrange(RIGHT) diff --git a/manimlib/mobject/matrix.py b/manimlib/mobject/matrix.py index b132381f..45b32395 100644 --- a/manimlib/mobject/matrix.py +++ b/manimlib/mobject/matrix.py @@ -76,7 +76,7 @@ class Matrix(VMobject): self, matrix: Sequence[Sequence[str | float | VMobject]], v_buff: float = 0.8, - h_buff: float = 1.3, + h_buff: float = 1.0, bracket_h_buff: float = 0.2, bracket_v_buff: float = 0.25, add_background_rectangles_to_entries: bool = False, diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index c3554ee2..1f794aff 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -376,7 +376,7 @@ class Mobject(object): return [self] def family_members_with_points(self) -> list[Self]: - return [m for m in self.get_family() if m.has_points()] + return [m for m in self.family if len(m.data) > 0] def get_ancestors(self, extended: bool = False) -> list[Mobject]: """ @@ -449,7 +449,9 @@ class Mobject(object): return self def set_submobjects(self, submobject_list: list[Mobject]) -> Self: - self.remove(*self.submobjects, reassemble=False) + if self.submobjects == submobject_list: + return self + self.clear() self.add(*submobject_list) return self diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 38fdc076..f54d18ec 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -111,6 +111,8 @@ class VMobject(Mobject): self.anti_alias_width = anti_alias_width self.fill_border_width = fill_border_width self._use_winding_fill = use_winding_fill + self._has_fill = False + self._has_stroke = False self.needs_new_triangulation = True self.triangulation = np.zeros(0, dtype='i4') @@ -134,6 +136,16 @@ class VMobject(Mobject): return super().add(*vmobjects) # Colors + def note_changed_fill(self) -> Self: + for submob in self.get_family(): + submob._has_fill = submob.has_fill() + return self + + def note_changed_stroke(self) -> Self: + for submob in self.get_family(): + submob._has_stroke = submob.has_stroke() + return self + def init_colors(self): self.set_fill( color=self.fill_color, @@ -164,6 +176,10 @@ class VMobject(Mobject): for name in names: super().set_rgba_array(rgba_array, name, recurse) + if name == "fill_rgba": + self.note_changed_fill() + elif name == "stroke_rgba": + self.note_changed_stroke() return self def set_fill( @@ -177,6 +193,7 @@ class VMobject(Mobject): if border_width is not None: for mob in self.get_family(recurse): mob.data["fill_border_width"] = border_width + self.note_changed_fill() return self def set_stroke( @@ -202,6 +219,8 @@ class VMobject(Mobject): if background is not None: for mob in self.get_family(recurse): mob.stroke_behind = background + + self.note_changed_stroke() return self def set_backstroke( @@ -255,6 +274,8 @@ class VMobject(Mobject): if shading is not None: mob.set_shading(*shading, recurse=False) + self.note_changed_fill() + self.note_changed_stroke() return self def get_style(self) -> dict[str, Any]: @@ -378,10 +399,12 @@ class VMobject(Mobject): return self.uniforms["anti_alias_width"] def has_stroke(self) -> bool: - return any(self.data['stroke_width']) and any(self.data['stroke_rgba'][:, 3]) + data = self.data if len(self.data) > 0 else self._data_defaults + return any(data['stroke_width']) and any(data['stroke_rgba'][:, 3]) def has_fill(self) -> bool: - return any(self.data['fill_rgba'][:, 3]) + data = self.data if len(self.data) > 0 else self._data_defaults + return any(data['fill_rgba'][:, 3]) def get_opacity(self) -> float: if self.has_fill(): @@ -971,6 +994,10 @@ class VMobject(Mobject): *args, **kwargs ) -> Self: super().interpolate(mobject1, mobject2, alpha, *args, **kwargs) + + self._has_stroke = mobject1._has_stroke or mobject2._has_stroke + self._has_fill = mobject1._has_fill or mobject2._has_fill + if self.has_fill() and not self._use_winding_fill: tri1 = mobject1.get_triangulation() tri2 = mobject2.get_triangulation() @@ -1026,13 +1053,22 @@ class VMobject(Mobject): vmob.pointwise_become_partial(self, a, b) return vmob + def resize_points( + self, + new_length: int, + resize_func: Callable[[np.ndarray, int], np.ndarray] = resize_array + ) -> Self: + super().resize_points(new_length, resize_func) + + n_curves = self.get_num_curves() + # Creates the pattern (0, 1, 2, 2, 3, 4, 4, 5, 6, ...) + self.outer_vert_indices = (np.arange(1, 3 * n_curves + 1) * 2) // 3 + return self + def get_outer_vert_indices(self) -> np.ndarray: """ Returns the pattern (0, 1, 2, 2, 3, 4, 4, 5, 6, ...) """ - n_curves = self.get_num_curves() - if len(self.outer_vert_indices) != 3 * n_curves: - self.outer_vert_indices = (np.arange(1, 3 * n_curves + 1) * 2) // 3 return self.outer_vert_indices # Data for shaders that may need refreshing @@ -1202,6 +1238,8 @@ class VMobject(Mobject): @triggers_refreshed_triangulation def set_data(self, data: np.ndarray) -> Self: super().set_data(data) + self.note_changed_fill() + self.note_changed_stroke() return self # TODO, how to be smart about tangents here? @@ -1291,19 +1329,19 @@ class VMobject(Mobject): for submob in family: submob.get_joint_products() indices = submob.get_outer_vert_indices() - has_fill = submob.has_fill() - has_stroke = submob.has_stroke() + has_fill = submob._has_fill + has_stroke = submob._has_stroke back_stroke = has_stroke and submob.stroke_behind front_stroke = has_stroke and not submob.stroke_behind if back_stroke: back_stroke_datas.append(submob.data[stroke_names][indices]) if front_stroke: stroke_datas.append(submob.data[stroke_names][indices]) - if has_fill and self._use_winding_fill: + if has_fill and submob._use_winding_fill: data = submob.data[fill_names] data["base_point"][:] = data["point"][0] fill_datas.append(data[indices]) - if has_fill and not self._use_winding_fill: + if has_fill and not submob._use_winding_fill: fill_datas.append(submob.data[fill_names]) fill_indices.append(submob.get_triangulation()) if has_fill and not front_stroke: diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index 2ebf1dec..0a764737 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -7,6 +7,7 @@ import platform import pyperclip import random import time +from functools import wraps from IPython.terminal import pt_inputhooks from IPython.terminal.embed import InteractiveShellEmbed @@ -38,6 +39,7 @@ from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.scene.scene_file_writer import SceneFileWriter from manimlib.utils.family_ops import extract_mobject_family_members from manimlib.utils.family_ops import recursive_mobject_remove +from manimlib.utils.iterables import batch_by_property from typing import TYPE_CHECKING @@ -119,6 +121,7 @@ class Scene(object): self.file_writer = SceneFileWriter(self, **self.file_writer_config) self.mobjects: list[Mobject] = [self.camera.frame] + self.render_groups: list[Mobject] = [] self.id_to_mobject_map: dict[int, Mobject] = dict() self.num_plays: int = 0 self.time: float = 0 @@ -298,7 +301,7 @@ class Scene(object): def get_image(self) -> Image: if self.window is not None: self.camera.use_window_fbo(False) - self.camera.capture(*self.mobjects) + self.camera.capture(*self.render_groups) image = self.camera.get_image() if self.window is not None: self.camera.use_window_fbo(True) @@ -319,7 +322,7 @@ class Scene(object): if self.window: self.window.clear() - self.camera.capture(*self.mobjects) + self.camera.capture(*self.render_groups) if self.window: self.window.swap_buffers() @@ -378,6 +381,30 @@ class Scene(object): def get_mobject_family_members(self) -> list[Mobject]: return extract_mobject_family_members(self.mobjects) + def assemble_render_groups(self): + """ + Rendering can be more efficient when mobjects of the + same type are grouped together, so this function creates + Groups of all clusters of adjacent Mobjects in the scene + """ + batches = batch_by_property(self.mobjects, lambda m: str(type(m))) + + for group in self.render_groups: + group.clear() + self.render_groups = [ + batch[0].get_group_class()(*batch) + for batch, key in batches + ] + + def affects_mobject_list(func: Callable): + @wraps(func) + def wrapper(self, *args, **kwargs): + func(self, *args, **kwargs) + self.assemble_render_groups() + return self + return wrapper + + @affects_mobject_list def add(self, *new_mobjects: Mobject): """ Mobjects will be displayed, from background to @@ -404,6 +431,7 @@ class Scene(object): )) return self + @affects_mobject_list def replace(self, mobject: Mobject, *replacements: Mobject): if mobject in self.mobjects: index = self.mobjects.index(mobject) @@ -414,6 +442,7 @@ class Scene(object): ] return self + @affects_mobject_list def remove(self, *mobjects_to_remove: Mobject): """ Removes anything in mobjects from scenes mobject list, but in the event that one @@ -431,11 +460,13 @@ class Scene(object): self.add(*mobjects) return self + @affects_mobject_list def bring_to_back(self, *mobjects: Mobject): self.remove(*mobjects) self.mobjects = list(mobjects) + self.mobjects return self + @affects_mobject_list def clear(self): self.mobjects = [] return self @@ -608,6 +639,7 @@ class Scene(object): else: self.update_mobjects(0) + @affects_mobject_list def play( self, *proto_animations: Animation | _AnimationBuilder, diff --git a/manimlib/shaders/quadratic_bezier_fill/frag.glsl b/manimlib/shaders/quadratic_bezier_fill/frag.glsl index 22d0edfb..4cfed975 100644 --- a/manimlib/shaders/quadratic_bezier_fill/frag.glsl +++ b/manimlib/shaders/quadratic_bezier_fill/frag.glsl @@ -6,16 +6,12 @@ in vec4 color; in float fill_all; in float orientation; in vec2 uv_coords; -in vec3 point; -in vec3 unit_normal; out vec4 frag_color; -#INSERT finalize_color.glsl - void main() { if (color.a == 0) discard; - frag_color = finalize_color(color, point, unit_normal); + frag_color = color; /* We want negatively oriented triangles to be canceled with positively oriented ones. The easiest way to do this is to give them negative alpha, @@ -33,9 +29,11 @@ void main() { cap is to make sure the original fragment color can be recovered even after blending with an (alpha = 1) color. */ - float a = 0.95 * frag_color.a; - if(winding && orientation < 0) a = -a / (1 - a); - frag_color.a = a; + if(winding){ + float a = 0.95 * frag_color.a; + if(orientation < 0) a = -a / (1 - a); + frag_color.a = a; + } if (bool(fill_all)) return; diff --git a/manimlib/shaders/quadratic_bezier_fill/geom.glsl b/manimlib/shaders/quadratic_bezier_fill/geom.glsl index c9428e67..99e10049 100644 --- a/manimlib/shaders/quadratic_bezier_fill/geom.glsl +++ b/manimlib/shaders/quadratic_bezier_fill/geom.glsl @@ -14,8 +14,6 @@ in vec3 v_unit_normal[3]; out vec4 color; out float fill_all; out float orientation; -out vec3 point; -out vec3 unit_normal; // uv space is where the curve coincides with y = x^2 out vec2 uv_coords; @@ -28,9 +26,12 @@ const vec2 SIMPLE_QUADRATIC[3] = vec2[3]( // Analog of import for manim only #INSERT emit_gl_Position.glsl +#INSERT finalize_color.glsl void emit_triangle(vec3 points[3], vec4 v_color[3]){ + vec3 unit_normal = v_unit_normal[1]; + orientation = sign(determinant(mat3( unit_normal, points[1] - points[0], @@ -39,8 +40,7 @@ void emit_triangle(vec3 points[3], vec4 v_color[3]){ for(int i = 0; i < 3; i++){ uv_coords = SIMPLE_QUADRATIC[i]; - color = v_color[i]; - point = points[i]; + color = finalize_color(v_color[i], points[i], unit_normal); emit_gl_Position(points[i]); EmitVertex(); } @@ -61,8 +61,6 @@ void main(){ // the first anchor is set equal to that anchor if (verts[0] == verts[1]) return; - unit_normal = v_unit_normal[1]; - if(winding){ // Emit main triangle fill_all = 1.0; diff --git a/manimlib/utils/tex.py b/manimlib/utils/tex.py index 719d228a..791cdc5c 100644 --- a/manimlib/utils/tex.py +++ b/manimlib/utils/tex.py @@ -16,7 +16,7 @@ def num_tex_symbols(tex: str) -> int: # \begin{array}{cc}, etc. pattern = "|".join( rf"(\\{s})" + r"(\{\w+\})?(\{\w+\})?(\[\w+\])?" - for s in ["begin", "end", "phantom"] + for s in ["begin", "end", "phantom", "text"] ) tex = re.sub(pattern, "", tex)