From 43821ab2baafe44e5b0de99ed8e4d97cc44e2d20 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Thu, 5 Dec 2024 10:09:15 -0600 Subject: [PATCH] Make caching on disk a decorator, and update implementations for Tex and Text mobjects --- manimlib/mobject/svg/old_tex_mobject.py | 6 +- manimlib/mobject/svg/tex_mobject.py | 7 +- manimlib/mobject/svg/text_mobject.py | 111 +++++++++++++----------- manimlib/utils/cache.py | 32 +++++-- manimlib/utils/customization.py | 2 + manimlib/utils/tex_file_writing.py | 20 ++++- 6 files changed, 104 insertions(+), 74 deletions(-) diff --git a/manimlib/mobject/svg/old_tex_mobject.py b/manimlib/mobject/svg/old_tex_mobject.py index 49bc671d..7adc216e 100644 --- a/manimlib/mobject/svg/old_tex_mobject.py +++ b/manimlib/mobject/svg/old_tex_mobject.py @@ -77,11 +77,7 @@ class SingleStringTex(SVGMobject): ) def get_svg_string_by_content(self, content: str) -> str: - return get_cached_value( - key=hash_string(str((content, self.template, self.additional_preamble))), - value_func=lambda: latex_to_svg(content, self.template, self.additional_preamble), - message=f"Writing {self.tex_string}..." - ) + return latex_to_svg(content, self.template, self.additional_preamble) def get_tex_file_body(self, tex_string: str) -> str: new_tex = self.get_modified_expression(tex_string) diff --git a/manimlib/mobject/svg/tex_mobject.py b/manimlib/mobject/svg/tex_mobject.py index 3ad52642..3e0a460a 100644 --- a/manimlib/mobject/svg/tex_mobject.py +++ b/manimlib/mobject/svg/tex_mobject.py @@ -6,7 +6,6 @@ from pathlib import Path from manimlib.mobject.svg.string_mobject import StringMobject from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject -from manimlib.utils.cache import get_cached_value from manimlib.utils.color import color_to_hex from manimlib.utils.color import hex_to_int from manimlib.utils.tex_file_writing import latex_to_svg @@ -86,11 +85,7 @@ class Tex(StringMobject): ) def get_svg_string_by_content(self, content: str) -> str: - return get_cached_value( - key=hash_string(str((content, self.template, self.additional_preamble))), - value_func=lambda: latex_to_svg(content, self.template, self.additional_preamble), - message=f"Writing {self.tex_string}..." - ) + return latex_to_svg(content, self.template, self.additional_preamble, short_tex=self.tex_string) def _handle_scale_side_effects(self, scale_factor: float) -> Self: self.font_size *= scale_factor diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 5aaeb6a9..fd0e891b 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -16,7 +16,7 @@ from manimlib.constants import DEFAULT_PIXEL_WIDTH, FRAME_WIDTH from manimlib.constants import NORMAL from manimlib.logger import log from manimlib.mobject.svg.string_mobject import StringMobject -from manimlib.utils.cache import get_cached_value +from manimlib.utils.cache import cache_on_disk from manimlib.utils.color import color_to_hex from manimlib.utils.color import int_to_hex from manimlib.utils.customization import get_customization @@ -51,6 +51,57 @@ class _Alignment: self.value = _Alignment.VAL_DICT[s.upper()] +@cache_on_disk +def markup_to_svg_string( + markup_str: str, + justify: bool = False, + indent: float = 0, + alignment: str = "", + line_width: float | None = None, +) -> str: + validate_error = manimpango.MarkupUtils.validate(markup_str) + if validate_error: + raise ValueError( + f"Invalid markup string \"{markup_str}\"\n" + \ + f"{validate_error}" + ) + + # `manimpango` is under construction, + # so the following code is intended to suit its interface + alignment = _Alignment(alignment) + if line_width is None: + pango_width = -1 + else: + pango_width = line_width / FRAME_WIDTH * DEFAULT_PIXEL_WIDTH + + # Write the result to a temporary svg file, and return it's contents. + # TODO, better would be to have this not write to file at all + with tempfile.NamedTemporaryFile(suffix='.svg', mode='r+') as tmp: + manimpango.MarkupUtils.text2svg( + text=markup_str, + font="", # Already handled + slant="NORMAL", # Already handled + weight="NORMAL", # Already handled + size=1, # Already handled + _=0, # Empty parameter + disable_liga=False, + file_name=tmp.name, + START_X=0, + START_Y=0, + width=DEFAULT_CANVAS_WIDTH, + height=DEFAULT_CANVAS_HEIGHT, + justify=justify, + indent=indent, + line_spacing=None, # Already handled + alignment=alignment, + pango_width=pango_width + ) + + # Read the contents + tmp.seek(0) + return tmp.read() + + class MarkupText(StringMobject): # See https://docs.gtk.org/Pango/pango_markup.html MARKUP_TAGS = { @@ -172,59 +223,13 @@ class MarkupText(StringMobject): ) def get_svg_string_by_content(self, content: str) -> str: - key = hash_string(str(( + self.content = content + return markup_to_svg_string( content, - self.justify, - self.indent, - self.alignment, - self.line_width - ))) - return get_cached_value(key, lambda: self.markup_to_svg_string(content)) - - def markup_to_svg_string(self, markup_str: str) -> str: - self.validate_markup_string(markup_str) - - # `manimpango` is under construction, - # so the following code is intended to suit its interface - alignment = _Alignment(self.alignment) - if self.line_width is None: - pango_width = -1 - else: - pango_width = self.line_width / FRAME_WIDTH * DEFAULT_PIXEL_WIDTH - - with tempfile.NamedTemporaryFile(suffix='.svg', mode='r+') as tmp: - manimpango.MarkupUtils.text2svg( - text=markup_str, - font="", # Already handled - slant="NORMAL", # Already handled - weight="NORMAL", # Already handled - size=1, # Already handled - _=0, # Empty parameter - disable_liga=False, - file_name=tmp.name, - START_X=0, - START_Y=0, - width=DEFAULT_CANVAS_WIDTH, - height=DEFAULT_CANVAS_HEIGHT, - justify=self.justify, - indent=self.indent, - line_spacing=None, # Already handled - alignment=alignment, - pango_width=pango_width - ) - - # Read the contents - tmp.seek(0) - return tmp.read() - - @staticmethod - def validate_markup_string(markup_str: str) -> None: - validate_error = manimpango.MarkupUtils.validate(markup_str) - if not validate_error: - return - raise ValueError( - f"Invalid markup string \"{markup_str}\"\n" + \ - f"{validate_error}" + justify=self.justify, + indent=self.indent, + alignment=self.alignment, + line_width=self.line_width ) # Toolkits diff --git a/manimlib/utils/cache.py b/manimlib/utils/cache.py index 2c1ce860..58ff7ca1 100644 --- a/manimlib/utils/cache.py +++ b/manimlib/utils/cache.py @@ -1,22 +1,38 @@ +from __future__ import annotations + import os from diskcache import Cache from contextlib import contextmanager +from functools import wraps from manimlib.utils.directories import get_cache_dir +from manimlib.utils.simple_functions import hash_string + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + T = TypeVar('T') CACHE_SIZE = 1e9 # 1 Gig +_cache = Cache(get_cache_dir(), size_limit=CACHE_SIZE) -def get_cached_value(key, value_func, message=""): - cache = Cache(get_cache_dir(), size_limit=CACHE_SIZE) +def cache_on_disk(func: Callable[..., T]) -> Callable[..., T]: + @wraps(func) + def wrapper(*args, **kwargs): + key = hash_string("".join(map(str, [func.__name__, args, kwargs]))) + value = _cache.get(key) + if value is None: + # print(f"Executing {func.__name__}({args[0]}, ...)") + value = func(*args, **kwargs) + _cache.set(key, value) + return value + return wrapper - value = cache.get(key) - if value is None: - with display_during_execution(message): - value = value_func() - cache.set(key, value) - return value + +def clear_cache(): + _cache.clear() @contextmanager diff --git a/manimlib/utils/customization.py b/manimlib/utils/customization.py index 7426deb6..a011ec07 100644 --- a/manimlib/utils/customization.py +++ b/manimlib/utils/customization.py @@ -11,7 +11,9 @@ CUSTOMIZATION = {} def get_customization(): if not CUSTOMIZATION: + print(CUSTOMIZATION) CUSTOMIZATION.update(get_custom_config()) + print(CUSTOMIZATION) directories = CUSTOMIZATION["directories"] # Unless user has specified otherwise, use the system default temp # directory for storing tex files, mobject_data, etc. diff --git a/manimlib/utils/tex_file_writing.py b/manimlib/utils/tex_file_writing.py index b3b14725..d5208c02 100644 --- a/manimlib/utils/tex_file_writing.py +++ b/manimlib/utils/tex_file_writing.py @@ -8,6 +8,7 @@ import subprocess from pathlib import Path import tempfile +from manimlib.utils.cache import cache_on_disk from manimlib.config import get_custom_config from manimlib.config import get_manim_dir from manimlib.logger import log @@ -62,10 +63,13 @@ def get_full_tex(content: str, preamble: str = ""): )) + "\n" +@cache_on_disk def latex_to_svg( latex: str, template: str = "", - additional_preamble: str = "" + additional_preamble: str = "", + short_tex: str = "", + show_message_during_execution: bool = True, ) -> str: """Convert LaTeX string to SVG string. @@ -81,6 +85,13 @@ def latex_to_svg( LatexError: If LaTeX compilation fails NotImplementedError: If compiler is not supported """ + if show_message_during_execution: + max_message_len = 80 + message = f"Writing {short_tex or latex}" + if len(message) > max_message_len: + message = message[:max_message_len - 3] + "..." + print(message, end="\r") + tex_config = get_tex_config() if template and template != tex_config["template"]: tex_config = get_tex_template_config(template) @@ -147,7 +158,12 @@ def latex_to_svg( ) # Return SVG string - return process.stdout.decode('utf-8') + result = process.stdout.decode('utf-8') + + if show_message_during_execution: + print(" " * len(message), end="\r") + + return result class LatexError(Exception):