This commit is contained in:
TonyCrane 2022-02-16 11:21:20 +08:00
commit 8ef42fae24
No known key found for this signature in database
GPG key ID: 2313A5058A9C637C
9 changed files with 239 additions and 230 deletions

View file

@ -2,6 +2,7 @@ from manimlib.constants import *
from manimlib.mobject.svg.tex_mobject import SingleStringTex
from manimlib.mobject.svg.text_mobject import Text
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.iterables import hash_obj
string_to_mob_map = {}
@ -20,23 +21,23 @@ class DecimalNumber(VMobject):
"include_background_rectangle": False,
"edge_to_fix": LEFT,
"font_size": 48,
"text_config": {} # Do not pass in font_size here
}
def __init__(self, number=0, **kwargs):
super().__init__(**kwargs)
self.set_submobjects_from_number(number)
self.init_colors()
def set_submobjects_from_number(self, number):
self.number = number
self.set_submobjects([])
string_to_mob_ = lambda s: self.string_to_mob(s, **self.text_config)
num_string = self.get_num_string(number)
self.add(*map(self.string_to_mob, num_string))
self.add(*map(string_to_mob_, num_string))
# Add non-numerical bits
if self.show_ellipsis:
dots = self.string_to_mob("...")
dots = string_to_mob_("...")
dots.arrange(RIGHT, buff=2 * dots[0].get_width())
self.add(dots)
if self.unit is not None:
@ -85,10 +86,10 @@ class DecimalNumber(VMobject):
def get_font_size(self):
return self.data["font_size"][0]
def string_to_mob(self, string, mob_class=Text):
if string not in string_to_mob_map:
string_to_mob_map[string] = mob_class(string, font_size=1)
mob = string_to_mob_map[string].copy()
def string_to_mob(self, string, mob_class=Text, **kwargs):
if (string, hash_obj(kwargs)) not in string_to_mob_map:
string_to_mob_map[(string, hash_obj(kwargs))] = mob_class(string, font_size=1, **kwargs)
mob = string_to_mob_map[(string, hash_obj(kwargs))].copy()
mob.scale(self.get_font_size())
return mob

View file

@ -318,9 +318,6 @@ class Bubble(SVGMobject):
self.content = Mobject()
self.refresh_triangulation()
def init_colors(self):
VMobject.init_colors(self)
def get_tip(self):
# TODO, find a better way
return self.get_corner(DOWN + self.direction) - 0.6 * self.direction

View file

@ -2,11 +2,11 @@ import itertools as it
import re
from types import MethodType
from manimlib.constants import BLACK
from manimlib.constants import WHITE
from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.color import color_to_int_rgb
from manimlib.utils.config_ops import digest_config
from manimlib.utils.iterables import adjacent_pairs
from manimlib.utils.iterables import remove_list_redundancies
from manimlib.utils.tex_file_writing import tex_to_svg_file
@ -18,25 +18,10 @@ from manimlib.logger import log
SCALE_FACTOR_PER_FONT_POINT = 0.001
TEX_HASH_TO_MOB_MAP = {}
def _get_neighbouring_pairs(iterable):
return list(adjacent_pairs(iterable))[:-1]
class _TexSVG(SVGMobject):
CONFIG = {
"color": BLACK,
"stroke_width": 0,
"height": None,
"path_string_config": {
"should_subdivide_sharp_curves": True,
"should_remove_null_curves": True,
},
}
class _TexParser(object):
def __init__(self, tex_string, additional_substrings):
self.tex_string = tex_string
@ -400,10 +385,21 @@ class _TexParser(object):
])
class MTex(VMobject):
class _TexSVG(SVGMobject):
CONFIG = {
"height": None,
"fill_opacity": 1.0,
"stroke_width": 0,
"path_string_config": {
"should_subdivide_sharp_curves": True,
"should_remove_null_curves": True,
},
}
class MTex(_TexSVG):
CONFIG = {
"color": WHITE,
"font_size": 48,
"alignment": "\\centering",
"tex_environment": "align*",
@ -413,65 +409,49 @@ class MTex(VMobject):
}
def __init__(self, tex_string, **kwargs):
super().__init__(**kwargs)
digest_config(self, kwargs)
tex_string = tex_string.strip()
# Prevent from passing an empty string.
if not tex_string:
tex_string = "\\quad"
self.tex_string = tex_string
self.__parser = _TexParser(
self.parser = _TexParser(
self.tex_string,
[*self.tex_to_color_map.keys(), *self.isolate]
)
mob = self.generate_mobject()
self.add(*mob.copy())
self.init_colors()
super().__init__(**kwargs)
self.set_color_by_tex_to_color_map(self.tex_to_color_map)
self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size)
@staticmethod
def color_to_label(color):
r, g, b = color_to_int_rgb(color)
rg = r * 256 + g
return rg * 256 + b
@property
def hash_seed(self):
return (
self.__class__.__name__,
self.svg_default,
self.path_string_config,
self.tex_string,
self.parser.specified_substrings,
self.alignment,
self.tex_environment,
self.use_plain_tex
)
def generate_mobject(self):
labelled_tex_string = self.__parser.get_labelled_tex_string()
labelled_tex_content = self.get_tex_file_content(labelled_tex_string)
hash_val = hash((labelled_tex_content, self.use_plain_tex))
def get_file_path(self):
return self._get_file_path(self.use_plain_tex)
if hash_val in TEX_HASH_TO_MOB_MAP:
return TEX_HASH_TO_MOB_MAP[hash_val]
if not self.use_plain_tex:
with display_during_execution(f"Writing \"{self.tex_string}\""):
labelled_svg_glyphs = self.tex_content_to_glyphs(
labelled_tex_content
)
glyph_labels = [
self.color_to_label(labelled_glyph.get_fill_color())
for labelled_glyph in labelled_svg_glyphs
]
mob = self.build_mobject(labelled_svg_glyphs, glyph_labels)
TEX_HASH_TO_MOB_MAP[hash_val] = mob
return mob
def _get_file_path(self, use_plain_tex):
if use_plain_tex:
tex_string = self.tex_string
else:
tex_string = self.parser.get_labelled_tex_string()
full_tex = self.get_tex_file_body(tex_string)
with display_during_execution(f"Writing \"{self.tex_string}\""):
labelled_svg_glyphs = self.tex_content_to_glyphs(
labelled_tex_content
)
tex_content = self.get_tex_file_content(self.tex_string)
svg_glyphs = self.tex_content_to_glyphs(tex_content)
glyph_labels = [
self.color_to_label(labelled_glyph.get_fill_color())
for labelled_glyph in labelled_svg_glyphs
]
mob = self.build_mobject(svg_glyphs, glyph_labels)
TEX_HASH_TO_MOB_MAP[hash_val] = mob
return mob
file_path = self.tex_to_svg_file_path(full_tex)
return file_path
def get_tex_file_content(self, tex_string):
def get_tex_file_body(self, tex_string):
if self.tex_environment:
tex_string = "\n".join([
f"\\begin{{{self.tex_environment}}}",
@ -480,17 +460,38 @@ class MTex(VMobject):
])
if self.alignment:
tex_string = "\n".join([self.alignment, tex_string])
return tex_string
tex_config = get_tex_config()
return tex_config["tex_body"].replace(
tex_config["text_to_replace"],
tex_string
)
@staticmethod
def tex_content_to_glyphs(tex_content):
tex_config = get_tex_config()
full_tex = tex_config["tex_body"].replace(
tex_config["text_to_replace"],
tex_content
)
filename = tex_to_svg_file(full_tex)
return _TexSVG(filename)
def tex_to_svg_file_path(tex_file_content):
return tex_to_svg_file(tex_file_content)
def generate_mobject(self):
super().generate_mobject()
if not self.use_plain_tex:
labelled_svg_glyphs = self
else:
file_path = self._get_file_path(use_plain_tex=False)
labelled_svg_glyphs = _TexSVG(file_path)
glyph_labels = [
self.color_to_label(labelled_glyph.get_fill_color())
for labelled_glyph in labelled_svg_glyphs
]
mob = self.build_mobject(self, glyph_labels)
self.set_submobjects(mob.submobjects)
@staticmethod
def color_to_label(color):
r, g, b = color_to_int_rgb(color)
rg = r * 256 + g
return rg * 256 + b
def build_mobject(self, svg_glyphs, glyph_labels):
if not svg_glyphs:
@ -514,11 +515,11 @@ class MTex(VMobject):
submob_labels.append(current_glyph_label)
submobjects.append(submobject)
indices = self.__parser.get_sorted_submob_indices(submob_labels)
indices = self.parser.get_sorted_submob_indices(submob_labels)
rearranged_submobjects = [submobjects[index] for index in indices]
rearranged_labels = [submob_labels[index] for index in indices]
submob_tex_strings = self.__parser.get_submob_tex_strings(
submob_tex_strings = self.parser.get_submob_tex_strings(
rearranged_labels
)
for submob, label, submob_tex in zip(
@ -531,14 +532,14 @@ class MTex(VMobject):
return VGroup(*rearranged_submobjects)
def get_part_by_tex_spans(self, tex_spans):
labels = self.__parser.get_containing_labels_by_tex_spans(tex_spans)
labels = self.parser.get_containing_labels_by_tex_spans(tex_spans)
return VGroup(*filter(
lambda submob: submob.submob_label in labels,
self.submobjects
))
def get_part_by_custom_span(self, custom_span):
tex_spans = self.__parser.find_span_components_of_custom_span(
tex_spans = self.parser.find_span_components_of_custom_span(
custom_span
)
if tex_spans is None:
@ -590,10 +591,10 @@ class MTex(VMobject):
]
def get_specified_substrings(self):
return self.__parser.get_specified_substrings()
return self.parser.get_specified_substrings()
def get_isolated_substrings(self):
return self.__parser.get_isolated_substrings()
return self.parser.get_isolated_substrings()
class MTexText(MTex):

View file

@ -1,7 +1,7 @@
import os
import re
import hashlib
import itertools as it
from xml.etree import ElementTree as ET
import svgelements as se
import numpy as np
@ -17,9 +17,13 @@ from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.config_ops import digest_config
from manimlib.utils.directories import get_mobject_data_dir
from manimlib.utils.images import get_full_vector_image_path
from manimlib.utils.iterables import hash_obj
from manimlib.logger import log
SVG_HASH_TO_MOB_MAP = {}
def _convert_point_to_3d(x, y):
return np.array([x, y, 0.0])
@ -29,8 +33,8 @@ class SVGMobject(VMobject):
"should_center": True,
"height": 2,
"width": None,
# Must be filled in a subclass, or when called
"file_name": None,
# Style that overrides the original svg
"color": None,
"opacity": None,
"fill_color": None,
@ -38,127 +42,119 @@ class SVGMobject(VMobject):
"stroke_width": None,
"stroke_color": None,
"stroke_opacity": None,
"path_string_config": {}
# Style that fills only when not specified
# If None, regarded as default values from svg standard
"svg_default": {
"color": None,
"opacity": None,
"fill_color": None,
"fill_opacity": None,
"stroke_width": None,
"stroke_color": None,
"stroke_opacity": None,
},
"path_string_config": {},
}
def __init__(self, file_name=None, **kwargs):
digest_config(self, kwargs)
self.file_name = file_name or self.file_name
if file_name is None:
raise Exception("Must specify file for SVGMobject")
self.file_path = get_full_vector_image_path(file_name)
super().__init__(**kwargs)
self.file_name = file_name or self.file_name
self.init_svg_mobject()
self.init_colors()
self.move_into_position()
def move_into_position(self):
if self.should_center:
self.center()
if self.height is not None:
self.set_height(self.height)
if self.width is not None:
self.set_width(self.width)
def init_svg_mobject(self):
hash_val = hash_obj(self.hash_seed)
if hash_val in SVG_HASH_TO_MOB_MAP:
mob = SVG_HASH_TO_MOB_MAP[hash_val].copy()
self.add(*mob)
return
def init_colors(self):
# Remove fill_color, fill_opacity,
# stroke_width, stroke_color, stroke_opacity
# as each submobject may have those values specified in svg file
self.set_stroke(background=self.draw_stroke_behind_fill)
self.set_gloss(self.gloss)
self.set_flat_stroke(self.flat_stroke)
return self
self.generate_mobject()
SVG_HASH_TO_MOB_MAP[hash_val] = self.copy()
def init_points(self):
with open(self.file_path, "r") as svg_file:
svg_string = svg_file.read()
# Create a temporary svg file to dump modified svg to be parsed
modified_svg_string = self.modify_svg_file(svg_string)
modified_file_path = self.file_path.replace(".svg", "_.svg")
with open(modified_file_path, "w") as modified_svg_file:
modified_svg_file.write(modified_svg_string)
# `color` attribute handles `currentColor` keyword
if self.fill_color:
color = self.fill_color
elif self.color:
color = self.color
else:
color = "black"
shapes = se.SVG.parse(
modified_file_path,
color=color
@property
def hash_seed(self):
# Returns data which can uniquely represent the result of `init_points`.
# The hashed value of it is stored as a key in `SVG_HASH_TO_MOB_MAP`.
return (
self.__class__.__name__,
self.svg_default,
self.path_string_config,
self.file_name
)
def generate_mobject(self):
file_path = self.get_file_path()
element_tree = ET.parse(file_path)
new_tree = self.modify_xml_tree(element_tree)
# Create a temporary svg file to dump modified svg to be parsed
modified_file_path = file_path.replace(".svg", "_.svg")
new_tree.write(modified_file_path)
svg = se.SVG.parse(modified_file_path)
os.remove(modified_file_path)
mobjects = self.get_mobjects_from(shapes)
mobjects = self.get_mobjects_from(svg)
self.add(*mobjects)
self.flip(RIGHT) # Flip y
self.scale(0.75)
def modify_svg_file(self, svg_string):
# svgelements cannot handle em, ex units
# Convert them using 1em = 16px, 1ex = 0.5em = 8px
def convert_unit(match_obj):
number = float(match_obj.group(1))
unit = match_obj.group(2)
factor = 16 if unit == "em" else 8
return str(number * factor) + "px"
def get_file_path(self):
if self.file_name is None:
raise Exception("Must specify file for SVGMobject")
return get_full_vector_image_path(self.file_name)
number_pattern = r"([-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?)(ex|em)(?![a-zA-Z])"
result = re.sub(number_pattern, convert_unit, svg_string)
def modify_xml_tree(self, element_tree):
config_style_dict = self.generate_config_style_dict()
style_keys = (
"fill",
"fill-opacity",
"stroke",
"stroke-opacity",
"stroke-width",
"style"
)
root = element_tree.getroot()
root_style_dict = {
k: v for k, v in root.attrib.items()
if k in style_keys
}
# Add a group tag to set style from configuration
style_dict = self.generate_context_values_from_config()
group_tag_begin = "<g " + " ".join([
f"{k}=\"{v}\""
for k, v in style_dict.items()
]) + ">"
group_tag_end = "</g>"
begin_insert_index = re.search(r"<svg[\s\S]*?>", result).end()
end_insert_index = re.search(r"[\s\S]*(</svg\s*>)", result).start(1)
result = "".join([
result[:begin_insert_index],
group_tag_begin,
result[begin_insert_index:end_insert_index],
group_tag_end,
result[end_insert_index:]
])
new_root = ET.Element("svg", {})
config_style_node = ET.SubElement(new_root, "g", config_style_dict)
root_style_node = ET.SubElement(config_style_node, "g", root_style_dict)
root_style_node.extend(root)
return ET.ElementTree(new_root)
return result
def generate_context_values_from_config(self):
def generate_config_style_dict(self):
keys_converting_dict = {
"fill": ("color", "fill_color"),
"fill-opacity": ("opacity", "fill_opacity"),
"stroke": ("color", "stroke_color"),
"stroke-opacity": ("opacity", "stroke_opacity"),
"stroke-width": ("stroke_width",)
}
svg_default_dict = self.svg_default
result = {}
if self.stroke_width is not None:
result["stroke-width"] = self.stroke_width
if self.color is not None:
result["fill"] = result["stroke"] = self.color
if self.fill_color is not None:
result["fill"] = self.fill_color
if self.stroke_color is not None:
result["stroke"] = self.stroke_color
if self.opacity is not None:
result["fill-opacity"] = result["stroke-opacity"] = self.opacity
if self.fill_opacity is not None:
result["fill-opacity"] = self.fill_opacity
if self.stroke_opacity is not None:
result["stroke-opacity"] = self.stroke_opacity
for svg_key, style_keys in keys_converting_dict.items():
for style_key in style_keys:
if svg_default_dict[style_key] is None:
continue
result[svg_key] = str(svg_default_dict[style_key])
return result
def get_mobjects_from(self, shape):
if isinstance(shape, se.Group):
return list(it.chain(*(
self.get_mobjects_from(child)
for child in shape
)))
mob = self.get_mobject_from(shape)
if mob is None:
return []
if isinstance(shape, se.Transformable) and shape.apply:
self.handle_transform(mob, shape.transform)
return [mob]
def get_mobjects_from(self, svg):
result = []
for shape in svg.elements():
if isinstance(shape, se.Group):
continue
mob = self.get_mobject_from(shape)
if mob is None:
continue
if isinstance(shape, se.Transformable) and shape.apply:
self.handle_transform(mob, shape.transform)
result.append(mob)
return result
@staticmethod
def handle_transform(mob, matrix):
@ -265,6 +261,14 @@ class SVGMobject(VMobject):
def text_to_mobject(self, text):
pass
def move_into_position(self):
if self.should_center:
self.center()
if self.height is not None:
self.set_height(self.height)
if self.width is not None:
self.set_width(self.width)
class VMobjectFromSVGPath(VMobject):
CONFIG = {
@ -320,4 +324,4 @@ class VMobjectFromSVGPath(VMobject):
_convert_point_to_3d(*segment.__getattribute__(attr_name))
for attr_name in attr_names
]
func(*points)
func(*points)

View file

@ -5,7 +5,6 @@ import re
from manimlib.constants import *
from manimlib.mobject.geometry import Line
from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.config_ops import digest_config
from manimlib.utils.tex_file_writing import tex_to_svg_file
@ -16,55 +15,50 @@ from manimlib.utils.tex_file_writing import display_during_execution
SCALE_FACTOR_PER_FONT_POINT = 0.001
tex_string_with_color_to_mob_map = {}
class SingleStringTex(VMobject):
class SingleStringTex(SVGMobject):
CONFIG = {
"height": None,
"fill_opacity": 1.0,
"stroke_width": 0,
"should_center": True,
"svg_default": {
"color": WHITE,
},
"path_string_config": {
"should_subdivide_sharp_curves": True,
"should_remove_null_curves": True,
},
"font_size": 48,
"height": None,
"organize_left_to_right": False,
"alignment": "\\centering",
"math_mode": True,
"organize_left_to_right": False,
}
def __init__(self, tex_string, **kwargs):
super().__init__(**kwargs)
assert(isinstance(tex_string, str))
assert isinstance(tex_string, str)
self.tex_string = tex_string
if tex_string not in tex_string_with_color_to_mob_map:
full_tex = self.get_tex_file_body(tex_string)
filename = tex_to_svg_file(full_tex)
svg_mob = SVGMobject(
filename,
height=None,
color=self.color,
stroke_width=self.stroke_width,
path_string_config={
"should_subdivide_sharp_curves": True,
"should_remove_null_curves": True,
}
)
tex_string_with_color_to_mob_map[(self.color, tex_string)] = svg_mob
self.add(*(
sm.copy()
for sm in tex_string_with_color_to_mob_map[(self.color, tex_string)]
))
self.init_colors()
super().__init__(**kwargs)
if self.height is None:
self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size)
if self.organize_left_to_right:
self.organize_submobjects_left_to_right()
def init_colors(self):
self.set_stroke(background=self.draw_stroke_behind_fill)
self.set_gloss(self.gloss)
self.set_flat_stroke(self.flat_stroke)
return self
@property
def hash_seed(self):
return (
self.__class__.__name__,
self.svg_default,
self.path_string_config,
self.tex_string,
self.alignment,
self.math_mode
)
def get_file_path(self):
full_tex = self.get_tex_file_body(self.tex_string)
with display_during_execution(f"Writing \"{self.tex_string}\""):
file_path = tex_to_svg_file(full_tex)
return file_path
def get_tex_file_body(self, tex_string):
new_tex = self.get_modified_expression(tex_string)

View file

@ -71,8 +71,6 @@ class Text(SVGMobject):
PangoUtils.remove_last_M(file_name)
self.remove_empty_path(file_name)
SVGMobject.__init__(self, file_name, **kwargs)
if self.color:
self.set_fill(self.color)
self.text = text
if self.disable_ligatures:
self.apply_space_chars()

View file

@ -63,7 +63,7 @@ class Scene(object):
# Items associated with interaction
self.mouse_point = Point()
self.mouse_drag_point = Point()
self.hold_on_wait = not self.presenter_mode
self.hold_on_wait = self.presenter_mode
# Much nicer to work with deterministic scenes
if self.random_seed is not None:
@ -629,9 +629,9 @@ class Scene(object):
self.camera.frame.to_default_state()
elif char == "q":
self.quit_interaction = True
elif char == " ":
elif char == " " or symbol == 65363: # Space or right arrow
self.hold_on_wait = False
elif char == "e":
elif char == "e" and modifiers == 3: # ctrl + shift + e
self.embed(close_scene_on_exit=False)
def on_resize(self, width: int, height: int):

View file

@ -139,3 +139,14 @@ def remove_nones(sequence):
def concatenate_lists(*list_of_lists):
return [item for l in list_of_lists for item in l]
def hash_obj(obj):
if isinstance(obj, dict):
new_obj = {k: hash_obj(v) for k, v in obj.items()}
return hash(tuple(frozenset(sorted(new_obj.items()))))
if isinstance(obj, (set, tuple, list)):
return hash(tuple(hash_obj(e) for e in obj))
return hash(obj)

View file

@ -126,6 +126,9 @@ def dvi_to_svg(dvi_file, regen_if_exists=False):
def display_during_execution(message):
# Only show top line
to_print = message.split("\n")[0]
max_characters = os.get_terminal_size().columns - 1
if len(to_print) > max_characters:
to_print = to_print[:max_characters - 3] + "..."
try:
print(to_print, end="\r")
yield