Allow VGroup and Group to accept generators and iterables as arguments

This commit is contained in:
Grant Sanderson 2024-03-07 09:23:02 -03:00
parent 4d67361800
commit 1372cf101c
2 changed files with 24 additions and 13 deletions

View file

@ -46,12 +46,12 @@ from manimlib.utils.space_ops import get_norm
from manimlib.utils.space_ops import rotation_matrix_transpose from manimlib.utils.space_ops import rotation_matrix_transpose
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import TypeVar, Generic from typing import TypeVar, Generic, Iterable
SubmobjectType = TypeVar('SubmobjectType', bound='Mobject') SubmobjectType = TypeVar('SubmobjectType', bound='Mobject')
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Iterable, Iterator, Union, Tuple, Optional from typing import Callable, Iterator, Union, Tuple, Optional
import numpy.typing as npt import numpy.typing as npt
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, UniformDict, Self from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, UniformDict, Self
from moderngl.context import Context from moderngl.context import Context
@ -2136,11 +2136,19 @@ class Mobject(object):
class Group(Mobject, Generic[SubmobjectType]): class Group(Mobject, Generic[SubmobjectType]):
def __init__(self, *mobjects: SubmobjectType, **kwargs): def __init__(self, *mobjects: SubmobjectType | Iterable[SubmobjectType], **kwargs):
if not all([isinstance(m, Mobject) for m in mobjects]):
raise Exception("All submobjects must be of type Mobject")
super().__init__(**kwargs) super().__init__(**kwargs)
self.add(*mobjects) self._ingest_args(*mobjects)
def _ingest_args(self, *args: Mobject | Iterable[Mobject]):
if len(args) == 0:
return
if all(isinstance(mob, Mobject) for mob in args):
self.add(*args)
elif isinstance(args[0], Iterable):
self.add(*args[0])
else:
raise Exception(f"Invalid argument to Group of type {type(args[0])}")
def __add__(self, other: Mobject | Group) -> Self: def __add__(self, other: Mobject | Group) -> Self:
assert isinstance(other, Mobject) assert isinstance(other, Mobject)

View file

@ -13,6 +13,7 @@ from manimlib.constants import JOINT_TYPE_MAP
from manimlib.constants import ORIGIN, OUT from manimlib.constants import ORIGIN, OUT
from manimlib.constants import TAU from manimlib.constants import TAU
from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Group
from manimlib.mobject.mobject import Point from manimlib.mobject.mobject import Point
from manimlib.utils.bezier import bezier from manimlib.utils.bezier import bezier
from manimlib.utils.bezier import get_quadratic_approximation_of_cubic from manimlib.utils.bezier import get_quadratic_approximation_of_cubic
@ -47,11 +48,11 @@ from manimlib.shader_wrapper import ShaderWrapper
from manimlib.shader_wrapper import FillShaderWrapper from manimlib.shader_wrapper import FillShaderWrapper
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Generic, TypeVar from typing import Generic, TypeVar, Iterable
SubVmobjectType = TypeVar('SubVmobjectType', bound='VMobject') SubVmobjectType = TypeVar('SubVmobjectType', bound='VMobject')
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Iterable, Tuple, Any from typing import Callable, Tuple, Any
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array, Self from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array, Self
from moderngl.context import Context from moderngl.context import Context
@ -1417,12 +1418,14 @@ class VMobject(Mobject):
return [sw for sw in shader_wrappers if len(sw.vert_data) > 0] return [sw for sw in shader_wrappers if len(sw.vert_data) > 0]
class VGroup(VMobject, Generic[SubVmobjectType]): class VGroup(Group, VMobject, Generic[SubVmobjectType]):
def __init__(self, *vmobjects: SubVmobjectType, **kwargs): def __init__(self, *vmobjects: SubVmobjectType | Iterable[SubVmobjectType], **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.add(*vmobjects) if any(isinstance(vmob, Mobject) and not isinstance(vmob, VMobject) for vmob in vmobjects):
if vmobjects: raise Exception("Only VMobjects can be passed into VGroup")
self.uniforms.update(vmobjects[0].uniforms) self._ingest_args(*vmobjects)
if self.submobjects:
self.uniforms.update(self.submobjects[0].uniforms)
def __add__(self, other: VMobject) -> Self: def __add__(self, other: VMobject) -> Self:
assert isinstance(other, VMobject) assert isinstance(other, VMobject)