Have StringMobject work with svg strings rather than necessarily writing to file

Change SVGMobject to allow taking in a string of svg code as an input
This commit is contained in:
Grant Sanderson 2024-12-04 19:11:21 -06:00
parent 671a31b298
commit 88370d4d5d
4 changed files with 63 additions and 46 deletions

View file

@ -66,17 +66,18 @@ class StringMobject(SVGMobject, ABC):
self.use_labelled_svg = use_labelled_svg self.use_labelled_svg = use_labelled_svg
self.parse() self.parse()
super().__init__(**kwargs) svg_string = self.get_svg_string()
super().__init__(svg_string=svg_string, **kwargs)
self.set_stroke(stroke_color, stroke_width) self.set_stroke(stroke_color, stroke_width)
self.set_fill(fill_color, border_width=fill_border_width) self.set_fill(fill_color, border_width=fill_border_width)
self.labels = [submob.label for submob in self.submobjects] self.labels = [submob.label for submob in self.submobjects]
def get_file_path(self, is_labelled: bool = False) -> str: def get_svg_string(self, is_labelled: bool = False) -> str:
is_labelled = is_labelled or self.use_labelled_svg content = self.get_content(is_labelled or self.use_labelled_svg)
return self.get_file_path_by_content(self.get_content(is_labelled)) return self.get_svg_string_by_content(content)
@abstractmethod @abstractmethod
def get_file_path_by_content(self, content: str) -> str: def get_svg_string_by_content(self, content: str) -> str:
return "" return ""
def assign_labels_by_color(self, mobjects: list[VMobject]) -> None: def assign_labels_by_color(self, mobjects: list[VMobject]) -> None:
@ -109,8 +110,8 @@ class StringMobject(SVGMobject, ABC):
) )
) )
def mobjects_from_file(self, file_path: str) -> list[VMobject]: def mobjects_from_svg_string(self, svg_string: str) -> list[VMobject]:
submobs = super().mobjects_from_file(file_path) submobs = super().mobjects_from_svg_string(svg_string)
if self.use_labelled_svg: if self.use_labelled_svg:
# This means submobjects are colored according to spans # This means submobjects are colored according to spans

View file

@ -6,6 +6,7 @@ from xml.etree import ElementTree as ET
import numpy as np import numpy as np
import svgelements as se import svgelements as se
import io import io
from pathlib import Path
from manimlib.constants import RIGHT from manimlib.constants import RIGHT
from manimlib.logger import log from manimlib.logger import log
@ -43,6 +44,7 @@ class SVGMobject(VMobject):
def __init__( def __init__(
self, self,
file_name: str = "", file_name: str = "",
svg_string: str = "",
should_center: bool = True, should_center: bool = True,
height: float | None = None, height: float | None = None,
width: float | None = None, width: float | None = None,
@ -67,11 +69,19 @@ class SVGMobject(VMobject):
path_string_config: dict = dict(), path_string_config: dict = dict(),
**kwargs **kwargs
): ):
self.file_name = file_name or self.file_name if svg_string != "":
self.svg_string = svg_string
elif file_name != "":
self.svg_string = self.file_name_to_svg_string(file_name)
elif self.file_name != "":
self.file_name_to_svg_string(self.file_name)
else:
raise Exception("Must specify either a file_name or svg_string SVGMobject")
self.svg_default = dict(svg_default) self.svg_default = dict(svg_default)
self.path_string_config = dict(path_string_config) self.path_string_config = dict(path_string_config)
super().__init__(**kwargs ) super().__init__(**kwargs)
self.init_svg_mobject() self.init_svg_mobject()
self.ensure_positive_orientation() self.ensure_positive_orientation()
@ -101,7 +111,7 @@ class SVGMobject(VMobject):
if hash_val in SVG_HASH_TO_MOB_MAP: if hash_val in SVG_HASH_TO_MOB_MAP:
submobs = [sm.copy() for sm in SVG_HASH_TO_MOB_MAP[hash_val]] submobs = [sm.copy() for sm in SVG_HASH_TO_MOB_MAP[hash_val]]
else: else:
submobs = self.mobjects_from_file(self.get_file_path()) submobs = self.mobjects_from_svg_string(self.svg_string)
SVG_HASH_TO_MOB_MAP[hash_val] = [sm.copy() for sm in submobs] SVG_HASH_TO_MOB_MAP[hash_val] = [sm.copy() for sm in submobs]
self.add(*submobs) self.add(*submobs)
@ -115,11 +125,11 @@ class SVGMobject(VMobject):
self.__class__.__name__, self.__class__.__name__,
self.svg_default, self.svg_default,
self.path_string_config, self.path_string_config,
self.file_name self.svg_string
) )
def mobjects_from_file(self, file_path: str) -> list[VMobject]: def mobjects_from_svg_string(self, svg_string: str) -> list[VMobject]:
element_tree = ET.parse(file_path) element_tree = ET.ElementTree(ET.fromstring(svg_string))
new_tree = self.modify_xml_tree(element_tree) new_tree = self.modify_xml_tree(element_tree)
# New svg based on tree contents # New svg based on tree contents
@ -131,10 +141,8 @@ class SVGMobject(VMobject):
return self.mobjects_from_svg(svg) return self.mobjects_from_svg(svg)
def get_file_path(self) -> str: def file_name_to_svg_string(self, file_name: str) -> str:
if self.file_name is None: return Path(get_full_vector_image_path(file_name)).read_text()
raise Exception("Must specify file for SVGMobject")
return get_full_vector_image_path(self.file_name)
def modify_xml_tree(self, element_tree: ET.ElementTree) -> ET.ElementTree: def modify_xml_tree(self, element_tree: ET.ElementTree) -> ET.ElementTree:
config_style_attrs = self.generate_config_style_dict() config_style_attrs = self.generate_config_style_dict()

View file

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import re import re
from pathlib import Path
from manimlib.mobject.svg.string_mobject import StringMobject from manimlib.mobject.svg.string_mobject import StringMobject
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
@ -82,10 +83,12 @@ class Tex(StringMobject):
self.additional_preamble self.additional_preamble
) )
def get_file_path_by_content(self, content: str) -> str: def get_svg_string_by_content(self, content: str) -> str:
return tex_content_to_svg_file( # TODO, implement this without writing to a file
file_path = tex_content_to_svg_file(
content, self.template, self.additional_preamble, self.tex_string content, self.template, self.additional_preamble, self.tex_string
) )
return Path(file_path).read_text()
def _handle_scale_side_effects(self, scale_factor: float) -> Self: def _handle_scale_side_effects(self, scale_factor: float) -> Self:
self.font_size *= scale_factor self.font_size *= scale_factor

View file

@ -4,6 +4,8 @@ from contextlib import contextmanager
import os import os
from pathlib import Path from pathlib import Path
import re import re
import tempfile
import hashlib
import manimpango import manimpango
import pygments import pygments
@ -169,7 +171,8 @@ class MarkupText(StringMobject):
self.disable_ligatures self.disable_ligatures
) )
def get_file_path_by_content(self, content: str) -> str: def get_svg_string_by_content(self, content: str) -> str:
# TODO, check the cache
hash_content = str(( hash_content = str((
content, content,
self.justify, self.justify,
@ -177,14 +180,11 @@ class MarkupText(StringMobject):
self.alignment, self.alignment,
self.line_width self.line_width
)) ))
svg_file = os.path.join( # hash_string(hash_content)
get_text_dir(), hash_string(hash_content) + ".svg" key = hashlib.sha256(hash_content.encode()).hexdigest()
) return self.markup_to_svg_string(content)
if not os.path.exists(svg_file):
self.markup_to_svg(content, svg_file)
return svg_file
def markup_to_svg(self, markup_str: str, file_name: str) -> str: def markup_to_svg_string(self, markup_str: str) -> str:
self.validate_markup_string(markup_str) self.validate_markup_string(markup_str)
# `manimpango` is under construction, # `manimpango` is under construction,
@ -195,25 +195,30 @@ class MarkupText(StringMobject):
else: else:
pango_width = self.line_width / FRAME_WIDTH * DEFAULT_PIXEL_WIDTH pango_width = self.line_width / FRAME_WIDTH * DEFAULT_PIXEL_WIDTH
return manimpango.MarkupUtils.text2svg( with tempfile.NamedTemporaryFile(suffix='.svg', mode='r+') as tmp:
text=markup_str, manimpango.MarkupUtils.text2svg(
font="", # Already handled text=markup_str,
slant="NORMAL", # Already handled font="", # Already handled
weight="NORMAL", # Already handled slant="NORMAL", # Already handled
size=1, # Already handled weight="NORMAL", # Already handled
_=0, # Empty parameter size=1, # Already handled
disable_liga=False, _=0, # Empty parameter
file_name=file_name, disable_liga=False,
START_X=0, file_name=tmp.name,
START_Y=0, START_X=0,
width=DEFAULT_CANVAS_WIDTH, START_Y=0,
height=DEFAULT_CANVAS_HEIGHT, width=DEFAULT_CANVAS_WIDTH,
justify=self.justify, height=DEFAULT_CANVAS_HEIGHT,
indent=self.indent, justify=self.justify,
line_spacing=None, # Already handled indent=self.indent,
alignment=alignment, line_spacing=None, # Already handled
pango_width=pango_width alignment=alignment,
) pango_width=pango_width
)
# Read the contents
tmp.seek(0)
return tmp.read()
@staticmethod @staticmethod
def validate_markup_string(markup_str: str) -> None: def validate_markup_string(markup_str: str) -> None: