From 6176bcd45a9d1fc06d4f9478d9bb9fe91f699f2c Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Tue, 20 Dec 2022 12:23:19 -0800 Subject: [PATCH] Add option for StringMobject to only render one svg --- manimlib/mobject/svg/string_mobject.py | 108 +++++++++++++++---------- 1 file changed, 67 insertions(+), 41 deletions(-) diff --git a/manimlib/mobject/svg/string_mobject.py b/manimlib/mobject/svg/string_mobject.py index b97da8df..adc160a6 100644 --- a/manimlib/mobject/svg/string_mobject.py +++ b/manimlib/mobject/svg/string_mobject.py @@ -9,6 +9,7 @@ from scipy.spatial.distance import cdist from manimlib.constants import WHITE from manimlib.logger import log from manimlib.mobject.svg.svg_mobject import SVGMobject +from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.color import color_to_hex from manimlib.utils.color import hex_to_int @@ -55,12 +56,17 @@ class StringMobject(SVGMobject, ABC): base_color: ManimColor = WHITE, isolate: Selector = (), protect: Selector = (), + # When set to true, only the labelled svg is + # rendered, and its contents are used directly + # for the body of this String Mobject + use_labelled_svg: bool = False, **kwargs ): self.string = string self.base_color = base_color or WHITE self.isolate = isolate self.protect = protect + self.use_labelled_svg = use_labelled_svg self.parse() super().__init__( @@ -72,47 +78,34 @@ class StringMobject(SVGMobject, ABC): ) self.labels = [submob.label for submob in self.submobjects] - def get_file_path(self) -> str: - original_content = self.get_content(is_labelled=False) - return self.get_file_path_by_content(original_content) + def get_file_path(self, is_labelled: bool = False) -> str: + is_labelled = is_labelled or self.use_labelled_svg + return self.get_file_path_by_content(self.get_content(is_labelled)) @abstractmethod def get_file_path_by_content(self, content: str) -> str: return "" - def generate_mobject(self) -> None: - super().generate_mobject() - + def assign_labels_by_color(self, mobjects: list[VMobject]) -> None: + """ + Assuming each mobject in the list `mobjects` has a fill color + meant to represent a numerical label, this assigns those + those numerical labels to each mobject as an attribute + """ labels_count = len(self.labelled_spans) if labels_count == 1: - for submob in self.submobjects: - submob.label = 0 + for mob in mobjects: + mob.label = 0 return - labelled_content = self.get_content(is_labelled=True) - file_path = self.get_file_path_by_content(labelled_content) - labelled_svg = SVGMobject(file_path) - if len(self.submobjects) != len(labelled_svg.submobjects): - log.warning( - "Cannot align submobjects of the labelled svg " + \ - "to the original svg. Skip the labelling process." - ) - for submob in self.submobjects: - submob.label = 0 - return - - self.rearrange_submobjects_by_positions(labelled_svg) unrecognizable_colors = [] - for submob, labelled_svg_submob in zip( - self.submobjects, labelled_svg.submobjects - ): - label = hex_to_int(color_to_hex( - labelled_svg_submob.get_fill_color() - )) + for mob in mobjects: + label = hex_to_int(color_to_hex(mob.get_fill_color())) if label >= labels_count: unrecognizable_colors.append(label) label = 0 - submob.label = label + mob.label = label + if unrecognizable_colors: log.warning( "Unrecognizable color labels detected (%s). " + \ @@ -123,26 +116,59 @@ class StringMobject(SVGMobject, ABC): ) ) + def mobjects_from_file(self, file_path: str) -> list[VMobject]: + submobs = super().mobjects_from_file(file_path) + + if self.use_labelled_svg: + # This means submobjects are colored according to spans + self.assign_labels_by_color(submobs) + return submobs + + # Otherwise, submobs are not colored, so generate a new list + # of submobject which are and use those for labels + unlabelled_submobs = submobs + labelled_content = self.get_content(is_labelled=True) + labelled_file = self.get_file_path_by_content(labelled_content) + labelled_submobs = super().mobjects_from_file(labelled_file) + self.labelled_submobs = labelled_submobs + self.unlabelled_submobs = unlabelled_submobs + + self.assign_labels_by_color(labelled_submobs) + self.rearrange_submobjects_by_positions(labelled_submobs, unlabelled_submobs) + for usm, lsm in zip(unlabelled_submobs, labelled_submobs): + usm.label = lsm.label + + if len(unlabelled_submobs) != len(labelled_submobs): + log.warning( + "Cannot align submobjects of the labelled svg " + \ + "to the original svg. Skip the labelling process." + ) + for usm in unlabelled_submobs: + usm.label = 0 + return unlabelled_submobs + + return unlabelled_submobs + def rearrange_submobjects_by_positions( - self, labelled_svg: SVGMobject + self, labelled_submobs: list[VMobject], unlabelled_submobs: list[VMobject], ) -> None: - # Rearrange submobjects of `labelled_svg` so that - # each submobject is labelled by the nearest one of `labelled_svg`. - # The correctness cannot be ensured, since the svg may - # change significantly after inserting color commands. - if not labelled_svg.submobjects: + """ + Rearrange `labeleled_submobjects` so that each submobject + is labelled by the nearest one of `unlabelled_submobs`. + The correctness cannot be ensured, since the svg may + change significantly after inserting color commands. + """ + if len(labelled_submobs) == 0: return - labelled_svg.replace(self) + labelled_svg = VGroup(*labelled_submobs) + labelled_svg.replace(VGroup(*unlabelled_submobs)) distance_matrix = cdist( - [submob.get_center() for submob in self.submobjects], - [submob.get_center() for submob in labelled_svg.submobjects] + [submob.get_center() for submob in unlabelled_submobs], + [submob.get_center() for submob in labelled_submobs] ) _, indices = linear_sum_assignment(distance_matrix) - labelled_svg.set_submobjects([ - labelled_svg.submobjects[index] - for index in indices - ]) + labelled_submobs[:] = [labelled_submobs[index] for index in indices] # Toolkits