diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 269905d5..4ef6b762 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -939,7 +939,8 @@ class Mobject(Container): The simplest mobject to be transformed to or from self. Should by a point of the appropriate type """ - raise Exception("Not implemented") + message = "get_point_mobject not implemented for {}" + raise Exception(message.format(self.__class__.__name__)) def align_points(self, mobject): count1 = self.get_num_points() @@ -954,23 +955,12 @@ class Mobject(Container): raise Exception("Not implemented") def align_submobjects(self, mobject): - # If one is empty, and the other is not, - # push it into its submobject list - self_has_points, mob_has_points = [ - mob.get_num_points() > 0 - for mob in (self, mobject) - ] - if self_has_points and not mob_has_points: - mobject.null_point_align(self) - elif mob_has_points and not self_has_points: - self.null_point_align(mobject) - self_count = len(self.submobjects) - mob_count = len(mobject.submobjects) - diff = self_count - mob_count - if diff < 0: - self.add_n_more_submobjects(-diff) - elif diff > 0: - mobject.add_n_more_submobjects(diff) + mob1 = self + mob2 = mobject + n1 = len(mob1.submobjects) + n2 = len(mob2.submobjects) + mob1.add_n_more_submobjects(max(0, n2 - n1)) + mob2.add_n_more_submobjects(max(0, n1 - n2)) return self def null_point_align(self, mobject): @@ -992,19 +982,31 @@ class Mobject(Container): return self def add_n_more_submobjects(self, n): + if n == 0: + return + curr = len(self.submobjects) - if n > 0 and curr == 0: - self.add(self.copy()) - n -= 1 - curr += 1 - indices = curr * np.arange(curr + n) // (curr + n) - new_submobjects = [] - for index in indices: - submob = self.submobjects[index] - if submob in new_submobjects: - submob = self.repeat_submobject(submob) - new_submobjects.append(submob) - self.submobjects = new_submobjects + if curr == 0: + # If empty, simply add n point mobjects + self.submobjects = [ + self.get_point_mobject() + for k in range(n) + ] + return + + target = curr + n + # TODO, factor this out to utils + repeat_indices = (np.arange(target) * curr) // target + split_factors = [ + sum(repeat_indices == i) + for i in range(curr) + ] + new_submobs = [] + for submob, sf in zip(self.submobjects, split_factors): + new_submobs.append(submob) + for k in range(1, sf): + new_submobs.append(submob.get_point_mobject()) + self.submobjects = new_submobs return self def repeat_submobject(self, submob):