Refactor Tex

This commit is contained in:
Michael W 2021-11-12 21:22:42 +08:00 committed by GitHub
parent da1cc44d90
commit 1b695e1c19
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 333 additions and 354 deletions

View file

@ -6,7 +6,6 @@ from manimlib.constants import *
from manimlib.animation.fading import FadeIn from manimlib.animation.fading import FadeIn
from manimlib.animation.growing import GrowFromCenter from manimlib.animation.growing import GrowFromCenter
from manimlib.mobject.svg.tex_mobject import Tex from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.mobject.svg.tex_mobject import SingleStringTex
from manimlib.mobject.svg.tex_mobject import TexText from manimlib.mobject.svg.tex_mobject import TexText
from manimlib.mobject.svg.text_mobject import Text from manimlib.mobject.svg.text_mobject import Text
from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VMobject
@ -14,7 +13,7 @@ from manimlib.utils.config_ops import digest_config
from manimlib.utils.space_ops import get_norm from manimlib.utils.space_ops import get_norm
class Brace(SingleStringTex): class Brace(Tex):
CONFIG = { CONFIG = {
"buff": 0.2, "buff": 0.2,
"tex_string": r"\underbrace{\qquad}" "tex_string": r"\underbrace{\qquad}"

View file

@ -1,352 +1,332 @@
from functools import reduce from functools import reduce
import operator as op import hashlib
import re import operator as op
import re
from manimlib.constants import * from types import MethodType
from manimlib.mobject.geometry import Line
from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.constants import *
from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.geometry import Line
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.utils.config_ops import digest_config from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.tex_file_writing import tex_to_svg_file from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.tex_file_writing import get_tex_config from manimlib.utils.config_ops import digest_config
from manimlib.utils.tex_file_writing import display_during_execution from manimlib.utils.tex_file_writing import tex_to_svg_file
from manimlib.utils.tex_file_writing import get_tex_config
from manimlib.utils.tex_file_writing import display_during_execution
SCALE_FACTOR_PER_FONT_POINT = 0.001
SCALE_FACTOR_PER_FONT_POINT = 0.001
tex_string_to_mob_map = {}
tex_hash_to_mob_map = {}
class SingleStringTex(VMobject):
CONFIG = {
"fill_opacity": 1.0, class SVGTex(SVGMobject):
"stroke_width": 0, CONFIG = {
"should_center": True, "height": None,
"font_size": 48, # The hierachy structure is needed for the `break_up_by_substrings` method
"height": None, "unpack_groups": False,
"organize_left_to_right": False, "path_string_config": {
"alignment": "\\centering", "should_subdivide_sharp_curves": True,
"math_mode": True, "should_remove_null_curves": True,
} }
}
def __init__(self, tex_string, **kwargs):
super().__init__(**kwargs) def __init__(self, tex_obj, **kwargs):
assert(isinstance(tex_string, str)) self.tex_obj = tex_obj
self.tex_string = tex_string full_tex = self.get_tex_file_body()
if tex_string not in tex_string_to_mob_map: filename = tex_to_svg_file(full_tex)
with display_during_execution(f" Writing \"{tex_string}\""): super().__init__(filename, **kwargs)
full_tex = self.get_tex_file_body(tex_string) self.break_up_by_substrings()
filename = tex_to_svg_file(full_tex) self.init_colors()
svg_mob = SVGMobject(
filename, def get_mobjects_from(self, element):
height=None, result = super().get_mobjects_from(element)
path_string_config={ if len(result) == 0:
"should_subdivide_sharp_curves": True, return result
"should_remove_null_curves": True, result[0].fill_color = None
} try:
) fill_color = element.getAttribute("fill")
tex_string_to_mob_map[tex_string] = svg_mob if fill_color:
self.add(*( result[0].fill_color = fill_color
sm.copy() except:
for sm in tex_string_to_mob_map[tex_string] pass
)) return result
self.init_colors()
def get_tex_file_body(self):
if self.height is None: new_tex = self.get_modified_expression()
self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) if self.tex_obj.math_mode:
if self.organize_left_to_right: new_tex = "\\begin{align*}\n" + new_tex + "\n\\end{align*}"
self.organize_submobjects_left_to_right()
new_tex = self.tex_obj.alignment + "\n" + new_tex
def get_tex_file_body(self, tex_string):
new_tex = self.get_modified_expression(tex_string) tex_config = get_tex_config()
if self.math_mode: return tex_config["tex_body"].replace(
new_tex = "\\begin{align*}\n" + new_tex + "\n\\end{align*}" tex_config["text_to_replace"],
new_tex
new_tex = self.alignment + "\n" + new_tex )
tex_config = get_tex_config() def get_modified_expression(self):
return tex_config["tex_body"].replace( tex_components = []
tex_config["text_to_replace"], to_isolate = self.tex_obj.substrings_to_isolate
new_tex for tex_substring in self.tex_obj.tex_substrings:
) if tex_substring not in to_isolate:
tex_components.append(tex_substring)
def get_modified_expression(self, tex_string): else:
return self.modify_special_strings(tex_string.strip()) color_index = to_isolate.index(tex_substring)
tex_components.append("".join([
def modify_special_strings(self, tex): "{\\color[RGB]{",
tex = tex.strip() str(self.get_nth_color_tuple(color_index))[1:-1],
should_add_filler = reduce(op.or_, [ "}",
# Fraction line needs something to be over tex_substring,
tex == "\\over", "}"
tex == "\\overline", ]))
# Makesure sqrt has overbar return self.tex_obj.arg_separator.join(tex_components)
tex == "\\sqrt",
tex == "\\sqrt{", def break_up_by_substrings(self):
# Need to add blank subscript or superscript """
tex.endswith("_"), Reorganize existing submojects one layer
tex.endswith("^"), deeper based on the structure of tex_substrings (as a list
tex.endswith("dot"), of tex_substrings)
]) """
if should_add_filler: if len(self.tex_obj.tex_substrings) == 1:
filler = "{\\quad}" submob = self.copy()
tex += filler self.set_submobjects([submob])
return self
if tex == "\\substack": new_submobjects = []
tex = "\\quad" new_submobject_components = []
for part in self.submobjects:
if tex == "": if part.fill_color is not None:
tex = "\\quad" if new_submobject_components:
new_submobjects.append(VGroup(*new_submobject_components))
# To keep files from starting with a line break new_submobject_components = []
if tex.startswith("\\\\"): new_submobjects.append(part)
tex = tex.replace("\\\\", "\\quad\\\\") else:
new_submobject_components.append(part)
tex = self.balance_braces(tex) if new_submobject_components:
new_submobjects.append(VGroup(*new_submobject_components))
# Handle imbalanced \left and \right
num_lefts, num_rights = [ for submob, tex_substring in zip(new_submobjects, self.tex_obj.tex_substrings):
len([ fill_color = submob.fill_color
s for s in tex.split(substr)[1:] if fill_color is not None:
if s and s[0] in "(){}[]|.\\" submob_tex_string = self.tex_obj.substrings_to_isolate[int(fill_color[1:], 16) - 1]
]) else:
for substr in ("\\left", "\\right") submob_tex_string = tex_substring
] submob.tex_string = submob_tex_string
if num_lefts != num_rights: # Prevent methods and classes using `get_tex()` from breaking.
tex = tex.replace("\\left", "\\big") submob.get_tex = MethodType(lambda sm: sm.tex_string, submob)
tex = tex.replace("\\right", "\\big") self.set_submobjects(new_submobjects)
for context in ["array"]: return self
begin_in = ("\\begin{%s}" % context) in tex
end_in = ("\\end{%s}" % context) in tex @staticmethod
if begin_in ^ end_in: def get_nth_color_tuple(n): ## TODO: Refactor
# Just turn this into a blank string, # Get a unique color different from black,
# which means caller should leave a # or the svg file will not include the color information.
# stray \\begin{...} with other symbols return (
tex = "" (n + 1) // 256 // 256,
return tex (n + 1) // 256 % 256,
(n + 1) % 256
def balance_braces(self, tex): )
"""
Makes Tex resiliant to unmatched braces
""" class Tex(VMobject):
num_unclosed_brackets = 0 CONFIG = {
for char in tex: "fill_opacity": 1.0,
if char == "{": "stroke_width": 0,
num_unclosed_brackets += 1 "should_center": True,
elif char == "}": "font_size": 48,
if num_unclosed_brackets == 0: "height": None,
tex = "{" + tex "organize_left_to_right": False,
else: "alignment": "\\centering",
num_unclosed_brackets -= 1 "math_mode": True,
tex += num_unclosed_brackets * "}" "arg_separator": "",
return tex "isolate": [],
"tex_to_color_map": {},
def get_tex(self): }
return self.tex_string
def __init__(self, *tex_strings, **kwargs):
def organize_submobjects_left_to_right(self): super().__init__(**kwargs)
self.sort(lambda p: p[0]) self.substrings_to_isolate = [*self.isolate, *self.tex_to_color_map.keys()]
return self self.tex_substrings = self.break_up_tex_strings(tex_strings)
tex_string = self.arg_separator.join(tex_strings)
self.tex_string = tex_string
class Tex(SingleStringTex):
CONFIG = { hash_val = self.tex2hash()
"arg_separator": "", if hash_val not in tex_hash_to_mob_map: ## TODO
"isolate": [], with display_during_execution(f" Writing \"{tex_string}\""):
"tex_to_color_map": {}, svg_mob = SVGTex(self)
} tex_hash_to_mob_map[hash_val] = svg_mob
self.add(*(
def __init__(self, *tex_strings, **kwargs): sm.copy()
digest_config(self, kwargs) for sm in tex_hash_to_mob_map[hash_val]
self.tex_strings = self.break_up_tex_strings(tex_strings) ))
full_string = self.arg_separator.join(self.tex_strings) self.init_colors()
super().__init__(full_string, **kwargs) self.set_color_by_tex_to_color_map(self.tex_to_color_map, substring=False)
self.break_up_by_substrings()
self.set_color_by_tex_to_color_map(self.tex_to_color_map) if self.height is None:
self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size)
if self.organize_left_to_right: if self.organize_left_to_right:
self.organize_submobjects_left_to_right() self.organize_submobjects_left_to_right()
def break_up_tex_strings(self, tex_strings): def tex2hash(self):
# Separate out any strings specified in the isolate id_str = self.tex_string + str(self.substrings_to_isolate)
# or tex_to_color_map lists. hasher = hashlib.sha256()
substrings_to_isolate = [*self.isolate, *self.tex_to_color_map.keys()] hasher.update(id_str.encode())
if len(substrings_to_isolate) == 0: return hasher.hexdigest()[:16]
return tex_strings
patterns = ( def break_up_tex_strings(self, tex_strings):
"({})".format(re.escape(ss)) # Separate out any strings specified in the isolate
for ss in substrings_to_isolate # or tex_to_color_map lists.
) if len(self.substrings_to_isolate) == 0:
pattern = "|".join(patterns) return tex_strings
pieces = [] patterns = (
for s in tex_strings: "({})".format(re.escape(ss))
if pattern: for ss in self.substrings_to_isolate
pieces.extend(re.split(pattern, s)) )
else: pattern = "|".join(patterns)
pieces.append(s) pieces = []
return list(filter(lambda s: s, pieces)) for s in tex_strings:
if pattern:
def break_up_by_substrings(self): pieces.extend(re.split(pattern, s))
""" else:
Reorganize existing submojects one layer pieces.append(s)
deeper based on the structure of tex_strings (as a list return list(filter(lambda s: s, pieces))
of tex_strings)
"""
if len(self.tex_strings) == 1: def organize_submobjects_left_to_right(self):
submob = self.copy() self.sort(lambda p: p[0])
self.set_submobjects([submob]) return self
return self
new_submobjects = [] def get_parts_by_tex(self, tex, substring=True, case_sensitive=True):
curr_index = 0 def test(tex1, tex2):
config = dict(self.CONFIG) if not case_sensitive:
config["alignment"] = "" tex1 = tex1.lower()
for tex_string in self.tex_strings: tex2 = tex2.lower()
tex_string = tex_string.strip() if substring:
if len(tex_string) == 0: return tex1 in tex2
continue else:
sub_tex_mob = SingleStringTex(tex_string, **config) return tex1 == tex2
num_submobs = len(sub_tex_mob)
if num_submobs == 0: return VGroup(*[
continue mob for mob in self.submobjects if test(tex, mob.get_tex())
new_index = curr_index + num_submobs ])
sub_tex_mob.set_submobjects(self[curr_index:new_index])
new_submobjects.append(sub_tex_mob) def get_part_by_tex(self, tex, **kwargs):
curr_index = new_index all_parts = self.get_parts_by_tex(tex, **kwargs)
self.set_submobjects(new_submobjects) return all_parts[0] if all_parts else None
return self
def set_color_by_tex(self, tex, color, **kwargs):
def get_parts_by_tex(self, tex, substring=True, case_sensitive=True): self.get_parts_by_tex(tex, **kwargs).set_color(color)
def test(tex1, tex2): return self
if not case_sensitive:
tex1 = tex1.lower() def set_color_by_tex_to_color_map(self, tex_to_color_map, **kwargs):
tex2 = tex2.lower() for tex, color in list(tex_to_color_map.items()):
if substring: self.set_color_by_tex(tex, color, **kwargs)
return tex1 in tex2 return self
else:
return tex1 == tex2 def index_of_part(self, part, start=0):
return self.submobjects.index(part, start)
return VGroup(*filter(
lambda m: isinstance(m, SingleStringTex) and test(tex, m.get_tex()), def index_of_part_by_tex(self, tex, start=0, **kwargs):
self.submobjects part = self.get_part_by_tex(tex, **kwargs)
)) return self.index_of_part(part, start)
def get_part_by_tex(self, tex, **kwargs): def slice_by_tex(self, start_tex=None, stop_tex=None, **kwargs):
all_parts = self.get_parts_by_tex(tex, **kwargs) if start_tex is None:
return all_parts[0] if all_parts else None start_index = 0
else:
def set_color_by_tex(self, tex, color, **kwargs): start_index = self.index_of_part_by_tex(start_tex, **kwargs)
self.get_parts_by_tex(tex, **kwargs).set_color(color)
return self if stop_tex is None:
return self[start_index:]
def set_color_by_tex_to_color_map(self, tex_to_color_map, **kwargs): else:
for tex, color in list(tex_to_color_map.items()): stop_index = self.index_of_part_by_tex(stop_tex, start=start_index, **kwargs)
self.set_color_by_tex(tex, color, **kwargs) return self[start_index:stop_index]
return self
def sort_alphabetically(self):
def index_of_part(self, part, start=0): self.submobjects.sort(key=lambda m: m.get_tex())
return self.submobjects.index(part, start)
def set_bstroke(self, color=BLACK, width=4):
def index_of_part_by_tex(self, tex, start=0, **kwargs): self.set_stroke(color, width, background=True)
part = self.get_part_by_tex(tex, **kwargs) return self
return self.index_of_part(part, start)
def slice_by_tex(self, start_tex=None, stop_tex=None, **kwargs): class TexText(Tex):
if start_tex is None: CONFIG = {
start_index = 0 "math_mode": False,
else: "arg_separator": "",
start_index = self.index_of_part_by_tex(start_tex, **kwargs) }
if stop_tex is None:
return self[start_index:] class BulletedList(TexText):
else: CONFIG = {
stop_index = self.index_of_part_by_tex(stop_tex, start=start_index, **kwargs) "buff": MED_LARGE_BUFF,
return self[start_index:stop_index] "dot_scale_factor": 2,
"alignment": "",
def sort_alphabetically(self): }
self.submobjects.sort(key=lambda m: m.get_tex())
def __init__(self, *items, **kwargs):
def set_bstroke(self, color=BLACK, width=4): line_separated_items = [s + "\\\\" for s in items]
self.set_stroke(color, width, background=True) TexText.__init__(self, *line_separated_items, **kwargs)
return self for part in self:
dot = Tex("\\cdot").scale(self.dot_scale_factor)
dot.next_to(part[0], LEFT, SMALL_BUFF)
class TexText(Tex): part.add_to_back(dot)
CONFIG = { self.arrange(
"math_mode": False, DOWN,
"arg_separator": "", aligned_edge=LEFT,
} buff=self.buff
)
class BulletedList(TexText): def fade_all_but(self, index_or_string, opacity=0.5):
CONFIG = { arg = index_or_string
"buff": MED_LARGE_BUFF, if isinstance(arg, str):
"dot_scale_factor": 2, part = self.get_part_by_tex(arg)
"alignment": "", elif isinstance(arg, int):
} part = self.submobjects[arg]
else:
def __init__(self, *items, **kwargs): raise Exception("Expected int or string, got {0}".format(arg))
line_separated_items = [s + "\\\\" for s in items] for other_part in self.submobjects:
TexText.__init__(self, *line_separated_items, **kwargs) if other_part is part:
for part in self: other_part.set_fill(opacity=1)
dot = Tex("\\cdot").scale(self.dot_scale_factor) else:
dot.next_to(part[0], LEFT, SMALL_BUFF) other_part.set_fill(opacity=opacity)
part.add_to_back(dot)
self.arrange(
DOWN, class TexFromPresetString(Tex):
aligned_edge=LEFT, CONFIG = {
buff=self.buff # To be filled by subclasses
) "tex": None,
"color": None,
def fade_all_but(self, index_or_string, opacity=0.5): }
arg = index_or_string
if isinstance(arg, str): def __init__(self, **kwargs):
part = self.get_part_by_tex(arg) digest_config(self, kwargs)
elif isinstance(arg, int): Tex.__init__(self, self.tex, **kwargs)
part = self.submobjects[arg] self.set_color(self.color)
else:
raise Exception("Expected int or string, got {0}".format(arg))
for other_part in self.submobjects: class Title(TexText):
if other_part is part: CONFIG = {
other_part.set_fill(opacity=1) "scale_factor": 1,
else: "include_underline": True,
other_part.set_fill(opacity=opacity) "underline_width": FRAME_WIDTH - 2,
# This will override underline_width
"match_underline_width_to_text": False,
class TexFromPresetString(Tex): "underline_buff": MED_SMALL_BUFF,
CONFIG = { }
# To be filled by subclasses
"tex": None, def __init__(self, *text_parts, **kwargs):
"color": None, TexText.__init__(self, *text_parts, **kwargs)
} self.scale(self.scale_factor)
self.to_edge(UP)
def __init__(self, **kwargs): if self.include_underline:
digest_config(self, kwargs) underline = Line(LEFT, RIGHT)
Tex.__init__(self, self.tex, **kwargs) underline.next_to(self, DOWN, buff=self.underline_buff)
self.set_color(self.color) if self.match_underline_width_to_text:
underline.match_width(self)
else:
class Title(TexText): underline.set_width(self.underline_width)
CONFIG = { self.add(underline)
"scale_factor": 1, self.underline = underline
"include_underline": True,
"underline_width": FRAME_WIDTH - 2,
# This will override underline_width
"match_underline_width_to_text": False,
"underline_buff": MED_SMALL_BUFF,
}
def __init__(self, *text_parts, **kwargs):
TexText.__init__(self, *text_parts, **kwargs)
self.scale(self.scale_factor)
self.to_edge(UP)
if self.include_underline:
underline = Line(LEFT, RIGHT)
underline.next_to(self, DOWN, buff=self.underline_buff)
if self.match_underline_width_to_text:
underline.match_width(self)
else:
underline.set_width(self.underline_width)
self.add(underline)
self.underline = underline