diff --git a/docs/source/development/changelog.rst b/docs/source/development/changelog.rst index 6e865dc8..077e823c 100644 --- a/docs/source/development/changelog.rst +++ b/docs/source/development/changelog.rst @@ -1,6 +1,67 @@ Changelog ========= +v1.6.1 +------ + +Fixed bugs +^^^^^^^^^^ +- Fixed the bug of ``MTex`` with multi-line tex string (`#1785 `__) +- Fixed ``interpolate`` (`#1788 `__) +- Fixed ``ImageMobject`` (`#1791 `__) + +Refactor +^^^^^^^^ +- Added ``\overset`` as a special string in ``Tex`` (`#1783 `__) +- Added ``outer_interpolate`` to perform interpolation using ``np.outer`` on arrays (`#1788 `__) + +v1.6.0 +------ + +Breaking changes +^^^^^^^^^^^^^^^^ +- **Python 3.6 is no longer supported** (`#1736 `__) + +Fixed bugs +^^^^^^^^^^ +- Fixed the width of riemann rectangles (`#1762 `__) +- Bug fixed in cases where empty array is passed to shader (`#1764 `__) +- Fixed ``AddTextWordByWord`` (`#1772 `__) +- Fixed ``ControlsExample`` (`#1781 `__) + + +New features +^^^^^^^^^^^^ +- Added more functions to ``Text`` (details: `#1751 `__) +- Allowed ``interpolate`` to work on an array of alpha values (`#1764 `__) +- Allowed ``Numberline.number_to_point`` and ``CoordinateSystem.coords_to_point`` to work on an array of inputs (`#1764 `__) +- Added a basic ``Prismify`` to turn a flat ``VMobject`` into something with depth (`#1764 `__) +- Added ``GlowDots``, analogous to ``GlowDot`` (`#1764 `__) +- Added ``TransformMatchingStrings`` which is compatible with ``Text`` and ``MTex`` (`#1772 `__) +- Added support for ``substring`` and ``case_sensitive`` parameters for ``LabelledString.get_parts_by_string`` (`#1780 `__) + + +Refactor +^^^^^^^^ +- Added type hints (`#1736 `__) +- Specifid UTF-8 encoding for tex files (`#1748 `__) +- Refactored ``Text`` with the latest manimpango (`#1751 `__) +- Reorganized getters for ``ParametricCurve`` (`#1757 `__) +- Refactored ``CameraFrame`` to use ``scipy.spatial.transform.Rotation`` (`#1764 `__) +- Refactored rotation methods to use ``scipy.spatial.transform.Rotation`` (`#1764 `__) +- Used ``stroke_color`` to init ``Arrow`` (`#1764 `__) +- Refactored ``Mobject.set_rgba_array_by_color`` (`#1764 `__) +- Made panning more sensitive to mouse movements (`#1764 `__) +- Added loading progress for large SVGs (`#1766 `__) +- Added getter/setter of ``field_of_view`` for ``CameraFrame`` (`#1770 `__) +- Renamed ``focal_distance`` to ``focal_dist_to_height`` and added getter/setter (`#1770 `__) +- Added getter and setter for ``VMobject.joint_type`` (`#1770 `__) +- Refactored ``VCube`` (`#1770 `__) +- Refactored ``Prism`` to receive ``width height depth`` instead of ``dimensions`` (`#1770 `__) +- Refactored ``Text``, ``MarkupText`` and ``MTex`` based on ``LabelledString`` (`#1772 `__) +- Refactored ``LabelledString`` and relevant classes (`#1779 `__) + + v1.5.0 ------ @@ -9,7 +70,7 @@ Fixed bugs - Bug fix for the case of calling ``Write`` on a null object (`#1740 `__) -New Features +New features ^^^^^^^^^^^^ - Added ``TransformMatchingMTex`` (`#1725 `__) - Added ``ImplicitFunction`` (`#1727 `__) @@ -60,7 +121,7 @@ Fixed bugs - Fixed some bugs of SVG path string parser (`#1717 `__) - Fixed some bugs of ``MTex`` (`#1720 `__) -New Features +New features ^^^^^^^^^^^^ - Added option to add ticks on x-axis in ``BarChart`` (`#1694 `__) - Added ``lable_buff`` config parameter for ``Brace`` (`#1704 `__) @@ -99,7 +160,7 @@ Fixed bugs - Fixed bug in ``ShowSubmobjectsOneByOne`` (`bcd0990 `__) - Fixed bug in ``TransformMatchingParts`` (`7023548 `__) -New Features +New features ^^^^^^^^^^^^ - Added CLI flag ``--log-level`` to specify log level (`e10f850 `__) @@ -167,7 +228,7 @@ Fixed bugs - Fixed bug with ``CoordinateSystem.get_lines_parallel_to_axis`` (`c726eb7 `__) - Fixed ``ComplexPlane`` -i display bug (`7732d2f `__) -New Features +New features ^^^^^^^^^^^^ - Supported the elliptical arc command ``A`` for ``SVGMobject`` (`#1598 `__) @@ -230,7 +291,7 @@ Fixed bugs - Rewrote ``earclip_triangulation`` to fix triangulation - Allowed sound_file_name to be taken in without extensions -New Features +New features ^^^^^^^^^^^^ - Added :class:`~manimlib.animation.indication.VShowPassingFlash` diff --git a/example_scenes.py b/example_scenes.py index 70321ce5..5f61b874 100644 --- a/example_scenes.py +++ b/example_scenes.py @@ -650,7 +650,7 @@ class ControlsExample(Scene): def text_updater(old_text): assert(isinstance(old_text, Text)) - new_text = Text(self.textbox.get_value(), size=old_text.size) + new_text = Text(self.textbox.get_value(), font_size=old_text.font_size) # new_text.align_data_and_family(old_text) new_text.move_to(old_text) if self.checkbox.get_value(): diff --git a/manimlib/__init__.py b/manimlib/__init__.py index 13e41ec0..a0147cf7 100644 --- a/manimlib/__init__.py +++ b/manimlib/__init__.py @@ -37,6 +37,7 @@ from manimlib.mobject.probability import * from manimlib.mobject.shape_matchers import * from manimlib.mobject.svg.brace import * from manimlib.mobject.svg.drawings import * +from manimlib.mobject.svg.labelled_string import * from manimlib.mobject.svg.mtex_mobject import * from manimlib.mobject.svg.svg_mobject import * from manimlib.mobject.svg.tex_mobject import * diff --git a/manimlib/animation/creation.py b/manimlib/animation/creation.py index 00588b46..6499d0af 100644 --- a/manimlib/animation/creation.py +++ b/manimlib/animation/creation.py @@ -7,6 +7,7 @@ import numpy as np from manimlib.animation.animation import Animation from manimlib.animation.composition import Succession +from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.bezier import integer_interpolate from manimlib.utils.config_ops import digest_config @@ -202,23 +203,19 @@ class ShowSubmobjectsOneByOne(ShowIncreasingSubsets): self.mobject.set_submobjects([self.all_submobs[index - 1]]) -# TODO, this is broken... -class AddTextWordByWord(Succession): +class AddTextWordByWord(ShowIncreasingSubsets): CONFIG = { # If given a value for run_time, it will - # override the time_per_char + # override the time_per_word "run_time": None, - "time_per_char": 0.06, + "time_per_word": 0.2, + "rate_func": linear, } - def __init__(self, text_mobject, **kwargs): + def __init__(self, string_mobject, **kwargs): + assert isinstance(string_mobject, LabelledString) + grouped_mobject = string_mobject.submob_groups digest_config(self, kwargs) - tpc = self.time_per_char - anims = it.chain(*[ - [ - ShowIncreasingSubsets(word, run_time=tpc * len(word)), - Animation(word, run_time=0.005 * len(word)**1.5), - ] - for word in text_mobject - ]) - super().__init__(*anims, **kwargs) + if self.run_time is None: + self.run_time = self.time_per_word * len(grouped_mobject) + super().__init__(grouped_mobject, **kwargs) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index 90ffa76f..dab88005 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -12,7 +12,7 @@ from manimlib.animation.transform import ReplacementTransform from manimlib.animation.transform import Transform from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Group -from manimlib.mobject.svg.mtex_mobject import MTex +from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.config_ops import digest_config @@ -153,106 +153,108 @@ class TransformMatchingTex(TransformMatchingParts): return mobject.get_tex() -class TransformMatchingMTex(AnimationGroup): +class TransformMatchingStrings(AnimationGroup): CONFIG = { "key_map": dict(), + "transform_mismatches": False, } - def __init__(self, source_mobject: MTex, target_mobject: MTex, **kwargs): + def __init__(self, + source: LabelledString, + target: LabelledString, + **kwargs + ): digest_config(self, kwargs) - assert isinstance(source_mobject, MTex) - assert isinstance(target_mobject, MTex) + assert isinstance(source, LabelledString) + assert isinstance(target, LabelledString) anims = [] - rest_source_submobs = source_mobject.submobjects.copy() - rest_target_submobs = target_mobject.submobjects.copy() + source_indices = list(range(len(source.labelled_submobjects))) + target_indices = list(range(len(target.labelled_submobjects))) - def add_anim_from(anim_class, func, source_attr, target_attr=None): - if target_attr is None: - target_attr = source_attr - source_parts = func(source_mobject, source_attr) - target_parts = func(target_mobject, target_attr) - filtered_source_parts = [ - submob_part for submob_part in source_parts - if all([ - submob in rest_source_submobs - for submob in submob_part - ]) + def get_indices_lists(mobject, parts): + return [ + [ + mobject.labelled_submobjects.index(submob) + for submob in part + ] + for part in parts ] - filtered_target_parts = [ - submob_part for submob_part in target_parts - if all([ - submob in rest_target_submobs - for submob in submob_part - ]) - ] - if not (filtered_source_parts and filtered_target_parts): - return - anims.append(anim_class( - VGroup(*filtered_source_parts), - VGroup(*filtered_target_parts), - **kwargs - )) - for submob in it.chain(*filtered_source_parts): - rest_source_submobs.remove(submob) - for submob in it.chain(*filtered_target_parts): - rest_target_submobs.remove(submob) - def get_submobs_from_keys(mobject, keys): - if not isinstance(keys, tuple): - keys = (keys,) - indices = [] + def add_anims_from(anim_class, func, source_args, target_args=None): + if target_args is None: + target_args = source_args.copy() + for source_arg, target_arg in zip(source_args, target_args): + source_parts = func(source, source_arg) + target_parts = func(target, target_arg) + source_indices_lists = list(filter( + lambda indices_list: all([ + index in source_indices + for index in indices_list + ]), get_indices_lists(source, source_parts) + )) + target_indices_lists = list(filter( + lambda indices_list: all([ + index in target_indices + for index in indices_list + ]), get_indices_lists(target, target_parts) + )) + if not source_indices_lists or not target_indices_lists: + continue + anims.append(anim_class(source_parts, target_parts, **kwargs)) + for index in it.chain(*source_indices_lists): + source_indices.remove(index) + for index in it.chain(*target_indices_lists): + target_indices.remove(index) + + def get_common_substrs(substrs_from_source, substrs_from_target): + return sorted([ + substr for substr in substrs_from_source + if substr and substr in substrs_from_target + ], key=len, reverse=True) + + def get_parts_from_keys(mobject, keys): + if isinstance(keys, str): + keys = [keys] + result = VGroup() for key in keys: - if isinstance(key, int): - indices.append(key) - elif isinstance(key, range): - indices.extend(key) - elif isinstance(key, str): - all_parts = mobject.get_parts_by_tex(key) - indices.extend(it.chain(*[ - mobject.indices_of_part(part) for part in all_parts - ])) - else: + if not isinstance(key, str): raise TypeError(key) - return VGroup(VGroup(*[ - mobject[i] for i in remove_list_redundancies(indices) - ])) + result.add(*mobject.get_parts_by_string(key)) + return result - for source_key, target_key in self.key_map.items(): - add_anim_from( - ReplacementTransform, get_submobs_from_keys, - source_key, target_key + add_anims_from( + ReplacementTransform, get_parts_from_keys, + self.key_map.keys(), self.key_map.values() + ) + add_anims_from( + FadeTransformPieces, + LabelledString.get_parts_by_string, + get_common_substrs( + source.specified_substrs, + target.specified_substrs ) + ) + add_anims_from( + FadeTransformPieces, + LabelledString.get_parts_by_group_substr, + get_common_substrs( + source.group_substrs, + target.group_substrs + ) + ) - common_specified_substrings = sorted(list( - set(source_mobject.get_specified_substrings()).intersection( - target_mobject.get_specified_substrings() + rest_source = VGroup(*[source[index] for index in source_indices]) + rest_target = VGroup(*[target[index] for index in target_indices]) + if self.transform_mismatches: + anims.append( + ReplacementTransform(rest_source, rest_target, **kwargs) ) - ), key=len, reverse=True) - for part_tex_string in common_specified_substrings: - add_anim_from( - FadeTransformPieces, MTex.get_parts_by_tex, part_tex_string + else: + anims.append( + FadeOutToPoint(rest_source, target.get_center(), **kwargs) ) - - common_submob_tex_strings = { - source_submob.get_tex() for source_submob in source_mobject - }.intersection({ - target_submob.get_tex() for target_submob in target_mobject - }) - for tex_string in common_submob_tex_strings: - add_anim_from( - FadeTransformPieces, - lambda mobject, attr: VGroup(*[ - VGroup(mob) for mob in mobject - if mob.get_tex() == attr - ]), - tex_string + anims.append( + FadeInFromPoint(rest_target, source.get_center(), **kwargs) ) - anims.append(FadeOutToPoint( - VGroup(*rest_source_submobs), target_mobject.get_center(), **kwargs - )) - anims.append(FadeInFromPoint( - VGroup(*rest_target_submobs), source_mobject.get_center(), **kwargs - )) - super().__init__(*anims) diff --git a/manimlib/mobject/number_line.py b/manimlib/mobject/number_line.py index 13b6a13b..bc96b55a 100644 --- a/manimlib/mobject/number_line.py +++ b/manimlib/mobject/number_line.py @@ -7,6 +7,7 @@ from manimlib.mobject.geometry import Line from manimlib.mobject.numbers import DecimalNumber from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.bezier import interpolate +from manimlib.utils.bezier import outer_interpolate from manimlib.utils.config_ops import digest_config from manimlib.utils.config_ops import merge_dicts_recursively from manimlib.utils.simple_functions import fdiv @@ -106,7 +107,7 @@ class NumberLine(Line): def number_to_point(self, number: float | np.ndarray) -> np.ndarray: alpha = (number - self.x_min) / (self.x_max - self.x_min) - return interpolate(self.get_start(), self.get_end(), alpha) + return outer_interpolate(self.get_start(), self.get_end(), alpha) def point_to_number(self, point: np.ndarray) -> float: points = self.get_points() diff --git a/manimlib/mobject/numbers.py b/manimlib/mobject/numbers.py index bd003fb6..beac837c 100644 --- a/manimlib/mobject/numbers.py +++ b/manimlib/mobject/numbers.py @@ -6,12 +6,9 @@ from manimlib.constants import * from manimlib.mobject.svg.tex_mobject import SingleStringTex from manimlib.mobject.svg.text_mobject import Text from manimlib.mobject.types.vectorized_mobject import VMobject -from manimlib.utils.iterables import hash_obj T = TypeVar("T", bound=VMobject) -string_to_mob_map: dict[str, VMobject] = {} - class DecimalNumber(VMobject): CONFIG = { @@ -92,9 +89,7 @@ class DecimalNumber(VMobject): return self.data["font_size"][0] def string_to_mob(self, string: str, mob_class: Type[T] = Text, **kwargs) -> T: - if (string, hash_obj(kwargs)) not in string_to_mob_map: - string_to_mob_map[(string, hash_obj(kwargs))] = mob_class(string, font_size=1, **kwargs) - mob = string_to_mob_map[(string, hash_obj(kwargs))].copy() + mob = mob_class(string, font_size=1, **kwargs) mob.scale(self.get_font_size()) return mob diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py new file mode 100644 index 00000000..58c47094 --- /dev/null +++ b/manimlib/mobject/svg/labelled_string.py @@ -0,0 +1,546 @@ +from __future__ import annotations + +import re +import colour +import itertools as it +from typing import Iterable, Union, Sequence +from abc import ABC, abstractmethod + +from manimlib.constants import BLACK, WHITE +from manimlib.mobject.svg.svg_mobject import SVGMobject +from manimlib.mobject.types.vectorized_mobject import VGroup +from manimlib.utils.color import color_to_int_rgb +from manimlib.utils.color import color_to_rgb +from manimlib.utils.color import rgb_to_hex +from manimlib.utils.config_ops import digest_config +from manimlib.utils.iterables import remove_list_redundancies + + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from manimlib.mobject.types.vectorized_mobject import VMobject + ManimColor = Union[str, colour.Color, Sequence[float]] + Span = tuple[int, int] + + +class _StringSVG(SVGMobject): + CONFIG = { + "height": None, + "stroke_width": 0, + "stroke_color": WHITE, + "path_string_config": { + "should_subdivide_sharp_curves": True, + "should_remove_null_curves": True, + }, + } + + +class LabelledString(_StringSVG, ABC): + """ + An abstract base class for `MTex` and `MarkupText` + """ + CONFIG = { + "base_color": WHITE, + "use_plain_file": False, + "isolate": [], + } + + def __init__(self, string: str, **kwargs): + self.string = string + digest_config(self, kwargs) + + # Convert `base_color` to hex code. + self.base_color = rgb_to_hex(color_to_rgb( + self.base_color \ + or self.svg_default.get("color", None) \ + or self.svg_default.get("fill_color", None) \ + or WHITE + )) + self.svg_default["fill_color"] = BLACK + + self.pre_parse() + self.parse() + super().__init__() + self.post_parse() + + def get_file_path(self) -> str: + return self.get_file_path_(use_plain_file=False) + + def get_file_path_(self, use_plain_file: bool) -> str: + content = self.get_content(use_plain_file) + return self.get_file_path_by_content(content) + + @abstractmethod + def get_file_path_by_content(self, content: str) -> str: + return "" + + def generate_mobject(self) -> None: + super().generate_mobject() + + submob_labels = [ + self.color_to_label(submob.get_fill_color()) + for submob in self.submobjects + ] + if self.use_plain_file or self.has_predefined_local_colors: + file_path = self.get_file_path_(use_plain_file=True) + plain_svg = _StringSVG( + file_path, + svg_default=self.svg_default, + path_string_config=self.path_string_config + ) + self.set_submobjects(plain_svg.submobjects) + else: + self.set_fill(self.base_color) + for submob, label in zip(self.submobjects, submob_labels): + submob.label = label + + def pre_parse(self) -> None: + self.string_len = len(self.string) + self.full_span = (0, self.string_len) + + def parse(self) -> None: + self.command_repl_items = self.get_command_repl_items() + self.command_spans = self.get_command_spans() + self.extra_entity_spans = self.get_extra_entity_spans() + self.entity_spans = self.get_entity_spans() + self.extra_ignored_spans = self.get_extra_ignored_spans() + self.skipped_spans = self.get_skipped_spans() + self.internal_specified_spans = self.get_internal_specified_spans() + self.external_specified_spans = self.get_external_specified_spans() + self.specified_spans = self.get_specified_spans() + self.label_span_list = self.get_label_span_list() + self.check_overlapping() + + def post_parse(self) -> None: + self.labelled_submobject_items = [ + (submob.label, submob) + for submob in self.submobjects + ] + self.labelled_submobjects = self.get_labelled_submobjects() + self.specified_substrs = self.get_specified_substrs() + self.group_items = self.get_group_items() + self.group_substrs = self.get_group_substrs() + self.submob_groups = self.get_submob_groups() + + def copy(self): + return self.deepcopy() + + # Toolkits + + def get_substr(self, span: Span) -> str: + return self.string[slice(*span)] + + def finditer( + self, pattern: str, flags: int = 0, **kwargs + ) -> Iterable[re.Match]: + return re.compile(pattern, flags).finditer(self.string, **kwargs) + + def search( + self, pattern: str, flags: int = 0, **kwargs + ) -> re.Match | None: + return re.compile(pattern, flags).search(self.string, **kwargs) + + def match( + self, pattern: str, flags: int = 0, **kwargs + ) -> re.Match | None: + return re.compile(pattern, flags).match(self.string, **kwargs) + + def find_spans(self, pattern: str, **kwargs) -> list[Span]: + return [ + match_obj.span() + for match_obj in self.finditer(pattern, **kwargs) + ] + + def find_substr(self, substr: str, **kwargs) -> list[Span]: + if not substr: + return [] + return self.find_spans(re.escape(substr), **kwargs) + + def find_substrs(self, substrs: list[str], **kwargs) -> list[Span]: + return list(it.chain(*[ + self.find_substr(substr, **kwargs) + for substr in remove_list_redundancies(substrs) + ])) + + @staticmethod + def get_neighbouring_pairs(iterable: list) -> list[tuple]: + return list(zip(iterable[:-1], iterable[1:])) + + @staticmethod + def span_contains(span_0: Span, span_1: Span) -> bool: + return span_0[0] <= span_1[0] and span_0[1] >= span_1[1] + + @staticmethod + def get_complement_spans( + interval_spans: list[Span], universal_span: Span + ) -> list[Span]: + if not interval_spans: + return [universal_span] + + span_ends, span_begins = zip(*interval_spans) + return list(zip( + (universal_span[0], *span_begins), + (*span_ends, universal_span[1]) + )) + + @staticmethod + def compress_neighbours(vals: list[int]) -> list[tuple[int, Span]]: + if not vals: + return [] + + unique_vals = [vals[0]] + indices = [0] + for index, val in enumerate(vals): + if val == unique_vals[-1]: + continue + unique_vals.append(val) + indices.append(index) + indices.append(len(vals)) + spans = LabelledString.get_neighbouring_pairs(indices) + return list(zip(unique_vals, spans)) + + @staticmethod + def find_region_index(seq: list[int], val: int) -> int: + # Returns an integer in `range(-1, len(seq))` satisfying + # `seq[result] <= val < seq[result + 1]`. + # `seq` should be sorted in ascending order. + if not seq or val < seq[0]: + return -1 + result = len(seq) - 1 + while val < seq[result]: + result -= 1 + return result + + @staticmethod + def take_nearest_value(seq: list[int], val: int, index_shift: int) -> int: + sorted_seq = sorted(seq) + index = LabelledString.find_region_index(sorted_seq, val) + return sorted_seq[index + index_shift] + + @staticmethod + def generate_span_repl_dict( + inserted_string_pairs: list[tuple[Span, tuple[str, str]]], + other_repl_items: list[tuple[Span, str]] + ) -> dict[Span, str]: + result = dict(other_repl_items) + if not inserted_string_pairs: + return result + + indices, _, _, inserted_strings = zip(*sorted([ + ( + span[flag], + -flag, + -span[1 - flag], + str_pair[flag] + ) + for span, str_pair in inserted_string_pairs + for flag in range(2) + ])) + result.update({ + (index, index): "".join(inserted_strings[slice(*item_span)]) + for index, item_span + in LabelledString.compress_neighbours(indices) + }) + return result + + def get_replaced_substr( + self, span: Span, span_repl_dict: dict[Span, str] + ) -> str: + repl_spans = sorted(filter( + lambda repl_span: self.span_contains(span, repl_span), + span_repl_dict.keys() + )) + if not all( + span_0[1] <= span_1[0] + for span_0, span_1 in self.get_neighbouring_pairs(repl_spans) + ): + raise ValueError("Overlapping replacement") + + pieces = [ + self.get_substr(piece_span) + for piece_span in self.get_complement_spans(repl_spans, span) + ] + repl_strs = [span_repl_dict[repl_span] for repl_span in repl_spans] + repl_strs.append("") + return "".join(it.chain(*zip(pieces, repl_strs))) + + @staticmethod + def rslide(index: int, skipped: list[Span]) -> int: + transfer_dict = dict(sorted(skipped)) + while index in transfer_dict.keys(): + index = transfer_dict[index] + return index + + @staticmethod + def lslide(index: int, skipped: list[Span]) -> int: + transfer_dict = dict(sorted([ + skipped_span[::-1] for skipped_span in skipped + ], reverse=True)) + while index in transfer_dict.keys(): + index = transfer_dict[index] + return index + + @staticmethod + def rgb_to_int(rgb_tuple: tuple[int, int, int]) -> int: + r, g, b = rgb_tuple + rg = r * 256 + g + return rg * 256 + b + + @staticmethod + def int_to_rgb(rgb_int: int) -> tuple[int, int, int]: + rg, b = divmod(rgb_int, 256) + r, g = divmod(rg, 256) + return r, g, b + + @staticmethod + def int_to_hex(rgb_int: int) -> str: + return "#{:06x}".format(rgb_int).upper() + + @staticmethod + def hex_to_int(rgb_hex: str) -> int: + return int(rgb_hex[1:], 16) + + @staticmethod + def color_to_label(color: ManimColor) -> int: + rgb_tuple = color_to_int_rgb(color) + rgb = LabelledString.rgb_to_int(rgb_tuple) + return rgb - 1 + + # Parsing + + @abstractmethod + def get_command_repl_items(self) -> list[tuple[Span, str]]: + return [] + + def get_command_spans(self) -> list[Span]: + return [cmd_span for cmd_span, _ in self.command_repl_items] + + @abstractmethod + def get_extra_entity_spans(self) -> list[Span]: + return [] + + def get_entity_spans(self) -> list[Span]: + return list(it.chain( + self.command_spans, + self.extra_entity_spans + )) + + @abstractmethod + def get_extra_ignored_spans(self) -> list[int]: + return [] + + def get_skipped_spans(self) -> list[Span]: + return list(it.chain( + self.find_spans(r"\s"), + self.command_spans, + self.extra_ignored_spans + )) + + def shrink_span(self, span: Span) -> Span: + return ( + self.rslide(span[0], self.skipped_spans), + self.lslide(span[1], self.skipped_spans) + ) + + @abstractmethod + def get_internal_specified_spans(self) -> list[Span]: + return [] + + @abstractmethod + def get_external_specified_spans(self) -> list[Span]: + return [] + + def get_specified_spans(self) -> list[Span]: + spans = list(it.chain( + self.internal_specified_spans, + self.external_specified_spans, + self.find_substrs(self.isolate) + )) + shrinked_spans = list(filter( + lambda span: span[0] < span[1] and not any([ + entity_span[0] < index < entity_span[1] + for index in span + for entity_span in self.entity_spans + ]), + [self.shrink_span(span) for span in spans] + )) + return remove_list_redundancies(shrinked_spans) + + @abstractmethod + def get_label_span_list(self) -> list[Span]: + return [] + + def check_overlapping(self) -> None: + for span_0, span_1 in it.product(self.label_span_list, repeat=2): + if not span_0[0] < span_1[0] < span_0[1] < span_1[1]: + continue + raise ValueError( + "Partially overlapping substrings detected: " + f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'" + ) + + @abstractmethod + def get_content(self, use_plain_file: bool) -> str: + return "" + + @abstractmethod + def has_predefined_local_colors(self) -> bool: + return False + + # Post-parsing + + def get_labelled_submobjects(self) -> list[VMobject]: + return [submob for _, submob in self.labelled_submobject_items] + + def get_cleaned_substr(self, span: Span) -> str: + span_repl_dict = dict.fromkeys(self.command_spans, "") + return self.get_replaced_substr(span, span_repl_dict) + + def get_specified_substrs(self) -> list[str]: + return remove_list_redundancies([ + self.get_cleaned_substr(span) + for span in self.specified_spans + ]) + + def get_group_items(self) -> list[tuple[str, VGroup]]: + if not self.labelled_submobject_items: + return [] + + labels, labelled_submobjects = zip(*self.labelled_submobject_items) + group_labels, labelled_submob_spans = zip( + *self.compress_neighbours(labels) + ) + ordered_spans = [ + self.label_span_list[label] if label != -1 else self.full_span + for label in group_labels + ] + interval_spans = [ + ( + next_span[0] + if self.span_contains(prev_span, next_span) + else prev_span[1], + prev_span[1] + if self.span_contains(next_span, prev_span) + else next_span[0] + ) + for prev_span, next_span in self.get_neighbouring_pairs( + ordered_spans + ) + ] + shrinked_spans = [ + self.shrink_span(span) + for span in self.get_complement_spans( + interval_spans, (ordered_spans[0][0], ordered_spans[-1][1]) + ) + ] + group_substrs = [ + self.get_cleaned_substr(span) if span[0] < span[1] else "" + for span in shrinked_spans + ] + submob_groups = VGroup(*[ + VGroup(*labelled_submobjects[slice(*submob_span)]) + for submob_span in labelled_submob_spans + ]) + return list(zip(group_substrs, submob_groups)) + + def get_group_substrs(self) -> list[str]: + return [group_substr for group_substr, _ in self.group_items] + + def get_submob_groups(self) -> list[VGroup]: + return [submob_group for _, submob_group in self.group_items] + + def get_parts_by_group_substr(self, substr: str) -> VGroup: + return VGroup(*[ + group + for group_substr, group in self.group_items + if group_substr == substr + ]) + + # Selector + + def find_span_components( + self, custom_span: Span, substring: bool = True + ) -> list[Span]: + shrinked_span = self.shrink_span(custom_span) + if shrinked_span[0] >= shrinked_span[1]: + return [] + + if substring: + indices = remove_list_redundancies(list(it.chain( + self.full_span, + *self.label_span_list + ))) + span_begin = self.take_nearest_value( + indices, shrinked_span[0], 0 + ) + span_end = self.take_nearest_value( + indices, shrinked_span[1] - 1, 1 + ) + else: + span_begin, span_end = shrinked_span + + span_choices = sorted(filter( + lambda span: self.span_contains((span_begin, span_end), span), + self.label_span_list + )) + # Choose spans that reach the farthest. + span_choices_dict = dict(span_choices) + + result = [] + while span_begin < span_end: + if span_begin not in span_choices_dict.keys(): + span_begin += 1 + continue + next_begin = span_choices_dict[span_begin] + result.append((span_begin, next_begin)) + span_begin = next_begin + return result + + def get_part_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup: + labels = [ + label for label, span in enumerate(self.label_span_list) + if any([ + self.span_contains(span_component, span) + for span_component in self.find_span_components( + custom_span, **kwargs + ) + ]) + ] + return VGroup(*[ + submob for label, submob in self.labelled_submobject_items + if label in labels + ]) + + def get_parts_by_string( + self, substr: str, + case_sensitive: bool = True, regex: bool = False, **kwargs + ) -> VGroup: + flags = 0 + if not case_sensitive: + flags |= re.I + pattern = substr if regex else re.escape(substr) + return VGroup(*[ + self.get_part_by_custom_span(span, **kwargs) + for span in self.find_spans(pattern, flags=flags) + if span[0] < span[1] + ]) + + def get_part_by_string( + self, substr: str, index: int = 0, **kwargs + ) -> VMobject: + return self.get_parts_by_string(substr, **kwargs)[index] + + def set_color_by_string(self, substr: str, color: ManimColor, **kwargs): + self.get_parts_by_string(substr, **kwargs).set_color(color) + return self + + def set_color_by_string_to_color_map( + self, string_to_color_map: dict[str, ManimColor], **kwargs + ): + for substr, color in string_to_color_map.items(): + self.set_color_by_string(substr, color, **kwargs) + return self + + def get_string(self) -> str: + return self.string diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 73303eab..fb7922e1 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -1,447 +1,41 @@ from __future__ import annotations -import re -import colour import itertools as it -from types import MethodType -from typing import Iterable, Union, Sequence +import colour +from typing import Union, Sequence -from manimlib.constants import WHITE -from manimlib.mobject.svg.svg_mobject import SVGMobject -from manimlib.mobject.types.vectorized_mobject import VGroup -from manimlib.utils.color import color_to_int_rgb -from manimlib.utils.config_ops import digest_config -from manimlib.utils.iterables import adjacent_pairs -from manimlib.utils.iterables import remove_list_redundancies +from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.utils.tex_file_writing import tex_to_svg_file from manimlib.utils.tex_file_writing import get_tex_config from manimlib.utils.tex_file_writing import display_during_execution -from manimlib.logger import log from typing import TYPE_CHECKING if TYPE_CHECKING: from manimlib.mobject.types.vectorized_mobject import VMobject + from manimlib.mobject.types.vectorized_mobject import VGroup ManimColor = Union[str, colour.Color, Sequence[float]] + Span = tuple[int, int] SCALE_FACTOR_PER_FONT_POINT = 0.001 -def _get_neighbouring_pairs(iterable: Iterable) -> list: - return list(adjacent_pairs(iterable))[:-1] - - -class _TexParser(object): - def __init__(self, tex_string: str, additional_substrings: list[str]): - self.tex_string = tex_string - self.whitespace_indices = self.get_whitespace_indices() - self.backslash_indices = self.get_backslash_indices() - self.script_indices = self.get_script_indices() - self.brace_indices_dict = self.get_brace_indices_dict() - self.tex_span_list: list[tuple[int, int]] = [] - self.script_span_to_char_dict: dict[tuple[int, int], str] = {} - self.script_span_to_tex_span_dict: dict[ - tuple[int, int], tuple[int, int] - ] = {} - self.neighbouring_script_span_pairs: list[tuple[int, int]] = [] - self.specified_substrings: list[str] = [] - self.add_tex_span((0, len(tex_string))) - self.break_up_by_scripts() - self.break_up_by_double_braces() - self.break_up_by_additional_substrings(additional_substrings) - self.tex_span_list.sort(key=lambda t: (t[0], -t[1])) - self.specified_substrings = remove_list_redundancies( - self.specified_substrings - ) - self.containing_labels_dict = self.get_containing_labels_dict() - - def add_tex_span(self, tex_span: tuple[int, int]) -> None: - if tex_span not in self.tex_span_list: - self.tex_span_list.append(tex_span) - - def get_whitespace_indices(self) -> list[int]: - return [ - match_obj.start() - for match_obj in re.finditer(r"\s", self.tex_string) - ] - - def get_backslash_indices(self) -> list[int]: - # Newlines (`\\`) don't count. - return [ - match_obj.end() - 1 - for match_obj in re.finditer(r"\\+", self.tex_string) - if len(match_obj.group()) % 2 == 1 - ] - - def filter_out_escaped_characters(self, indices) -> list[int]: - return list(filter( - lambda index: index - 1 not in self.backslash_indices, - indices - )) - - def get_script_indices(self) -> list[int]: - return self.filter_out_escaped_characters([ - match_obj.start() - for match_obj in re.finditer(r"[_^]", self.tex_string) - ]) - - def get_brace_indices_dict(self) -> dict[int, int]: - tex_string = self.tex_string - indices = self.filter_out_escaped_characters([ - match_obj.start() - for match_obj in re.finditer(r"[{}]", tex_string) - ]) - result = {} - left_brace_indices_stack = [] - for index in indices: - if tex_string[index] == "{": - left_brace_indices_stack.append(index) - else: - left_brace_index = left_brace_indices_stack.pop() - result[left_brace_index] = index - return result - - def break_up_by_scripts(self) -> None: - # Match subscripts & superscripts. - tex_string = self.tex_string - whitespace_indices = self.whitespace_indices - brace_indices_dict = self.brace_indices_dict - script_spans = [] - for script_index in self.script_indices: - script_char = tex_string[script_index] - extended_begin = script_index - while extended_begin - 1 in whitespace_indices: - extended_begin -= 1 - script_begin = script_index + 1 - while script_begin in whitespace_indices: - script_begin += 1 - if script_begin in brace_indices_dict.keys(): - script_end = brace_indices_dict[script_begin] + 1 - else: - pattern = re.compile(r"[a-zA-Z0-9]|\\[a-zA-Z]+") - match_obj = pattern.match(tex_string, pos=script_begin) - if not match_obj: - script_name = { - "_": "subscript", - "^": "superscript" - }[script_char] - log.warning( - f"Unclear {script_name} detected while parsing. " - "Please use braces to clarify" - ) - continue - script_end = match_obj.end() - tex_span = (script_begin, script_end) - script_span = (extended_begin, script_end) - script_spans.append(script_span) - self.add_tex_span(tex_span) - self.script_span_to_char_dict[script_span] = script_char - self.script_span_to_tex_span_dict[script_span] = tex_span - - if not script_spans: - return - - _, sorted_script_spans = zip(*sorted([ - (index, script_span) - for script_span in script_spans - for index in script_span - ])) - for span_0, span_1 in _get_neighbouring_pairs(sorted_script_spans): - if span_0[1] == span_1[0]: - self.neighbouring_script_span_pairs.append((span_0, span_1)) - - def break_up_by_double_braces(self) -> None: - # Match paired double braces (`{{...}}`). - tex_string = self.tex_string - reversed_indices_dict = dict( - item[::-1] for item in self.brace_indices_dict.items() - ) - skip = False - for prev_right_index, right_index in _get_neighbouring_pairs( - list(reversed_indices_dict.keys()) - ): - if skip: - skip = False - continue - if right_index != prev_right_index + 1: - continue - left_index = reversed_indices_dict[right_index] - prev_left_index = reversed_indices_dict[prev_right_index] - if left_index != prev_left_index - 1: - continue - tex_span = (left_index, right_index + 1) - self.add_tex_span(tex_span) - self.specified_substrings.append(tex_string[slice(*tex_span)]) - skip = True - - def break_up_by_additional_substrings( - self, - additional_substrings: list[str] - ) -> None: - stripped_substrings = sorted(remove_list_redundancies([ - string.strip() - for string in additional_substrings - ])) - if "" in stripped_substrings: - stripped_substrings.remove("") - - tex_string = self.tex_string - all_tex_spans = [] - for string in stripped_substrings: - match_objs = list(re.finditer(re.escape(string), tex_string)) - if not match_objs: - continue - self.specified_substrings.append(string) - for match_obj in match_objs: - all_tex_spans.append(match_obj.span()) - - former_script_spans_dict = dict([ - script_span_pair[0][::-1] - for script_span_pair in self.neighbouring_script_span_pairs - ]) - for span_begin, span_end in all_tex_spans: - # Deconstruct spans containing one out of two scripts. - if span_end in former_script_spans_dict.keys(): - span_end = former_script_spans_dict[span_end] - if span_begin >= span_end: - continue - self.add_tex_span((span_begin, span_end)) - - def get_containing_labels_dict(self) -> dict[tuple[int, int], list[int]]: - tex_span_list = self.tex_span_list - result = { - tex_span: [] - for tex_span in tex_span_list - } - overlapping_tex_span_pairs = [] - for index_0, span_0 in enumerate(tex_span_list): - for index_1, span_1 in enumerate(tex_span_list[index_0:]): - if span_0[1] <= span_1[0]: - continue - if span_0[1] < span_1[1]: - overlapping_tex_span_pairs.append((span_0, span_1)) - result[span_0].append(index_0 + index_1) - if overlapping_tex_span_pairs: - tex_string = self.tex_string - log.error("Partially overlapping substrings detected:") - for tex_span_pair in overlapping_tex_span_pairs: - log.error(", ".join( - f"\"{tex_string[slice(*tex_span)]}\"" - for tex_span in tex_span_pair - )) - raise ValueError - return result - - def get_labelled_tex_string(self) -> str: - indices, _, flags, labels = zip(*sorted([ - (*tex_span[::(1, -1)[flag]], flag, label) - for label, tex_span in enumerate(self.tex_span_list) - for flag in range(2) - ], key=lambda t: (t[0], -t[2], -t[1]))) - command_pieces = [ - ("{{" + self.get_color_command(label), "}}")[flag] - for flag, label in zip(flags, labels) - ][1:-1] - command_pieces.insert(0, "") - string_pieces = [ - self.tex_string[slice(*tex_span)] - for tex_span in _get_neighbouring_pairs(indices) - ] - return "".join(it.chain(*zip(command_pieces, string_pieces))) - - @staticmethod - def get_color_command(label: int) -> str: - rg, b = divmod(label, 256) - r, g = divmod(rg, 256) - return "".join([ - "\\color[RGB]", - "{", - ",".join(map(str, (r, g, b))), - "}" - ]) - - def get_sorted_submob_indices(self, submob_labels: list[int]) -> list[int]: - def script_span_to_submob_range(script_span): - tex_span = self.script_span_to_tex_span_dict[script_span] - submob_indices = [ - index for index, label in enumerate(submob_labels) - if label in self.containing_labels_dict[tex_span] - ] - return range(submob_indices[0], submob_indices[-1] + 1) - - filtered_script_span_pairs = filter( - lambda script_span_pair: all([ - self.script_span_to_char_dict[script_span] == character - for script_span, character in zip(script_span_pair, "_^") - ]), - self.neighbouring_script_span_pairs - ) - switch_range_pairs = sorted([ - tuple([ - script_span_to_submob_range(script_span) - for script_span in script_span_pair - ]) - for script_span_pair in filtered_script_span_pairs - ], key=lambda t: (t[0].stop, -t[0].start)) - result = list(range(len(submob_labels))) - for range_0, range_1 in switch_range_pairs: - result = [ - *result[:range_1.start], - *result[range_0.start:range_0.stop], - *result[range_1.stop:range_0.start], - *result[range_1.start:range_1.stop], - *result[range_0.stop:] - ] - return result - - def get_submob_tex_strings(self, submob_labels: list[int]) -> list[str]: - ordered_tex_spans = [ - self.tex_span_list[label] for label in submob_labels - ] - ordered_containing_labels = [ - self.containing_labels_dict[tex_span] - for tex_span in ordered_tex_spans - ] - ordered_span_begins, ordered_span_ends = zip(*ordered_tex_spans) - string_span_begins = [ - prev_end if prev_label in containing_labels else curr_begin - for prev_end, prev_label, containing_labels, curr_begin in zip( - ordered_span_ends[:-1], submob_labels[:-1], - ordered_containing_labels[1:], ordered_span_begins[1:] - ) - ] - string_span_begins.insert(0, ordered_span_begins[0]) - string_span_ends = [ - next_begin if next_label in containing_labels else curr_end - for next_begin, next_label, containing_labels, curr_end in zip( - ordered_span_begins[1:], submob_labels[1:], - ordered_containing_labels[:-1], ordered_span_ends[:-1] - ) - ] - string_span_ends.append(ordered_span_ends[-1]) - - tex_string = self.tex_string - left_brace_indices = sorted(self.brace_indices_dict.keys()) - right_brace_indices = sorted(self.brace_indices_dict.values()) - ignored_indices = sorted(it.chain( - self.whitespace_indices, - left_brace_indices, - right_brace_indices, - self.script_indices - )) - result = [] - for span_begin, span_end in zip(string_span_begins, string_span_ends): - while span_begin in ignored_indices: - span_begin += 1 - if span_begin >= span_end: - result.append("") - continue - while span_end - 1 in ignored_indices: - span_end -= 1 - unclosed_left_brace = 0 - unclosed_right_brace = 0 - for index in range(span_begin, span_end): - if index in left_brace_indices: - unclosed_left_brace += 1 - elif index in right_brace_indices: - if unclosed_left_brace == 0: - unclosed_right_brace += 1 - else: - unclosed_left_brace -= 1 - result.append("".join([ - unclosed_right_brace * "{", - tex_string[span_begin:span_end], - unclosed_left_brace * "}" - ])) - return result - - def find_span_components_of_custom_span( - self, - custom_span: tuple[int, int] - ) -> list[tuple[int, int]] | None: - skipped_indices = sorted(it.chain( - self.whitespace_indices, - self.script_indices - )) - tex_span_choices = sorted(filter( - lambda tex_span: all([ - tex_span[0] >= custom_span[0], - tex_span[1] <= custom_span[1] - ]), - self.tex_span_list - )) - # Choose spans that reach the farthest. - tex_span_choices_dict = dict(tex_span_choices) - - span_begin, span_end = custom_span - result = [] - while span_begin != span_end: - if span_begin not in tex_span_choices_dict.keys(): - if span_begin in skipped_indices: - span_begin += 1 - continue - return None - next_begin = tex_span_choices_dict[span_begin] - result.append((span_begin, next_begin)) - span_begin = next_begin - return result - - def get_containing_labels_by_tex_spans( - self, - tex_spans: list[tuple[int, int]] - ) -> list[int]: - return remove_list_redundancies(list(it.chain(*[ - self.containing_labels_dict[tex_span] - for tex_span in tex_spans - ]))) - - def get_specified_substrings(self) -> list[str]: - return self.specified_substrings - - def get_isolated_substrings(self) -> list[str]: - return remove_list_redundancies([ - self.tex_string[slice(*tex_span)] - for tex_span in self.tex_span_list - ]) - - -class _TexSVG(SVGMobject): +class MTex(LabelledString): CONFIG = { - "height": None, - "fill_opacity": 1.0, - "stroke_width": 0, - "path_string_config": { - "should_subdivide_sharp_curves": True, - "should_remove_null_curves": True, - }, - } - - -class MTex(_TexSVG): - CONFIG = { - "color": WHITE, "font_size": 48, "alignment": "\\centering", "tex_environment": "align*", - "isolate": [], "tex_to_color_map": {}, - "use_plain_tex": False, } def __init__(self, tex_string: str, **kwargs): - digest_config(self, kwargs) - tex_string = tex_string.strip() # Prevent from passing an empty string. if not tex_string: - tex_string = "\\quad" + tex_string = "\\\\" self.tex_string = tex_string - self.parser = _TexParser( - self.tex_string, - [*self.tex_to_color_map.keys(), *self.isolate] - ) - super().__init__(**kwargs) + super().__init__(tex_string, **kwargs) self.set_color_by_tex_to_color_map(self.tex_to_color_map) self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) @@ -452,183 +46,288 @@ class MTex(_TexSVG): self.__class__.__name__, self.svg_default, self.path_string_config, + self.base_color, + self.use_plain_file, + self.isolate, self.tex_string, - self.parser.specified_substrings, self.alignment, self.tex_environment, - self.use_plain_tex + self.tex_to_color_map ) - def get_file_path(self) -> str: - return self.get_file_path_(use_plain_tex=self.use_plain_tex) - - def get_file_path_(self, use_plain_tex: bool) -> str: - if use_plain_tex: - tex_string = self.tex_string - else: - tex_string = self.parser.get_labelled_tex_string() - - full_tex = self.get_tex_file_body(tex_string) + def get_file_path_by_content(self, content: str) -> str: + tex_config = get_tex_config() + full_tex = tex_config["tex_body"].replace( + tex_config["text_to_replace"], + content + ) with display_during_execution(f"Writing \"{self.tex_string}\""): - file_path = self.tex_to_svg_file_path(full_tex) + file_path = tex_to_svg_file(full_tex) return file_path - def get_tex_file_body(self, tex_string: str) -> str: + def pre_parse(self) -> None: + super().pre_parse() + self.backslash_indices = self.get_backslash_indices() + self.brace_index_pairs = self.get_brace_index_pairs() + self.script_char_spans = self.get_script_char_spans() + self.script_content_spans = self.get_script_content_spans() + self.script_spans = self.get_script_spans() + + # Toolkits + + @staticmethod + def get_color_command_str(rgb_int: int) -> str: + rgb_tuple = MTex.int_to_rgb(rgb_int) + return "".join([ + "\\color[RGB]", + "{", + ",".join(map(str, rgb_tuple)), + "}" + ]) + + # Pre-parsing + + def get_backslash_indices(self) -> list[int]: + # The latter of `\\` doesn't count. + return list(it.chain(*[ + range(span[0], span[1], 2) + for span in self.find_spans(r"\\+") + ])) + + def get_unescaped_char_spans(self, chars: str): + return sorted(filter( + lambda span: span[0] - 1 not in self.backslash_indices, + self.find_substrs(list(chars)) + )) + + def get_brace_index_pairs(self) -> list[Span]: + left_brace_indices = [] + right_brace_indices = [] + left_brace_indices_stack = [] + for span in self.get_unescaped_char_spans("{}"): + index = span[0] + if self.get_substr(span) == "{": + left_brace_indices_stack.append(index) + else: + if not left_brace_indices_stack: + raise ValueError("Missing '{' inserted") + left_brace_index = left_brace_indices_stack.pop() + left_brace_indices.append(left_brace_index) + right_brace_indices.append(index) + if left_brace_indices_stack: + raise ValueError("Missing '}' inserted") + return list(zip(left_brace_indices, right_brace_indices)) + + def get_script_char_spans(self) -> list[int]: + return self.get_unescaped_char_spans("_^") + + def get_script_content_spans(self) -> list[Span]: + result = [] + brace_indices_dict = dict(self.brace_index_pairs) + script_pattern = r"[a-zA-Z0-9]|\\[a-zA-Z]+" + for script_char_span in self.script_char_spans: + span_begin = self.match(r"\s*", pos=script_char_span[1]).end() + if span_begin in brace_indices_dict.keys(): + span_end = brace_indices_dict[span_begin] + 1 + else: + match_obj = self.match(script_pattern, pos=span_begin) + if not match_obj: + script_name = { + "_": "subscript", + "^": "superscript" + }[script_char] + raise ValueError( + f"Unclear {script_name} detected while parsing. " + "Please use braces to clarify" + ) + span_end = match_obj.end() + result.append((span_begin, span_end)) + return result + + def get_script_spans(self) -> list[Span]: + return [ + ( + self.search(r"\s*$", endpos=script_char_span[0]).start(), + script_content_span[1] + ) + for script_char_span, script_content_span in zip( + self.script_char_spans, self.script_content_spans + ) + ] + + # Parsing + + def get_command_repl_items(self) -> list[tuple[Span, str]]: + color_related_command_dict = { + "color": (1, False), + "textcolor": (1, False), + "pagecolor": (1, True), + "colorbox": (1, True), + "fcolorbox": (2, True), + } + result = [] + backslash_indices = self.backslash_indices + right_brace_indices = [ + right_index + for left_index, right_index in self.brace_index_pairs + ] + pattern = "".join([ + r"\\", + "(", + "|".join(color_related_command_dict.keys()), + ")", + r"(?![a-zA-Z])" + ]) + for match_obj in self.finditer(pattern): + span_begin, cmd_end = match_obj.span() + if span_begin not in backslash_indices: + continue + cmd_name = match_obj.group(1) + n_braces, substitute_cmd = color_related_command_dict[cmd_name] + span_end = self.take_nearest_value( + right_brace_indices, cmd_end, n_braces + ) + 1 + if substitute_cmd: + repl_str = "\\" + cmd_name + n_braces * "{black}" + else: + repl_str = "" + result.append(((span_begin, span_end), repl_str)) + return result + + def get_extra_entity_spans(self) -> list[Span]: + return [ + self.match(r"\\([a-zA-Z]+|.)", pos=index).span() + for index in self.backslash_indices + ] + + def get_extra_ignored_spans(self) -> list[int]: + return self.script_char_spans.copy() + + def get_internal_specified_spans(self) -> list[Span]: + # Match paired double braces (`{{...}}`). + result = [] + reversed_brace_indices_dict = dict([ + pair[::-1] for pair in self.brace_index_pairs + ]) + skip = False + for prev_right_index, right_index in self.get_neighbouring_pairs( + list(reversed_brace_indices_dict.keys()) + ): + if skip: + skip = False + continue + if right_index != prev_right_index + 1: + continue + left_index = reversed_brace_indices_dict[right_index] + prev_left_index = reversed_brace_indices_dict[prev_right_index] + if left_index != prev_left_index - 1: + continue + result.append((left_index, right_index + 1)) + skip = True + return result + + def get_external_specified_spans(self) -> list[Span]: + return self.find_substrs(list(self.tex_to_color_map.keys())) + + def get_label_span_list(self) -> list[Span]: + result = self.script_content_spans.copy() + for span_begin, span_end in self.specified_spans: + shrinked_end = self.lslide(span_end, self.script_spans) + if span_begin >= shrinked_end: + continue + shrinked_span = (span_begin, shrinked_end) + if shrinked_span in result: + continue + result.append(shrinked_span) + return result + + def get_content(self, use_plain_file: bool) -> str: + if use_plain_file: + span_repl_dict = {} + else: + extended_label_span_list = [ + span + if span in self.script_content_spans + else (span[0], self.rslide(span[1], self.script_spans)) + for span in self.label_span_list + ] + inserted_string_pairs = [ + (span, ( + "{{" + self.get_color_command_str(label + 1), + "}}" + )) + for label, span in enumerate(extended_label_span_list) + ] + span_repl_dict = self.generate_span_repl_dict( + inserted_string_pairs, + self.command_repl_items + ) + result = self.get_replaced_substr(self.full_span, span_repl_dict) + if self.tex_environment: - tex_string = "\n".join([ + result = "\n".join([ f"\\begin{{{self.tex_environment}}}", - tex_string, + result, f"\\end{{{self.tex_environment}}}" ]) if self.alignment: - tex_string = "\n".join([self.alignment, tex_string]) - - tex_config = get_tex_config() - return tex_config["tex_body"].replace( - tex_config["text_to_replace"], - tex_string - ) + result = "\n".join([self.alignment, result]) + if use_plain_file: + result = "\n".join([ + self.get_color_command_str(self.hex_to_int(self.base_color)), + result + ]) + return result - @staticmethod - def tex_to_svg_file_path(tex_file_content: str) -> str: - return tex_to_svg_file(tex_file_content) + @property + def has_predefined_local_colors(self) -> bool: + return bool(self.command_repl_items) - def generate_mobject(self) -> None: - super().generate_mobject() + # Post-parsing - if not self.use_plain_tex: - labelled_svg_glyphs = self - else: - file_path = self.get_file_path_(use_plain_tex=False) - labelled_svg_glyphs = _TexSVG(file_path) + def get_cleaned_substr(self, span: Span) -> str: + substr = super().get_cleaned_substr(span) + if not self.brace_index_pairs: + return substr - glyph_labels = [ - self.color_to_label(labelled_glyph.get_fill_color()) - for labelled_glyph in labelled_svg_glyphs - ] - rearranged_submobs = self.rearrange_submobjects( - self.submobjects, glyph_labels - ) - self.set_submobjects(rearranged_submobs) - - @staticmethod - def color_to_label(color: ManimColor) -> int: - r, g, b = color_to_int_rgb(color) - rg = r * 256 + g - return rg * 256 + b - - def rearrange_submobjects( - self, - svg_glyphs: list[VMobject], - glyph_labels: list[int] - ) -> list[VMobject]: - if not svg_glyphs: - return [] - - # Simply pack together adjacent mobjects with the same label. - submobjects = [] - submob_labels = [] - new_glyphs = [] - current_glyph_label = glyph_labels[0] - for glyph, label in zip(svg_glyphs, glyph_labels): - if label == current_glyph_label: - new_glyphs.append(glyph) - else: - submobject = VGroup(*new_glyphs) - submob_labels.append(current_glyph_label) - submobjects.append(submobject) - new_glyphs = [glyph] - current_glyph_label = label - submobject = VGroup(*new_glyphs) - submob_labels.append(current_glyph_label) - submobjects.append(submobject) - - indices = self.parser.get_sorted_submob_indices(submob_labels) - rearranged_submobjects = [submobjects[index] for index in indices] - rearranged_labels = [submob_labels[index] for index in indices] - - submob_tex_strings = self.parser.get_submob_tex_strings( - rearranged_labels - ) - for submob, label, submob_tex in zip( - rearranged_submobjects, rearranged_labels, submob_tex_strings - ): - submob.submob_label = label - submob.tex_string = submob_tex - # Support `get_tex()` method here. - submob.get_tex = MethodType(lambda inst: inst.tex_string, submob) - return rearranged_submobjects - - def get_part_by_tex_spans( - self, - tex_spans: list[tuple[int, int]] - ) -> VGroup: - labels = self.parser.get_containing_labels_by_tex_spans(tex_spans) - return VGroup(*filter( - lambda submob: submob.submob_label in labels, - self.submobjects - )) - - def get_part_by_custom_span(self, custom_span: tuple[int, int]) -> VGroup: - tex_spans = self.parser.find_span_components_of_custom_span( - custom_span - ) - if tex_spans is None: - tex = self.tex_string[slice(*custom_span)] - raise ValueError(f"Failed to match mobjects from tex: \"{tex}\"") - return self.get_part_by_tex_spans(tex_spans) - - def get_parts_by_tex(self, tex: str) -> VGroup: - return VGroup(*[ - self.get_part_by_custom_span(match_obj.span()) - for match_obj in re.finditer( - re.escape(tex.strip()), self.tex_string - ) + # Balance braces. + left_brace_indices, right_brace_indices = zip(*self.brace_index_pairs) + unclosed_left_braces = 0 + unclosed_right_braces = 0 + for index in range(*span): + if index in left_brace_indices: + unclosed_left_braces += 1 + elif index in right_brace_indices: + if unclosed_left_braces == 0: + unclosed_right_braces += 1 + else: + unclosed_left_braces -= 1 + return "".join([ + unclosed_right_braces * "{", + substr, + unclosed_left_braces * "}" ]) - def get_part_by_tex(self, tex: str, index: int = 0) -> VMobject: - all_parts = self.get_parts_by_tex(tex) - return all_parts[index] + # Method alias - def set_color_by_tex(self, tex: str, color: ManimColor): - self.get_parts_by_tex(tex).set_color(color) - return self + def get_parts_by_tex(self, tex: str, **kwargs) -> VGroup: + return self.get_parts_by_string(tex, **kwargs) + + def get_part_by_tex(self, tex: str, **kwargs) -> VMobject: + return self.get_part_by_string(tex, **kwargs) + + def set_color_by_tex(self, tex: str, color: ManimColor, **kwargs): + return self.set_color_by_string(tex, color, **kwargs) def set_color_by_tex_to_color_map( - self, - tex_to_color_map: dict[str, ManimColor] + self, tex_to_color_map: dict[str, ManimColor], **kwargs ): - for tex, color in tex_to_color_map.items(): - self.set_color_by_tex(tex, color) - return self - - def indices_of_part(self, part: Iterable[VMobject]) -> list[int]: - indices = [ - index for index, submob in enumerate(self.submobjects) - if submob in part - ] - if not indices: - raise ValueError("Failed to find part in tex") - return indices - - def indices_of_part_by_tex(self, tex: str, index: int = 0) -> list[int]: - part = self.get_part_by_tex(tex, index=index) - return self.indices_of_part(part) + return self.set_color_by_string_to_color_map( + tex_to_color_map, **kwargs + ) def get_tex(self) -> str: - return self.tex_string - - def get_submob_tex(self) -> list[str]: - return [ - submob.get_tex() - for submob in self.submobjects - ] - - def get_specified_substrings(self) -> list[str]: - return self.parser.get_specified_substrings() - - def get_isolated_substrings(self) -> list[str]: - return self.parser.get_isolated_substrings() + return self.get_string() class MTexText(MTex): diff --git a/manimlib/mobject/svg/svg_mobject.py b/manimlib/mobject/svg/svg_mobject.py index c057c1b2..b44c107f 100644 --- a/manimlib/mobject/svg/svg_mobject.py +++ b/manimlib/mobject/svg/svg_mobject.py @@ -173,6 +173,8 @@ class SVGMobject(VMobject): else: log.warning(f"Unsupported element type: {type(shape)}") continue + if not mob.has_points(): + continue self.apply_style_to_mobject(mob, shape) if isinstance(shape, se.Transformable) and shape.apply: self.handle_transform(mob, shape.transform) diff --git a/manimlib/mobject/svg/tex_mobject.py b/manimlib/mobject/svg/tex_mobject.py index 717f1c24..619f5bc9 100644 --- a/manimlib/mobject/svg/tex_mobject.py +++ b/manimlib/mobject/svg/tex_mobject.py @@ -100,6 +100,18 @@ class SingleStringTex(SVGMobject): filler = "{\\quad}" tex += filler + should_add_double_filler = reduce(op.or_, [ + tex == "\\overset", + # TODO: these can't be used since they change + # the latex draw order. + # tex == "\\frac", # you can use \\over as a alternative + # tex == "\\dfrac", + # tex == "\\binom", + ]) + if should_add_double_filler: + filler = "{\\quad}{\\quad}" + tex += filler + if tex == "\\substack": tex = "\\quad" diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index a13d1d80..c3c3be19 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -2,11 +2,10 @@ from __future__ import annotations import os import re -import typing +import itertools as it from pathlib import Path - -import xml.sax.saxutils as saxutils from contextlib import contextmanager +import typing from typing import Iterable, Sequence, Union import pygments @@ -17,198 +16,88 @@ from manimpango import MarkupUtils from manimlib.logger import log from manimlib.constants import * -from manimlib.mobject.geometry import Dot -from manimlib.mobject.svg.svg_mobject import SVGMobject -from manimlib.mobject.types.vectorized_mobject import VGroup +from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.utils.customization import get_customization from manimlib.utils.tex_file_writing import tex_hash from manimlib.utils.config_ops import digest_config from manimlib.utils.directories import get_downloads_dir from manimlib.utils.directories import get_text_dir +from manimlib.utils.iterables import remove_list_redundancies from typing import TYPE_CHECKING if TYPE_CHECKING: from manimlib.mobject.types.vectorized_mobject import VMobject + from manimlib.mobject.types.vectorized_mobject import VGroup + ManimColor = Union[str, colour.Color, Sequence[float]] + Span = tuple[int, int] + TEXT_MOB_SCALE_FACTOR = 0.0076 DEFAULT_LINE_SPACING_SCALE = 0.6 -class _TextParser(object): - # See https://docs.gtk.org/Pango/pango_markup.html - # A tag containing two aliases will cause warning, - # so only use the first key of each group of aliases. - SPAN_ATTR_KEY_ALIAS_LIST = ( - ("font", "font_desc"), - ("font_family", "face"), - ("font_size", "size"), - ("font_style", "style"), - ("font_weight", "weight"), - ("font_variant", "variant"), - ("font_stretch", "stretch"), - ("font_features",), - ("foreground", "fgcolor", "color"), - ("background", "bgcolor"), - ("alpha", "fgalpha"), - ("background_alpha", "bgalpha"), - ("underline",), - ("underline_color",), - ("overline",), - ("overline_color",), - ("rise",), - ("baseline_shift",), - ("font_scale",), - ("strikethrough",), - ("strikethrough_color",), - ("fallback",), - ("lang",), - ("letter_spacing",), - ("gravity",), - ("gravity_hint",), - ("show",), - ("insert_hyphens",), - ("allow_breaks",), - ("line_height",), - ("text_transform",), - ("segment",), - ) - SPAN_ATTR_KEY_CONVERSION = { - key: key_alias_list[0] - for key_alias_list in SPAN_ATTR_KEY_ALIAS_LIST - for key in key_alias_list - } - - TAG_TO_ATTR_DICT = { - "b": {"font_weight": "bold"}, - "big": {"font_size": "larger"}, - "i": {"font_style": "italic"}, - "s": {"strikethrough": "true"}, - "sub": {"baseline_shift": "subscript", "font_scale": "subscript"}, - "sup": {"baseline_shift": "superscript", "font_scale": "superscript"}, - "small": {"font_size": "smaller"}, - "tt": {"font_family": "monospace"}, - "u": {"underline": "single"}, - } - - def __init__(self, text: str = "", is_markup: bool = True): - self.text = text - self.is_markup = is_markup - self.global_attrs = {} - self.local_attrs = {(0, len(self.text)): {}} - self.tag_strings = set() - if is_markup: - self.parse_markup() - - def parse_markup(self) -> None: - tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*('[^']*'|"[^"]*")\s*)*)>""" - attr_pattern = r"""(\w+)\s*\=\s*(?:(?:'([^']*)')|(?:"([^"]*)"))""" - start_match_obj_stack = [] - match_obj_pairs = [] - for match_obj in re.finditer(tag_pattern, self.text): - if not match_obj.group(1): - start_match_obj_stack.append(match_obj) - else: - match_obj_pairs.append((start_match_obj_stack.pop(), match_obj)) - self.tag_strings.add(match_obj.group()) - assert not start_match_obj_stack, "Unclosed tag(s) detected" - - for start_match_obj, end_match_obj in match_obj_pairs: - tag_name = start_match_obj.group(2) - assert tag_name == end_match_obj.group(2), "Unmatched tag names" - assert not end_match_obj.group(3), "Attributes shan't exist in ending tags" - if tag_name == "span": - attr_dict = { - match.group(1): match.group(2) or match.group(3) - for match in re.finditer(attr_pattern, start_match_obj.group(3)) - } - elif tag_name in _TextParser.TAG_TO_ATTR_DICT.keys(): - assert not start_match_obj.group(3), f"Attributes shan't exist in tag '{tag_name}'" - attr_dict = _TextParser.TAG_TO_ATTR_DICT[tag_name] - else: - raise AssertionError(f"Unknown tag: '{tag_name}'") - - text_span = (start_match_obj.end(), end_match_obj.start()) - self.update_local_attrs(text_span, attr_dict) - - @staticmethod - def convert_key_alias(key: str) -> str: - return _TextParser.SPAN_ATTR_KEY_CONVERSION[key] - - @staticmethod - def update_attr_dict(attr_dict: dict[str, str], key: str, value: typing.Any) -> None: - converted_key = _TextParser.convert_key_alias(key) - attr_dict[converted_key] = str(value) - - def update_global_attr(self, key: str, value: typing.Any) -> None: - _TextParser.update_attr_dict(self.global_attrs, key, value) - - def update_global_attrs(self, attr_dict: dict[str, typing.Any]) -> None: - for key, value in attr_dict.items(): - self.update_global_attr(key, value) - - def update_local_attr(self, span: tuple[int, int], key: str, value: typing.Any) -> None: - if span[0] >= span[1]: - log.warning(f"Span {span} doesn't match any part of the string") - return - - if span in self.local_attrs.keys(): - _TextParser.update_attr_dict(self.local_attrs[span], key, value) - return - - span_triplets = [] - for sp, attr_dict in self.local_attrs.items(): - if sp[1] <= span[0] or span[1] <= sp[0]: - continue - span_to_become = (max(sp[0], span[0]), min(sp[1], span[1])) - spans_to_add = [] - if sp[0] < span[0]: - spans_to_add.append((sp[0], span[0])) - if span[1] < sp[1]: - spans_to_add.append((span[1], sp[1])) - span_triplets.append((sp, span_to_become, spans_to_add)) - for span_to_remove, span_to_become, spans_to_add in span_triplets: - attr_dict = self.local_attrs.pop(span_to_remove) - for span_to_add in spans_to_add: - self.local_attrs[span_to_add] = attr_dict.copy() - self.local_attrs[span_to_become] = attr_dict - _TextParser.update_attr_dict(self.local_attrs[span_to_become], key, value) - - def update_local_attrs(self, text_span: tuple[int, int], attr_dict: dict[str, typing.Any]) -> None: - for key, value in attr_dict.items(): - self.update_local_attr(text_span, key, value) - - def remove_tags(self, string: str) -> str: - for tag_string in self.tag_strings: - string = string.replace(tag_string, "") - return string - - def get_text_pieces(self) -> list[tuple[str, dict[str, str]]]: - result = [] - for span in sorted(self.local_attrs.keys()): - text_piece = self.remove_tags(self.text[slice(*span)]) - if not text_piece: - continue - if not self.is_markup: - text_piece = saxutils.escape(text_piece) - attr_dict = self.global_attrs.copy() - attr_dict.update(self.local_attrs[span]) - result.append((text_piece, attr_dict)) - return result - - def get_markup_str_with_attrs(self) -> str: - return "".join([ - f"{text_piece}" - for text_piece, attr_dict in self.get_text_pieces() - ]) - - @staticmethod - def get_attr_dict_str(attr_dict: dict[str, str]) -> str: - return " ".join([ - f"{key}='{value}'" - for key, value in attr_dict.items() - ]) +# See https://docs.gtk.org/Pango/pango_markup.html +# A tag containing two aliases will cause warning, +# so only use the first key of each group of aliases. +SPAN_ATTR_KEY_ALIAS_LIST = ( + ("font", "font_desc"), + ("font_family", "face"), + ("font_size", "size"), + ("font_style", "style"), + ("font_weight", "weight"), + ("font_variant", "variant"), + ("font_stretch", "stretch"), + ("font_features",), + ("foreground", "fgcolor", "color"), + ("background", "bgcolor"), + ("alpha", "fgalpha"), + ("background_alpha", "bgalpha"), + ("underline",), + ("underline_color",), + ("overline",), + ("overline_color",), + ("rise",), + ("baseline_shift",), + ("font_scale",), + ("strikethrough",), + ("strikethrough_color",), + ("fallback",), + ("lang",), + ("letter_spacing",), + ("gravity",), + ("gravity_hint",), + ("show",), + ("insert_hyphens",), + ("allow_breaks",), + ("line_height",), + ("text_transform",), + ("segment",), +) +COLOR_RELATED_KEYS = ( + "foreground", + "background", + "underline_color", + "overline_color", + "strikethrough_color" +) +SPAN_ATTR_KEY_CONVERSION = { + key: key_alias_list[0] + for key_alias_list in SPAN_ATTR_KEY_ALIAS_LIST + for key in key_alias_list +} +TAG_TO_ATTR_DICT = { + "b": {"font_weight": "bold"}, + "big": {"font_size": "larger"}, + "i": {"font_style": "italic"}, + "s": {"strikethrough": "true"}, + "sub": {"baseline_shift": "subscript", "font_scale": "subscript"}, + "sup": {"baseline_shift": "superscript", "font_scale": "superscript"}, + "small": {"font_size": "smaller"}, + "tt": {"font_family": "monospace"}, + "u": {"underline": "single"}, +} # Temporary handler @@ -223,16 +112,9 @@ class _Alignment: self.value = _Alignment.VAL_DICT[s.upper()] -class Text(SVGMobject): +class MarkupText(LabelledString): CONFIG = { - # Mobject - "stroke_width": 0, - "svg_default": { - "color": WHITE, - }, - "height": None, - # Text - "is_markup": False, + "is_markup": True, "font_size": 48, "lsh": None, "justify": False, @@ -240,8 +122,6 @@ class Text(SVGMobject): "alignment": "LEFT", "line_width_factor": None, "font": "", - "disable_ligatures": True, - "apply_space_chars": True, "slant": NORMAL, "weight": NORMAL, "gradient": None, @@ -257,13 +137,22 @@ class Text(SVGMobject): def __init__(self, text: str, **kwargs): self.full2short(kwargs) digest_config(self, kwargs) - validate_error = MarkupUtils.validate(text) - if validate_error: - raise ValueError(validate_error) - self.text = text - self.parser = _TextParser(text, is_markup=self.is_markup) - super().__init__(**kwargs) + if not self.font: + self.font = get_customization()["style"]["font"] + if self.is_markup: + validate_error = MarkupUtils.validate(text) + if validate_error: + raise ValueError(validate_error) + + self.text = text + super().__init__(text, **kwargs) + + if self.t2g: + log.warning( + "Manim currently cannot parse gradient from svg. " + "Please set gradient via `set_color_by_gradient`.", + ) if self.gradient: self.set_color_by_gradient(*self.gradient) if self.height is None: @@ -275,6 +164,9 @@ class Text(SVGMobject): self.__class__.__name__, self.svg_default, self.path_string_config, + self.base_color, + self.use_plain_file, + self.isolate, self.text, self.is_markup, self.font_size, @@ -284,8 +176,6 @@ class Text(SVGMobject): self.alignment, self.line_width_factor, self.font, - self.disable_ligatures, - self.apply_space_chars, self.slant, self.weight, self.t2c, @@ -296,68 +186,28 @@ class Text(SVGMobject): self.local_configs ) - def get_file_path(self) -> str: - full_markup = self.get_full_markup_str() + def full2short(self, config: dict) -> None: + conversion_dict = { + "line_spacing_height": "lsh", + "text2color": "t2c", + "text2font": "t2f", + "text2gradient": "t2g", + "text2slant": "t2s", + "text2weight": "t2w" + } + for kwargs in [config, self.CONFIG]: + for long_name, short_name in conversion_dict.items(): + if long_name in kwargs: + kwargs[short_name] = kwargs.pop(long_name) + + def get_file_path_by_content(self, content: str) -> str: svg_file = os.path.join( - get_text_dir(), tex_hash(full_markup) + ".svg" + get_text_dir(), tex_hash(content) + ".svg" ) if not os.path.exists(svg_file): - self.markup_to_svg(full_markup, svg_file) + self.markup_to_svg(content, svg_file) return svg_file - def get_full_markup_str(self) -> str: - if self.t2g: - log.warning( - "Manim currently cannot parse gradient from svg. " - "Please set gradient via `set_color_by_gradient`.", - ) - - config_style_dict = self.generate_config_style_dict() - global_attr_dict = { - "line_height": ((self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1) * 0.6, - "font_family": self.font or get_customization()["style"]["font"], - "font_size": self.font_size * 1024, - "font_style": self.slant, - "font_weight": self.weight, - # TODO, it seems this doesn't work - "font_features": "liga=0,dlig=0,clig=0,hlig=0" if self.disable_ligatures else None, - "foreground": config_style_dict.get("fill", None), - "alpha": config_style_dict.get("fill-opacity", None) - } - global_attr_dict = { - k: v - for k, v in global_attr_dict.items() - if v is not None - } - global_attr_dict.update(self.global_config) - self.parser.update_global_attrs(global_attr_dict) - - local_attr_items = [ - (word_or_text_span, {key: value}) - for t2x_dict, key in ( - (self.t2c, "foreground"), - (self.t2f, "font_family"), - (self.t2s, "font_style"), - (self.t2w, "font_weight") - ) - for word_or_text_span, value in t2x_dict.items() - ] - local_attr_items.extend(self.local_configs.items()) - for word_or_text_span, local_config in local_attr_items: - for text_span in self.find_indexes(word_or_text_span): - self.parser.update_local_attrs(text_span, local_config) - - return self.parser.get_markup_str_with_attrs() - - def find_indexes(self, word_or_text_span: str | tuple[int, int]) -> list[tuple[int, int]]: - if isinstance(word_or_text_span, tuple): - return [word_or_text_span] - - return [ - match_obj.span() - for match_obj in re.finditer(re.escape(word_or_text_span), self.text) - ] - def markup_to_svg(self, markup_str: str, file_name: str) -> str: # `manimpango` is under construction, # so the following code is intended to suit its interface @@ -374,7 +224,7 @@ class Text(SVGMobject): weight="NORMAL", # Already handled size=1, # Already handled _=0, # Empty parameter - disable_liga=False, # Already handled + disable_liga=False, file_name=file_name, START_X=0, START_Y=0, @@ -387,63 +237,302 @@ class Text(SVGMobject): pango_width=pango_width ) - def generate_mobject(self) -> None: - super().generate_mobject() + def pre_parse(self) -> None: + super().pre_parse() + self.tag_items_from_markup = self.get_tag_items_from_markup() + self.global_dict_from_config = self.get_global_dict_from_config() + self.local_dicts_from_markup = self.get_local_dicts_from_markup() + self.local_dicts_from_config = self.get_local_dicts_from_config() + self.predefined_attr_dicts = self.get_predefined_attr_dicts() - # Remove empty paths - submobjects = list(filter(lambda submob: submob.has_points(), self)) + # Toolkits - # Apply space characters - if self.apply_space_chars: - content_str = self.parser.remove_tags(self.text) - if self.is_markup: - content_str = saxutils.unescape(content_str) - for match_obj in re.finditer(r"\s", content_str): - char_index = match_obj.start() - space = Dot(radius=0, fill_opacity=0, stroke_opacity=0) - space.move_to(submobjects[max(char_index - 1, 0)].get_center()) - submobjects.insert(char_index, space) - self.set_submobjects(submobjects) + @staticmethod + def get_attr_dict_str(attr_dict: dict[str, str]) -> str: + return " ".join([ + f"{key}='{val}'" + for key, val in attr_dict.items() + ]) - def full2short(self, config: dict) -> None: - conversion_dict = { - "line_spacing_height": "lsh", - "text2color": "t2c", - "text2font": "t2f", - "text2gradient": "t2g", - "text2slant": "t2s", - "text2weight": "t2w" - } - for kwargs in [config, self.CONFIG]: - for long_name, short_name in conversion_dict.items(): - if long_name in kwargs: - kwargs[short_name] = kwargs.pop(long_name) - - def get_parts_by_text(self, word: str) -> VGroup: - if self.is_markup: - log.warning( - "Slicing MarkupText via `get_parts_by_text`, " - "the result could be unexpected." - ) - elif not self.apply_space_chars: - log.warning( - "Slicing Text via `get_parts_by_text` without applying spaces, " - "the result could be unexpected." - ) - return VGroup(*( - self[i:j] - for i, j in self.find_indexes(word) + @staticmethod + def merge_attr_dicts( + attr_dict_items: list[Span, str, typing.Any] + ) -> list[tuple[Span, dict[str, str]]]: + index_seq = [0] + attr_dict_list = [{}] + for span, attr_dict in attr_dict_items: + if span[0] >= span[1]: + continue + region_indices = [ + MarkupText.find_region_index(index_seq, index) + for index in span + ] + for flag in (1, 0): + if index_seq[region_indices[flag]] == span[flag]: + continue + region_index = region_indices[flag] + index_seq.insert(region_index + 1, span[flag]) + attr_dict_list.insert( + region_index + 1, attr_dict_list[region_index].copy() + ) + region_indices[flag] += 1 + if flag == 0: + region_indices[1] += 1 + for key, val in attr_dict.items(): + if not key: + continue + for mid_dict in attr_dict_list[slice(*region_indices)]: + mid_dict[key] = val + return list(zip( + MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1] )) - def get_part_by_text(self, word: str) -> VMobject | None: - parts = self.get_parts_by_text(word) - return parts[0] if parts else None + def find_substr_or_span( + self, substr_or_span: str | tuple[int | None, int | None] + ) -> list[Span]: + if isinstance(substr_or_span, str): + return self.find_substr(substr_or_span) + + span = tuple([ + ( + min(index, self.string_len) + if index >= 0 + else max(index + self.string_len, 0) + ) + if index is not None else default_index + for index, default_index in zip(substr_or_span, self.full_span) + ]) + if span[0] >= span[1]: + return [] + return [span] + + # Pre-parsing + + def get_tag_items_from_markup( + self + ) -> list[tuple[Span, Span, dict[str, str]]]: + if not self.is_markup: + return [] + + tag_pattern = r"""<(/?)(\w+)\s*((?:\w+\s*\=\s*(['"]).*?\4\s*)*)>""" + attr_pattern = r"""(\w+)\s*\=\s*(['"])(.*?)\2""" + begin_match_obj_stack = [] + match_obj_pairs = [] + for match_obj in self.finditer(tag_pattern): + if not match_obj.group(1): + begin_match_obj_stack.append(match_obj) + else: + match_obj_pairs.append( + (begin_match_obj_stack.pop(), match_obj) + ) + if begin_match_obj_stack: + raise ValueError("Unclosed tag(s) detected") + + result = [] + for begin_match_obj, end_match_obj in match_obj_pairs: + tag_name = begin_match_obj.group(2) + if tag_name != end_match_obj.group(2): + raise ValueError("Unmatched tag names") + if end_match_obj.group(3): + raise ValueError("Attributes shan't exist in ending tags") + if tag_name == "span": + attr_dict = { + match.group(1): match.group(3) + for match in re.finditer( + attr_pattern, begin_match_obj.group(3) + ) + } + elif tag_name in TAG_TO_ATTR_DICT.keys(): + if begin_match_obj.group(3): + raise ValueError( + f"Attributes shan't exist in tag '{tag_name}'" + ) + attr_dict = TAG_TO_ATTR_DICT[tag_name].copy() + else: + raise ValueError(f"Unknown tag: '{tag_name}'") + + result.append( + (begin_match_obj.span(), end_match_obj.span(), attr_dict) + ) + return result + + def get_global_dict_from_config(self) -> dict[str, typing.Any]: + result = { + "line_height": ( + (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 + ) * 0.6, + "font_family": self.font, + "font_size": self.font_size * 1024, + "font_style": self.slant, + "font_weight": self.weight + } + result.update(self.global_config) + return result + + def get_local_dicts_from_markup( + self + ) -> list[Span, dict[str, str]]: + return sorted([ + ((begin_tag_span[0], end_tag_span[1]), attr_dict) + for begin_tag_span, end_tag_span, attr_dict + in self.tag_items_from_markup + ]) + + def get_local_dicts_from_config( + self + ) -> list[Span, dict[str, typing.Any]]: + return [ + (span, {key: val}) + for t2x_dict, key in ( + (self.t2c, "foreground"), + (self.t2f, "font_family"), + (self.t2s, "font_style"), + (self.t2w, "font_weight") + ) + for substr_or_span, val in t2x_dict.items() + for span in self.find_substr_or_span(substr_or_span) + ] + [ + (span, local_config) + for substr_or_span, local_config in self.local_configs.items() + for span in self.find_substr_or_span(substr_or_span) + ] + + def get_predefined_attr_dicts(self) -> list[Span, dict[str, str]]: + attr_dict_items = [ + (self.full_span, self.global_dict_from_config), + *self.local_dicts_from_markup, + *self.local_dicts_from_config + ] + return [ + (span, { + SPAN_ATTR_KEY_CONVERSION[key.lower()]: str(val) + for key, val in attr_dict.items() + }) + for span, attr_dict in attr_dict_items + ] + + # Parsing + + def get_command_repl_items(self) -> list[tuple[Span, str]]: + result = [ + (tag_span, "") + for begin_tag, end_tag, _ in self.tag_items_from_markup + for tag_span in (begin_tag, end_tag) + ] + if not self.is_markup: + result += [ + (span, escaped) + for char, escaped in ( + ("&", "&"), + (">", ">"), + ("<", "<") + ) + for span in self.find_substr(char) + ] + return result + + def get_extra_entity_spans(self) -> list[Span]: + if not self.is_markup: + return [] + return self.find_spans(r"&.*?;") + + def get_extra_ignored_spans(self) -> list[int]: + return [] + + def get_internal_specified_spans(self) -> list[Span]: + return [span for span, _ in self.local_dicts_from_markup] + + def get_external_specified_spans(self) -> list[Span]: + return [span for span, _ in self.local_dicts_from_config] + + def get_label_span_list(self) -> list[Span]: + breakup_indices = remove_list_redundancies(list(it.chain(*it.chain( + self.find_spans(r"\s+"), + self.find_spans(r"\b"), + self.specified_spans + )))) + breakup_indices = sorted(filter( + lambda index: not any([ + span[0] < index < span[1] + for span in self.entity_spans + ]), + breakup_indices + )) + return list(filter( + lambda span: self.get_substr(span).strip(), + self.get_neighbouring_pairs(breakup_indices) + )) + + def get_content(self, use_plain_file: bool) -> str: + if use_plain_file: + attr_dict_items = [ + (self.full_span, {"foreground": self.base_color}), + *self.predefined_attr_dicts, + *[ + (span, {}) + for span in self.label_span_list + ] + ] + else: + attr_dict_items = [ + (self.full_span, {"foreground": BLACK}), + *[ + (span, { + key: BLACK if key in COLOR_RELATED_KEYS else val + for key, val in attr_dict.items() + }) + for span, attr_dict in self.predefined_attr_dicts + ], + *[ + (span, {"foreground": self.int_to_hex(label + 1)}) + for label, span in enumerate(self.label_span_list) + ] + ] + inserted_string_pairs = [ + (span, ( + f"", + "" + )) + for span, attr_dict in self.merge_attr_dicts(attr_dict_items) + ] + span_repl_dict = self.generate_span_repl_dict( + inserted_string_pairs, self.command_repl_items + ) + return self.get_replaced_substr(self.full_span, span_repl_dict) + + @property + def has_predefined_local_colors(self) -> bool: + return any([ + key in COLOR_RELATED_KEYS + for _, attr_dict in self.predefined_attr_dicts + for key in attr_dict.keys() + ]) + + # Method alias + + def get_parts_by_text(self, text: str, **kwargs) -> VGroup: + return self.get_parts_by_string(text, **kwargs) + + def get_part_by_text(self, text: str, **kwargs) -> VMobject: + return self.get_part_by_string(text, **kwargs) + + def set_color_by_text(self, text: str, color: ManimColor, **kwargs): + return self.set_color_by_string(text, color, **kwargs) + + def set_color_by_text_to_color_map( + self, text_to_color_map: dict[str, ManimColor], **kwargs + ): + return self.set_color_by_string_to_color_map( + text_to_color_map, **kwargs + ) + + def get_text(self) -> str: + return self.get_string() -class MarkupText(Text): +class Text(MarkupText): CONFIG = { - "is_markup": True, - "apply_space_chars": False, + "is_markup": False, } @@ -461,7 +550,9 @@ class Code(MarkupText): digest_config(self, kwargs) self.code = code lexer = pygments.lexers.get_lexer_by_name(self.language) - formatter = pygments.formatters.PangoMarkupFormatter(style=self.code_style) + formatter = pygments.formatters.PangoMarkupFormatter( + style=self.code_style + ) markup = pygments.highlight(code, lexer, formatter) markup = re.sub(r"", "", markup) super().__init__(markup, **kwargs) diff --git a/manimlib/mobject/types/image_mobject.py b/manimlib/mobject/types/image_mobject.py index d3f11f2b..54166d36 100644 --- a/manimlib/mobject/types/image_mobject.py +++ b/manimlib/mobject/types/image_mobject.py @@ -48,6 +48,9 @@ class ImageMobject(Mobject): mob.data["opacity"] = np.array([[o] for o in listify(opacity)]) return self + def set_color(self, color, opacity=None, recurse=None): + return self + def point_to_rgb(self, point: np.ndarray) -> np.ndarray: x0, y0 = self.get_corner(UL)[:2] x1, y1 = self.get_corner(DR)[:2] diff --git a/manimlib/utils/bezier.py b/manimlib/utils/bezier.py index 16ec3e10..71d3d2b9 100644 --- a/manimlib/utils/bezier.py +++ b/manimlib/utils/bezier.py @@ -80,15 +80,10 @@ def partial_quadratic_bezier_points( # Linear interpolation variants -def interpolate(start: T, end: T, alpha: float) -> T: + +def interpolate(start: T, end: T, alpha: np.ndarray | float) -> T: try: - if isinstance(alpha, float): - return (1 - alpha) * start + alpha * end - # Otherwise, assume alpha is a list or array, and return - # an appropriated shaped array of all corresponding - # interpolations - result = np.outer(1 - alpha, start) + np.outer(alpha, end) - return result.reshape((*np.shape(alpha), *np.shape(start))) + return (1 - alpha) * start + alpha * end except TypeError: log.debug(f"`start` parameter with type `{type(start)}` and dtype `{start.dtype}`") log.debug(f"`end` parameter with type `{type(end)}` and dtype `{end.dtype}`") @@ -97,6 +92,15 @@ def interpolate(start: T, end: T, alpha: float) -> T: sys.exit(2) +def outer_interpolate( + start: np.ndarray | float, + end: np.ndarray | float, + alpha: np.ndarray | float, +) -> T: + result = np.outer(1 - alpha, start) + np.outer(alpha, end) + return result.reshape((*np.shape(alpha), *np.shape(start))) + + def set_array_by_interpolation( arr: np.ndarray, arr1: np.ndarray, diff --git a/manimlib/utils/init_config.py b/manimlib/utils/init_config.py index cb0a1787..36ae9d4b 100644 --- a/manimlib/utils/init_config.py +++ b/manimlib/utils/init_config.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import yaml import inspect diff --git a/manimlib/utils/iterables.py b/manimlib/utils/iterables.py index 6729d359..99788a42 100644 --- a/manimlib/utils/iterables.py +++ b/manimlib/utils/iterables.py @@ -1,6 +1,5 @@ from __future__ import annotations -import itertools as it from typing import Callable, Iterable, Sequence, TypeVar import numpy as np @@ -36,10 +35,6 @@ def list_difference_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]: return [e for e in l1 if e not in l2] -def all_elements_are_instances(iterable: Iterable, Class: type) -> bool: - return all([isinstance(e, Class) for e in iterable]) - - def adjacent_n_tuples(objects: Iterable[T], n: int) -> zip[tuple[T, T]]: return zip(*[ [*objects[k:], *objects[:k]] @@ -133,30 +128,6 @@ def make_even( ) -def make_even_by_cycling( - iterable_1: Iterable[T], - iterable_2: Iterable[S] -) -> tuple[list[T], list[S]]: - length = max(len(iterable_1), len(iterable_2)) - cycle1 = it.cycle(iterable_1) - cycle2 = it.cycle(iterable_2) - return ( - [next(cycle1) for x in range(length)], - [next(cycle2) for x in range(length)] - ) - - -def remove_nones(sequence: Iterable) -> list: - return [x for x in sequence if x] - - -# Note this is redundant with it.chain - - -def concatenate_lists(*list_of_lists): - return [item for l in list_of_lists for item in l] - - def hash_obj(obj: object) -> int: if isinstance(obj, dict): new_obj = {k: hash_obj(v) for k, v in obj.items()} diff --git a/manimlib/utils/space_ops.py b/manimlib/utils/space_ops.py index a959b8a5..e6ecfb13 100644 --- a/manimlib/utils/space_ops.py +++ b/manimlib/utils/space_ops.py @@ -154,7 +154,7 @@ def angle_between_vectors(v1: np.ndarray, v2: np.ndarray) -> float: n2 = get_norm(v2) if n1 == 0 or n2 == 0: return 0 - cos_angle = np.dot(v1, v2) / (n1 * n2) + cos_angle = np.dot(v1, v2) / np.float64(n1 * n2) return math.acos(clip(cos_angle, -1, 1)) diff --git a/setup.cfg b/setup.cfg index b52ed439..934f051c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = manimgl -version = 1.5.0 +version = 1.6.1 author = Grant Sanderson author_email= grant@3blue1brown.com description = Animation engine for explanatory math videos