Adjust some typings (#1765)

* Adjust some typings

* Adjust typings
This commit is contained in:
YishiMichael 2022-03-23 14:17:34 +08:00 committed by GitHub
parent aefde2969f
commit 6ad8636fab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 54 additions and 76 deletions

View file

@ -84,8 +84,6 @@ Text
.. code-block:: python
START_X = 30
START_Y = 20
NORMAL = "NORMAL"
ITALIC = "ITALIC"
OBLIQUE = "OBLIQUE"

View file

@ -64,8 +64,6 @@ JOINT_TYPE_MAP = {
}
# Related to Text
START_X = 30
START_Y = 20
NORMAL = "NORMAL"
ITALIC = "ITALIC"
OBLIQUE = "OBLIQUE"

View file

@ -18,7 +18,12 @@ from manimlib.utils.tex_file_writing import get_tex_config
from manimlib.utils.tex_file_writing import display_during_execution
from manimlib.logger import log
ManimColor = Union[str, colour.Color, Sequence[float]]
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.mobject.types.vectorized_mobject import VMobject
ManimColor = Union[str, colour.Color, Sequence[float]]
SCALE_FACTOR_PER_FONT_POINT = 0.001
@ -29,7 +34,7 @@ def _get_neighbouring_pairs(iterable: Iterable) -> list:
class _TexParser(object):
def __init__(self, tex_string: str, additional_substrings: str):
def __init__(self, tex_string: str, additional_substrings: list[str]):
self.tex_string = tex_string
self.whitespace_indices = self.get_whitespace_indices()
self.backslash_indices = self.get_backslash_indices()
@ -173,7 +178,7 @@ class _TexParser(object):
def break_up_by_additional_substrings(
self,
additional_substrings: Iterable[str]
additional_substrings: list[str]
) -> None:
stripped_substrings = sorted(remove_list_redundancies([
string.strip()
@ -257,7 +262,7 @@ class _TexParser(object):
"}"
])
def get_sorted_submob_indices(self, submob_labels: Iterable[int]) -> list[int]:
def get_sorted_submob_indices(self, submob_labels: list[int]) -> list[int]:
def script_span_to_submob_range(script_span):
tex_span = self.script_span_to_tex_span_dict[script_span]
submob_indices = [
@ -291,7 +296,7 @@ class _TexParser(object):
]
return result
def get_submob_tex_strings(self, submob_labels: Iterable[int]) -> list[str]:
def get_submob_tex_strings(self, submob_labels: list[int]) -> list[str]:
ordered_tex_spans = [
self.tex_span_list[label] for label in submob_labels
]
@ -385,7 +390,7 @@ class _TexParser(object):
def get_containing_labels_by_tex_spans(
self,
tex_spans: Iterable[tuple[int, int]]
tex_spans: list[tuple[int, int]]
) -> list[int]:
return remove_list_redundancies(list(it.chain(*[
self.containing_labels_dict[tex_span]
@ -442,9 +447,7 @@ class MTex(_TexSVG):
self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size)
@property
def hash_seed(
self
) -> tuple[str, dict[str], dict[str, bool], str, list[str], str, str, bool]:
def hash_seed(self) -> tuple:
return (
self.__class__.__name__,
self.svg_default,
@ -457,9 +460,9 @@ class MTex(_TexSVG):
)
def get_file_path(self) -> str:
return self._get_file_path(self.use_plain_tex)
return self.get_file_path_(use_plain_tex=self.use_plain_tex)
def _get_file_path(self, use_plain_tex: bool) -> str:
def get_file_path_(self, use_plain_tex: bool) -> str:
if use_plain_tex:
tex_string = self.tex_string
else:
@ -496,15 +499,17 @@ class MTex(_TexSVG):
if not self.use_plain_tex:
labelled_svg_glyphs = self
else:
file_path = self._get_file_path(use_plain_tex=False)
file_path = self.get_file_path_(use_plain_tex=False)
labelled_svg_glyphs = _TexSVG(file_path)
glyph_labels = [
self.color_to_label(labelled_glyph.get_fill_color())
for labelled_glyph in labelled_svg_glyphs
]
mob = self.build_mobject(self, glyph_labels)
self.set_submobjects(mob.submobjects)
rearranged_submobs = self.rearrange_submobjects(
self.submobjects, glyph_labels
)
self.set_submobjects(rearranged_submobs)
@staticmethod
def color_to_label(color: ManimColor) -> int:
@ -512,13 +517,13 @@ class MTex(_TexSVG):
rg = r * 256 + g
return rg * 256 + b
def build_mobject(
def rearrange_submobjects(
self,
svg_glyphs: _TexSVG | None,
glyph_labels: Iterable[int]
) -> VGroup:
svg_glyphs: list[VMobject],
glyph_labels: list[int]
) -> list[VMobject]:
if not svg_glyphs:
return VGroup()
return []
# Simply pack together adjacent mobjects with the same label.
submobjects = []
@ -552,11 +557,11 @@ class MTex(_TexSVG):
submob.tex_string = submob_tex
# Support `get_tex()` method here.
submob.get_tex = MethodType(lambda inst: inst.tex_string, submob)
return VGroup(*rearranged_submobjects)
return rearranged_submobjects
def get_part_by_tex_spans(
self,
tex_spans: Iterable[tuple[int, int]]
tex_spans: list[tuple[int, int]]
) -> VGroup:
labels = self.parser.get_containing_labels_by_tex_spans(tex_spans)
return VGroup(*filter(
@ -581,7 +586,7 @@ class MTex(_TexSVG):
)
])
def get_part_by_tex(self, tex: str, index: int = 0) -> VGroup:
def get_part_by_tex(self, tex: str, index: int = 0) -> VMobject:
all_parts = self.get_parts_by_tex(tex)
return all_parts[index]
@ -597,7 +602,7 @@ class MTex(_TexSVG):
self.set_color_by_tex(tex, color)
return self
def indices_of_part(self, part: Iterable[VGroup]) -> list[int]:
def indices_of_part(self, part: Iterable[VMobject]) -> list[int]:
indices = [
index for index, submob in enumerate(self.submobjects)
if submob in part

View file

@ -76,7 +76,7 @@ class SVGMobject(VMobject):
SVG_HASH_TO_MOB_MAP[hash_val] = self.copy()
@property
def hash_seed(self) -> tuple[str, dict[str], dict[str, bool], str]:
def hash_seed(self) -> tuple:
# Returns data which can uniquely represent the result of `init_points`.
# The hashed value of it is stored as a key in `SVG_HASH_TO_MOB_MAP`.
return (
@ -189,30 +189,6 @@ class SVGMobject(VMobject):
mob.shift(vec)
return mob
def get_mobject_from(self, shape: se.GraphicObject) -> VMobject | None:
shape_class_to_func_map: dict[
type, Callable[[se.GraphicObject], VMobject]
] = {
se.Path: self.path_to_mobject,
se.SimpleLine: self.line_to_mobject,
se.Rect: self.rect_to_mobject,
se.Circle: self.circle_to_mobject,
se.Ellipse: self.ellipse_to_mobject,
se.Polygon: self.polygon_to_mobject,
se.Polyline: self.polyline_to_mobject,
# se.Text: self.text_to_mobject, # TODO
}
for shape_class, func in shape_class_to_func_map.items():
if isinstance(shape, shape_class):
mob = func(shape)
self.apply_style_to_mobject(mob, shape)
return mob
shape_class_name = shape.__class__.__name__
if shape_class_name != "SVGElement":
log.warning(f"Unsupported element type: {shape_class_name}")
return None
@staticmethod
def apply_style_to_mobject(
mob: VMobject,

View file

@ -50,7 +50,7 @@ class SingleStringTex(SVGMobject):
self.organize_submobjects_left_to_right()
@property
def hash_seed(self) -> tuple[str, dict[str], dict[str, bool], str, str, bool]:
def hash_seed(self) -> tuple:
return (
self.__class__.__name__,
self.svg_default,

View file

@ -12,7 +12,6 @@ from typing import Iterable, Sequence, Union
import pygments
import pygments.formatters
import pygments.lexers
import pygments.styles
from manimpango import MarkupUtils
@ -20,6 +19,7 @@ from manimlib.logger import log
from manimlib.constants import *
from manimlib.mobject.geometry import Dot
from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.customization import get_customization
from manimlib.utils.tex_file_writing import tex_hash
from manimlib.utils.config_ops import digest_config
@ -30,9 +30,7 @@ from manimlib.utils.directories import get_text_dir
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import colour
from manimlib.mobject.types.vectorized_mobject import VMobject
ManimColor = Union[str, colour.Color, Sequence[float]]
TEXT_MOB_SCALE_FACTOR = 0.0076
DEFAULT_LINE_SPACING_SCALE = 0.6
@ -199,14 +197,14 @@ class _TextParser(object):
result.append((text_piece, attr_dict))
return result
def get_markup_str_with_attrs(self):
def get_markup_str_with_attrs(self) -> str:
return "".join([
f"<span {_TextParser.get_attr_dict_str(attr_dict)}>{text_piece}</span>"
for text_piece, attr_dict in self.get_text_pieces()
])
@staticmethod
def get_attr_dict_str(attr_dict: dict[str, str]):
def get_attr_dict_str(attr_dict: dict[str, str]) -> str:
return " ".join([
f"{key}='{value}'"
for key, value in attr_dict.items()
@ -215,9 +213,14 @@ class _TextParser(object):
# Temporary handler
class _Alignment:
VAL_LIST = ["LEFT", "CENTER", "RIGHT"]
def __init__(self, s):
self.value = _Alignment.VAL_LIST.index(s.upper())
VAL_DICT = {
"LEFT": 0,
"CENTER": 1,
"RIGHT": 2
}
def __init__(self, s: str):
self.value = _Alignment.VAL_DICT[s.upper()]
class Text(SVGMobject):
@ -251,7 +254,7 @@ class Text(SVGMobject):
"local_configs": {},
}
def __init__(self, text, **kwargs):
def __init__(self, text: str, **kwargs):
self.full2short(kwargs)
digest_config(self, kwargs)
validate_error = MarkupUtils.validate(text)
@ -267,7 +270,7 @@ class Text(SVGMobject):
self.scale(TEXT_MOB_SCALE_FACTOR)
@property
def hash_seed(self):
def hash_seed(self) -> tuple:
return (
self.__class__.__name__,
self.svg_default,
@ -293,7 +296,7 @@ class Text(SVGMobject):
self.local_configs
)
def get_file_path(self):
def get_file_path(self) -> str:
full_markup = self.get_full_markup_str()
svg_file = os.path.join(
get_text_dir(), tex_hash(full_markup) + ".svg"
@ -302,7 +305,7 @@ class Text(SVGMobject):
self.markup_to_svg(full_markup, svg_file)
return svg_file
def get_full_markup_str(self):
def get_full_markup_str(self) -> str:
if self.t2g:
log.warning(
"Manim currently cannot parse gradient from svg. "
@ -346,7 +349,7 @@ class Text(SVGMobject):
return self.parser.get_markup_str_with_attrs()
def find_indexes(self, word_or_text_span):
def find_indexes(self, word_or_text_span: str | tuple[int, int]) -> list[tuple[int, int]]:
if isinstance(word_or_text_span, tuple):
return [word_or_text_span]
@ -355,7 +358,7 @@ class Text(SVGMobject):
for match_obj in re.finditer(re.escape(word_or_text_span), self.text)
]
def markup_to_svg(self, markup_str, file_name):
def markup_to_svg(self, markup_str: str, file_name: str) -> str:
# `manimpango` is under construction,
# so the following code is intended to suit its interface
alignment = _Alignment(self.alignment)
@ -384,7 +387,7 @@ class Text(SVGMobject):
pango_width=pango_width
)
def generate_mobject(self):
def generate_mobject(self) -> None:
super().generate_mobject()
# Remove empty paths
@ -395,15 +398,14 @@ class Text(SVGMobject):
content_str = self.parser.remove_tags(self.text)
if self.is_markup:
content_str = saxutils.unescape(content_str)
for char_index, char in enumerate(content_str):
if not re.match(r"\s", char):
continue
for match_obj in re.finditer(r"\s", content_str):
char_index = match_obj.start()
space = Dot(radius=0, fill_opacity=0, stroke_opacity=0)
space.move_to(submobjects[max(char_index - 1, 0)].get_center())
submobjects.insert(char_index, space)
self.set_submobjects(submobjects)
def full2short(self, config):
def full2short(self, config: dict) -> None:
conversion_dict = {
"line_spacing_height": "lsh",
"text2color": "t2c",
@ -417,7 +419,7 @@ class Text(SVGMobject):
if long_name in kwargs:
kwargs[short_name] = kwargs.pop(long_name)
def get_parts_by_text(self, word):
def get_parts_by_text(self, word: str) -> VGroup:
if self.is_markup:
log.warning(
"Slicing MarkupText via `get_parts_by_text`, "
@ -433,7 +435,7 @@ class Text(SVGMobject):
for i, j in self.find_indexes(word)
))
def get_part_by_text(self, word):
def get_part_by_text(self, word: str) -> VMobject | None:
parts = self.get_parts_by_text(word)
return parts[0] if parts else None
@ -445,7 +447,6 @@ class MarkupText(Text):
}
class Code(MarkupText):
CONFIG = {
"font": "Consolas",
@ -456,7 +457,7 @@ class Code(MarkupText):
"code_style": "monokai",
}
def __init__(self, code, **kwargs):
def __init__(self, code: str, **kwargs):
digest_config(self, kwargs)
self.code = code
lexer = pygments.lexers.get_lexer_by_name(self.language)