From 1372cf101cc6e97a863c0eb2a58a46155bd3a225 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Thu, 7 Mar 2024 09:23:02 -0300 Subject: [PATCH] Allow VGroup and Group to accept generators and iterables as arguments --- manimlib/mobject/mobject.py | 20 ++++++++++++++------ manimlib/mobject/types/vectorized_mobject.py | 17 ++++++++++------- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index e5e4939f..656d8153 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -46,12 +46,12 @@ from manimlib.utils.space_ops import get_norm from manimlib.utils.space_ops import rotation_matrix_transpose from typing import TYPE_CHECKING -from typing import TypeVar, Generic +from typing import TypeVar, Generic, Iterable SubmobjectType = TypeVar('SubmobjectType', bound='Mobject') if TYPE_CHECKING: - from typing import Callable, Iterable, Iterator, Union, Tuple, Optional + from typing import Callable, Iterator, Union, Tuple, Optional import numpy.typing as npt from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, UniformDict, Self from moderngl.context import Context @@ -2136,11 +2136,19 @@ class Mobject(object): class Group(Mobject, Generic[SubmobjectType]): - def __init__(self, *mobjects: SubmobjectType, **kwargs): - if not all([isinstance(m, Mobject) for m in mobjects]): - raise Exception("All submobjects must be of type Mobject") + def __init__(self, *mobjects: SubmobjectType | Iterable[SubmobjectType], **kwargs): super().__init__(**kwargs) - self.add(*mobjects) + self._ingest_args(*mobjects) + + def _ingest_args(self, *args: Mobject | Iterable[Mobject]): + if len(args) == 0: + return + if all(isinstance(mob, Mobject) for mob in args): + self.add(*args) + elif isinstance(args[0], Iterable): + self.add(*args[0]) + else: + raise Exception(f"Invalid argument to Group of type {type(args[0])}") def __add__(self, other: Mobject | Group) -> Self: assert isinstance(other, Mobject) diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 7972d190..7286ff6f 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -13,6 +13,7 @@ from manimlib.constants import JOINT_TYPE_MAP from manimlib.constants import ORIGIN, OUT from manimlib.constants import TAU from manimlib.mobject.mobject import Mobject +from manimlib.mobject.mobject import Group from manimlib.mobject.mobject import Point from manimlib.utils.bezier import bezier from manimlib.utils.bezier import get_quadratic_approximation_of_cubic @@ -47,11 +48,11 @@ from manimlib.shader_wrapper import ShaderWrapper from manimlib.shader_wrapper import FillShaderWrapper from typing import TYPE_CHECKING -from typing import Generic, TypeVar +from typing import Generic, TypeVar, Iterable SubVmobjectType = TypeVar('SubVmobjectType', bound='VMobject') if TYPE_CHECKING: - from typing import Callable, Iterable, Tuple, Any + from typing import Callable, Tuple, Any from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array, Self from moderngl.context import Context @@ -1417,12 +1418,14 @@ class VMobject(Mobject): return [sw for sw in shader_wrappers if len(sw.vert_data) > 0] -class VGroup(VMobject, Generic[SubVmobjectType]): - def __init__(self, *vmobjects: SubVmobjectType, **kwargs): +class VGroup(Group, VMobject, Generic[SubVmobjectType]): + def __init__(self, *vmobjects: SubVmobjectType | Iterable[SubVmobjectType], **kwargs): super().__init__(**kwargs) - self.add(*vmobjects) - if vmobjects: - self.uniforms.update(vmobjects[0].uniforms) + if any(isinstance(vmob, Mobject) and not isinstance(vmob, VMobject) for vmob in vmobjects): + raise Exception("Only VMobjects can be passed into VGroup") + self._ingest_args(*vmobjects) + if self.submobjects: + self.uniforms.update(self.submobjects[0].uniforms) def __add__(self, other: VMobject) -> Self: assert isinstance(other, VMobject)