mirror of
https://github.com/3b1b/manim.git
synced 2025-08-05 16:49:03 +00:00
Allow VGroup and Group to accept generators and iterables as arguments
This commit is contained in:
parent
4d67361800
commit
1372cf101c
2 changed files with 24 additions and 13 deletions
|
@ -46,12 +46,12 @@ from manimlib.utils.space_ops import get_norm
|
||||||
from manimlib.utils.space_ops import rotation_matrix_transpose
|
from manimlib.utils.space_ops import rotation_matrix_transpose
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from typing import TypeVar, Generic
|
from typing import TypeVar, Generic, Iterable
|
||||||
SubmobjectType = TypeVar('SubmobjectType', bound='Mobject')
|
SubmobjectType = TypeVar('SubmobjectType', bound='Mobject')
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import Callable, Iterable, Iterator, Union, Tuple, Optional
|
from typing import Callable, Iterator, Union, Tuple, Optional
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, UniformDict, Self
|
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, UniformDict, Self
|
||||||
from moderngl.context import Context
|
from moderngl.context import Context
|
||||||
|
@ -2136,11 +2136,19 @@ class Mobject(object):
|
||||||
|
|
||||||
|
|
||||||
class Group(Mobject, Generic[SubmobjectType]):
|
class Group(Mobject, Generic[SubmobjectType]):
|
||||||
def __init__(self, *mobjects: SubmobjectType, **kwargs):
|
def __init__(self, *mobjects: SubmobjectType | Iterable[SubmobjectType], **kwargs):
|
||||||
if not all([isinstance(m, Mobject) for m in mobjects]):
|
|
||||||
raise Exception("All submobjects must be of type Mobject")
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.add(*mobjects)
|
self._ingest_args(*mobjects)
|
||||||
|
|
||||||
|
def _ingest_args(self, *args: Mobject | Iterable[Mobject]):
|
||||||
|
if len(args) == 0:
|
||||||
|
return
|
||||||
|
if all(isinstance(mob, Mobject) for mob in args):
|
||||||
|
self.add(*args)
|
||||||
|
elif isinstance(args[0], Iterable):
|
||||||
|
self.add(*args[0])
|
||||||
|
else:
|
||||||
|
raise Exception(f"Invalid argument to Group of type {type(args[0])}")
|
||||||
|
|
||||||
def __add__(self, other: Mobject | Group) -> Self:
|
def __add__(self, other: Mobject | Group) -> Self:
|
||||||
assert isinstance(other, Mobject)
|
assert isinstance(other, Mobject)
|
||||||
|
|
|
@ -13,6 +13,7 @@ from manimlib.constants import JOINT_TYPE_MAP
|
||||||
from manimlib.constants import ORIGIN, OUT
|
from manimlib.constants import ORIGIN, OUT
|
||||||
from manimlib.constants import TAU
|
from manimlib.constants import TAU
|
||||||
from manimlib.mobject.mobject import Mobject
|
from manimlib.mobject.mobject import Mobject
|
||||||
|
from manimlib.mobject.mobject import Group
|
||||||
from manimlib.mobject.mobject import Point
|
from manimlib.mobject.mobject import Point
|
||||||
from manimlib.utils.bezier import bezier
|
from manimlib.utils.bezier import bezier
|
||||||
from manimlib.utils.bezier import get_quadratic_approximation_of_cubic
|
from manimlib.utils.bezier import get_quadratic_approximation_of_cubic
|
||||||
|
@ -47,11 +48,11 @@ from manimlib.shader_wrapper import ShaderWrapper
|
||||||
from manimlib.shader_wrapper import FillShaderWrapper
|
from manimlib.shader_wrapper import FillShaderWrapper
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from typing import Generic, TypeVar
|
from typing import Generic, TypeVar, Iterable
|
||||||
SubVmobjectType = TypeVar('SubVmobjectType', bound='VMobject')
|
SubVmobjectType = TypeVar('SubVmobjectType', bound='VMobject')
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import Callable, Iterable, Tuple, Any
|
from typing import Callable, Tuple, Any
|
||||||
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array, Self
|
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array, Self
|
||||||
from moderngl.context import Context
|
from moderngl.context import Context
|
||||||
|
|
||||||
|
@ -1417,12 +1418,14 @@ class VMobject(Mobject):
|
||||||
return [sw for sw in shader_wrappers if len(sw.vert_data) > 0]
|
return [sw for sw in shader_wrappers if len(sw.vert_data) > 0]
|
||||||
|
|
||||||
|
|
||||||
class VGroup(VMobject, Generic[SubVmobjectType]):
|
class VGroup(Group, VMobject, Generic[SubVmobjectType]):
|
||||||
def __init__(self, *vmobjects: SubVmobjectType, **kwargs):
|
def __init__(self, *vmobjects: SubVmobjectType | Iterable[SubVmobjectType], **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.add(*vmobjects)
|
if any(isinstance(vmob, Mobject) and not isinstance(vmob, VMobject) for vmob in vmobjects):
|
||||||
if vmobjects:
|
raise Exception("Only VMobjects can be passed into VGroup")
|
||||||
self.uniforms.update(vmobjects[0].uniforms)
|
self._ingest_args(*vmobjects)
|
||||||
|
if self.submobjects:
|
||||||
|
self.uniforms.update(self.submobjects[0].uniforms)
|
||||||
|
|
||||||
def __add__(self, other: VMobject) -> Self:
|
def __add__(self, other: VMobject) -> Self:
|
||||||
assert isinstance(other, VMobject)
|
assert isinstance(other, VMobject)
|
||||||
|
|
Loading…
Add table
Reference in a new issue