Merge branch 'master' of github.com:3b1b/manim into video-work

This commit is contained in:
Grant Sanderson 2022-02-13 14:54:18 -08:00
commit e09a264cab
19 changed files with 865 additions and 1068 deletions

View file

@ -1,14 +1,23 @@
Changelog
=========
Unreleased
----------
v1.4.1
------
Fixed bugs
^^^^^^^^^^
- `#1724 <https://github.com/3b1b/manim/pull/1724>`__: Temporarily fixed boolean operations' bug
- `d2e0811 <https://github.com/3b1b/manim/commit/d2e0811285f7908e71a65e664fec88b1af1c6144>`__: Import ``Iterable`` from ``collections.abc`` instead of ``collections`` which is deprecated since python 3.9
v1.4.0
------
Fixed bugs
^^^^^^^^^^
- `f1996f8 <https://github.com/3b1b/manim/pull/1697/commits/f1996f8479f9e33d626b3b66e9eb6995ce231d86>`__: Temporarily fixed ``Lightbulb``
- `#1712 <https://github.com/3b1b/manim/pull/1712>`__: Fixed some bugs of ``SVGMobject``
- `#1717 <https://github.com/3b1b/manim/pull/1717>`__: Fixed some bugs of SVG path string parser
- `#1720 <https://github.com/3b1b/manim/pull/1720>`__: Fixed some bugs of ``MTex``
New Features
^^^^^^^^^^^^
@ -16,6 +25,8 @@ New Features
- `#1704 <https://github.com/3b1b/manim/pull/1704>`__: Added ``lable_buff`` config parameter for ``Brace``
- `#1712 <https://github.com/3b1b/manim/pull/1712>`__: Added support for ``rotate skewX skewY`` transform in SVG
- `#1717 <https://github.com/3b1b/manim/pull/1717>`__: Added style support to ``SVGMobject``
- `#1719 <https://github.com/3b1b/manim/pull/1719>`__: Added parser to <style> element of SVG
- `#1719 <https://github.com/3b1b/manim/pull/1719>`__: Added support for <line> element in ``SVGMobject``
Refactor
^^^^^^^^
@ -24,6 +35,12 @@ Refactor
- `#1712 <https://github.com/3b1b/manim/pull/1712>`__: Refactored SVG path string parser
- `#1712 <https://github.com/3b1b/manim/pull/1712>`__: Allowed ``Mobject.scale`` to receive iterable ``scale_factor``
- `#1716 <https://github.com/3b1b/manim/pull/1716>`__: Refactored ``MTex``
- `#1721 <https://github.com/3b1b/manim/pull/1721>`__: Improved config helper (``manimgl --config``)
- `#1723 <https://github.com/3b1b/manim/pull/1723>`__: Refactored ``MTex``
Dependencies
^^^^^^^^^^^^
- `#1719 <https://github.com/3b1b/manim/pull/1719>`__: Added dependency on python package `cssselect2 <https://github.com/Kozea/cssselect2>`__
v1.3.0
@ -88,7 +105,7 @@ Refactor
Dependencies
^^^^^^^^^^^^
- `#1675 <https://github.com/3b1b/manim/pull/1675>`__: Added dependency on python packages `skia-pathops <https://github.com/fonttools/skia-pathops>`__
- `#1675 <https://github.com/3b1b/manim/pull/1675>`__: Added dependency on python package `skia-pathops <https://github.com/fonttools/skia-pathops>`__
v1.2.0
------

View file

@ -1,7 +1,6 @@
from manimlib.animation.animation import Animation
from manimlib.animation.composition import Succession
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.mobject.mobject import Group
from manimlib.utils.bezier import integer_interpolate
from manimlib.utils.config_ops import digest_config
from manimlib.utils.rate_functions import linear

View file

@ -17,7 +17,6 @@ class Broadcast(LaggedStart):
"remover": True,
"lag_ratio": 0.2,
"run_time": 3,
"remover": True,
}
def __init__(self, focal_point, **kwargs):

View file

@ -1,16 +1,20 @@
import numpy as np
import itertools as it
from manimlib.animation.composition import AnimationGroup
from manimlib.animation.fading import FadeTransformPieces
from manimlib.animation.fading import FadeInFromPoint
from manimlib.animation.fading import FadeOutToPoint
from manimlib.animation.transform import ReplacementTransform
from manimlib.animation.transform import Transform
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Group
from manimlib.mobject.svg.mtex_mobject import MTex
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.config_ops import digest_config
from manimlib.utils.iterables import remove_list_redundancies
class TransformMatchingParts(AnimationGroup):
@ -139,3 +143,108 @@ class TransformMatchingTex(TransformMatchingParts):
@staticmethod
def get_mobject_key(mobject):
return mobject.get_tex()
class TransformMatchingMTex(AnimationGroup):
CONFIG = {
"key_map": dict(),
}
def __init__(self, source_mobject, target_mobject, **kwargs):
digest_config(self, kwargs)
assert isinstance(source_mobject, MTex)
assert isinstance(target_mobject, MTex)
anims = []
rest_source_submobs = source_mobject.submobjects.copy()
rest_target_submobs = target_mobject.submobjects.copy()
def add_anim_from(anim_class, func, source_attr, target_attr=None):
if target_attr is None:
target_attr = source_attr
source_parts = func(source_mobject, source_attr)
target_parts = func(target_mobject, target_attr)
filtered_source_parts = [
submob_part for submob_part in source_parts
if all([
submob in rest_source_submobs
for submob in submob_part
])
]
filtered_target_parts = [
submob_part for submob_part in target_parts
if all([
submob in rest_target_submobs
for submob in submob_part
])
]
if not (filtered_source_parts and filtered_target_parts):
return
anims.append(anim_class(
VGroup(*filtered_source_parts),
VGroup(*filtered_target_parts),
**kwargs
))
for submob in it.chain(*filtered_source_parts):
rest_source_submobs.remove(submob)
for submob in it.chain(*filtered_target_parts):
rest_target_submobs.remove(submob)
def get_submobs_from_keys(mobject, keys):
if not isinstance(keys, tuple):
keys = (keys,)
indices = []
for key in keys:
if isinstance(key, int):
indices.append(key)
elif isinstance(key, range):
indices.extend(key)
elif isinstance(key, str):
all_parts = mobject.get_parts_by_tex(key)
indices.extend(it.chain(*[
mobject.indices_of_part(part) for part in all_parts
]))
else:
raise TypeError(key)
return VGroup(VGroup(*[
mobject[i] for i in remove_list_redundancies(indices)
]))
for source_key, target_key in self.key_map.items():
add_anim_from(
ReplacementTransform, get_submobs_from_keys,
source_key, target_key
)
common_specified_substrings = sorted(list(
set(source_mobject.get_specified_substrings()).intersection(
target_mobject.get_specified_substrings()
)
), key=len, reverse=True)
for part_tex_string in common_specified_substrings:
add_anim_from(
FadeTransformPieces, MTex.get_parts_by_tex, part_tex_string
)
common_submob_tex_strings = {
source_submob.get_tex() for source_submob in source_mobject
}.intersection({
target_submob.get_tex() for target_submob in target_mobject
})
for tex_string in common_submob_tex_strings:
add_anim_from(
FadeTransformPieces,
lambda mobject, attr: VGroup(*[
VGroup(mob) for mob in mobject
if mob.get_tex() == attr
]),
tex_string
)
anims.append(FadeOutToPoint(
VGroup(*rest_source_submobs), target_mobject.get_center(), **kwargs
))
anims.append(FadeInFromPoint(
VGroup(*rest_target_submobs), source_mobject.get_center(), **kwargs
))
super().__init__(*anims)

View file

@ -41,7 +41,7 @@ def _convert_skia_path_to_vmobject(path, vmobject):
vmobject.add_quadratic_bezier_curve_to(*points)
else:
raise Exception(f"Unsupported: {path_verb}")
return vmobject
return vmobject.reverse_points()
class Union(VMobject):

View file

@ -1,3 +1,5 @@
from isosurfaces import plot_isoline
from manimlib.constants import *
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.config_ops import digest_config
@ -70,3 +72,40 @@ class FunctionGraph(ParametricCurve):
def get_point_from_function(self, x):
return self.t_func(x)
class ImplicitFunction(VMobject):
CONFIG = {
"x_range": [-FRAME_X_RADIUS, FRAME_X_RADIUS],
"y_range": [-FRAME_Y_RADIUS, FRAME_Y_RADIUS],
"min_depth": 5,
"max_quads": 1500,
"use_smoothing": True
}
def __init__(self, func, x_range=None, y_range=None, **kwargs):
digest_config(self, kwargs)
self.function = func
super().__init__(**kwargs)
def init_points(self):
p_min, p_max = (
np.array([self.x_range[0], self.y_range[0]]),
np.array([self.x_range[1], self.y_range[1]]),
)
curves = plot_isoline(
fn=lambda u: self.function(u[0], u[1]),
pmin=p_min,
pmax=p_max,
min_depth=self.min_depth,
max_quads=self.max_quads,
) # returns a list of lists of 2D points
curves = [
np.pad(curve, [(0, 0), (0, 1)]) for curve in curves if curve != []
] # add z coord as 0
for curve in curves:
self.start_new_path(curve[0])
self.add_points_as_corners(curve[1:])
if self.use_smoothing:
self.make_smooth()
return self

View file

@ -608,8 +608,8 @@ class Arrow(Line):
self.insert_tip_anchor()
return self
def init_colors(self, override=True):
super().init_colors(override)
def init_colors(self):
super().init_colors()
self.create_tip_with_stroke_width()
def get_arc_length(self):
@ -849,6 +849,11 @@ class Polygon(VMobject):
return self
class Polyline(Polygon):
def init_points(self):
self.set_points_as_corners(self.vertices)
class RegularPolygon(Polygon):
CONFIG = {
"start_angle": None,

View file

@ -112,7 +112,7 @@ class Matrix(VMobject):
"\\left[",
"\\begin{array}{c}",
*height * ["\\quad \\\\"],
"\\end{array}"
"\\end{array}",
"\\right]",
]))[0]
bracket_pair.set_height(

View file

@ -4,7 +4,7 @@ import random
import sys
import moderngl
from functools import wraps
from collections import Iterable
from collections.abc import Iterable
import numpy as np
@ -109,8 +109,8 @@ class Mobject(object):
"reflectiveness": self.reflectiveness,
}
def init_colors(self, override=True):
self.set_color(self.color, self.opacity, override)
def init_colors(self):
self.set_color(self.color, self.opacity)
def init_points(self):
# Typically implemented in subclass, unlpess purposefully left blank

View file

@ -1,11 +1,12 @@
import itertools as it
import re
import sys
from types import MethodType
from manimlib.constants import BLACK
from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.color import color_to_int_rgb
from manimlib.utils.iterables import adjacent_pairs
from manimlib.utils.iterables import remove_list_redundancies
from manimlib.utils.tex_file_writing import tex_to_svg_file
@ -20,16 +21,14 @@ SCALE_FACTOR_PER_FONT_POINT = 0.001
TEX_HASH_TO_MOB_MAP = {}
def _contains(span_0, span_1):
return span_0[0] <= span_1[0] and span_1[1] <= span_0[1]
def _get_neighbouring_pairs(iterable):
return list(adjacent_pairs(iterable))[:-1]
class _PlainTex(SVGMobject):
class _TexSVG(SVGMobject):
CONFIG = {
"color": BLACK,
"stroke_width": 0,
"height": None,
"path_string_config": {
"should_subdivide_sharp_curves": True,
@ -38,239 +37,367 @@ class _PlainTex(SVGMobject):
}
class _LabelledTex(_PlainTex):
def __init__(self, file_name=None, **kwargs):
super().__init__(file_name, **kwargs)
for glyph in self:
glyph.glyph_label = _LabelledTex.color_str_to_label(glyph.fill_color)
@staticmethod
def color_str_to_label(color_str):
if len(color_str) == 4:
# "#RGB" => "#RRGGBB"
color_str = "#" + "".join([c * 2 for c in color_str[1:]])
return int(color_str[1:], 16) - 1
def get_mobjects_from(self, element, style):
result = super().get_mobjects_from(element, style)
for mob in result:
if not hasattr(mob, "glyph_label"):
mob.glyph_label = -1
try:
color_str = element.getAttribute("fill")
if color_str:
glyph_label = _LabelledTex.color_str_to_label(color_str)
for mob in result:
mob.glyph_label = glyph_label
except:
pass
return result
class _TexSpan(object):
def __init__(self, script_type, label):
# `script_type`: 0 for normal, 1 for subscript, 2 for superscript.
# Only those spans with `script_type == 0` will be colored.
self.script_type = script_type
self.label = label
self.containing_labels = []
def __repr__(self):
return "_TexSpan(" + ", ".join([
attrib_name + "=" + str(getattr(self, attrib_name))
for attrib_name in ["script_type", "label", "containing_labels"]
]) + ")"
class _TexParser(object):
def __init__(self, tex_string, additional_substrings):
self.tex_string = tex_string
self.tex_spans_dict = {}
self.whitespace_indices = self.get_whitespace_indices()
self.backslash_indices = self.get_backslash_indices()
self.script_indices = self.get_script_indices()
self.brace_indices_dict = self.get_brace_indices_dict()
self.tex_span_list = []
self.script_span_to_char_dict = {}
self.script_span_to_tex_span_dict = {}
self.neighbouring_script_span_pairs = []
self.specified_substrings = []
self.current_label = 0
self.brace_index_pairs = self.get_brace_index_pairs()
self.add_tex_span((0, len(tex_string)))
self.break_up_by_double_braces()
self.break_up_by_scripts()
self.break_up_by_double_braces()
self.break_up_by_additional_substrings(additional_substrings)
self.check_if_overlap()
self.analyse_containing_labels()
self.specified_substrings = remove_list_redundancies(self.specified_substrings)
self.tex_span_list.sort(key=lambda t: (t[0], -t[1]))
self.specified_substrings = remove_list_redundancies(
self.specified_substrings
)
self.containing_labels_dict = self.get_containing_labels_dict()
@staticmethod
def label_to_color_tuple(rgb):
# Get a unique color different from black,
# or the svg file will not include the color information.
rg, b = divmod(rgb, 256)
r, g = divmod(rg, 256)
return r, g, b
def add_tex_span(self, tex_span):
if tex_span not in self.tex_span_list:
self.tex_span_list.append(tex_span)
def add_tex_span(self, span_tuple, script_type=0, label=-1):
if span_tuple in self.tex_spans_dict:
return
def get_whitespace_indices(self):
return [
match_obj.start()
for match_obj in re.finditer(r"\s", self.tex_string)
]
if script_type == 0:
# Should be additionally labelled.
label = self.current_label
self.current_label += 1
def get_backslash_indices(self):
# Newlines (`\\`) don't count.
return [
match_obj.end() - 1
for match_obj in re.finditer(r"\\+", self.tex_string)
if len(match_obj.group()) % 2 == 1
]
tex_span = _TexSpan(script_type, label)
self.tex_spans_dict[span_tuple] = tex_span
def filter_out_escaped_characters(self, indices):
return list(filter(
lambda index: index - 1 not in self.backslash_indices,
indices
))
def add_specified_substring(self, span_tuple):
substring = self.tex_string[slice(*span_tuple)]
self.specified_substrings.append(substring)
def get_script_indices(self):
return self.filter_out_escaped_characters([
match_obj.start()
for match_obj in re.finditer(r"[_^]", self.tex_string)
])
def get_brace_index_pairs(self):
result = []
left_brace_indices = []
for match_obj in re.finditer(r"(\\*)(\{|\})", self.tex_string):
# Braces following even numbers of backslashes are counted.
if len(match_obj.group(1)) % 2 == 1:
continue
if match_obj.group(2) == "{":
left_brace_index = match_obj.span(2)[0]
left_brace_indices.append(left_brace_index)
def get_brace_indices_dict(self):
tex_string = self.tex_string
indices = self.filter_out_escaped_characters([
match_obj.start()
for match_obj in re.finditer(r"[{}]", tex_string)
])
result = {}
left_brace_indices_stack = []
for index in indices:
if tex_string[index] == "{":
left_brace_indices_stack.append(index)
else:
left_brace_index = left_brace_indices.pop()
right_brace_index = match_obj.span(2)[1]
result.append((left_brace_index, right_brace_index))
if left_brace_indices:
self.raise_tex_parsing_error("unmatched braces")
left_brace_index = left_brace_indices_stack.pop()
result[left_brace_index] = index
return result
def break_up_by_double_braces(self):
# Match paired double braces (`{{...}}`).
skip_pair = False
for prev_span_tuple, span_tuple in _get_neighbouring_pairs(
self.brace_index_pairs
):
if skip_pair:
skip_pair = False
continue
if all([
span_tuple[0] == prev_span_tuple[0] - 1,
span_tuple[1] == prev_span_tuple[1] + 1
]):
self.add_tex_span(span_tuple)
self.add_specified_substring(span_tuple)
skip_pair = True
def break_up_by_scripts(self):
# Match subscripts & superscripts.
tex_string = self.tex_string
brace_indices_dict = dict(self.brace_index_pairs)
for match_obj in re.finditer(r"((?<!\\)(_|\^)\s*)|(\s+(_|\^)\s*)", tex_string):
script_type = 1 if "_" in match_obj.group() else 2
token_begin, token_end = match_obj.span()
if token_end in brace_indices_dict:
content_span = (token_end, brace_indices_dict[token_end])
whitespace_indices = self.whitespace_indices
brace_indices_dict = self.brace_indices_dict
script_spans = []
for script_index in self.script_indices:
script_char = tex_string[script_index]
extended_begin = script_index
while extended_begin - 1 in whitespace_indices:
extended_begin -= 1
script_begin = script_index + 1
while script_begin in whitespace_indices:
script_begin += 1
if script_begin in brace_indices_dict.keys():
script_end = brace_indices_dict[script_begin] + 1
else:
content_match_obj = re.match(r"\w|\\[a-zA-Z]+", tex_string[token_end:])
if not content_match_obj:
self.raise_tex_parsing_error("unclear subscript/superscript")
content_span = tuple([
index + token_end for index in content_match_obj.span()
])
self.add_tex_span(content_span)
label = self.tex_spans_dict[content_span].label
self.add_tex_span(
(token_begin, content_span[1]),
script_type=script_type,
label=label
)
pattern = re.compile(r"[a-zA-Z0-9]|\\[a-zA-Z]+")
match_obj = pattern.match(tex_string, pos=script_begin)
if not match_obj:
script_name = {
"_": "subscript",
"^": "superscript"
}[script_char]
log.warning(
f"Unclear {script_name} detected while parsing. "
"Please use braces to clarify"
)
continue
script_end = match_obj.end()
tex_span = (script_begin, script_end)
script_span = (extended_begin, script_end)
script_spans.append(script_span)
self.add_tex_span(tex_span)
self.script_span_to_char_dict[script_span] = script_char
self.script_span_to_tex_span_dict[script_span] = tex_span
if not script_spans:
return
_, sorted_script_spans = zip(*sorted([
(index, script_span)
for script_span in script_spans
for index in script_span
]))
for span_0, span_1 in _get_neighbouring_pairs(sorted_script_spans):
if span_0[1] == span_1[0]:
self.neighbouring_script_span_pairs.append((span_0, span_1))
def break_up_by_double_braces(self):
# Match paired double braces (`{{...}}`).
tex_string = self.tex_string
reversed_indices_dict = dict(
item[::-1] for item in self.brace_indices_dict.items()
)
skip = False
for prev_right_index, right_index in _get_neighbouring_pairs(
list(reversed_indices_dict.keys())
):
if skip:
skip = False
continue
if right_index != prev_right_index + 1:
continue
left_index = reversed_indices_dict[right_index]
prev_left_index = reversed_indices_dict[prev_right_index]
if left_index != prev_left_index - 1:
continue
tex_span = (left_index, right_index + 1)
self.add_tex_span(tex_span)
self.specified_substrings.append(tex_string[slice(*tex_span)])
skip = True
def break_up_by_additional_substrings(self, additional_substrings):
tex_string = self.tex_string
all_span_tuples = []
for string in additional_substrings:
# Only match non-crossing strings.
for match_obj in re.finditer(re.escape(string), tex_string):
all_span_tuples.append(match_obj.span())
stripped_substrings = sorted(remove_list_redundancies([
string.strip()
for string in additional_substrings
]))
if "" in stripped_substrings:
stripped_substrings.remove("")
script_spans_dict = dict([
span_tuple[::-1]
for span_tuple, tex_span in self.tex_spans_dict.items()
if tex_span.script_type != 0
tex_string = self.tex_string
all_tex_spans = []
for string in stripped_substrings:
match_objs = list(re.finditer(re.escape(string), tex_string))
if not match_objs:
continue
self.specified_substrings.append(string)
for match_obj in match_objs:
all_tex_spans.append(match_obj.span())
former_script_spans_dict = dict([
script_span_pair[0][::-1]
for script_span_pair in self.neighbouring_script_span_pairs
])
for span_begin, span_end in all_span_tuples:
if span_end in script_spans_dict.values():
# Deconstruct spans with subscripts & superscripts.
while span_end in script_spans_dict:
span_end = script_spans_dict[span_end]
for span_begin, span_end in all_tex_spans:
# Deconstruct spans containing one out of two scripts.
if span_end in former_script_spans_dict.keys():
span_end = former_script_spans_dict[span_end]
if span_begin >= span_end:
continue
span_tuple = (span_begin, span_end)
self.add_tex_span(span_tuple)
self.add_specified_substring(span_tuple)
self.add_tex_span((span_begin, span_end))
def check_if_overlap(self):
span_tuples = sorted(
self.tex_spans_dict.keys(),
key=lambda t: (t[0], -t[1])
)
overlapping_span_pairs = []
for i, span_0 in enumerate(span_tuples):
for span_1 in span_tuples[i + 1 :]:
def get_containing_labels_dict(self):
tex_span_list = self.tex_span_list
result = {
tex_span: []
for tex_span in tex_span_list
}
overlapping_tex_span_pairs = []
for index_0, span_0 in enumerate(tex_span_list):
for index_1, span_1 in enumerate(tex_span_list[index_0:]):
if span_0[1] <= span_1[0]:
continue
if span_0[1] < span_1[1]:
overlapping_span_pairs.append((span_0, span_1))
if overlapping_span_pairs:
overlapping_tex_span_pairs.append((span_0, span_1))
result[span_0].append(index_0 + index_1)
if overlapping_tex_span_pairs:
tex_string = self.tex_string
log.error("Overlapping substring pairs occur in MTex:")
for span_tuple_pair in overlapping_span_pairs:
log.error("Partially overlapping substrings detected:")
for tex_span_pair in overlapping_tex_span_pairs:
log.error(", ".join(
f"\"{tex_string[slice(*span_tuple)]}\""
for span_tuple in span_tuple_pair
f"\"{tex_string[slice(*tex_span)]}\""
for tex_span in tex_span_pair
))
sys.exit(2)
def analyse_containing_labels(self):
for span_0, tex_span_0 in self.tex_spans_dict.items():
if tex_span_0.script_type != 0:
continue
for span_1, tex_span_1 in self.tex_spans_dict.items():
if _contains(span_1, span_0):
tex_span_1.containing_labels.append(tex_span_0.label)
def get_labelled_expression(self):
tex_string = self.tex_string
if not self.tex_spans_dict:
return tex_string
# Remove the span of extire tex string.
indices_with_labels = sorted([
(span_tuple[i], i, span_tuple[1 - i], tex_span.label)
for span_tuple, tex_span in self.tex_spans_dict.items()
if tex_span.script_type == 0
for i in range(2)
], key=lambda t: (t[0], -t[1], -t[2]))[1:]
result = tex_string[: indices_with_labels[0][0]]
for index_with_label, next_index_with_label in _get_neighbouring_pairs(
indices_with_labels
):
index, flag, _, label = index_with_label
next_index, *_ = next_index_with_label
# Adding one more pair of braces will help maintain the glyghs of tex file...
if flag == 0:
color_tuple = _TexParser.label_to_color_tuple(label)
result += "".join([
"{{",
"\\color[RGB]",
"{",
",".join(map(str, color_tuple)),
"}"
])
else:
result += "}}"
result += tex_string[index : next_index]
raise ValueError
return result
def raise_tex_parsing_error(self, message):
raise ValueError(f"Failed to parse tex ({message}): \"{self.tex_string}\"")
def get_labelled_tex_string(self):
indices, _, flags, labels = zip(*sorted([
(*tex_span[::(1, -1)[flag]], flag, label)
for label, tex_span in enumerate(self.tex_span_list)
for flag in range(2)
], key=lambda t: (t[0], -t[2], -t[1])))
command_pieces = [
("{{" + self.get_color_command(label), "}}")[flag]
for flag, label in zip(flags, labels)
][1:-1]
command_pieces.insert(0, "")
string_pieces = [
self.tex_string[slice(*tex_span)]
for tex_span in _get_neighbouring_pairs(indices)
]
return "".join(it.chain(*zip(command_pieces, string_pieces)))
@staticmethod
def get_color_command(label):
rg, b = divmod(label, 256)
r, g = divmod(rg, 256)
return "".join([
"\\color[RGB]",
"{",
",".join(map(str, (r, g, b))),
"}"
])
def get_sorted_submob_indices(self, submob_labels):
def script_span_to_submob_range(script_span):
tex_span = self.script_span_to_tex_span_dict[script_span]
submob_indices = [
index for index, label in enumerate(submob_labels)
if label in self.containing_labels_dict[tex_span]
]
return range(submob_indices[0], submob_indices[-1] + 1)
filtered_script_span_pairs = filter(
lambda script_span_pair: all([
self.script_span_to_char_dict[script_span] == character
for script_span, character in zip(script_span_pair, "_^")
]),
self.neighbouring_script_span_pairs
)
switch_range_pairs = sorted([
tuple([
script_span_to_submob_range(script_span)
for script_span in script_span_pair
])
for script_span_pair in filtered_script_span_pairs
], key=lambda t: (t[0].stop, -t[0].start))
result = list(range(len(submob_labels)))
for range_0, range_1 in switch_range_pairs:
result = [
*result[:range_1.start],
*result[range_0.start:range_0.stop],
*result[range_1.stop:range_0.start],
*result[range_1.start:range_1.stop],
*result[range_0.stop:]
]
return result
def get_submob_tex_strings(self, submob_labels):
ordered_tex_spans = [
self.tex_span_list[label] for label in submob_labels
]
ordered_containing_labels = [
self.containing_labels_dict[tex_span]
for tex_span in ordered_tex_spans
]
ordered_span_begins, ordered_span_ends = zip(*ordered_tex_spans)
string_span_begins = [
prev_end if prev_label in containing_labels else curr_begin
for prev_end, prev_label, containing_labels, curr_begin in zip(
ordered_span_ends[:-1], submob_labels[:-1],
ordered_containing_labels[1:], ordered_span_begins[1:]
)
]
string_span_begins.insert(0, ordered_span_begins[0])
string_span_ends = [
next_begin if next_label in containing_labels else curr_end
for next_begin, next_label, containing_labels, curr_end in zip(
ordered_span_begins[1:], submob_labels[1:],
ordered_containing_labels[:-1], ordered_span_ends[:-1]
)
]
string_span_ends.append(ordered_span_ends[-1])
tex_string = self.tex_string
left_brace_indices = sorted(self.brace_indices_dict.keys())
right_brace_indices = sorted(self.brace_indices_dict.values())
ignored_indices = sorted(it.chain(
self.whitespace_indices,
left_brace_indices,
right_brace_indices,
self.script_indices
))
result = []
for span_begin, span_end in zip(string_span_begins, string_span_ends):
while span_begin in ignored_indices:
span_begin += 1
if span_begin >= span_end:
result.append("")
continue
while span_end - 1 in ignored_indices:
span_end -= 1
unclosed_left_brace = 0
unclosed_right_brace = 0
for index in range(span_begin, span_end):
if index in left_brace_indices:
unclosed_left_brace += 1
elif index in right_brace_indices:
if unclosed_left_brace == 0:
unclosed_right_brace += 1
else:
unclosed_left_brace -= 1
result.append("".join([
unclosed_right_brace * "{",
tex_string[span_begin:span_end],
unclosed_left_brace * "}"
]))
return result
def find_span_components_of_custom_span(self, custom_span):
skipped_indices = sorted(it.chain(
self.whitespace_indices,
self.script_indices
))
tex_span_choices = sorted(filter(
lambda tex_span: all([
tex_span[0] >= custom_span[0],
tex_span[1] <= custom_span[1]
]),
self.tex_span_list
))
# Choose spans that reach the farthest.
tex_span_choices_dict = dict(tex_span_choices)
span_begin, span_end = custom_span
result = []
while span_begin != span_end:
if span_begin not in tex_span_choices_dict.keys():
if span_begin in skipped_indices:
span_begin += 1
continue
return None
next_begin = tex_span_choices_dict[span_begin]
result.append((span_begin, next_begin))
span_begin = next_begin
return result
def get_containing_labels_by_tex_spans(self, tex_spans):
return remove_list_redundancies(list(it.chain(*[
self.containing_labels_dict[tex_span]
for tex_span in tex_spans
])))
def get_specified_substrings(self):
return self.specified_substrings
def get_isolated_substrings(self):
return remove_list_redundancies([
self.tex_string[slice(*tex_span)]
for tex_span in self.tex_span_list
])
class MTex(VMobject):
@ -282,7 +409,7 @@ class MTex(VMobject):
"tex_environment": "align*",
"isolate": [],
"tex_to_color_map": {},
"generate_plain_tex_file": False,
"use_plain_tex": False,
}
def __init__(self, tex_string, **kwargs):
@ -293,227 +420,138 @@ class MTex(VMobject):
tex_string = "\\quad"
self.tex_string = tex_string
self.generate_mobject()
self.__parser = _TexParser(
self.tex_string,
[*self.tex_to_color_map.keys(), *self.isolate]
)
mob = self.generate_mobject()
self.add(*mob.copy())
self.init_colors()
self.set_color_by_tex_to_color_map(self.tex_to_color_map)
self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size)
def get_additional_substrings_to_break_up(self):
result = remove_list_redundancies([
*self.tex_to_color_map.keys(), *self.isolate
])
if "" in result:
result.remove("")
return result
def get_parser(self):
return _TexParser(self.tex_string, self.get_additional_substrings_to_break_up())
@staticmethod
def color_to_label(color):
r, g, b = color_to_int_rgb(color)
rg = r * 256 + g
return rg * 256 + b
def generate_mobject(self):
tex_string = self.tex_string
tex_parser = self.get_parser()
self.tex_spans_dict = tex_parser.tex_spans_dict
self.specified_substrings = tex_parser.specified_substrings
labelled_tex_string = self.__parser.get_labelled_tex_string()
labelled_tex_content = self.get_tex_file_content(labelled_tex_string)
hash_val = hash((labelled_tex_content, self.use_plain_tex))
plain_full_tex = self.get_tex_file_body(tex_string)
plain_hash_val = hash(plain_full_tex)
if plain_hash_val in TEX_HASH_TO_MOB_MAP:
self.add(*TEX_HASH_TO_MOB_MAP[plain_hash_val].copy())
return self
if hash_val in TEX_HASH_TO_MOB_MAP:
return TEX_HASH_TO_MOB_MAP[hash_val]
labelled_expression = tex_parser.get_labelled_expression()
full_tex = self.get_tex_file_body(labelled_expression)
hash_val = hash(full_tex)
if hash_val in TEX_HASH_TO_MOB_MAP and not self.generate_plain_tex_file:
self.add(*TEX_HASH_TO_MOB_MAP[hash_val].copy())
return self
if not self.use_plain_tex:
with display_during_execution(f"Writing \"{self.tex_string}\""):
labelled_svg_glyphs = self.tex_content_to_glyphs(
labelled_tex_content
)
glyph_labels = [
self.color_to_label(labelled_glyph.get_fill_color())
for labelled_glyph in labelled_svg_glyphs
]
mob = self.build_mobject(labelled_svg_glyphs, glyph_labels)
TEX_HASH_TO_MOB_MAP[hash_val] = mob
return mob
with display_during_execution(f"Writing \"{tex_string}\""):
filename = tex_to_svg_file(full_tex)
svg_mob = _LabelledTex(filename)
self.add(*svg_mob.copy())
self.build_submobjects()
TEX_HASH_TO_MOB_MAP[hash_val] = self
if not self.generate_plain_tex_file:
return self
with display_during_execution(f"Writing \"{self.tex_string}\""):
labelled_svg_glyphs = self.tex_content_to_glyphs(
labelled_tex_content
)
tex_content = self.get_tex_file_content(self.tex_string)
svg_glyphs = self.tex_content_to_glyphs(tex_content)
glyph_labels = [
self.color_to_label(labelled_glyph.get_fill_color())
for labelled_glyph in labelled_svg_glyphs
]
mob = self.build_mobject(svg_glyphs, glyph_labels)
TEX_HASH_TO_MOB_MAP[hash_val] = mob
return mob
with display_during_execution(f"Writing \"{tex_string}\""):
filename = tex_to_svg_file(plain_full_tex)
plain_svg_mob = _PlainTex(filename)
svg_mob = TEX_HASH_TO_MOB_MAP[hash_val]
for plain_submob, submob in zip(plain_svg_mob, svg_mob):
plain_submob.glyph_label = submob.glyph_label
self.add(*plain_svg_mob.copy())
self.build_submobjects()
TEX_HASH_TO_MOB_MAP[plain_hash_val] = self
return self
def get_tex_file_body(self, new_tex):
def get_tex_file_content(self, tex_string):
if self.tex_environment:
new_tex = "\n".join([
tex_string = "\n".join([
f"\\begin{{{self.tex_environment}}}",
new_tex,
tex_string,
f"\\end{{{self.tex_environment}}}"
])
if self.alignment:
new_tex = "\n".join([self.alignment, new_tex])
tex_string = "\n".join([self.alignment, tex_string])
return tex_string
@staticmethod
def tex_content_to_glyphs(tex_content):
tex_config = get_tex_config()
return tex_config["tex_body"].replace(
full_tex = tex_config["tex_body"].replace(
tex_config["text_to_replace"],
new_tex
tex_content
)
filename = tex_to_svg_file(full_tex)
return _TexSVG(filename)
def build_submobjects(self):
if not self.submobjects:
return
self.group_submobjects()
self.sort_scripts_in_tex_order()
self.assign_submob_tex_strings()
def build_mobject(self, svg_glyphs, glyph_labels):
if not svg_glyphs:
return VGroup()
def group_submobjects(self):
# Simply pack together adjacent mobjects with the same label.
new_submobjects = []
def append_new_submobject(glyphs):
if glyphs:
submobject = VGroup(*glyphs)
submobject.submob_label = glyphs[0].glyph_label
new_submobjects.append(submobject)
submobjects = []
submob_labels = []
new_glyphs = []
current_glyph_label = 0
for submob in self.submobjects:
if submob.glyph_label == current_glyph_label:
new_glyphs.append(submob)
current_glyph_label = glyph_labels[0]
for glyph, label in zip(svg_glyphs, glyph_labels):
if label == current_glyph_label:
new_glyphs.append(glyph)
else:
append_new_submobject(new_glyphs)
new_glyphs = [submob]
current_glyph_label = submob.glyph_label
append_new_submobject(new_glyphs)
self.set_submobjects(new_submobjects)
submobject = VGroup(*new_glyphs)
submob_labels.append(current_glyph_label)
submobjects.append(submobject)
new_glyphs = [glyph]
current_glyph_label = label
submobject = VGroup(*new_glyphs)
submob_labels.append(current_glyph_label)
submobjects.append(submobject)
def sort_scripts_in_tex_order(self):
# LaTeX always puts superscripts before subscripts.
# This function sorts the submobjects of scripts in the order of tex given.
tex_spans_dict = self.tex_spans_dict
index_and_span_list = sorted([
(index, span_tuple)
for span_tuple, tex_span in tex_spans_dict.items()
if tex_span.script_type != 0
for index in span_tuple
])
indices = self.__parser.get_sorted_submob_indices(submob_labels)
rearranged_submobjects = [submobjects[index] for index in indices]
rearranged_labels = [submob_labels[index] for index in indices]
switch_range_pairs = []
for index_and_span_0, index_and_span_1 in _get_neighbouring_pairs(
index_and_span_list
submob_tex_strings = self.__parser.get_submob_tex_strings(
rearranged_labels
)
for submob, label, submob_tex in zip(
rearranged_submobjects, rearranged_labels, submob_tex_strings
):
index_0, span_tuple_0 = index_and_span_0
index_1, span_tuple_1 = index_and_span_1
if index_0 != index_1:
continue
if not all([
tex_spans_dict[span_tuple_0].script_type == 1,
tex_spans_dict[span_tuple_1].script_type == 2
]):
continue
submob_range_0 = self.range_of_part(
self.get_part_by_span_tuples([span_tuple_0])
)
submob_range_1 = self.range_of_part(
self.get_part_by_span_tuples([span_tuple_1])
)
switch_range_pairs.append((submob_range_0, submob_range_1))
switch_range_pairs.sort(key=lambda pair: (pair[0].stop, -pair[0].start))
indices = list(range(len(self.submobjects)))
for submob_range_0, submob_range_1 in switch_range_pairs:
indices = [
*indices[: submob_range_1.start],
*indices[submob_range_0.start : submob_range_0.stop],
*indices[submob_range_1.stop : submob_range_0.start],
*indices[submob_range_1.start : submob_range_1.stop],
*indices[submob_range_0.stop :]
]
submobs = self.submobjects
self.set_submobjects([submobs[i] for i in indices])
def assign_submob_tex_strings(self):
# Not sure whether this is the best practice...
# This temporarily supports `TransformMatchingTex`.
tex_string = self.tex_string
tex_spans_dict = self.tex_spans_dict
# Use tex strings including "_", "^".
label_dict = {}
for span_tuple, tex_span in tex_spans_dict.items():
if tex_span.script_type != 0:
label_dict[tex_span.label] = span_tuple
else:
if tex_span.label not in label_dict:
label_dict[tex_span.label] = span_tuple
curr_labels = [submob.submob_label for submob in self.submobjects]
prev_labels = [curr_labels[-1], *curr_labels[:-1]]
next_labels = [*curr_labels[1:], curr_labels[0]]
tex_string_spans = []
for curr_label, prev_label, next_label in zip(
curr_labels, prev_labels, next_labels
):
curr_span_tuple = label_dict[curr_label]
prev_span_tuple = label_dict[prev_label]
next_span_tuple = label_dict[next_label]
containing_labels = tex_spans_dict[curr_span_tuple].containing_labels
tex_string_spans.append([
prev_span_tuple[1] if prev_label in containing_labels else curr_span_tuple[0],
next_span_tuple[0] if next_label in containing_labels else curr_span_tuple[1]
])
tex_string_spans[0][0] = label_dict[curr_labels[0]][0]
tex_string_spans[-1][1] = label_dict[curr_labels[-1]][1]
for submob, tex_string_span in zip(self.submobjects, tex_string_spans):
submob.tex_string = tex_string[slice(*tex_string_span)]
submob.submob_label = label
submob.tex_string = submob_tex
# Support `get_tex()` method here.
submob.get_tex = MethodType(lambda inst: inst.tex_string, submob)
return VGroup(*rearranged_submobjects)
def get_part_by_span_tuples(self, span_tuples):
tex_spans_dict = self.tex_spans_dict
labels = set(it.chain(*[
tex_spans_dict[span_tuple].containing_labels
for span_tuple in span_tuples
]))
def get_part_by_tex_spans(self, tex_spans):
labels = self.__parser.get_containing_labels_by_tex_spans(tex_spans)
return VGroup(*filter(
lambda submob: submob.submob_label in labels,
self.submobjects
))
def find_span_components_of_custom_span(self, custom_span_tuple, partial_result=[]):
span_begin, span_end = custom_span_tuple
if span_begin == span_end:
return partial_result
next_begin_choices = sorted([
span_tuple[1]
for span_tuple in self.tex_spans_dict.keys()
if span_tuple[0] == span_begin and span_tuple[1] <= span_end
], reverse=True)
for next_begin in next_begin_choices:
result = self.find_span_components_of_custom_span(
(next_begin, span_end), [*partial_result, (span_begin, next_begin)]
)
if result is not None:
return result
return None
def get_part_by_custom_span_tuple(self, custom_span_tuple):
span_tuples = self.find_span_components_of_custom_span(custom_span_tuple)
if span_tuples is None:
tex = self.tex_string[slice(*custom_span_tuple)]
raise ValueError(f"Failed to get span of tex: \"{tex}\"")
return self.get_part_by_span_tuples(span_tuples)
def get_part_by_custom_span(self, custom_span):
tex_spans = self.__parser.find_span_components_of_custom_span(
custom_span
)
if tex_spans is None:
tex = self.tex_string[slice(*custom_span)]
raise ValueError(f"Failed to match mobjects from tex: \"{tex}\"")
return self.get_part_by_tex_spans(tex_spans)
def get_parts_by_tex(self, tex):
return VGroup(*[
self.get_part_by_custom_span_tuple(match_obj.span())
for match_obj in re.finditer(re.escape(tex), self.tex_string)
self.get_part_by_custom_span(match_obj.span())
for match_obj in re.finditer(
re.escape(tex.strip()), self.tex_string
)
])
def get_part_by_tex(self, tex, index=0):
@ -525,16 +563,13 @@ class MTex(VMobject):
return self
def set_color_by_tex_to_color_map(self, tex_to_color_map):
for tex, color in list(tex_to_color_map.items()):
try:
self.set_color_by_tex(tex, color)
except:
pass
for tex, color in tex_to_color_map.items():
self.set_color_by_tex(tex, color)
return self
def indices_of_part(self, part):
indices = [
i for i, submob in enumerate(self.submobjects)
index for index, submob in enumerate(self.submobjects)
if submob in part
]
if not indices:
@ -545,42 +580,20 @@ class MTex(VMobject):
part = self.get_part_by_tex(tex, index=index)
return self.indices_of_part(part)
def indices_of_all_parts_by_tex(self, tex, index=0):
all_parts = self.get_parts_by_tex(tex)
return list(it.chain(*[
self.indices_of_part(part) for part in all_parts
]))
def range_of_part(self, part):
indices = self.indices_of_part(part)
return range(indices[0], indices[-1] + 1)
def range_of_part_by_tex(self, tex, index=0):
part = self.get_part_by_tex(tex, index=index)
return self.range_of_part(part)
def index_of_part(self, part):
return self.indices_of_part(part)[0]
def index_of_part_by_tex(self, tex, index=0):
part = self.get_part_by_tex(tex, index=index)
return self.index_of_part(part)
def get_tex(self):
return self.tex_string
def get_all_isolated_substrings(self):
tex_string = self.tex_string
return remove_list_redundancies([
tex_string[slice(*span_tuple)]
for span_tuple in self.tex_spans_dict.keys()
])
def get_submob_tex(self):
return [
submob.get_tex()
for submob in self.submobjects
]
def list_tex_strings_of_submobjects(self):
# Work with `index_labels()`.
log.debug(f"Submobjects of \"{self.get_tex()}\":")
for i, submob in enumerate(self.submobjects):
log.debug(f"{i}: \"{submob.get_tex()}\"")
def get_specified_substrings(self):
return self.__parser.get_specified_substrings()
def get_isolated_substrings(self):
return self.__parser.get_isolated_substrings()
class MTexText(MTex):

View file

@ -1,102 +1,27 @@
import itertools as it
import re
import string
import os
import re
import hashlib
import itertools as it
import cssselect2
from colour import web2hex
from xml.etree import ElementTree
from tinycss2 import serialize as css_serialize
from tinycss2 import parse_stylesheet, parse_declaration_list
from manimlib.constants import ORIGIN, UP, DOWN, LEFT, RIGHT, IN
from manimlib.constants import DEGREES, PI
import svgelements as se
import numpy as np
from manimlib.constants import RIGHT
from manimlib.mobject.geometry import Line
from manimlib.mobject.geometry import Circle
from manimlib.mobject.geometry import Polygon
from manimlib.mobject.geometry import Polyline
from manimlib.mobject.geometry import Rectangle
from manimlib.mobject.geometry import RoundedRectangle
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.color import *
from manimlib.utils.config_ops import digest_config
from manimlib.utils.directories import get_mobject_data_dir
from manimlib.utils.images import get_full_vector_image_path
from manimlib.utils.simple_functions import clip
from manimlib.logger import log
DEFAULT_STYLE = {
"fill": "black",
"stroke": "none",
"fill-opacity": "1",
"stroke-opacity": "1",
"stroke-width": 0,
}
def cascade_element_style(element, inherited):
style = inherited.copy()
for attr in DEFAULT_STYLE:
if element.get(attr):
style[attr] = element.get(attr)
if element.get("style"):
declarations = parse_declaration_list(element.get("style"))
for declaration in declarations:
style[declaration.name] = css_serialize(declaration.value)
return style
def parse_color(color):
color = color.strip()
if color[0:3] == "rgb":
splits = color[4:-1].strip().split(",")
if splits[0].strip()[-1] == "%":
parsed_rgbs = [float(i.strip()[:-1]) / 100.0 for i in splits]
else:
parsed_rgbs = [int(i) / 255.0 for i in splits]
return rgb_to_hex(parsed_rgbs)
else:
return web2hex(color)
def fill_default_values(style, default_style):
default = DEFAULT_STYLE.copy()
default.update(default_style)
for attr in default:
if attr not in style:
style[attr] = default[attr]
def parse_style(style, default_style):
manim_style = {}
fill_default_values(style, default_style)
for key in ("fill-opacity", "stroke-opacity", "stroke-width"):
value = style[key]
if isinstance(value, str) and value.endswith("px"):
value = float(value[:-2]) * 0 # HACKY, need to fix
manim_style[key.replace("-", "_")] = float(value)
if style["fill"] == "none":
manim_style["fill_opacity"] = 0
else:
manim_style["fill_color"] = parse_color(style["fill"])
if style["stroke"] == "none":
manim_style["stroke_width"] = 0
if "fill_color" in manim_style:
manim_style["stroke_color"] = manim_style["fill_color"]
else:
manim_style["stroke_color"] = parse_color(style["stroke"])
return manim_style
def _convert_point_to_3d(x, y):
return np.array([x, y, 0.0])
class SVGMobject(VMobject):
@ -104,11 +29,15 @@ class SVGMobject(VMobject):
"should_center": True,
"height": 2,
"width": None,
# Must be filled in in a subclass, or when called
# Must be filled in a subclass, or when called
"file_name": None,
"unpack_groups": True, # if False, creates a hierarchy of VGroups
"stroke_width": 0.0,
"fill_opacity": 1.0,
"color": None,
"opacity": None,
"fill_color": None,
"fill_opacity": None,
"stroke_width": None,
"stroke_color": None,
"stroke_opacity": None,
"path_string_config": {}
}
@ -130,338 +59,232 @@ class SVGMobject(VMobject):
if self.width is not None:
self.set_width(self.width)
def init_colors(self, override=False):
super().init_colors(override=override)
def init_colors(self):
# Remove fill_color, fill_opacity,
# stroke_width, stroke_color, stroke_opacity
# as each submobject may have those values specified in svg file
self.set_stroke(background=self.draw_stroke_behind_fill)
self.set_gloss(self.gloss)
self.set_flat_stroke(self.flat_stroke)
return self
def init_points(self):
etree = ElementTree.parse(self.file_path)
wrapper = cssselect2.ElementWrapper.from_xml_root(etree)
svg = etree.getroot()
namespace = svg.tag.split("}")[0][1:]
self.ref_to_element = {}
self.css_matcher = cssselect2.Matcher()
with open(self.file_path, "r") as svg_file:
svg_string = svg_file.read()
for style in etree.findall(f"{{{namespace}}}style"):
self.parse_css_style(style.text)
# Create a temporary svg file to dump modified svg to be parsed
modified_svg_string = self.modify_svg_file(svg_string)
modified_file_path = self.file_path.replace(".svg", "_.svg")
with open(modified_file_path, "w") as modified_svg_file:
modified_svg_file.write(modified_svg_string)
mobjects = self.get_mobjects_from(wrapper, dict())
if self.unpack_groups:
self.add(*mobjects)
else:
self.add(*mobjects[0].submobjects)
def get_mobjects_from(self, wrapper, style):
result = []
element = wrapper.etree_element
if not isinstance(element, ElementTree.Element):
return result
matches = self.css_matcher.match(wrapper)
if matches:
for match in matches:
_, _, _, css_style = match
style.update(css_style)
style = cascade_element_style(element, style)
tag = element.tag.split("}")[-1]
if tag == 'defs':
self.update_ref_to_element(wrapper, style)
elif tag in ['g', 'svg', 'symbol']:
result += it.chain(*(
self.get_mobjects_from(child, style)
for child in wrapper.iter_children()
))
elif tag == 'path':
result.append(self.path_string_to_mobject(
element.get('d'), style
))
elif tag == 'use':
result += self.use_to_mobjects(element, style)
elif tag == 'line':
result.append(self.line_to_mobject(element, style))
elif tag == 'rect':
result.append(self.rect_to_mobject(element, style))
elif tag == 'circle':
result.append(self.circle_to_mobject(element, style))
elif tag == 'ellipse':
result.append(self.ellipse_to_mobject(element, style))
elif tag in ['polygon', 'polyline']:
result.append(self.polygon_to_mobject(element, style))
elif tag == 'style':
pass
else:
log.warning(f"Unsupported element type: {tag}")
pass # TODO, support <text> tag
result = [m for m in result if m is not None]
self.handle_transforms(element, VGroup(*result))
if len(result) > 1 and not self.unpack_groups:
result = [VGroup(*result)]
return result
def generate_default_style(self):
style = {
"fill-opacity": self.fill_opacity,
"stroke-width": self.stroke_width,
"stroke-opacity": self.stroke_opacity,
}
if self.color:
style["fill"] = style["stroke"] = self.color
# `color` attribute handles `currentColor` keyword
if self.fill_color:
style["fill"] = self.fill_color
if self.stroke_color:
style["stroke"] = self.stroke_color
return style
def parse_css_style(self, css):
rules = parse_stylesheet(css, True, True)
for rule in rules:
selectors = cssselect2.compile_selector_list(rule.prelude)
declarations = parse_declaration_list(rule.content)
style = {
declaration.name: css_serialize(declaration.value)
for declaration in declarations
if declaration.name in DEFAULT_STYLE
}
payload = style
for selector in selectors:
self.css_matcher.add_selector(selector, payload)
def path_string_to_mobject(self, path_string, style):
return VMobjectFromSVGPathstring(
path_string,
**self.path_string_config,
**parse_style(style, self.generate_default_style()),
color = self.fill_color
elif self.color:
color = self.color
else:
color = "black"
shapes = se.SVG.parse(
modified_file_path,
color=color
)
os.remove(modified_file_path)
def use_to_mobjects(self, use_element, local_style):
# Remove initial "#" character
ref = use_element.get(r"{http://www.w3.org/1999/xlink}href")[1:]
if ref not in self.ref_to_element:
log.warning(f"{ref} not recognized")
return VGroup()
def_element, def_style = self.ref_to_element[ref]
style = local_style.copy()
style.update(def_style)
return self.get_mobjects_from(def_element, style)
mobjects = self.get_mobjects_from(shapes)
self.add(*mobjects)
self.flip(RIGHT) # Flip y
self.scale(0.75)
def attribute_to_float(self, attr):
stripped_attr = "".join([
char for char in attr
if char in string.digits + "." + "-"
def modify_svg_file(self, svg_string):
# svgelements cannot handle em, ex units
# Convert them using 1em = 16px, 1ex = 0.5em = 8px
def convert_unit(match_obj):
number = float(match_obj.group(1))
unit = match_obj.group(2)
factor = 16 if unit == "em" else 8
return str(number * factor) + "px"
number_pattern = r"([-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?)(ex|em)(?![a-zA-Z])"
result = re.sub(number_pattern, convert_unit, svg_string)
# Add a group tag to set style from configuration
style_dict = self.generate_context_values_from_config()
group_tag_begin = "<g " + " ".join([
f"{k}=\"{v}\""
for k, v in style_dict.items()
]) + ">"
group_tag_end = "</g>"
begin_insert_index = re.search(r"<svg[\s\S]*?>", result).end()
end_insert_index = re.search(r"[\s\S]*(</svg\s*>)", result).start(1)
result = "".join([
result[:begin_insert_index],
group_tag_begin,
result[begin_insert_index:end_insert_index],
group_tag_end,
result[end_insert_index:]
])
return float(stripped_attr)
def polygon_to_mobject(self, polygon_element, style):
path_string = polygon_element.get("points")
for digit in string.digits:
path_string = path_string.replace(f" {digit}", f"L {digit}")
path_string = path_string.replace("L", "M", 1)
return self.path_string_to_mobject(path_string, style)
def circle_to_mobject(self, circle_element, style):
x, y, r = (
self.attribute_to_float(circle_element.get(key, "0.0"))
for key in ("cx", "cy", "r")
)
return Circle(
radius=r,
**parse_style(style, self.generate_default_style())
).shift(x * RIGHT + y * DOWN)
def ellipse_to_mobject(self, circle_element, style):
x, y, rx, ry = (
self.attribute_to_float(circle_element.get(key, "0.0"))
for key in ("cx", "cy", "rx", "ry")
)
result = Circle(**parse_style(style, self.generate_default_style()))
result.stretch(rx, 0)
result.stretch(ry, 1)
result.shift(x * RIGHT + y * DOWN)
return result
def rect_to_mobject(self, rect_element, style):
stroke_width = rect_element.get("stroke-width", "")
corner_radius = rect_element.get("rx", "")
def generate_context_values_from_config(self):
result = {}
if self.stroke_width is not None:
result["stroke-width"] = self.stroke_width
if self.color is not None:
result["fill"] = result["stroke"] = self.color
if self.fill_color is not None:
result["fill"] = self.fill_color
if self.stroke_color is not None:
result["stroke"] = self.stroke_color
if self.opacity is not None:
result["fill-opacity"] = result["stroke-opacity"] = self.opacity
if self.fill_opacity is not None:
result["fill-opacity"] = self.fill_opacity
if self.stroke_opacity is not None:
result["stroke-opacity"] = self.stroke_opacity
return result
if stroke_width in ["", "none", "0"]:
stroke_width = 0
def get_mobjects_from(self, shape):
if isinstance(shape, se.Group):
return list(it.chain(*(
self.get_mobjects_from(child)
for child in shape
)))
if corner_radius in ["", "0", "none"]:
corner_radius = 0
mob = self.get_mobject_from(shape)
if mob is None:
return []
corner_radius = float(corner_radius)
if isinstance(shape, se.Transformable) and shape.apply:
self.handle_transform(mob, shape.transform)
return [mob]
parsed_style = parse_style(style, self.generate_default_style())
parsed_style["stroke_width"] = stroke_width
@staticmethod
def handle_transform(mob, matrix):
mat = np.array([
[matrix.a, matrix.c],
[matrix.b, matrix.d]
])
vec = np.array([matrix.e, matrix.f, 0.0])
mob.apply_matrix(mat)
mob.shift(vec)
return mob
if corner_radius == 0:
def get_mobject_from(self, shape):
shape_class_to_func_map = {
se.Path: self.path_to_mobject,
se.SimpleLine: self.line_to_mobject,
se.Rect: self.rect_to_mobject,
se.Circle: self.circle_to_mobject,
se.Ellipse: self.ellipse_to_mobject,
se.Polygon: self.polygon_to_mobject,
se.Polyline: self.polyline_to_mobject,
# se.Text: self.text_to_mobject, # TODO
}
for shape_class, func in shape_class_to_func_map.items():
if isinstance(shape, shape_class):
mob = func(shape)
self.apply_style_to_mobject(mob, shape)
return mob
shape_class_name = shape.__class__.__name__
if shape_class_name != "SVGElement":
log.warning(f"Unsupported element type: {shape_class_name}")
return None
@staticmethod
def apply_style_to_mobject(mob, shape):
mob.set_style(
stroke_width=shape.stroke_width,
stroke_color=shape.stroke.hex,
stroke_opacity=shape.stroke.opacity,
fill_color=shape.fill.hex,
fill_opacity=shape.fill.opacity
)
return mob
def path_to_mobject(self, path):
return VMobjectFromSVGPath(path, **self.path_string_config)
def line_to_mobject(self, line):
return Line(
start=_convert_point_to_3d(line.x1, line.y1),
end=_convert_point_to_3d(line.x2, line.y2)
)
def rect_to_mobject(self, rect):
if rect.rx == 0 or rect.ry == 0:
mob = Rectangle(
width=self.attribute_to_float(
rect_element.get("width", "")
),
height=self.attribute_to_float(
rect_element.get("height", "")
),
**parsed_style,
width=rect.width,
height=rect.height,
)
else:
mob = RoundedRectangle(
width=self.attribute_to_float(
rect_element.get("width", "")
),
height=self.attribute_to_float(
rect_element.get("height", "")
),
corner_radius=corner_radius,
**parsed_style
width=rect.width,
height=rect.height * rect.rx / rect.ry,
corner_radius=rect.rx
)
mob.shift(mob.get_center() - mob.get_corner(UP + LEFT))
mob.stretch_to_fit_height(rect.height)
mob.shift(_convert_point_to_3d(
rect.x + rect.width / 2,
rect.y + rect.height / 2
))
return mob
def line_to_mobject(self, line_element, style):
x1, y1, x2, y2 = (
self.attribute_to_float(line_element.get(key, "0.0"))
for key in ("x1", "y1", "x2", "y2")
)
return Line(
[x1, -y1, 0], [x2, -y2, 0],
**parse_style(style, self.generate_default_style())
)
def handle_transforms(self, element, mobject):
x, y = (
self.attribute_to_float(element.get(key, "0.0"))
for key in ("x", "y")
)
mobject.shift(x * RIGHT + y * DOWN)
def circle_to_mobject(self, circle):
# svgelements supports `rx` & `ry` but `r`
mob = Circle(radius=circle.rx)
mob.shift(_convert_point_to_3d(
circle.cx, circle.cy
))
return mob
transform_names = [
"matrix",
"translate", "translateX", "translateY",
"scale", "scaleX", "scaleY",
"rotate",
"skewX", "skewY"
def ellipse_to_mobject(self, ellipse):
mob = Circle(radius=ellipse.rx)
mob.stretch_to_fit_height(2 * ellipse.ry)
mob.shift(_convert_point_to_3d(
ellipse.cx, ellipse.cy
))
return mob
def polygon_to_mobject(self, polygon):
points = [
_convert_point_to_3d(*point)
for point in polygon
]
transform_pattern = re.compile("|".join([x + r"[^)]*\)" for x in transform_names]))
number_pattern = re.compile(r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?")
transforms = transform_pattern.findall(element.get("transform", ""))[::-1]
return Polygon(*points)
for transform in transforms:
op_name, op_args = transform.split("(")
op_name = op_name.strip()
op_args = [float(x) for x in number_pattern.findall(op_args)]
def polyline_to_mobject(self, polyline):
points = [
_convert_point_to_3d(*point)
for point in polyline
]
return Polyline(*points)
if op_name == "matrix":
self._handle_matrix_transform(mobject, op_name, op_args)
elif op_name.startswith("translate"):
self._handle_translate_transform(mobject, op_name, op_args)
elif op_name.startswith("scale"):
self._handle_scale_transform(mobject, op_name, op_args)
elif op_name == "rotate":
self._handle_rotate_transform(mobject, op_name, op_args)
elif op_name.startswith("skew"):
self._handle_skew_transform(mobject, op_name, op_args)
def _handle_matrix_transform(self, mobject, op_name, op_args):
transform = np.array(op_args).reshape([3, 2])
x = transform[2][0]
y = -transform[2][1]
matrix = np.identity(self.dim)
matrix[:2, :2] = transform[:2, :]
matrix[1] *= -1
matrix[:, 1] *= -1
for mob in mobject.family_members_with_points():
mob.apply_matrix(matrix.T)
mobject.shift(x * RIGHT + y * UP)
def _handle_translate_transform(self, mobject, op_name, op_args):
if op_name.endswith("X"):
x, y = op_args[0], 0
elif op_name.endswith("Y"):
x, y = 0, op_args[0]
else:
x, y = op_args
mobject.shift(x * RIGHT + y * DOWN)
def _handle_scale_transform(self, mobject, op_name, op_args):
if op_name.endswith("X"):
sx, sy = op_args[0], 1
elif op_name.endswith("Y"):
sx, sy = 1, op_args[0]
elif len(op_args) == 2:
sx, sy = op_args
else:
sx = sy = op_args[0]
if sx < 0:
mobject.flip(UP)
sx = -sx
if sy < 0:
mobject.flip(RIGHT)
sy = -sy
mobject.scale(np.array([sx, sy, 1]), about_point=ORIGIN)
def _handle_rotate_transform(self, mobject, op_name, op_args):
if len(op_args) == 1:
mobject.rotate(op_args[0] * DEGREES, axis=IN, about_point=ORIGIN)
else:
deg, x, y = op_args
mobject.rotate(deg * DEGREES, axis=IN, about_point=np.array([x, y, 0]))
def _handle_skew_transform(self, mobject, op_name, op_args):
rad = op_args[0] * DEGREES
if op_name == "skewX":
tana = np.tan(rad)
self._handle_matrix_transform(mobject, None, [1., 0., tana, 1., 0., 0.])
elif op_name == "skewY":
tana = np.tan(rad)
self._handle_matrix_transform(mobject, None, [1., tana, 0., 1., 0., 0.])
def flatten(self, input_list):
output_list = []
for i in input_list:
if isinstance(i, list):
output_list.extend(self.flatten(i))
else:
output_list.append(i)
return output_list
def get_all_childWrappers_have_id(self, wrapper):
all_childWrappers_have_id = []
element = wrapper.etree_element
if not isinstance(element, ElementTree.Element):
return
if element.get('id'):
return [wrapper]
for e in wrapper.iter_children():
all_childWrappers_have_id.append(self.get_all_childWrappers_have_id(e))
return self.flatten([e for e in all_childWrappers_have_id if e])
def update_ref_to_element(self, wrapper, style):
new_refs = {
e.etree_element.get('id', ''): (e, style)
for e in self.get_all_childWrappers_have_id(wrapper)
}
self.ref_to_element.update(new_refs)
def text_to_mobject(self, text):
pass
class VMobjectFromSVGPathstring(VMobject):
class VMobjectFromSVGPath(VMobject):
CONFIG = {
"long_lines": False,
"should_subdivide_sharp_curves": False,
"should_remove_null_curves": False,
}
def __init__(self, path_string, **kwargs):
self.path_string = path_string
def __init__(self, path_obj, **kwargs):
# Get rid of arcs
path_obj.approximate_arcs_with_quads()
self.path_obj = path_obj
super().__init__(**kwargs)
def init_points(self):
# After a given svg_path has been converted into points, the result
# will be saved to a file so that future calls for the same path
# don't need to retrace the same computation.
hasher = hashlib.sha256(self.path_string.encode())
path_string = self.path_obj.d()
hasher = hashlib.sha256(path_string.encode())
path_hash = hasher.hexdigest()[:16]
points_filepath = os.path.join(get_mobject_data_dir(), f"{path_hash}_points.npy")
tris_filepath = os.path.join(get_mobject_data_dir(), f"{path_hash}_tris.npy")
@ -478,239 +301,23 @@ class VMobjectFromSVGPathstring(VMobject):
if self.should_remove_null_curves:
# Get rid of any null curves
self.set_points(self.get_points_without_null_curves())
# SVG treats y-coordinate differently
self.stretch(-1, 1, about_point=ORIGIN)
# Save to a file for future use
np.save(points_filepath, self.get_points())
np.save(tris_filepath, self.get_triangulation())
def get_commands_and_coord_strings(self):
all_commands = list(self.get_command_to_function_map().keys())
all_commands += [c.lower() for c in all_commands]
pattern = "[{}]".format("".join(all_commands))
return zip(
re.findall(pattern, self.path_string),
re.split(pattern, self.path_string)[1:]
)
def handle_commands(self):
relative_point = ORIGIN
for command, coord_string in self.get_commands_and_coord_strings():
func, number_types_str = self.command_to_function(command)
upper_command = command.upper()
if upper_command == "Z":
func() # `close_path` takes no arguments
relative_point = self.get_last_point()
continue
number_types = np.array(list(number_types_str))
n_numbers = len(number_types_str)
number_list = _PathStringParser(coord_string, number_types_str).args
number_groups = np.array(number_list).reshape((-1, n_numbers))
for ind, numbers in enumerate(number_groups):
if command.islower():
# Treat it as a relative command
numbers[number_types == "x"] += relative_point[0]
numbers[number_types == "y"] += relative_point[1]
if upper_command == "A":
args = [*numbers[:5], np.array([*numbers[5:7], 0.0])]
elif upper_command == "H":
args = [np.array([numbers[0], relative_point[1], 0.0])]
elif upper_command == "V":
args = [np.array([relative_point[0], numbers[0], 0.0])]
else:
args = list(np.hstack((
numbers.reshape((-1, 2)), np.zeros((n_numbers // 2, 1))
)))
if upper_command == "M" and ind != 0:
# M x1 y1 x2 y2 is equal to M x1 y1 L x2 y2
func, _ = self.command_to_function("L")
func(*args)
relative_point = self.get_last_point()
def add_elliptical_arc_to(self, rx, ry, x_axis_rotation, large_arc_flag, sweep_flag, point):
def close_to_zero(a, threshold=1e-5):
return abs(a) < threshold
def solve_2d_linear_equation(a, b, c):
"""
Using Crammer's rule to solve the linear equation `[a b]x = c`
where `a`, `b` and `c` are all 2d vectors.
"""
def det(a, b):
return a[0] * b[1] - a[1] * b[0]
d = det(a, b)
if close_to_zero(d):
raise Exception("Cannot handle 0 determinant.")
return [det(c, b) / d, det(a, c) / d]
def get_arc_center_and_angles(x0, y0, rx, ry, phi, large_arc_flag, sweep_flag, x1, y1):
"""
The parameter functions of an ellipse rotated `phi` radians counterclockwise is (on `alpha`):
x = cx + rx * cos(alpha) * cos(phi) + ry * sin(alpha) * sin(phi),
y = cy + rx * cos(alpha) * sin(phi) - ry * sin(alpha) * cos(phi).
Now we have two points sitting on the ellipse: `(x0, y0)`, `(x1, y1)`, corresponding to 4 equations,
and we want to hunt for 4 variables: `cx`, `cy`, `alpha0` and `alpha_1`.
Let `d_alpha = alpha1 - alpha0`, then:
if `sweep_flag = 0` and `large_arc_flag = 1`, then `PI <= d_alpha < 2 * PI`;
if `sweep_flag = 0` and `large_arc_flag = 0`, then `0 < d_alpha <= PI`;
if `sweep_flag = 1` and `large_arc_flag = 0`, then `-PI <= d_alpha < 0`;
if `sweep_flag = 1` and `large_arc_flag = 1`, then `-2 * PI < d_alpha <= -PI`.
"""
xd = x1 - x0
yd = y1 - y0
if close_to_zero(xd) and close_to_zero(yd):
raise Exception("Cannot find arc center since the start point and the end point meet.")
# Find `p = cos(alpha1) - cos(alpha0)`, `q = sin(alpha1) - sin(alpha0)`
eq0 = [rx * np.cos(phi), ry * np.sin(phi), xd]
eq1 = [rx * np.sin(phi), -ry * np.cos(phi), yd]
p, q = solve_2d_linear_equation(*zip(eq0, eq1))
# Find `s = (alpha1 - alpha0) / 2`, `t = (alpha1 + alpha0) / 2`
# If `sin(s) = 0`, this requires `p = q = 0`,
# implying `xd = yd = 0`, which is impossible.
sin_s = (p ** 2 + q ** 2) ** 0.5 / 2
if sweep_flag:
sin_s = -sin_s
sin_s = clip(sin_s, -1, 1)
s = np.arcsin(sin_s)
if large_arc_flag:
if not sweep_flag:
s = PI - s
else:
s = -PI - s
sin_t = -p / (2 * sin_s)
cos_t = q / (2 * sin_s)
cos_t = clip(cos_t, -1, 1)
t = np.arccos(cos_t)
if sin_t <= 0:
t = -t
# We can make sure `0 < abs(s) < PI`, `-PI <= t < PI`.
alpha0 = t - s
alpha_1 = t + s
cx = x0 - rx * np.cos(alpha0) * np.cos(phi) - ry * np.sin(alpha0) * np.sin(phi)
cy = y0 - rx * np.cos(alpha0) * np.sin(phi) + ry * np.sin(alpha0) * np.cos(phi)
return cx, cy, alpha0, alpha_1
def get_point_on_ellipse(cx, cy, rx, ry, phi, angle):
return np.array([
cx + rx * np.cos(angle) * np.cos(phi) + ry * np.sin(angle) * np.sin(phi),
cy + rx * np.cos(angle) * np.sin(phi) - ry * np.sin(angle) * np.cos(phi),
0
])
def convert_elliptical_arc_to_quadratic_bezier_curve(
cx, cy, rx, ry, phi, start_angle, end_angle, n_components=8
):
theta = (end_angle - start_angle) / n_components / 2
handles = np.array([
get_point_on_ellipse(cx, cy, rx / np.cos(theta), ry / np.cos(theta), phi, a)
for a in np.linspace(
start_angle + theta,
end_angle - theta,
n_components,
)
])
anchors = np.array([
get_point_on_ellipse(cx, cy, rx, ry, phi, a)
for a in np.linspace(
start_angle + theta * 2,
end_angle,
n_components,
)
])
return handles, anchors
phi = x_axis_rotation * DEGREES
x0, y0 = self.get_last_point()[:2]
cx, cy, start_angle, end_angle = get_arc_center_and_angles(
x0, y0, rx, ry, phi, large_arc_flag, sweep_flag, point[0], point[1]
)
handles, anchors = convert_elliptical_arc_to_quadratic_bezier_curve(
cx, cy, rx, ry, phi, start_angle, end_angle
)
for handle, anchor in zip(handles, anchors):
self.add_quadratic_bezier_curve_to(handle, anchor)
def command_to_function(self, command):
return self.get_command_to_function_map()[command.upper()]
def get_command_to_function_map(self):
"""
Associates svg command to VMobject function, and
the types of arguments it takes in
"""
return {
"M": (self.start_new_path, "xy"),
"L": (self.add_line_to, "xy"),
"H": (self.add_line_to, "x"),
"V": (self.add_line_to, "y"),
"C": (self.add_cubic_bezier_curve_to, "xyxyxy"),
"S": (self.add_smooth_cubic_curve_to, "xyxy"),
"Q": (self.add_quadratic_bezier_curve_to, "xyxy"),
"T": (self.add_smooth_curve_to, "xy"),
"A": (self.add_elliptical_arc_to, "uuaffxy"),
"Z": (self.close_path, ""),
segment_class_to_func_map = {
se.Move: (self.start_new_path, ("end",)),
se.Close: (self.close_path, ()),
se.Line: (self.add_line_to, ("end",)),
se.QuadraticBezier: (self.add_quadratic_bezier_curve_to, ("control", "end")),
se.CubicBezier: (self.add_cubic_bezier_curve_to, ("control1", "control2", "end"))
}
def get_original_path_string(self):
return self.path_string
class InvalidPathError(ValueError):
pass
class _PathStringParser:
# modified from https://github.com/regebro/svg.path/
def __init__(self, arguments, rules):
self.args = []
arguments = bytearray(arguments, "ascii")
self._strip_array(arguments)
while arguments:
for rule in rules:
self._rule_to_function_map[rule](arguments)
@property
def _rule_to_function_map(self):
return {
"x": self._get_number,
"y": self._get_number,
"a": self._get_number,
"u": self._get_unsigned_number,
"f": self._get_flag,
}
def _strip_array(self, arg_array):
# wsp: (0x9, 0x20, 0xA, 0xC, 0xD) with comma 0x2C
# https://www.w3.org/TR/SVG/paths.html#PathDataBNF
while arg_array and arg_array[0] in [0x9, 0x20, 0xA, 0xC, 0xD, 0x2C]:
arg_array[0:1] = b""
def _get_number(self, arg_array):
pattern = re.compile(rb"^[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?")
res = pattern.search(arg_array)
if not res:
raise InvalidPathError(f"Expected a number, got '{arg_array}'")
number = float(res.group())
self.args.append(number)
arg_array[res.start():res.end()] = b""
self._strip_array(arg_array)
return number
def _get_unsigned_number(self, arg_array):
number = self._get_number(arg_array)
if number < 0:
raise InvalidPathError(f"Expected an unsigned number, got '{number}'")
return number
def _get_flag(self, arg_array):
flag = arg_array[0]
if flag != 48 and flag != 49:
raise InvalidPathError(f"Expected a flag (0/1), got '{chr(flag)}'")
flag -= 48
self.args.append(flag)
arg_array[0:1] = b""
self._strip_array(arg_array)
return flag
for segment in self.path_obj:
segment_class = segment.__class__
func, attr_names = segment_class_to_func_map[segment_class]
points = [
_convert_point_to_3d(*segment.__getattribute__(attr_name))
for attr_name in attr_names
]
func(*points)

View file

@ -53,13 +53,19 @@ class SingleStringTex(VMobject):
sm.copy()
for sm in tex_string_with_color_to_mob_map[(self.color, tex_string)]
))
self.init_colors(override=False)
self.init_colors()
if self.height is None:
self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size)
if self.organize_left_to_right:
self.organize_submobjects_left_to_right()
def init_colors(self):
self.set_stroke(background=self.draw_stroke_behind_fill)
self.set_gloss(self.gloss)
self.set_flat_stroke(self.flat_stroke)
return self
def get_tex_file_body(self, tex_string):
new_tex = self.get_modified_expression(tex_string)
if self.math_mode:

View file

@ -71,6 +71,8 @@ class Text(SVGMobject):
PangoUtils.remove_last_M(file_name)
self.remove_empty_path(file_name)
SVGMobject.__init__(self, file_name, **kwargs)
if self.color:
self.set_fill(self.color)
self.text = text
if self.disable_ligatures:
self.apply_space_chars()
@ -85,9 +87,6 @@ class Text(SVGMobject):
if self.height is None:
self.scale(TEXT_MOB_SCALE_FACTOR)
def init_colors(self, override=True):
super().init_colors(override=override)
def remove_empty_path(self, file_name):
with open(file_name, 'r') as fpr:
content = fpr.read()

View file

@ -260,7 +260,7 @@ class TexturedSurface(Surface):
super().init_uniforms()
self.uniforms["num_textures"] = self.num_textures
def init_colors(self, override=True):
def init_colors(self):
self.data["opacity"] = np.array([self.uv_surface.data["rgbas"][:, 3]])
def set_opacity(self, opacity, recurse=True):

View file

@ -90,7 +90,7 @@ class VMobject(Mobject):
})
# Colors
def init_colors(self, override=True):
def init_colors(self):
self.set_fill(
color=self.fill_color or self.color,
opacity=self.fill_opacity,
@ -103,9 +103,6 @@ class VMobject(Mobject):
)
self.set_gloss(self.gloss)
self.set_flat_stroke(self.flat_stroke)
if not override:
for submobjects in self.submobjects:
submobjects.init_colors(override=False)
return self
def set_rgba_array(self, rgba_array, name=None, recurse=False):

View file

@ -277,7 +277,6 @@ class DiscreteGraphScene(Scene):
def trace_cycle(self, cycle=None, color="yellow", run_time=2.0):
if cycle is None:
cycle = self.graph.region_cycles[0]
time_per_edge = run_time / len(cycle)
next_in_cycle = it.cycle(cycle)
next(next_in_cycle) # jump one ahead
self.traced_cycle = Mobject(*[

View file

@ -287,9 +287,6 @@ class LinearTransformationScene(VectorScene):
},
"background_plane_kwargs": {
"color": GREY,
"axis_config": {
"stroke_color": GREY_B,
},
"axis_config": {
"color": GREY,
},

View file

@ -1,4 +1,3 @@
argparse
colour
numpy
Pillow
@ -15,9 +14,9 @@ pygments
pyyaml
rich
screeninfo
pyreadline; sys_platform == 'win32'
validators
ipython
PyOpenGL
manimpango>=0.2.0,<0.4.0
cssselect2
isosurfaces
svgelements

View file

@ -1,6 +1,6 @@
[metadata]
name = manimgl
version = 1.3.0
version = 1.4.1
author = Grant Sanderson
author_email= grant@3blue1brown.com
description = Animation engine for explanatory math videos
@ -12,12 +12,24 @@ project_urls =
Documentation = https://3b1b.github.io/manim/
Source Code = https://github.com/3b1b/manim
license = MIT
classifiers =
Development Status :: 4 - Beta
License :: OSI Approved :: MIT License
Topic :: Scientific/Engineering
Topic :: Multimedia :: Video
Topic :: Multimedia :: Graphics
Programming Language :: Python :: 3.6
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3 :: Only
Natural Language :: English
[options]
packages = find:
include_package_data=True
install_requires =
argparse
include_package_data = True
install_requires =
colour
numpy
Pillow
@ -34,12 +46,12 @@ install_requires =
pyyaml
rich
screeninfo
pyreadline; sys_platform == 'win32'
validators
ipython
PyOpenGL
manimpango>=0.2.0,<0.4.0
cssselect2
isosurfaces
svgelements
[options.entry_points]
console_scripts =