Attempt to refactor SVGMobject with svgelements (#1731)

* Some small refactors

* Refactor MTex

* Implement TransformMatchingMTex

* Some refactors

* Some refactors

* Some small refactors

* Strip strings before matching

* Implement get_submob_tex

* Use RGB color mode

* Some small refactors

* Try refactoring SVGMobject with svglib

* Refactor SVGMobject using svgelements

* Refactor SVGMobject using svgelements

* Use functions instead of func names as dict values

* style: modify import order to conform to PEP8

* Set default values to None

* modify import order

* Remove unused import

Co-authored-by: TonyCrane <tonycrane@foxmail.com>
This commit is contained in:
YishiMichael 2022-02-11 23:53:21 +08:00 committed by GitHub
parent baba6929df
commit 67f5b10626
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 240 additions and 622 deletions

View file

@ -608,8 +608,8 @@ class Arrow(Line):
self.insert_tip_anchor() self.insert_tip_anchor()
return self return self
def init_colors(self, override=True): def init_colors(self):
super().init_colors(override) super().init_colors()
self.create_tip_with_stroke_width() self.create_tip_with_stroke_width()
def get_arc_length(self): def get_arc_length(self):
@ -849,6 +849,11 @@ class Polygon(VMobject):
return self return self
class Polyline(Polygon):
def init_points(self):
self.set_points_as_corners(self.vertices)
class RegularPolygon(Polygon): class RegularPolygon(Polygon):
CONFIG = { CONFIG = {
"start_angle": None, "start_angle": None,

View file

@ -109,8 +109,8 @@ class Mobject(object):
"reflectiveness": self.reflectiveness, "reflectiveness": self.reflectiveness,
} }
def init_colors(self, override=True): def init_colors(self):
self.set_color(self.color, self.opacity, override) self.set_color(self.color, self.opacity)
def init_points(self): def init_points(self):
# Typically implemented in subclass, unlpess purposefully left blank # Typically implemented in subclass, unlpess purposefully left blank

View file

@ -28,6 +28,7 @@ def _get_neighbouring_pairs(iterable):
class _TexSVG(SVGMobject): class _TexSVG(SVGMobject):
CONFIG = { CONFIG = {
"color": BLACK, "color": BLACK,
"stroke_width": 0,
"height": None, "height": None,
"path_string_config": { "path_string_config": {
"should_subdivide_sharp_curves": True, "should_subdivide_sharp_curves": True,
@ -449,7 +450,7 @@ class MTex(VMobject):
labelled_tex_content labelled_tex_content
) )
glyph_labels = [ glyph_labels = [
self.color_to_label(labelled_glyph.fill_color) self.color_to_label(labelled_glyph.get_fill_color())
for labelled_glyph in labelled_svg_glyphs for labelled_glyph in labelled_svg_glyphs
] ]
mob = self.build_mobject(labelled_svg_glyphs, glyph_labels) mob = self.build_mobject(labelled_svg_glyphs, glyph_labels)
@ -463,7 +464,7 @@ class MTex(VMobject):
tex_content = self.get_tex_file_content(self.tex_string) tex_content = self.get_tex_file_content(self.tex_string)
svg_glyphs = self.tex_content_to_glyphs(tex_content) svg_glyphs = self.tex_content_to_glyphs(tex_content)
glyph_labels = [ glyph_labels = [
self.color_to_label(labelled_glyph.fill_color) self.color_to_label(labelled_glyph.get_fill_color())
for labelled_glyph in labelled_svg_glyphs for labelled_glyph in labelled_svg_glyphs
] ]
mob = self.build_mobject(svg_glyphs, glyph_labels) mob = self.build_mobject(svg_glyphs, glyph_labels)

View file

@ -1,100 +1,28 @@
import itertools as it
import re
import string
import os import os
import re
import hashlib import hashlib
import itertools as it
import cssselect2 import svgelements as se
from colour import web2hex import numpy as np
from xml.etree import ElementTree
from tinycss2 import serialize as css_serialize
from tinycss2 import parse_stylesheet, parse_declaration_list
from manimlib.constants import ORIGIN, UP, DOWN, LEFT, RIGHT, IN
from manimlib.constants import DEGREES, PI
from manimlib.constants import RIGHT
from manimlib.mobject.geometry import Line from manimlib.mobject.geometry import Line
from manimlib.mobject.geometry import Circle from manimlib.mobject.geometry import Circle
from manimlib.mobject.geometry import Polygon
from manimlib.mobject.geometry import Polyline
from manimlib.mobject.geometry import Rectangle from manimlib.mobject.geometry import Rectangle
from manimlib.mobject.geometry import RoundedRectangle from manimlib.mobject.geometry import RoundedRectangle
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.color import *
from manimlib.utils.config_ops import digest_config from manimlib.utils.config_ops import digest_config
from manimlib.utils.directories import get_mobject_data_dir from manimlib.utils.directories import get_mobject_data_dir
from manimlib.utils.images import get_full_vector_image_path from manimlib.utils.images import get_full_vector_image_path
from manimlib.utils.simple_functions import clip
from manimlib.logger import log from manimlib.logger import log
DEFAULT_STYLE = { def _convert_point_to_3d(x, y):
"fill": "black", return np.array([x, y, 0.0])
"stroke": "none",
"fill-opacity": "1",
"stroke-opacity": "1",
"stroke-width": 0,
}
def cascade_element_style(element, inherited):
style = inherited.copy()
for attr in DEFAULT_STYLE:
if element.get(attr):
style[attr] = element.get(attr)
if element.get("style"):
declarations = parse_declaration_list(element.get("style"))
for declaration in declarations:
style[declaration.name] = css_serialize(declaration.value)
return style
def parse_color(color):
color = color.strip()
if color[0:3] == "rgb":
splits = color[4:-1].strip().split(",")
if splits[0].strip()[-1] == "%":
parsed_rgbs = [float(i.strip()[:-1]) / 100.0 for i in splits]
else:
parsed_rgbs = [int(i) / 255.0 for i in splits]
return rgb_to_hex(parsed_rgbs)
else:
return web2hex(color)
def fill_default_values(style, default_style):
default = DEFAULT_STYLE.copy()
default.update(default_style)
for attr in default:
if attr not in style:
style[attr] = default[attr]
def parse_style(style, default_style):
manim_style = {}
fill_default_values(style, default_style)
manim_style["fill_opacity"] = float(style["fill-opacity"])
manim_style["stroke_opacity"] = float(style["stroke-opacity"])
manim_style["stroke_width"] = float(style["stroke-width"])
if style["fill"] == "none":
manim_style["fill_opacity"] = 0
else:
manim_style["fill_color"] = parse_color(style["fill"])
if style["stroke"] == "none":
manim_style["stroke_width"] = 0
if "fill_color" in manim_style:
manim_style["stroke_color"] = manim_style["fill_color"]
else:
manim_style["stroke_color"] = parse_color(style["stroke"])
return manim_style
class SVGMobject(VMobject): class SVGMobject(VMobject):
@ -102,11 +30,15 @@ class SVGMobject(VMobject):
"should_center": True, "should_center": True,
"height": 2, "height": 2,
"width": None, "width": None,
# Must be filled in in a subclass, or when called # Must be filled in a subclass, or when called
"file_name": None, "file_name": None,
"unpack_groups": True, # if False, creates a hierarchy of VGroups "color": None,
"stroke_width": 0.0, "opacity": None,
"fill_opacity": 1.0, "fill_color": None,
"fill_opacity": None,
"stroke_width": None,
"stroke_color": None,
"stroke_opacity": None,
"path_string_config": {} "path_string_config": {}
} }
@ -128,338 +60,232 @@ class SVGMobject(VMobject):
if self.width is not None: if self.width is not None:
self.set_width(self.width) self.set_width(self.width)
def init_colors(self, override=False): def init_colors(self):
super().init_colors(override=override) # Remove fill_color, fill_opacity,
# stroke_width, stroke_color, stroke_opacity
# as each submobject may have those values specified in svg file
self.set_stroke(background=self.draw_stroke_behind_fill)
self.set_gloss(self.gloss)
self.set_flat_stroke(self.flat_stroke)
return self
def init_points(self): def init_points(self):
etree = ElementTree.parse(self.file_path) with open(self.file_path, "r") as svg_file:
wrapper = cssselect2.ElementWrapper.from_xml_root(etree) svg_string = svg_file.read()
svg = etree.getroot()
namespace = svg.tag.split("}")[0][1:]
self.ref_to_element = {}
self.css_matcher = cssselect2.Matcher()
for style in etree.findall(f"{{{namespace}}}style"): # Create a temporary svg file to dump modified svg to be parsed
self.parse_css_style(style.text) modified_svg_string = self.modify_svg_file(svg_string)
modified_file_path = self.file_path.replace(".svg", "_.svg")
with open(modified_file_path, "w") as modified_svg_file:
modified_svg_file.write(modified_svg_string)
mobjects = self.get_mobjects_from(wrapper, dict()) # `color` attribute handles `currentColor` keyword
if self.unpack_groups:
self.add(*mobjects)
else:
self.add(*mobjects[0].submobjects)
def get_mobjects_from(self, wrapper, style):
result = []
element = wrapper.etree_element
if not isinstance(element, ElementTree.Element):
return result
matches = self.css_matcher.match(wrapper)
if matches:
for match in matches:
_, _, _, css_style = match
style.update(css_style)
style = cascade_element_style(element, style)
tag = element.tag.split("}")[-1]
if tag == 'defs':
self.update_ref_to_element(wrapper, style)
elif tag in ['g', 'svg', 'symbol']:
result += it.chain(*(
self.get_mobjects_from(child, style)
for child in wrapper.iter_children()
))
elif tag == 'path':
result.append(self.path_string_to_mobject(
element.get('d'), style
))
elif tag == 'use':
result += self.use_to_mobjects(element, style)
elif tag == 'line':
result.append(self.line_to_mobject(element, style))
elif tag == 'rect':
result.append(self.rect_to_mobject(element, style))
elif tag == 'circle':
result.append(self.circle_to_mobject(element, style))
elif tag == 'ellipse':
result.append(self.ellipse_to_mobject(element, style))
elif tag in ['polygon', 'polyline']:
result.append(self.polygon_to_mobject(element, style))
elif tag == 'style':
pass
else:
log.warning(f"Unsupported element type: {tag}")
pass # TODO, support <text> tag
result = [m for m in result if m is not None]
self.handle_transforms(element, VGroup(*result))
if len(result) > 1 and not self.unpack_groups:
result = [VGroup(*result)]
return result
def generate_default_style(self):
style = {
"fill-opacity": self.fill_opacity,
"stroke-width": self.stroke_width,
"stroke-opacity": self.stroke_opacity,
}
if self.color:
style["fill"] = style["stroke"] = self.color
if self.fill_color: if self.fill_color:
style["fill"] = self.fill_color color = self.fill_color
if self.stroke_color: elif self.color:
style["stroke"] = self.stroke_color color = self.color
return style else:
color = "black"
def parse_css_style(self, css): shapes = se.SVG.parse(
rules = parse_stylesheet(css, True, True) modified_file_path,
for rule in rules: color=color
selectors = cssselect2.compile_selector_list(rule.prelude)
declarations = parse_declaration_list(rule.content)
style = {
declaration.name: css_serialize(declaration.value)
for declaration in declarations
if declaration.name in DEFAULT_STYLE
}
payload = style
for selector in selectors:
self.css_matcher.add_selector(selector, payload)
def path_string_to_mobject(self, path_string, style):
return VMobjectFromSVGPathstring(
path_string,
**self.path_string_config,
**parse_style(style, self.generate_default_style()),
) )
os.remove(modified_file_path)
def use_to_mobjects(self, use_element, local_style): mobjects = self.get_mobjects_from(shapes)
# Remove initial "#" character self.add(*mobjects)
ref = use_element.get(r"{http://www.w3.org/1999/xlink}href")[1:] self.flip(RIGHT) # Flip y
if ref not in self.ref_to_element: self.scale(0.75)
log.warning(f"{ref} not recognized")
return VGroup()
def_element, def_style = self.ref_to_element[ref]
style = local_style.copy()
style.update(def_style)
return self.get_mobjects_from(def_element, style)
def attribute_to_float(self, attr): def modify_svg_file(self, svg_string):
stripped_attr = "".join([ # svgelements cannot handle em, ex units
char for char in attr # Convert them using 1em = 16px, 1ex = 0.5em = 8px
if char in string.digits + "." + "-" def convert_unit(match_obj):
number = float(match_obj.group(1))
unit = match_obj.group(2)
factor = 16 if unit == "em" else 8
return str(number * factor) + "px"
number_pattern = r"([-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?)(ex|em)(?![a-zA-Z])"
result = re.sub(number_pattern, convert_unit, svg_string)
# Add a group tag to set style from configuration
style_dict = self.generate_context_values_from_config()
group_tag_begin = "<g " + " ".join([
f"{k}=\"{v}\""
for k, v in style_dict.items()
]) + ">"
group_tag_end = "</g>"
begin_insert_index = re.search(r"<svg[\s\S]*?>", result).end()
end_insert_index = re.search(r"[\s\S]*(</svg\s*>)", result).start(1)
result = "".join([
result[:begin_insert_index],
group_tag_begin,
result[begin_insert_index:end_insert_index],
group_tag_end,
result[end_insert_index:]
]) ])
return float(stripped_attr)
def polygon_to_mobject(self, polygon_element, style):
path_string = polygon_element.get("points")
for digit in string.digits:
path_string = path_string.replace(f" {digit}", f"L {digit}")
path_string = path_string.replace("L", "M", 1)
return self.path_string_to_mobject(path_string, style)
def circle_to_mobject(self, circle_element, style):
x, y, r = (
self.attribute_to_float(circle_element.get(key, "0.0"))
for key in ("cx", "cy", "r")
)
return Circle(
radius=r,
**parse_style(style, self.generate_default_style())
).shift(x * RIGHT + y * DOWN)
def ellipse_to_mobject(self, circle_element, style):
x, y, rx, ry = (
self.attribute_to_float(circle_element.get(key, "0.0"))
for key in ("cx", "cy", "rx", "ry")
)
result = Circle(**parse_style(style, self.generate_default_style()))
result.stretch(rx, 0)
result.stretch(ry, 1)
result.shift(x * RIGHT + y * DOWN)
return result return result
def rect_to_mobject(self, rect_element, style): def generate_context_values_from_config(self):
stroke_width = rect_element.get("stroke-width", "") result = {}
corner_radius = rect_element.get("rx", "") if self.stroke_width is not None:
result["stroke-width"] = self.stroke_width
if self.color is not None:
result["fill"] = result["stroke"] = self.color
if self.fill_color is not None:
result["fill"] = self.fill_color
if self.stroke_color is not None:
result["stroke"] = self.stroke_color
if self.opacity is not None:
result["fill-opacity"] = result["stroke-opacity"] = self.opacity
if self.fill_opacity is not None:
result["fill-opacity"] = self.fill_opacity
if self.stroke_opacity is not None:
result["stroke-opacity"] = self.stroke_opacity
return result
if stroke_width in ["", "none", "0"]: def get_mobjects_from(self, shape):
stroke_width = 0 if isinstance(shape, se.Group):
return list(it.chain(*(
self.get_mobjects_from(child)
for child in shape
)))
if corner_radius in ["", "0", "none"]: mob = self.get_mobject_from(shape)
corner_radius = 0 if mob is None:
return []
corner_radius = float(corner_radius) if isinstance(shape, se.Transformable) and shape.apply:
self.handle_transform(mob, shape.transform)
return [mob]
parsed_style = parse_style(style, self.generate_default_style()) @staticmethod
parsed_style["stroke_width"] = stroke_width def handle_transform(mob, matrix):
mat = np.array([
[matrix.a, matrix.c],
[matrix.b, matrix.d]
])
vec = np.array([matrix.e, matrix.f, 0.0])
mob.apply_matrix(mat)
mob.shift(vec)
return mob
if corner_radius == 0: def get_mobject_from(self, shape):
shape_class_to_func_map = {
se.Path: self.path_to_mobject,
se.SimpleLine: self.line_to_mobject,
se.Rect: self.rect_to_mobject,
se.Circle: self.circle_to_mobject,
se.Ellipse: self.ellipse_to_mobject,
se.Polygon: self.polygon_to_mobject,
se.Polyline: self.polyline_to_mobject,
# se.Text: self.text_to_mobject, # TODO
}
for shape_class, func in shape_class_to_func_map.items():
if isinstance(shape, shape_class):
mob = func(shape)
self.apply_style_to_mobject(mob, shape)
return mob
shape_class_name = shape.__class__.__name__
if shape_class_name != "SVGElement":
log.warning(f"Unsupported element type: {shape_class_name}")
return None
@staticmethod
def apply_style_to_mobject(mob, shape):
mob.set_style(
stroke_width=shape.stroke_width,
stroke_color=shape.stroke.hex,
stroke_opacity=shape.stroke.opacity,
fill_color=shape.fill.hex,
fill_opacity=shape.fill.opacity
)
return mob
def path_to_mobject(self, path):
return VMobjectFromSVGPath(path, **self.path_string_config)
def line_to_mobject(self, line):
return Line(
start=_convert_point_to_3d(line.x1, line.y1),
end=_convert_point_to_3d(line.x2, line.y2)
)
def rect_to_mobject(self, rect):
if rect.rx == 0 or rect.ry == 0:
mob = Rectangle( mob = Rectangle(
width=self.attribute_to_float( width=rect.width,
rect_element.get("width", "") height=rect.height,
),
height=self.attribute_to_float(
rect_element.get("height", "")
),
**parsed_style,
) )
else: else:
mob = RoundedRectangle( mob = RoundedRectangle(
width=self.attribute_to_float( width=rect.width,
rect_element.get("width", "") height=rect.height * rect.rx / rect.ry,
), corner_radius=rect.rx
height=self.attribute_to_float(
rect_element.get("height", "")
),
corner_radius=corner_radius,
**parsed_style
) )
mob.stretch_to_fit_height(rect.height)
mob.shift(mob.get_center() - mob.get_corner(UP + LEFT)) mob.shift(_convert_point_to_3d(
rect.x + rect.width / 2,
rect.y + rect.height / 2
))
return mob return mob
def line_to_mobject(self, line_element, style):
x1, y1, x2, y2 = (
self.attribute_to_float(line_element.get(key, "0.0"))
for key in ("x1", "y1", "x2", "y2")
)
return Line(
[x1, -y1, 0], [x2, -y2, 0],
**parse_style(style, self.generate_default_style())
)
def handle_transforms(self, element, mobject): def circle_to_mobject(self, circle):
x, y = ( # svgelements supports `rx` & `ry` but `r`
self.attribute_to_float(element.get(key, "0.0")) mob = Circle(radius=circle.rx)
for key in ("x", "y") mob.shift(_convert_point_to_3d(
) circle.cx, circle.cy
mobject.shift(x * RIGHT + y * DOWN) ))
return mob
transform_names = [ def ellipse_to_mobject(self, ellipse):
"matrix", mob = Circle(radius=ellipse.rx)
"translate", "translateX", "translateY", mob.stretch_to_fit_height(2 * ellipse.ry)
"scale", "scaleX", "scaleY", mob.shift(_convert_point_to_3d(
"rotate", ellipse.cx, ellipse.cy
"skewX", "skewY" ))
return mob
def polygon_to_mobject(self, polygon):
points = [
_convert_point_to_3d(*point)
for point in polygon
] ]
transform_pattern = re.compile("|".join([x + r"[^)]*\)" for x in transform_names])) return Polygon(*points)
number_pattern = re.compile(r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?")
transforms = transform_pattern.findall(element.get("transform", ""))[::-1]
for transform in transforms: def polyline_to_mobject(self, polyline):
op_name, op_args = transform.split("(") points = [
op_name = op_name.strip() _convert_point_to_3d(*point)
op_args = [float(x) for x in number_pattern.findall(op_args)] for point in polyline
]
return Polyline(*points)
if op_name == "matrix": def text_to_mobject(self, text):
self._handle_matrix_transform(mobject, op_name, op_args) pass
elif op_name.startswith("translate"):
self._handle_translate_transform(mobject, op_name, op_args)
elif op_name.startswith("scale"):
self._handle_scale_transform(mobject, op_name, op_args)
elif op_name == "rotate":
self._handle_rotate_transform(mobject, op_name, op_args)
elif op_name.startswith("skew"):
self._handle_skew_transform(mobject, op_name, op_args)
def _handle_matrix_transform(self, mobject, op_name, op_args):
transform = np.array(op_args).reshape([3, 2])
x = transform[2][0]
y = -transform[2][1]
matrix = np.identity(self.dim)
matrix[:2, :2] = transform[:2, :]
matrix[1] *= -1
matrix[:, 1] *= -1
for mob in mobject.family_members_with_points():
mob.apply_matrix(matrix.T)
mobject.shift(x * RIGHT + y * UP)
def _handle_translate_transform(self, mobject, op_name, op_args):
if op_name.endswith("X"):
x, y = op_args[0], 0
elif op_name.endswith("Y"):
x, y = 0, op_args[0]
else:
x, y = op_args
mobject.shift(x * RIGHT + y * DOWN)
def _handle_scale_transform(self, mobject, op_name, op_args):
if op_name.endswith("X"):
sx, sy = op_args[0], 1
elif op_name.endswith("Y"):
sx, sy = 1, op_args[0]
elif len(op_args) == 2:
sx, sy = op_args
else:
sx = sy = op_args[0]
if sx < 0:
mobject.flip(UP)
sx = -sx
if sy < 0:
mobject.flip(RIGHT)
sy = -sy
mobject.scale(np.array([sx, sy, 1]), about_point=ORIGIN)
def _handle_rotate_transform(self, mobject, op_name, op_args):
if len(op_args) == 1:
mobject.rotate(op_args[0] * DEGREES, axis=IN, about_point=ORIGIN)
else:
deg, x, y = op_args
mobject.rotate(deg * DEGREES, axis=IN, about_point=np.array([x, y, 0]))
def _handle_skew_transform(self, mobject, op_name, op_args):
rad = op_args[0] * DEGREES
if op_name == "skewX":
tana = np.tan(rad)
self._handle_matrix_transform(mobject, None, [1., 0., tana, 1., 0., 0.])
elif op_name == "skewY":
tana = np.tan(rad)
self._handle_matrix_transform(mobject, None, [1., tana, 0., 1., 0., 0.])
def flatten(self, input_list):
output_list = []
for i in input_list:
if isinstance(i, list):
output_list.extend(self.flatten(i))
else:
output_list.append(i)
return output_list
def get_all_childWrappers_have_id(self, wrapper):
all_childWrappers_have_id = []
element = wrapper.etree_element
if not isinstance(element, ElementTree.Element):
return
if element.get('id'):
return [wrapper]
for e in wrapper.iter_children():
all_childWrappers_have_id.append(self.get_all_childWrappers_have_id(e))
return self.flatten([e for e in all_childWrappers_have_id if e])
def update_ref_to_element(self, wrapper, style):
new_refs = {
e.etree_element.get('id', ''): (e, style)
for e in self.get_all_childWrappers_have_id(wrapper)
}
self.ref_to_element.update(new_refs)
class VMobjectFromSVGPathstring(VMobject): class VMobjectFromSVGPath(VMobject):
CONFIG = { CONFIG = {
"long_lines": False, "long_lines": False,
"should_subdivide_sharp_curves": False, "should_subdivide_sharp_curves": False,
"should_remove_null_curves": False, "should_remove_null_curves": False,
} }
def __init__(self, path_string, **kwargs): def __init__(self, path_obj, **kwargs):
self.path_string = path_string # Get rid of arcs
path_obj.approximate_arcs_with_quads()
self.path_obj = path_obj
super().__init__(**kwargs) super().__init__(**kwargs)
def init_points(self): def init_points(self):
# After a given svg_path has been converted into points, the result # After a given svg_path has been converted into points, the result
# will be saved to a file so that future calls for the same path # will be saved to a file so that future calls for the same path
# don't need to retrace the same computation. # don't need to retrace the same computation.
hasher = hashlib.sha256(self.path_string.encode()) path_string = self.path_obj.d()
hasher = hashlib.sha256(path_string.encode())
path_hash = hasher.hexdigest()[:16] path_hash = hasher.hexdigest()[:16]
points_filepath = os.path.join(get_mobject_data_dir(), f"{path_hash}_points.npy") points_filepath = os.path.join(get_mobject_data_dir(), f"{path_hash}_points.npy")
tris_filepath = os.path.join(get_mobject_data_dir(), f"{path_hash}_tris.npy") tris_filepath = os.path.join(get_mobject_data_dir(), f"{path_hash}_tris.npy")
@ -476,239 +302,23 @@ class VMobjectFromSVGPathstring(VMobject):
if self.should_remove_null_curves: if self.should_remove_null_curves:
# Get rid of any null curves # Get rid of any null curves
self.set_points(self.get_points_without_null_curves()) self.set_points(self.get_points_without_null_curves())
# SVG treats y-coordinate differently
self.stretch(-1, 1, about_point=ORIGIN)
# Save to a file for future use # Save to a file for future use
np.save(points_filepath, self.get_points()) np.save(points_filepath, self.get_points())
np.save(tris_filepath, self.get_triangulation()) np.save(tris_filepath, self.get_triangulation())
def get_commands_and_coord_strings(self):
all_commands = list(self.get_command_to_function_map().keys())
all_commands += [c.lower() for c in all_commands]
pattern = "[{}]".format("".join(all_commands))
return zip(
re.findall(pattern, self.path_string),
re.split(pattern, self.path_string)[1:]
)
def handle_commands(self): def handle_commands(self):
relative_point = ORIGIN segment_class_to_func_map = {
for command, coord_string in self.get_commands_and_coord_strings(): se.Move: (self.start_new_path, ("end",)),
func, number_types_str = self.command_to_function(command) se.Close: (self.close_path, ()),
upper_command = command.upper() se.Line: (self.add_line_to, ("end",)),
if upper_command == "Z": se.QuadraticBezier: (self.add_quadratic_bezier_curve_to, ("control", "end")),
func() # `close_path` takes no arguments se.CubicBezier: (self.add_cubic_bezier_curve_to, ("control1", "control2", "end"))
relative_point = self.get_last_point()
continue
number_types = np.array(list(number_types_str))
n_numbers = len(number_types_str)
number_list = _PathStringParser(coord_string, number_types_str).args
number_groups = np.array(number_list).reshape((-1, n_numbers))
for ind, numbers in enumerate(number_groups):
if command.islower():
# Treat it as a relative command
numbers[number_types == "x"] += relative_point[0]
numbers[number_types == "y"] += relative_point[1]
if upper_command == "A":
args = [*numbers[:5], np.array([*numbers[5:7], 0.0])]
elif upper_command == "H":
args = [np.array([numbers[0], relative_point[1], 0.0])]
elif upper_command == "V":
args = [np.array([relative_point[0], numbers[0], 0.0])]
else:
args = list(np.hstack((
numbers.reshape((-1, 2)), np.zeros((n_numbers // 2, 1))
)))
if upper_command == "M" and ind != 0:
# M x1 y1 x2 y2 is equal to M x1 y1 L x2 y2
func, _ = self.command_to_function("L")
func(*args)
relative_point = self.get_last_point()
def add_elliptical_arc_to(self, rx, ry, x_axis_rotation, large_arc_flag, sweep_flag, point):
def close_to_zero(a, threshold=1e-5):
return abs(a) < threshold
def solve_2d_linear_equation(a, b, c):
"""
Using Crammer's rule to solve the linear equation `[a b]x = c`
where `a`, `b` and `c` are all 2d vectors.
"""
def det(a, b):
return a[0] * b[1] - a[1] * b[0]
d = det(a, b)
if close_to_zero(d):
raise Exception("Cannot handle 0 determinant.")
return [det(c, b) / d, det(a, c) / d]
def get_arc_center_and_angles(x0, y0, rx, ry, phi, large_arc_flag, sweep_flag, x1, y1):
"""
The parameter functions of an ellipse rotated `phi` radians counterclockwise is (on `alpha`):
x = cx + rx * cos(alpha) * cos(phi) + ry * sin(alpha) * sin(phi),
y = cy + rx * cos(alpha) * sin(phi) - ry * sin(alpha) * cos(phi).
Now we have two points sitting on the ellipse: `(x0, y0)`, `(x1, y1)`, corresponding to 4 equations,
and we want to hunt for 4 variables: `cx`, `cy`, `alpha0` and `alpha_1`.
Let `d_alpha = alpha1 - alpha0`, then:
if `sweep_flag = 0` and `large_arc_flag = 1`, then `PI <= d_alpha < 2 * PI`;
if `sweep_flag = 0` and `large_arc_flag = 0`, then `0 < d_alpha <= PI`;
if `sweep_flag = 1` and `large_arc_flag = 0`, then `-PI <= d_alpha < 0`;
if `sweep_flag = 1` and `large_arc_flag = 1`, then `-2 * PI < d_alpha <= -PI`.
"""
xd = x1 - x0
yd = y1 - y0
if close_to_zero(xd) and close_to_zero(yd):
raise Exception("Cannot find arc center since the start point and the end point meet.")
# Find `p = cos(alpha1) - cos(alpha0)`, `q = sin(alpha1) - sin(alpha0)`
eq0 = [rx * np.cos(phi), ry * np.sin(phi), xd]
eq1 = [rx * np.sin(phi), -ry * np.cos(phi), yd]
p, q = solve_2d_linear_equation(*zip(eq0, eq1))
# Find `s = (alpha1 - alpha0) / 2`, `t = (alpha1 + alpha0) / 2`
# If `sin(s) = 0`, this requires `p = q = 0`,
# implying `xd = yd = 0`, which is impossible.
sin_s = (p ** 2 + q ** 2) ** 0.5 / 2
if sweep_flag:
sin_s = -sin_s
sin_s = clip(sin_s, -1, 1)
s = np.arcsin(sin_s)
if large_arc_flag:
if not sweep_flag:
s = PI - s
else:
s = -PI - s
sin_t = -p / (2 * sin_s)
cos_t = q / (2 * sin_s)
cos_t = clip(cos_t, -1, 1)
t = np.arccos(cos_t)
if sin_t <= 0:
t = -t
# We can make sure `0 < abs(s) < PI`, `-PI <= t < PI`.
alpha0 = t - s
alpha_1 = t + s
cx = x0 - rx * np.cos(alpha0) * np.cos(phi) - ry * np.sin(alpha0) * np.sin(phi)
cy = y0 - rx * np.cos(alpha0) * np.sin(phi) + ry * np.sin(alpha0) * np.cos(phi)
return cx, cy, alpha0, alpha_1
def get_point_on_ellipse(cx, cy, rx, ry, phi, angle):
return np.array([
cx + rx * np.cos(angle) * np.cos(phi) + ry * np.sin(angle) * np.sin(phi),
cy + rx * np.cos(angle) * np.sin(phi) - ry * np.sin(angle) * np.cos(phi),
0
])
def convert_elliptical_arc_to_quadratic_bezier_curve(
cx, cy, rx, ry, phi, start_angle, end_angle, n_components=8
):
theta = (end_angle - start_angle) / n_components / 2
handles = np.array([
get_point_on_ellipse(cx, cy, rx / np.cos(theta), ry / np.cos(theta), phi, a)
for a in np.linspace(
start_angle + theta,
end_angle - theta,
n_components,
)
])
anchors = np.array([
get_point_on_ellipse(cx, cy, rx, ry, phi, a)
for a in np.linspace(
start_angle + theta * 2,
end_angle,
n_components,
)
])
return handles, anchors
phi = x_axis_rotation * DEGREES
x0, y0 = self.get_last_point()[:2]
cx, cy, start_angle, end_angle = get_arc_center_and_angles(
x0, y0, rx, ry, phi, large_arc_flag, sweep_flag, point[0], point[1]
)
handles, anchors = convert_elliptical_arc_to_quadratic_bezier_curve(
cx, cy, rx, ry, phi, start_angle, end_angle
)
for handle, anchor in zip(handles, anchors):
self.add_quadratic_bezier_curve_to(handle, anchor)
def command_to_function(self, command):
return self.get_command_to_function_map()[command.upper()]
def get_command_to_function_map(self):
"""
Associates svg command to VMobject function, and
the types of arguments it takes in
"""
return {
"M": (self.start_new_path, "xy"),
"L": (self.add_line_to, "xy"),
"H": (self.add_line_to, "x"),
"V": (self.add_line_to, "y"),
"C": (self.add_cubic_bezier_curve_to, "xyxyxy"),
"S": (self.add_smooth_cubic_curve_to, "xyxy"),
"Q": (self.add_quadratic_bezier_curve_to, "xyxy"),
"T": (self.add_smooth_curve_to, "xy"),
"A": (self.add_elliptical_arc_to, "uuaffxy"),
"Z": (self.close_path, ""),
} }
for segment in self.path_obj:
def get_original_path_string(self): segment_class = segment.__class__
return self.path_string func, attr_names = segment_class_to_func_map[segment_class]
points = [
_convert_point_to_3d(*segment.__getattribute__(attr_name))
class InvalidPathError(ValueError): for attr_name in attr_names
pass ]
func(*points)
class _PathStringParser:
# modified from https://github.com/regebro/svg.path/
def __init__(self, arguments, rules):
self.args = []
arguments = bytearray(arguments, "ascii")
self._strip_array(arguments)
while arguments:
for rule in rules:
self._rule_to_function_map[rule](arguments)
@property
def _rule_to_function_map(self):
return {
"x": self._get_number,
"y": self._get_number,
"a": self._get_number,
"u": self._get_unsigned_number,
"f": self._get_flag,
}
def _strip_array(self, arg_array):
# wsp: (0x9, 0x20, 0xA, 0xC, 0xD) with comma 0x2C
# https://www.w3.org/TR/SVG/paths.html#PathDataBNF
while arg_array and arg_array[0] in [0x9, 0x20, 0xA, 0xC, 0xD, 0x2C]:
arg_array[0:1] = b""
def _get_number(self, arg_array):
pattern = re.compile(rb"^[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?")
res = pattern.search(arg_array)
if not res:
raise InvalidPathError(f"Expected a number, got '{arg_array}'")
number = float(res.group())
self.args.append(number)
arg_array[res.start():res.end()] = b""
self._strip_array(arg_array)
return number
def _get_unsigned_number(self, arg_array):
number = self._get_number(arg_array)
if number < 0:
raise InvalidPathError(f"Expected an unsigned number, got '{number}'")
return number
def _get_flag(self, arg_array):
flag = arg_array[0]
if flag != 48 and flag != 49:
raise InvalidPathError(f"Expected a flag (0/1), got '{chr(flag)}'")
flag -= 48
self.args.append(flag)
arg_array[0:1] = b""
self._strip_array(arg_array)
return flag

View file

@ -54,13 +54,19 @@ class SingleStringTex(VMobject):
sm.copy() sm.copy()
for sm in tex_string_with_color_to_mob_map[(self.color, tex_string)] for sm in tex_string_with_color_to_mob_map[(self.color, tex_string)]
)) ))
self.init_colors(override=False) self.init_colors()
if self.height is None: if self.height is None:
self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size)
if self.organize_left_to_right: if self.organize_left_to_right:
self.organize_submobjects_left_to_right() self.organize_submobjects_left_to_right()
def init_colors(self):
self.set_stroke(background=self.draw_stroke_behind_fill)
self.set_gloss(self.gloss)
self.set_flat_stroke(self.flat_stroke)
return self
def get_tex_file_body(self, tex_string): def get_tex_file_body(self, tex_string):
new_tex = self.get_modified_expression(tex_string) new_tex = self.get_modified_expression(tex_string)
if self.math_mode: if self.math_mode:

View file

@ -71,6 +71,8 @@ class Text(SVGMobject):
PangoUtils.remove_last_M(file_name) PangoUtils.remove_last_M(file_name)
self.remove_empty_path(file_name) self.remove_empty_path(file_name)
SVGMobject.__init__(self, file_name, **kwargs) SVGMobject.__init__(self, file_name, **kwargs)
if self.color:
self.set_fill(self.color)
self.text = text self.text = text
if self.disable_ligatures: if self.disable_ligatures:
self.apply_space_chars() self.apply_space_chars()
@ -85,9 +87,6 @@ class Text(SVGMobject):
if self.height is None: if self.height is None:
self.scale(TEXT_MOB_SCALE_FACTOR) self.scale(TEXT_MOB_SCALE_FACTOR)
def init_colors(self, override=True):
super().init_colors(override=override)
def remove_empty_path(self, file_name): def remove_empty_path(self, file_name):
with open(file_name, 'r') as fpr: with open(file_name, 'r') as fpr:
content = fpr.read() content = fpr.read()

View file

@ -260,7 +260,7 @@ class TexturedSurface(Surface):
super().init_uniforms() super().init_uniforms()
self.uniforms["num_textures"] = self.num_textures self.uniforms["num_textures"] = self.num_textures
def init_colors(self, override=True): def init_colors(self):
self.data["opacity"] = np.array([self.uv_surface.data["rgbas"][:, 3]]) self.data["opacity"] = np.array([self.uv_surface.data["rgbas"][:, 3]])
def set_opacity(self, opacity, recurse=True): def set_opacity(self, opacity, recurse=True):

View file

@ -90,7 +90,7 @@ class VMobject(Mobject):
}) })
# Colors # Colors
def init_colors(self, override=True): def init_colors(self):
self.set_fill( self.set_fill(
color=self.fill_color or self.color, color=self.fill_color or self.color,
opacity=self.fill_opacity, opacity=self.fill_opacity,
@ -103,9 +103,6 @@ class VMobject(Mobject):
) )
self.set_gloss(self.gloss) self.set_gloss(self.gloss)
self.set_flat_stroke(self.flat_stroke) self.set_flat_stroke(self.flat_stroke)
if not override:
for submobjects in self.submobjects:
submobjects.init_colors(override=False)
return self return self
def set_rgba_array(self, rgba_array, name=None, recurse=False): def set_rgba_array(self, rgba_array, name=None, recurse=False):

View file

@ -18,5 +18,5 @@ validators
ipython ipython
PyOpenGL PyOpenGL
manimpango>=0.2.0,<0.4.0 manimpango>=0.2.0,<0.4.0
cssselect2
isosurfaces isosurfaces
svgelements

View file

@ -50,8 +50,8 @@ install_requires =
ipython ipython
PyOpenGL PyOpenGL
manimpango>=0.2.0,<0.4.0 manimpango>=0.2.0,<0.4.0
cssselect2
isosurfaces isosurfaces
svgelements
[options.entry_points] [options.entry_points]
console_scripts = console_scripts =