Allow VGroup and Group to accept generators and iterables as arguments

This commit is contained in:
Grant Sanderson 2024-03-07 09:23:02 -03:00
parent 4d67361800
commit 1372cf101c
2 changed files with 24 additions and 13 deletions

View file

@ -46,12 +46,12 @@ from manimlib.utils.space_ops import get_norm
from manimlib.utils.space_ops import rotation_matrix_transpose
from typing import TYPE_CHECKING
from typing import TypeVar, Generic
from typing import TypeVar, Generic, Iterable
SubmobjectType = TypeVar('SubmobjectType', bound='Mobject')
if TYPE_CHECKING:
from typing import Callable, Iterable, Iterator, Union, Tuple, Optional
from typing import Callable, Iterator, Union, Tuple, Optional
import numpy.typing as npt
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, UniformDict, Self
from moderngl.context import Context
@ -2136,11 +2136,19 @@ class Mobject(object):
class Group(Mobject, Generic[SubmobjectType]):
def __init__(self, *mobjects: SubmobjectType, **kwargs):
if not all([isinstance(m, Mobject) for m in mobjects]):
raise Exception("All submobjects must be of type Mobject")
def __init__(self, *mobjects: SubmobjectType | Iterable[SubmobjectType], **kwargs):
super().__init__(**kwargs)
self.add(*mobjects)
self._ingest_args(*mobjects)
def _ingest_args(self, *args: Mobject | Iterable[Mobject]):
if len(args) == 0:
return
if all(isinstance(mob, Mobject) for mob in args):
self.add(*args)
elif isinstance(args[0], Iterable):
self.add(*args[0])
else:
raise Exception(f"Invalid argument to Group of type {type(args[0])}")
def __add__(self, other: Mobject | Group) -> Self:
assert isinstance(other, Mobject)

View file

@ -13,6 +13,7 @@ from manimlib.constants import JOINT_TYPE_MAP
from manimlib.constants import ORIGIN, OUT
from manimlib.constants import TAU
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Group
from manimlib.mobject.mobject import Point
from manimlib.utils.bezier import bezier
from manimlib.utils.bezier import get_quadratic_approximation_of_cubic
@ -47,11 +48,11 @@ from manimlib.shader_wrapper import ShaderWrapper
from manimlib.shader_wrapper import FillShaderWrapper
from typing import TYPE_CHECKING
from typing import Generic, TypeVar
from typing import Generic, TypeVar, Iterable
SubVmobjectType = TypeVar('SubVmobjectType', bound='VMobject')
if TYPE_CHECKING:
from typing import Callable, Iterable, Tuple, Any
from typing import Callable, Tuple, Any
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array, Self
from moderngl.context import Context
@ -1417,12 +1418,14 @@ class VMobject(Mobject):
return [sw for sw in shader_wrappers if len(sw.vert_data) > 0]
class VGroup(VMobject, Generic[SubVmobjectType]):
def __init__(self, *vmobjects: SubVmobjectType, **kwargs):
class VGroup(Group, VMobject, Generic[SubVmobjectType]):
def __init__(self, *vmobjects: SubVmobjectType | Iterable[SubVmobjectType], **kwargs):
super().__init__(**kwargs)
self.add(*vmobjects)
if vmobjects:
self.uniforms.update(vmobjects[0].uniforms)
if any(isinstance(vmob, Mobject) and not isinstance(vmob, VMobject) for vmob in vmobjects):
raise Exception("Only VMobjects can be passed into VGroup")
self._ingest_args(*vmobjects)
if self.submobjects:
self.uniforms.update(self.submobjects[0].uniforms)
def __add__(self, other: VMobject) -> Self:
assert isinstance(other, VMobject)