From 6ad8636fab8767ac3111c83a5bcbc8535412d42a Mon Sep 17 00:00:00 2001 From: YishiMichael <50232075+YishiMichael@users.noreply.github.com> Date: Wed, 23 Mar 2022 14:17:34 +0800 Subject: [PATCH] Adjust some typings (#1765) * Adjust some typings * Adjust typings --- docs/source/documentation/constants.rst | 2 - manimlib/constants.py | 2 - manimlib/mobject/svg/mtex_mobject.py | 51 ++++++++++++++----------- manimlib/mobject/svg/svg_mobject.py | 26 +------------ manimlib/mobject/svg/tex_mobject.py | 2 +- manimlib/mobject/svg/text_mobject.py | 47 ++++++++++++----------- 6 files changed, 54 insertions(+), 76 deletions(-) diff --git a/docs/source/documentation/constants.rst b/docs/source/documentation/constants.rst index cbf96bae..7dabd50e 100644 --- a/docs/source/documentation/constants.rst +++ b/docs/source/documentation/constants.rst @@ -84,8 +84,6 @@ Text .. code-block:: python - START_X = 30 - START_Y = 20 NORMAL = "NORMAL" ITALIC = "ITALIC" OBLIQUE = "OBLIQUE" diff --git a/manimlib/constants.py b/manimlib/constants.py index 6c82bba2..590a9cda 100644 --- a/manimlib/constants.py +++ b/manimlib/constants.py @@ -64,8 +64,6 @@ JOINT_TYPE_MAP = { } # Related to Text -START_X = 30 -START_Y = 20 NORMAL = "NORMAL" ITALIC = "ITALIC" OBLIQUE = "OBLIQUE" diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index d3617d0b..73303eab 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -18,7 +18,12 @@ from manimlib.utils.tex_file_writing import get_tex_config from manimlib.utils.tex_file_writing import display_during_execution from manimlib.logger import log -ManimColor = Union[str, colour.Color, Sequence[float]] + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from manimlib.mobject.types.vectorized_mobject import VMobject + ManimColor = Union[str, colour.Color, Sequence[float]] SCALE_FACTOR_PER_FONT_POINT = 0.001 @@ -29,7 +34,7 @@ def _get_neighbouring_pairs(iterable: Iterable) -> list: class _TexParser(object): - def __init__(self, tex_string: str, additional_substrings: str): + def __init__(self, tex_string: str, additional_substrings: list[str]): self.tex_string = tex_string self.whitespace_indices = self.get_whitespace_indices() self.backslash_indices = self.get_backslash_indices() @@ -173,7 +178,7 @@ class _TexParser(object): def break_up_by_additional_substrings( self, - additional_substrings: Iterable[str] + additional_substrings: list[str] ) -> None: stripped_substrings = sorted(remove_list_redundancies([ string.strip() @@ -257,7 +262,7 @@ class _TexParser(object): "}" ]) - def get_sorted_submob_indices(self, submob_labels: Iterable[int]) -> list[int]: + def get_sorted_submob_indices(self, submob_labels: list[int]) -> list[int]: def script_span_to_submob_range(script_span): tex_span = self.script_span_to_tex_span_dict[script_span] submob_indices = [ @@ -291,7 +296,7 @@ class _TexParser(object): ] return result - def get_submob_tex_strings(self, submob_labels: Iterable[int]) -> list[str]: + def get_submob_tex_strings(self, submob_labels: list[int]) -> list[str]: ordered_tex_spans = [ self.tex_span_list[label] for label in submob_labels ] @@ -385,7 +390,7 @@ class _TexParser(object): def get_containing_labels_by_tex_spans( self, - tex_spans: Iterable[tuple[int, int]] + tex_spans: list[tuple[int, int]] ) -> list[int]: return remove_list_redundancies(list(it.chain(*[ self.containing_labels_dict[tex_span] @@ -442,9 +447,7 @@ class MTex(_TexSVG): self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) @property - def hash_seed( - self - ) -> tuple[str, dict[str], dict[str, bool], str, list[str], str, str, bool]: + def hash_seed(self) -> tuple: return ( self.__class__.__name__, self.svg_default, @@ -457,9 +460,9 @@ class MTex(_TexSVG): ) def get_file_path(self) -> str: - return self._get_file_path(self.use_plain_tex) + return self.get_file_path_(use_plain_tex=self.use_plain_tex) - def _get_file_path(self, use_plain_tex: bool) -> str: + def get_file_path_(self, use_plain_tex: bool) -> str: if use_plain_tex: tex_string = self.tex_string else: @@ -496,15 +499,17 @@ class MTex(_TexSVG): if not self.use_plain_tex: labelled_svg_glyphs = self else: - file_path = self._get_file_path(use_plain_tex=False) + file_path = self.get_file_path_(use_plain_tex=False) labelled_svg_glyphs = _TexSVG(file_path) glyph_labels = [ self.color_to_label(labelled_glyph.get_fill_color()) for labelled_glyph in labelled_svg_glyphs ] - mob = self.build_mobject(self, glyph_labels) - self.set_submobjects(mob.submobjects) + rearranged_submobs = self.rearrange_submobjects( + self.submobjects, glyph_labels + ) + self.set_submobjects(rearranged_submobs) @staticmethod def color_to_label(color: ManimColor) -> int: @@ -512,13 +517,13 @@ class MTex(_TexSVG): rg = r * 256 + g return rg * 256 + b - def build_mobject( + def rearrange_submobjects( self, - svg_glyphs: _TexSVG | None, - glyph_labels: Iterable[int] - ) -> VGroup: + svg_glyphs: list[VMobject], + glyph_labels: list[int] + ) -> list[VMobject]: if not svg_glyphs: - return VGroup() + return [] # Simply pack together adjacent mobjects with the same label. submobjects = [] @@ -552,11 +557,11 @@ class MTex(_TexSVG): submob.tex_string = submob_tex # Support `get_tex()` method here. submob.get_tex = MethodType(lambda inst: inst.tex_string, submob) - return VGroup(*rearranged_submobjects) + return rearranged_submobjects def get_part_by_tex_spans( self, - tex_spans: Iterable[tuple[int, int]] + tex_spans: list[tuple[int, int]] ) -> VGroup: labels = self.parser.get_containing_labels_by_tex_spans(tex_spans) return VGroup(*filter( @@ -581,7 +586,7 @@ class MTex(_TexSVG): ) ]) - def get_part_by_tex(self, tex: str, index: int = 0) -> VGroup: + def get_part_by_tex(self, tex: str, index: int = 0) -> VMobject: all_parts = self.get_parts_by_tex(tex) return all_parts[index] @@ -597,7 +602,7 @@ class MTex(_TexSVG): self.set_color_by_tex(tex, color) return self - def indices_of_part(self, part: Iterable[VGroup]) -> list[int]: + def indices_of_part(self, part: Iterable[VMobject]) -> list[int]: indices = [ index for index, submob in enumerate(self.submobjects) if submob in part diff --git a/manimlib/mobject/svg/svg_mobject.py b/manimlib/mobject/svg/svg_mobject.py index 86e77f00..1ba0923b 100644 --- a/manimlib/mobject/svg/svg_mobject.py +++ b/manimlib/mobject/svg/svg_mobject.py @@ -76,7 +76,7 @@ class SVGMobject(VMobject): SVG_HASH_TO_MOB_MAP[hash_val] = self.copy() @property - def hash_seed(self) -> tuple[str, dict[str], dict[str, bool], str]: + def hash_seed(self) -> tuple: # Returns data which can uniquely represent the result of `init_points`. # The hashed value of it is stored as a key in `SVG_HASH_TO_MOB_MAP`. return ( @@ -189,30 +189,6 @@ class SVGMobject(VMobject): mob.shift(vec) return mob - def get_mobject_from(self, shape: se.GraphicObject) -> VMobject | None: - shape_class_to_func_map: dict[ - type, Callable[[se.GraphicObject], VMobject] - ] = { - se.Path: self.path_to_mobject, - se.SimpleLine: self.line_to_mobject, - se.Rect: self.rect_to_mobject, - se.Circle: self.circle_to_mobject, - se.Ellipse: self.ellipse_to_mobject, - se.Polygon: self.polygon_to_mobject, - se.Polyline: self.polyline_to_mobject, - # se.Text: self.text_to_mobject, # TODO - } - for shape_class, func in shape_class_to_func_map.items(): - if isinstance(shape, shape_class): - mob = func(shape) - self.apply_style_to_mobject(mob, shape) - return mob - - shape_class_name = shape.__class__.__name__ - if shape_class_name != "SVGElement": - log.warning(f"Unsupported element type: {shape_class_name}") - return None - @staticmethod def apply_style_to_mobject( mob: VMobject, diff --git a/manimlib/mobject/svg/tex_mobject.py b/manimlib/mobject/svg/tex_mobject.py index a7627783..717f1c24 100644 --- a/manimlib/mobject/svg/tex_mobject.py +++ b/manimlib/mobject/svg/tex_mobject.py @@ -50,7 +50,7 @@ class SingleStringTex(SVGMobject): self.organize_submobjects_left_to_right() @property - def hash_seed(self) -> tuple[str, dict[str], dict[str, bool], str, str, bool]: + def hash_seed(self) -> tuple: return ( self.__class__.__name__, self.svg_default, diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index c6abe3c6..a13d1d80 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -12,7 +12,6 @@ from typing import Iterable, Sequence, Union import pygments import pygments.formatters import pygments.lexers -import pygments.styles from manimpango import MarkupUtils @@ -20,6 +19,7 @@ from manimlib.logger import log from manimlib.constants import * from manimlib.mobject.geometry import Dot from manimlib.mobject.svg.svg_mobject import SVGMobject +from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.customization import get_customization from manimlib.utils.tex_file_writing import tex_hash from manimlib.utils.config_ops import digest_config @@ -30,9 +30,7 @@ from manimlib.utils.directories import get_text_dir from typing import TYPE_CHECKING if TYPE_CHECKING: - import colour from manimlib.mobject.types.vectorized_mobject import VMobject - ManimColor = Union[str, colour.Color, Sequence[float]] TEXT_MOB_SCALE_FACTOR = 0.0076 DEFAULT_LINE_SPACING_SCALE = 0.6 @@ -199,14 +197,14 @@ class _TextParser(object): result.append((text_piece, attr_dict)) return result - def get_markup_str_with_attrs(self): + def get_markup_str_with_attrs(self) -> str: return "".join([ f"{text_piece}" for text_piece, attr_dict in self.get_text_pieces() ]) @staticmethod - def get_attr_dict_str(attr_dict: dict[str, str]): + def get_attr_dict_str(attr_dict: dict[str, str]) -> str: return " ".join([ f"{key}='{value}'" for key, value in attr_dict.items() @@ -215,9 +213,14 @@ class _TextParser(object): # Temporary handler class _Alignment: - VAL_LIST = ["LEFT", "CENTER", "RIGHT"] - def __init__(self, s): - self.value = _Alignment.VAL_LIST.index(s.upper()) + VAL_DICT = { + "LEFT": 0, + "CENTER": 1, + "RIGHT": 2 + } + + def __init__(self, s: str): + self.value = _Alignment.VAL_DICT[s.upper()] class Text(SVGMobject): @@ -251,7 +254,7 @@ class Text(SVGMobject): "local_configs": {}, } - def __init__(self, text, **kwargs): + def __init__(self, text: str, **kwargs): self.full2short(kwargs) digest_config(self, kwargs) validate_error = MarkupUtils.validate(text) @@ -267,7 +270,7 @@ class Text(SVGMobject): self.scale(TEXT_MOB_SCALE_FACTOR) @property - def hash_seed(self): + def hash_seed(self) -> tuple: return ( self.__class__.__name__, self.svg_default, @@ -293,7 +296,7 @@ class Text(SVGMobject): self.local_configs ) - def get_file_path(self): + def get_file_path(self) -> str: full_markup = self.get_full_markup_str() svg_file = os.path.join( get_text_dir(), tex_hash(full_markup) + ".svg" @@ -302,7 +305,7 @@ class Text(SVGMobject): self.markup_to_svg(full_markup, svg_file) return svg_file - def get_full_markup_str(self): + def get_full_markup_str(self) -> str: if self.t2g: log.warning( "Manim currently cannot parse gradient from svg. " @@ -346,7 +349,7 @@ class Text(SVGMobject): return self.parser.get_markup_str_with_attrs() - def find_indexes(self, word_or_text_span): + def find_indexes(self, word_or_text_span: str | tuple[int, int]) -> list[tuple[int, int]]: if isinstance(word_or_text_span, tuple): return [word_or_text_span] @@ -355,7 +358,7 @@ class Text(SVGMobject): for match_obj in re.finditer(re.escape(word_or_text_span), self.text) ] - def markup_to_svg(self, markup_str, file_name): + def markup_to_svg(self, markup_str: str, file_name: str) -> str: # `manimpango` is under construction, # so the following code is intended to suit its interface alignment = _Alignment(self.alignment) @@ -384,7 +387,7 @@ class Text(SVGMobject): pango_width=pango_width ) - def generate_mobject(self): + def generate_mobject(self) -> None: super().generate_mobject() # Remove empty paths @@ -395,15 +398,14 @@ class Text(SVGMobject): content_str = self.parser.remove_tags(self.text) if self.is_markup: content_str = saxutils.unescape(content_str) - for char_index, char in enumerate(content_str): - if not re.match(r"\s", char): - continue + for match_obj in re.finditer(r"\s", content_str): + char_index = match_obj.start() space = Dot(radius=0, fill_opacity=0, stroke_opacity=0) space.move_to(submobjects[max(char_index - 1, 0)].get_center()) submobjects.insert(char_index, space) self.set_submobjects(submobjects) - def full2short(self, config): + def full2short(self, config: dict) -> None: conversion_dict = { "line_spacing_height": "lsh", "text2color": "t2c", @@ -417,7 +419,7 @@ class Text(SVGMobject): if long_name in kwargs: kwargs[short_name] = kwargs.pop(long_name) - def get_parts_by_text(self, word): + def get_parts_by_text(self, word: str) -> VGroup: if self.is_markup: log.warning( "Slicing MarkupText via `get_parts_by_text`, " @@ -433,7 +435,7 @@ class Text(SVGMobject): for i, j in self.find_indexes(word) )) - def get_part_by_text(self, word): + def get_part_by_text(self, word: str) -> VMobject | None: parts = self.get_parts_by_text(word) return parts[0] if parts else None @@ -445,7 +447,6 @@ class MarkupText(Text): } - class Code(MarkupText): CONFIG = { "font": "Consolas", @@ -456,7 +457,7 @@ class Code(MarkupText): "code_style": "monokai", } - def __init__(self, code, **kwargs): + def __init__(self, code: str, **kwargs): digest_config(self, kwargs) self.code = code lexer = pygments.lexers.get_lexer_by_name(self.language)