From eec6b01a72f0a9ec4e4960c259b59e3b191b7640 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Thu, 14 Apr 2022 21:07:31 +0800 Subject: [PATCH] Refactor labelled_string.py --- manimlib/mobject/svg/labelled_string.py | 100 ++++++++++-------------- manimlib/mobject/svg/mtex_mobject.py | 18 +---- manimlib/mobject/svg/text_mobject.py | 62 +++++++-------- 3 files changed, 71 insertions(+), 109 deletions(-) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 3f45c38a..765d96cb 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -4,10 +4,9 @@ from abc import ABC, abstractmethod import itertools as it import re -from manimlib.constants import BLACK, WHITE +from manimlib.constants import WHITE from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup -from manimlib.utils.color import color_to_int_rgb from manimlib.utils.color import color_to_rgb from manimlib.utils.color import rgb_to_hex from manimlib.utils.config_ops import digest_config @@ -25,7 +24,10 @@ if TYPE_CHECKING: Span = tuple[int, int] -class _StringSVG(SVGMobject): +class LabelledString(SVGMobject, ABC): + """ + An abstract base class for `MTex` and `MarkupText` + """ CONFIG = { "height": None, "stroke_width": 0, @@ -34,16 +36,6 @@ class _StringSVG(SVGMobject): "should_subdivide_sharp_curves": True, "should_remove_null_curves": True, }, - } - - -class LabelledString(_StringSVG, ABC): - """ - An abstract base class for `MTex` and `MarkupText` - """ - CONFIG = { - "base_color": WHITE, - "use_plain_file": False, "isolate": [], } @@ -51,14 +43,11 @@ class LabelledString(_StringSVG, ABC): self.string = string digest_config(self, kwargs) - # Convert `base_color` to hex code. - self.base_color = rgb_to_hex(color_to_rgb( - self.base_color \ - or self.svg_default.get("color", None) \ - or self.svg_default.get("fill_color", None) \ + self.base_color_int = self.color_to_int( + self.svg_default.get("fill_color") \ + or self.svg_default.get("color") \ or WHITE - )) - self.svg_default["fill_color"] = BLACK + ) self.pre_parse() self.parse() @@ -66,7 +55,7 @@ class LabelledString(_StringSVG, ABC): self.post_parse() def get_file_path(self) -> str: - return self.get_file_path_(use_plain_file=False) + return self.get_file_path_(use_plain_file=True) def get_file_path_(self, use_plain_file: bool) -> str: content = self.get_content(use_plain_file) @@ -79,22 +68,34 @@ class LabelledString(_StringSVG, ABC): def generate_mobject(self) -> None: super().generate_mobject() - submob_labels = [ - self.color_to_label(submob.get_fill_color()) - for submob in self.submobjects - ] - if self.use_plain_file or self.has_predefined_local_colors: - file_path = self.get_file_path_(use_plain_file=True) - plain_svg = _StringSVG( - file_path, - svg_default=self.svg_default, - path_string_config=self.path_string_config - ) - self.set_submobjects(plain_svg.submobjects) + if self.label_span_list: + file_path = self.get_file_path_(use_plain_file=False) + labelled_svg = SVGMobject(file_path) + submob_color_ints = [ + self.color_to_int(submob.get_fill_color()) + for submob in labelled_svg.submobjects + ] else: - self.set_fill(self.base_color) - for submob, label in zip(self.submobjects, submob_labels): - submob.label = label + submob_color_ints = [0] * len(self.submobjects) + + if len(self.submobjects) != len(submob_color_ints): + raise ValueError( + "Cannot align submobjects of the labelled svg " + "to the original svg" + ) + + unrecognized_color_ints = remove_list_redundancies(sorted(filter( + lambda color_int: color_int > len(self.label_span_list), + submob_color_ints + ))) + if unrecognized_color_ints: + raise ValueError( + "Unrecognized color label(s) detected: " + f"{','.join(map(self.int_to_hex, unrecognized_color_ints))}" + ) + + for submob, color_int in zip(self.submobjects, submob_color_ints): + submob.label = color_int - 1 def pre_parse(self) -> None: self.string_len = len(self.string) @@ -283,31 +284,14 @@ class LabelledString(_StringSVG, ABC): return index @staticmethod - def rgb_to_int(rgb_tuple: tuple[int, int, int]) -> int: - r, g, b = rgb_tuple - rg = r * 256 + g - return rg * 256 + b - - @staticmethod - def int_to_rgb(rgb_int: int) -> tuple[int, int, int]: - rg, b = divmod(rgb_int, 256) - r, g = divmod(rg, 256) - return r, g, b + def color_to_int(color: ManimColor) -> int: + hex_code = rgb_to_hex(color_to_rgb(color)) + return int(hex_code[1:], 16) @staticmethod def int_to_hex(rgb_int: int) -> str: return "#{:06x}".format(rgb_int).upper() - @staticmethod - def hex_to_int(rgb_hex: str) -> int: - return int(rgb_hex[1:], 16) - - @staticmethod - def color_to_label(color: ManimColor) -> int: - rgb_tuple = color_to_int_rgb(color) - rgb = LabelledString.rgb_to_int(rgb_tuple) - return rgb - 1 - # Parsing @abstractmethod @@ -387,10 +371,6 @@ class LabelledString(_StringSVG, ABC): def get_content(self, use_plain_file: bool) -> str: return "" - @abstractmethod - def has_predefined_local_colors(self) -> bool: - return False - # Post-parsing def get_labelled_submobjects(self) -> list[VMobject]: diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 70128b1f..8c7d4843 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -47,8 +47,6 @@ class MTex(LabelledString): self.__class__.__name__, self.svg_default, self.path_string_config, - self.base_color, - self.use_plain_file, self.isolate, self.tex_string, self.alignment, @@ -78,13 +76,9 @@ class MTex(LabelledString): @staticmethod def get_color_command_str(rgb_int: int) -> str: - rgb_tuple = MTex.int_to_rgb(rgb_int) - return "".join([ - "\\color[RGB]", - "{", - ",".join(map(str, rgb_tuple)), - "}" - ]) + rg, b = divmod(rgb_int, 256) + r, g = divmod(rg, 256) + return f"\\color[RGB]{{{r}, {g}, {b}}}" # Pre-parsing @@ -276,15 +270,11 @@ class MTex(LabelledString): result = "\n".join([self.alignment, result]) if use_plain_file: result = "\n".join([ - self.get_color_command_str(self.hex_to_int(self.base_color)), + self.get_color_command_str(self.base_color_int), result ]) return result - @property - def has_predefined_local_colors(self) -> bool: - return bool(self.command_repl_items) - # Post-parsing def get_cleaned_substr(self, span: Span) -> str: diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index d2bf0b6f..fdcdb5fe 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -27,7 +27,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from colour import Color - from typing import Any, Union + from typing import Union from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VGroup @@ -43,7 +43,7 @@ DEFAULT_LINE_SPACING_SCALE = 0.6 # See https://docs.gtk.org/Pango/pango_markup.html # A tag containing two aliases will cause warning, # so only use the first key of each group of aliases. -SPAN_ATTR_KEY_ALIAS_LIST = ( +MARKUP_KEY_ALIAS_LIST = ( ("font", "font_desc"), ("font_family", "face"), ("font_size", "size"), @@ -77,19 +77,14 @@ SPAN_ATTR_KEY_ALIAS_LIST = ( ("text_transform",), ("segment",), ) -COLOR_RELATED_KEYS = ( +MARKUP_COLOR_KEYS = ( "foreground", - "background", - "underline_color", - "overline_color", - "strikethrough_color" + "background", + "underline_color", + "overline_color", + "strikethrough_color" ) -SPAN_ATTR_KEY_CONVERSION = { - key: key_alias_list[0] - for key_alias_list in SPAN_ATTR_KEY_ALIAS_LIST - for key in key_alias_list -} -TAG_TO_ATTR_DICT = { +MARKUP_TAG_CONVERSION_DICT = { "b": {"font_weight": "bold"}, "big": {"font_size": "larger"}, "i": {"font_style": "italic"}, @@ -166,8 +161,6 @@ class MarkupText(LabelledString): self.__class__.__name__, self.svg_default, self.path_string_config, - self.base_color, - self.use_plain_file, self.isolate, self.text, self.is_markup, @@ -258,7 +251,7 @@ class MarkupText(LabelledString): @staticmethod def merge_attr_dicts( - attr_dict_items: list[Span, str, Any] + attr_dict_items: list[tuple[Span, dict[str, str]]] ) -> list[tuple[Span, dict[str, str]]]: index_seq = [0] attr_dict_list = [{}] @@ -344,12 +337,12 @@ class MarkupText(LabelledString): attr_pattern, begin_match_obj.group(3) ) } - elif tag_name in TAG_TO_ATTR_DICT.keys(): + elif tag_name in MARKUP_TAG_CONVERSION_DICT.keys(): if begin_match_obj.group(3): raise ValueError( f"Attributes shan't exist in tag '{tag_name}'" ) - attr_dict = TAG_TO_ATTR_DICT[tag_name].copy() + attr_dict = MARKUP_TAG_CONVERSION_DICT[tag_name].copy() else: raise ValueError(f"Unknown tag: '{tag_name}'") @@ -358,13 +351,13 @@ class MarkupText(LabelledString): ) return result - def get_global_dict_from_config(self) -> dict[str, Any]: + def get_global_dict_from_config(self) -> dict[str, str]: result = { - "line_height": ( + "line_height": str(( (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 - ) * 0.6, + ) * 0.6), "font_family": self.font, - "font_size": self.font_size * 1024, + "font_size": str(self.font_size * 1024), "font_style": self.slant, "font_weight": self.weight } @@ -382,7 +375,7 @@ class MarkupText(LabelledString): def get_local_dicts_from_config( self - ) -> list[Span, dict[str, Any]]: + ) -> list[Span, dict[str, str]]: return [ (span, {key: val}) for t2x_dict, key in ( @@ -405,9 +398,14 @@ class MarkupText(LabelledString): *self.local_dicts_from_markup, *self.local_dicts_from_config ] + key_conversion_dict = { + key: key_alias_list[0] + for key_alias_list in MARKUP_KEY_ALIAS_LIST + for key in key_alias_list + } return [ (span, { - SPAN_ATTR_KEY_CONVERSION[key.lower()]: str(val) + key_conversion_dict[key.lower()]: val for key, val in attr_dict.items() }) for span, attr_dict in attr_dict_items @@ -442,7 +440,7 @@ class MarkupText(LabelledString): return [] def get_internal_specified_spans(self) -> list[Span]: - return [span for span, _ in self.local_dicts_from_markup] + return [] def get_external_specified_spans(self) -> list[Span]: return [span for span, _ in self.local_dicts_from_config] @@ -468,7 +466,9 @@ class MarkupText(LabelledString): def get_content(self, use_plain_file: bool) -> str: if use_plain_file: attr_dict_items = [ - (self.full_span, {"foreground": self.base_color}), + (self.full_span, { + "foreground": self.int_to_hex(self.base_color_int) + }), *self.predefined_attr_dicts, *[ (span, {}) @@ -480,7 +480,7 @@ class MarkupText(LabelledString): (self.full_span, {"foreground": BLACK}), *[ (span, { - key: BLACK if key in COLOR_RELATED_KEYS else val + key: BLACK if key in MARKUP_COLOR_KEYS else val for key, val in attr_dict.items() }) for span, attr_dict in self.predefined_attr_dicts @@ -502,14 +502,6 @@ class MarkupText(LabelledString): ) return self.get_replaced_substr(self.full_span, span_repl_dict) - @property - def has_predefined_local_colors(self) -> bool: - return any([ - key in COLOR_RELATED_KEYS - for _, attr_dict in self.predefined_attr_dicts - for key in attr_dict.keys() - ]) - # Method alias def get_parts_by_text(self, text: str, **kwargs) -> VGroup: