Have StringMobject work with svg strings rather than necessarily writing to file

Change SVGMobject to allow taking in a string of svg code as an input
This commit is contained in:
Grant Sanderson 2024-12-04 19:11:21 -06:00
parent 671a31b298
commit 88370d4d5d
4 changed files with 63 additions and 46 deletions

View file

@ -66,17 +66,18 @@ class StringMobject(SVGMobject, ABC):
self.use_labelled_svg = use_labelled_svg
self.parse()
super().__init__(**kwargs)
svg_string = self.get_svg_string()
super().__init__(svg_string=svg_string, **kwargs)
self.set_stroke(stroke_color, stroke_width)
self.set_fill(fill_color, border_width=fill_border_width)
self.labels = [submob.label for submob in self.submobjects]
def get_file_path(self, is_labelled: bool = False) -> str:
is_labelled = is_labelled or self.use_labelled_svg
return self.get_file_path_by_content(self.get_content(is_labelled))
def get_svg_string(self, is_labelled: bool = False) -> str:
content = self.get_content(is_labelled or self.use_labelled_svg)
return self.get_svg_string_by_content(content)
@abstractmethod
def get_file_path_by_content(self, content: str) -> str:
def get_svg_string_by_content(self, content: str) -> str:
return ""
def assign_labels_by_color(self, mobjects: list[VMobject]) -> None:
@ -109,8 +110,8 @@ class StringMobject(SVGMobject, ABC):
)
)
def mobjects_from_file(self, file_path: str) -> list[VMobject]:
submobs = super().mobjects_from_file(file_path)
def mobjects_from_svg_string(self, svg_string: str) -> list[VMobject]:
submobs = super().mobjects_from_svg_string(svg_string)
if self.use_labelled_svg:
# This means submobjects are colored according to spans

View file

@ -6,6 +6,7 @@ from xml.etree import ElementTree as ET
import numpy as np
import svgelements as se
import io
from pathlib import Path
from manimlib.constants import RIGHT
from manimlib.logger import log
@ -43,6 +44,7 @@ class SVGMobject(VMobject):
def __init__(
self,
file_name: str = "",
svg_string: str = "",
should_center: bool = True,
height: float | None = None,
width: float | None = None,
@ -67,11 +69,19 @@ class SVGMobject(VMobject):
path_string_config: dict = dict(),
**kwargs
):
self.file_name = file_name or self.file_name
if svg_string != "":
self.svg_string = svg_string
elif file_name != "":
self.svg_string = self.file_name_to_svg_string(file_name)
elif self.file_name != "":
self.file_name_to_svg_string(self.file_name)
else:
raise Exception("Must specify either a file_name or svg_string SVGMobject")
self.svg_default = dict(svg_default)
self.path_string_config = dict(path_string_config)
super().__init__(**kwargs )
super().__init__(**kwargs)
self.init_svg_mobject()
self.ensure_positive_orientation()
@ -101,7 +111,7 @@ class SVGMobject(VMobject):
if hash_val in SVG_HASH_TO_MOB_MAP:
submobs = [sm.copy() for sm in SVG_HASH_TO_MOB_MAP[hash_val]]
else:
submobs = self.mobjects_from_file(self.get_file_path())
submobs = self.mobjects_from_svg_string(self.svg_string)
SVG_HASH_TO_MOB_MAP[hash_val] = [sm.copy() for sm in submobs]
self.add(*submobs)
@ -115,11 +125,11 @@ class SVGMobject(VMobject):
self.__class__.__name__,
self.svg_default,
self.path_string_config,
self.file_name
self.svg_string
)
def mobjects_from_file(self, file_path: str) -> list[VMobject]:
element_tree = ET.parse(file_path)
def mobjects_from_svg_string(self, svg_string: str) -> list[VMobject]:
element_tree = ET.ElementTree(ET.fromstring(svg_string))
new_tree = self.modify_xml_tree(element_tree)
# New svg based on tree contents
@ -131,10 +141,8 @@ class SVGMobject(VMobject):
return self.mobjects_from_svg(svg)
def get_file_path(self) -> str:
if self.file_name is None:
raise Exception("Must specify file for SVGMobject")
return get_full_vector_image_path(self.file_name)
def file_name_to_svg_string(self, file_name: str) -> str:
return Path(get_full_vector_image_path(file_name)).read_text()
def modify_xml_tree(self, element_tree: ET.ElementTree) -> ET.ElementTree:
config_style_attrs = self.generate_config_style_dict()

View file

@ -1,6 +1,7 @@
from __future__ import annotations
import re
from pathlib import Path
from manimlib.mobject.svg.string_mobject import StringMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
@ -82,10 +83,12 @@ class Tex(StringMobject):
self.additional_preamble
)
def get_file_path_by_content(self, content: str) -> str:
return tex_content_to_svg_file(
def get_svg_string_by_content(self, content: str) -> str:
# TODO, implement this without writing to a file
file_path = tex_content_to_svg_file(
content, self.template, self.additional_preamble, self.tex_string
)
return Path(file_path).read_text()
def _handle_scale_side_effects(self, scale_factor: float) -> Self:
self.font_size *= scale_factor

View file

@ -4,6 +4,8 @@ from contextlib import contextmanager
import os
from pathlib import Path
import re
import tempfile
import hashlib
import manimpango
import pygments
@ -169,7 +171,8 @@ class MarkupText(StringMobject):
self.disable_ligatures
)
def get_file_path_by_content(self, content: str) -> str:
def get_svg_string_by_content(self, content: str) -> str:
# TODO, check the cache
hash_content = str((
content,
self.justify,
@ -177,14 +180,11 @@ class MarkupText(StringMobject):
self.alignment,
self.line_width
))
svg_file = os.path.join(
get_text_dir(), hash_string(hash_content) + ".svg"
)
if not os.path.exists(svg_file):
self.markup_to_svg(content, svg_file)
return svg_file
# hash_string(hash_content)
key = hashlib.sha256(hash_content.encode()).hexdigest()
return self.markup_to_svg_string(content)
def markup_to_svg(self, markup_str: str, file_name: str) -> str:
def markup_to_svg_string(self, markup_str: str) -> str:
self.validate_markup_string(markup_str)
# `manimpango` is under construction,
@ -195,25 +195,30 @@ class MarkupText(StringMobject):
else:
pango_width = self.line_width / FRAME_WIDTH * DEFAULT_PIXEL_WIDTH
return manimpango.MarkupUtils.text2svg(
text=markup_str,
font="", # Already handled
slant="NORMAL", # Already handled
weight="NORMAL", # Already handled
size=1, # Already handled
_=0, # Empty parameter
disable_liga=False,
file_name=file_name,
START_X=0,
START_Y=0,
width=DEFAULT_CANVAS_WIDTH,
height=DEFAULT_CANVAS_HEIGHT,
justify=self.justify,
indent=self.indent,
line_spacing=None, # Already handled
alignment=alignment,
pango_width=pango_width
)
with tempfile.NamedTemporaryFile(suffix='.svg', mode='r+') as tmp:
manimpango.MarkupUtils.text2svg(
text=markup_str,
font="", # Already handled
slant="NORMAL", # Already handled
weight="NORMAL", # Already handled
size=1, # Already handled
_=0, # Empty parameter
disable_liga=False,
file_name=tmp.name,
START_X=0,
START_Y=0,
width=DEFAULT_CANVAS_WIDTH,
height=DEFAULT_CANVAS_HEIGHT,
justify=self.justify,
indent=self.indent,
line_spacing=None, # Already handled
alignment=alignment,
pango_width=pango_width
)
# Read the contents
tmp.seek(0)
return tmp.read()
@staticmethod
def validate_markup_string(markup_str: str) -> None: