From 88370d4d5da329d4866b5edd68cf7e267d0d822e Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Wed, 4 Dec 2024 19:11:21 -0600 Subject: [PATCH] Have StringMobject work with svg strings rather than necessarily writing to file Change SVGMobject to allow taking in a string of svg code as an input --- manimlib/mobject/svg/string_mobject.py | 15 ++++--- manimlib/mobject/svg/svg_mobject.py | 28 +++++++----- manimlib/mobject/svg/tex_mobject.py | 7 ++- manimlib/mobject/svg/text_mobject.py | 59 ++++++++++++++------------ 4 files changed, 63 insertions(+), 46 deletions(-) diff --git a/manimlib/mobject/svg/string_mobject.py b/manimlib/mobject/svg/string_mobject.py index 98277da5..85031d87 100644 --- a/manimlib/mobject/svg/string_mobject.py +++ b/manimlib/mobject/svg/string_mobject.py @@ -66,17 +66,18 @@ class StringMobject(SVGMobject, ABC): self.use_labelled_svg = use_labelled_svg self.parse() - super().__init__(**kwargs) + svg_string = self.get_svg_string() + super().__init__(svg_string=svg_string, **kwargs) self.set_stroke(stroke_color, stroke_width) self.set_fill(fill_color, border_width=fill_border_width) self.labels = [submob.label for submob in self.submobjects] - def get_file_path(self, is_labelled: bool = False) -> str: - is_labelled = is_labelled or self.use_labelled_svg - return self.get_file_path_by_content(self.get_content(is_labelled)) + def get_svg_string(self, is_labelled: bool = False) -> str: + content = self.get_content(is_labelled or self.use_labelled_svg) + return self.get_svg_string_by_content(content) @abstractmethod - def get_file_path_by_content(self, content: str) -> str: + def get_svg_string_by_content(self, content: str) -> str: return "" def assign_labels_by_color(self, mobjects: list[VMobject]) -> None: @@ -109,8 +110,8 @@ class StringMobject(SVGMobject, ABC): ) ) - def mobjects_from_file(self, file_path: str) -> list[VMobject]: - submobs = super().mobjects_from_file(file_path) + def mobjects_from_svg_string(self, svg_string: str) -> list[VMobject]: + submobs = super().mobjects_from_svg_string(svg_string) if self.use_labelled_svg: # This means submobjects are colored according to spans diff --git a/manimlib/mobject/svg/svg_mobject.py b/manimlib/mobject/svg/svg_mobject.py index 020ed762..688fa4c5 100644 --- a/manimlib/mobject/svg/svg_mobject.py +++ b/manimlib/mobject/svg/svg_mobject.py @@ -6,6 +6,7 @@ from xml.etree import ElementTree as ET import numpy as np import svgelements as se import io +from pathlib import Path from manimlib.constants import RIGHT from manimlib.logger import log @@ -43,6 +44,7 @@ class SVGMobject(VMobject): def __init__( self, file_name: str = "", + svg_string: str = "", should_center: bool = True, height: float | None = None, width: float | None = None, @@ -67,11 +69,19 @@ class SVGMobject(VMobject): path_string_config: dict = dict(), **kwargs ): - self.file_name = file_name or self.file_name + if svg_string != "": + self.svg_string = svg_string + elif file_name != "": + self.svg_string = self.file_name_to_svg_string(file_name) + elif self.file_name != "": + self.file_name_to_svg_string(self.file_name) + else: + raise Exception("Must specify either a file_name or svg_string SVGMobject") + self.svg_default = dict(svg_default) self.path_string_config = dict(path_string_config) - super().__init__(**kwargs ) + super().__init__(**kwargs) self.init_svg_mobject() self.ensure_positive_orientation() @@ -101,7 +111,7 @@ class SVGMobject(VMobject): if hash_val in SVG_HASH_TO_MOB_MAP: submobs = [sm.copy() for sm in SVG_HASH_TO_MOB_MAP[hash_val]] else: - submobs = self.mobjects_from_file(self.get_file_path()) + submobs = self.mobjects_from_svg_string(self.svg_string) SVG_HASH_TO_MOB_MAP[hash_val] = [sm.copy() for sm in submobs] self.add(*submobs) @@ -115,11 +125,11 @@ class SVGMobject(VMobject): self.__class__.__name__, self.svg_default, self.path_string_config, - self.file_name + self.svg_string ) - def mobjects_from_file(self, file_path: str) -> list[VMobject]: - element_tree = ET.parse(file_path) + def mobjects_from_svg_string(self, svg_string: str) -> list[VMobject]: + element_tree = ET.ElementTree(ET.fromstring(svg_string)) new_tree = self.modify_xml_tree(element_tree) # New svg based on tree contents @@ -131,10 +141,8 @@ class SVGMobject(VMobject): return self.mobjects_from_svg(svg) - def get_file_path(self) -> str: - if self.file_name is None: - raise Exception("Must specify file for SVGMobject") - return get_full_vector_image_path(self.file_name) + def file_name_to_svg_string(self, file_name: str) -> str: + return Path(get_full_vector_image_path(file_name)).read_text() def modify_xml_tree(self, element_tree: ET.ElementTree) -> ET.ElementTree: config_style_attrs = self.generate_config_style_dict() diff --git a/manimlib/mobject/svg/tex_mobject.py b/manimlib/mobject/svg/tex_mobject.py index 5211cbcb..9412830f 100644 --- a/manimlib/mobject/svg/tex_mobject.py +++ b/manimlib/mobject/svg/tex_mobject.py @@ -1,6 +1,7 @@ from __future__ import annotations import re +from pathlib import Path from manimlib.mobject.svg.string_mobject import StringMobject from manimlib.mobject.types.vectorized_mobject import VGroup @@ -82,10 +83,12 @@ class Tex(StringMobject): self.additional_preamble ) - def get_file_path_by_content(self, content: str) -> str: - return tex_content_to_svg_file( + def get_svg_string_by_content(self, content: str) -> str: + # TODO, implement this without writing to a file + file_path = tex_content_to_svg_file( content, self.template, self.additional_preamble, self.tex_string ) + return Path(file_path).read_text() def _handle_scale_side_effects(self, scale_factor: float) -> Self: self.font_size *= scale_factor diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index a16d322f..6989515e 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -4,6 +4,8 @@ from contextlib import contextmanager import os from pathlib import Path import re +import tempfile +import hashlib import manimpango import pygments @@ -169,7 +171,8 @@ class MarkupText(StringMobject): self.disable_ligatures ) - def get_file_path_by_content(self, content: str) -> str: + def get_svg_string_by_content(self, content: str) -> str: + # TODO, check the cache hash_content = str(( content, self.justify, @@ -177,14 +180,11 @@ class MarkupText(StringMobject): self.alignment, self.line_width )) - svg_file = os.path.join( - get_text_dir(), hash_string(hash_content) + ".svg" - ) - if not os.path.exists(svg_file): - self.markup_to_svg(content, svg_file) - return svg_file + # hash_string(hash_content) + key = hashlib.sha256(hash_content.encode()).hexdigest() + return self.markup_to_svg_string(content) - def markup_to_svg(self, markup_str: str, file_name: str) -> str: + def markup_to_svg_string(self, markup_str: str) -> str: self.validate_markup_string(markup_str) # `manimpango` is under construction, @@ -195,25 +195,30 @@ class MarkupText(StringMobject): else: pango_width = self.line_width / FRAME_WIDTH * DEFAULT_PIXEL_WIDTH - return manimpango.MarkupUtils.text2svg( - text=markup_str, - font="", # Already handled - slant="NORMAL", # Already handled - weight="NORMAL", # Already handled - size=1, # Already handled - _=0, # Empty parameter - disable_liga=False, - file_name=file_name, - START_X=0, - START_Y=0, - width=DEFAULT_CANVAS_WIDTH, - height=DEFAULT_CANVAS_HEIGHT, - justify=self.justify, - indent=self.indent, - line_spacing=None, # Already handled - alignment=alignment, - pango_width=pango_width - ) + with tempfile.NamedTemporaryFile(suffix='.svg', mode='r+') as tmp: + manimpango.MarkupUtils.text2svg( + text=markup_str, + font="", # Already handled + slant="NORMAL", # Already handled + weight="NORMAL", # Already handled + size=1, # Already handled + _=0, # Empty parameter + disable_liga=False, + file_name=tmp.name, + START_X=0, + START_Y=0, + width=DEFAULT_CANVAS_WIDTH, + height=DEFAULT_CANVAS_HEIGHT, + justify=self.justify, + indent=self.indent, + line_spacing=None, # Already handled + alignment=alignment, + pango_width=pango_width + ) + + # Read the contents + tmp.seek(0) + return tmp.read() @staticmethod def validate_markup_string(markup_str: str) -> None: