Add a few type hints to specify VMobject family always consists of VMobjects

This commit is contained in:
Grant Sanderson 2022-12-17 17:03:34 -08:00
parent 1f0427d685
commit 0e558db122
2 changed files with 10 additions and 4 deletions

View file

@ -314,7 +314,7 @@ class Mobject(object):
# Family matters
def __getitem__(self, value):
def __getitem__(self, value: int | slice) -> Mobject:
if isinstance(value, slice):
GroupClass = self.get_group_class()
return GroupClass(*self.split().__getitem__(value))
@ -326,7 +326,7 @@ class Mobject(object):
def __len__(self):
return len(self.split())
def split(self):
def split(self) -> list[Mobject]:
return self.submobjects
def assemble_family(self):

View file

@ -130,6 +130,14 @@ class VMobject(Mobject):
def get_grid(self, *args, **kwargs) -> VGroup:
return super().get_grid(*args, **kwargs)
def __getitem__(self, value: int | slice) -> VMobject:
return super().__getitem__(value)
def add(self, *vmobjects: VMobject):
if not all((isinstance(m, VMobject) for m in vmobjects)):
raise Exception("All submobjects must be of type VMobject")
super().add(*vmobjects)
# Colors
def init_colors(self):
self.set_fill(
@ -1160,8 +1168,6 @@ class VMobject(Mobject):
class VGroup(VMobject):
def __init__(self, *vmobjects: VMobject, **kwargs):
if not all([isinstance(m, VMobject) for m in vmobjects]):
raise Exception("All submobjects must be of type VMobject")
super().__init__(**kwargs)
self.add(*vmobjects)