mirror of
https://github.com/3b1b/manim.git
synced 2025-04-13 09:47:07 +00:00
701 lines
26 KiB
Python
701 lines
26 KiB
Python
import itertools as it
|
|
import re
|
|
import string
|
|
import os
|
|
import hashlib
|
|
|
|
import cssselect2
|
|
from colour import web2hex
|
|
from xml.etree import ElementTree
|
|
from tinycss2 import serialize as css_serialize
|
|
from tinycss2 import parse_stylesheet, parse_declaration_list
|
|
|
|
from manimlib.constants import ORIGIN, UP, DOWN, LEFT, RIGHT, IN
|
|
from manimlib.constants import DEGREES, PI
|
|
|
|
from manimlib.mobject.geometry import Circle
|
|
from manimlib.mobject.geometry import Rectangle
|
|
from manimlib.mobject.geometry import RoundedRectangle
|
|
from manimlib.mobject.types.vectorized_mobject import VGroup
|
|
from manimlib.mobject.types.vectorized_mobject import VMobject
|
|
from manimlib.utils.color import *
|
|
from manimlib.utils.config_ops import digest_config
|
|
from manimlib.utils.directories import get_mobject_data_dir
|
|
from manimlib.utils.images import get_full_vector_image_path
|
|
from manimlib.utils.simple_functions import clip
|
|
from manimlib.logger import log
|
|
|
|
|
|
DEFAULT_STYLE = {
|
|
"fill": "black",
|
|
"stroke": "none",
|
|
"fill-opacity": "1",
|
|
"stroke-opacity": "1",
|
|
"stroke-width": 0,
|
|
}
|
|
|
|
|
|
def cascade_element_style(element, inherited):
|
|
style = inherited.copy()
|
|
|
|
for attr in DEFAULT_STYLE:
|
|
if element.get(attr):
|
|
style[attr] = element.get(attr)
|
|
|
|
if element.get("style"):
|
|
declarations = parse_declaration_list(element.get("style"))
|
|
for declaration in declarations:
|
|
style[declaration.name] = css_serialize(declaration.value)
|
|
|
|
return style
|
|
|
|
|
|
def parse_color(color):
|
|
color = color.strip()
|
|
|
|
if color[0:3] == "rgb":
|
|
splits = color[4:-1].strip().split(",")
|
|
if splits[0].strip()[-1] == "%":
|
|
parsed_rgbs = [float(i.strip()[:-1]) / 100.0 for i in splits]
|
|
else:
|
|
parsed_rgbs = [int(i) / 255.0 for i in splits]
|
|
return rgb_to_hex(parsed_rgbs)
|
|
|
|
else:
|
|
return web2hex(color)
|
|
|
|
|
|
def fill_default_values(style, default_style):
|
|
default = DEFAULT_STYLE.copy()
|
|
default.update(default_style)
|
|
for attr in default:
|
|
if attr not in style:
|
|
style[attr] = default[attr]
|
|
|
|
|
|
def parse_style(style, default_style):
|
|
manim_style = {}
|
|
fill_default_values(style, default_style)
|
|
|
|
manim_style["fill_opacity"] = float(style["fill-opacity"])
|
|
manim_style["stroke_opacity"] = float(style["stroke-opacity"])
|
|
manim_style["stroke_width"] = float(style["stroke-width"])
|
|
|
|
if style["fill"] == "none":
|
|
manim_style["fill_opacity"] = 0
|
|
else:
|
|
manim_style["fill_color"] = parse_color(style["fill"])
|
|
|
|
if style["stroke"] == "none":
|
|
manim_style["stroke_width"] = 0
|
|
if "fill_color" in manim_style:
|
|
manim_style["stroke_color"] = manim_style["fill_color"]
|
|
else:
|
|
manim_style["stroke_color"] = parse_color(style["stroke"])
|
|
|
|
return manim_style
|
|
|
|
|
|
class SVGMobject(VMobject):
|
|
CONFIG = {
|
|
"should_center": True,
|
|
"height": 2,
|
|
"width": None,
|
|
# Must be filled in in a subclass, or when called
|
|
"file_name": None,
|
|
"unpack_groups": True, # if False, creates a hierarchy of VGroups
|
|
"stroke_width": 0.0,
|
|
"fill_opacity": 1.0,
|
|
"path_string_config": {}
|
|
}
|
|
|
|
def __init__(self, file_name=None, **kwargs):
|
|
digest_config(self, kwargs)
|
|
self.file_name = file_name or self.file_name
|
|
if file_name is None:
|
|
raise Exception("Must specify file for SVGMobject")
|
|
self.file_path = get_full_vector_image_path(file_name)
|
|
|
|
super().__init__(**kwargs)
|
|
self.move_into_position()
|
|
|
|
def move_into_position(self):
|
|
if self.should_center:
|
|
self.center()
|
|
if self.height is not None:
|
|
self.set_height(self.height)
|
|
if self.width is not None:
|
|
self.set_width(self.width)
|
|
|
|
def init_colors(self, override=False):
|
|
super().init_colors(override=override)
|
|
|
|
def init_points(self):
|
|
etree = ElementTree.parse(self.file_path)
|
|
wrapper = cssselect2.ElementWrapper.from_xml_root(etree)
|
|
svg = etree.getroot()
|
|
namespace = svg.tag.split("}")[0][1:]
|
|
self.ref_to_element = {}
|
|
self.css_matcher = cssselect2.Matcher()
|
|
|
|
for style in etree.findall(f"{{{namespace}}}style"):
|
|
self.parse_css_style(style.text)
|
|
|
|
mobjects = self.get_mobjects_from(wrapper, dict())
|
|
if self.unpack_groups:
|
|
self.add(*mobjects)
|
|
else:
|
|
self.add(*mobjects[0].submobjects)
|
|
|
|
def get_mobjects_from(self, wrapper, style):
|
|
result = []
|
|
element = wrapper.etree_element
|
|
if not isinstance(element, ElementTree.Element):
|
|
return result
|
|
|
|
matches = self.css_matcher.match(wrapper)
|
|
if matches:
|
|
for match in matches:
|
|
_, _, _, css_style = match
|
|
style.update(css_style)
|
|
style = cascade_element_style(element, style)
|
|
|
|
tag = element.tag.split("}")[-1]
|
|
if tag == 'defs':
|
|
self.update_ref_to_element(wrapper, style)
|
|
elif tag in ['g', 'svg', 'symbol']:
|
|
result += it.chain(*(
|
|
self.get_mobjects_from(child, style)
|
|
for child in wrapper.iter_children()
|
|
))
|
|
elif tag == 'path':
|
|
result.append(self.path_string_to_mobject(
|
|
element.get('d'), style
|
|
))
|
|
elif tag == 'use':
|
|
result += self.use_to_mobjects(element, style)
|
|
elif tag == 'rect':
|
|
result.append(self.rect_to_mobject(element, style))
|
|
elif tag == 'circle':
|
|
result.append(self.circle_to_mobject(element, style))
|
|
elif tag == 'ellipse':
|
|
result.append(self.ellipse_to_mobject(element, style))
|
|
elif tag in ['polygon', 'polyline']:
|
|
result.append(self.polygon_to_mobject(element, style))
|
|
elif tag == 'style':
|
|
pass
|
|
else:
|
|
log.warning(f"Unsupported element type: {tag}")
|
|
pass # TODO, support <line> and <text> tag
|
|
result = [m.insert_n_curves(0) for m in result if m is not None]
|
|
self.handle_transforms(element, VGroup(*result))
|
|
if len(result) > 1 and not self.unpack_groups:
|
|
result = [VGroup(*result)]
|
|
|
|
return result
|
|
|
|
def generate_default_style(self):
|
|
style = {
|
|
"fill-opacity": self.fill_opacity,
|
|
"stroke-width": self.stroke_width,
|
|
"stroke-opacity": self.stroke_opacity,
|
|
}
|
|
if self.color:
|
|
style["fill"] = style["stroke"] = self.color
|
|
if self.fill_color:
|
|
style["fill"] = self.fill_color
|
|
if self.stroke_color:
|
|
style["stroke"] = self.stroke_color
|
|
return style
|
|
|
|
def parse_css_style(self, css):
|
|
rules = parse_stylesheet(css, True, True)
|
|
for rule in rules:
|
|
selectors = cssselect2.compile_selector_list(rule.prelude)
|
|
declarations = parse_declaration_list(rule.content)
|
|
style = {
|
|
declaration.name: css_serialize(declaration.value)
|
|
for declaration in declarations
|
|
if declaration.name in DEFAULT_STYLE
|
|
}
|
|
payload = style
|
|
for selector in selectors:
|
|
self.css_matcher.add_selector(selector, payload)
|
|
|
|
def path_string_to_mobject(self, path_string, style):
|
|
return VMobjectFromSVGPathstring(
|
|
path_string,
|
|
**self.path_string_config,
|
|
**parse_style(style, self.generate_default_style()),
|
|
)
|
|
|
|
def use_to_mobjects(self, use_element, local_style):
|
|
# Remove initial "#" character
|
|
ref = use_element.get(r"{http://www.w3.org/1999/xlink}href")[1:]
|
|
if ref not in self.ref_to_element:
|
|
log.warning(f"{ref} not recognized")
|
|
return VGroup()
|
|
def_element, def_style = self.ref_to_element[ref]
|
|
style = local_style.copy()
|
|
style.update(def_style)
|
|
return self.get_mobjects_from(def_element, style)
|
|
|
|
def attribute_to_float(self, attr):
|
|
stripped_attr = "".join([
|
|
char for char in attr
|
|
if char in string.digits + "." + "-"
|
|
])
|
|
return float(stripped_attr)
|
|
|
|
def polygon_to_mobject(self, polygon_element, style):
|
|
path_string = polygon_element.get("points")
|
|
for digit in string.digits:
|
|
path_string = path_string.replace(f" {digit}", f"L {digit}")
|
|
path_string = path_string.replace("L", "M", 1)
|
|
return self.path_string_to_mobject(path_string, style)
|
|
|
|
def circle_to_mobject(self, circle_element, style):
|
|
x, y, r = (
|
|
self.attribute_to_float(circle_element.get(key, "0.0"))
|
|
for key in ("cx", "cy", "r")
|
|
)
|
|
return Circle(
|
|
radius=r,
|
|
**parse_style(style, self.generate_default_style())
|
|
).shift(x * RIGHT + y * DOWN)
|
|
|
|
def ellipse_to_mobject(self, circle_element, style):
|
|
x, y, rx, ry = (
|
|
self.attribute_to_float(circle_element.get(key, "0.0"))
|
|
for key in ("cx", "cy", "rx", "ry")
|
|
)
|
|
result = Circle(**parse_style(style, self.generate_default_style()))
|
|
result.stretch(rx, 0)
|
|
result.stretch(ry, 1)
|
|
result.shift(x * RIGHT + y * DOWN)
|
|
return result
|
|
|
|
def rect_to_mobject(self, rect_element, style):
|
|
stroke_width = rect_element.get("stroke-width", "")
|
|
corner_radius = rect_element.get("rx", "")
|
|
|
|
if stroke_width in ["", "none", "0"]:
|
|
stroke_width = 0
|
|
|
|
if corner_radius in ["", "0", "none"]:
|
|
corner_radius = 0
|
|
|
|
corner_radius = float(corner_radius)
|
|
|
|
parsed_style = parse_style(style, self.generate_default_style())
|
|
parsed_style["stroke_width"] = stroke_width
|
|
|
|
if corner_radius == 0:
|
|
mob = Rectangle(
|
|
width=self.attribute_to_float(
|
|
rect_element.get("width", "")
|
|
),
|
|
height=self.attribute_to_float(
|
|
rect_element.get("height", "")
|
|
),
|
|
**parsed_style,
|
|
)
|
|
else:
|
|
mob = RoundedRectangle(
|
|
width=self.attribute_to_float(
|
|
rect_element.get("width", "")
|
|
),
|
|
height=self.attribute_to_float(
|
|
rect_element.get("height", "")
|
|
),
|
|
corner_radius=corner_radius,
|
|
**parsed_style
|
|
)
|
|
|
|
mob.shift(mob.get_center() - mob.get_corner(UP + LEFT))
|
|
return mob
|
|
|
|
def handle_transforms(self, element, mobject):
|
|
x, y = (
|
|
self.attribute_to_float(element.get(key, "0.0"))
|
|
for key in ("x", "y")
|
|
)
|
|
mobject.shift(x * RIGHT + y * DOWN)
|
|
|
|
transform_names = [
|
|
"matrix",
|
|
"translate", "translateX", "translateY",
|
|
"scale", "scaleX", "scaleY",
|
|
"rotate",
|
|
"skewX", "skewY"
|
|
]
|
|
transform_pattern = re.compile("|".join([x + r"[^)]*\)" for x in transform_names]))
|
|
number_pattern = re.compile(r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?")
|
|
transforms = transform_pattern.findall(element.get("transform", ""))[::-1]
|
|
|
|
for transform in transforms:
|
|
op_name, op_args = transform.split("(")
|
|
op_name = op_name.strip()
|
|
op_args = [float(x) for x in number_pattern.findall(op_args)]
|
|
|
|
if op_name == "matrix":
|
|
self._handle_matrix_transform(mobject, op_name, op_args)
|
|
elif op_name.startswith("translate"):
|
|
self._handle_translate_transform(mobject, op_name, op_args)
|
|
elif op_name.startswith("scale"):
|
|
self._handle_scale_transform(mobject, op_name, op_args)
|
|
elif op_name == "rotate":
|
|
self._handle_rotate_transform(mobject, op_name, op_args)
|
|
elif op_name.startswith("skew"):
|
|
self._handle_skew_transform(mobject, op_name, op_args)
|
|
|
|
def _handle_matrix_transform(self, mobject, op_name, op_args):
|
|
transform = np.array(op_args).reshape([3, 2])
|
|
x = transform[2][0]
|
|
y = -transform[2][1]
|
|
matrix = np.identity(self.dim)
|
|
matrix[:2, :2] = transform[:2, :]
|
|
matrix[1] *= -1
|
|
matrix[:, 1] *= -1
|
|
for mob in mobject.family_members_with_points():
|
|
mob.apply_matrix(matrix.T)
|
|
mobject.shift(x * RIGHT + y * UP)
|
|
|
|
def _handle_translate_transform(self, mobject, op_name, op_args):
|
|
if op_name.endswith("X"):
|
|
x, y = op_args[0], 0
|
|
elif op_name.endswith("Y"):
|
|
x, y = 0, op_args[0]
|
|
else:
|
|
x, y = op_args
|
|
mobject.shift(x * RIGHT + y * DOWN)
|
|
|
|
def _handle_scale_transform(self, mobject, op_name, op_args):
|
|
if op_name.endswith("X"):
|
|
sx, sy = op_args[0], 1
|
|
elif op_name.endswith("Y"):
|
|
sx, sy = 1, op_args[0]
|
|
elif len(op_args) == 2:
|
|
sx, sy = op_args
|
|
else:
|
|
sx = sy = op_args[0]
|
|
if sx < 0:
|
|
mobject.flip(UP)
|
|
sx = -sx
|
|
if sy < 0:
|
|
mobject.flip(RIGHT)
|
|
sy = -sy
|
|
mobject.scale(np.array([sx, sy, 1]), about_point=ORIGIN)
|
|
|
|
def _handle_rotate_transform(self, mobject, op_name, op_args):
|
|
if len(op_args) == 1:
|
|
mobject.rotate(op_args[0] * DEGREES, axis=IN, about_point=ORIGIN)
|
|
else:
|
|
deg, x, y = op_args
|
|
mobject.rotate(deg * DEGREES, axis=IN, about_point=np.array([x, y, 0]))
|
|
|
|
def _handle_skew_transform(self, mobject, op_name, op_args):
|
|
rad = op_args[0] * DEGREES
|
|
if op_name == "skewX":
|
|
tana = np.tan(rad)
|
|
self._handle_matrix_transform(mobject, None, [1., 0., tana, 1., 0., 0.])
|
|
elif op_name == "skewY":
|
|
tana = np.tan(rad)
|
|
self._handle_matrix_transform(mobject, None, [1., tana, 0., 1., 0., 0.])
|
|
|
|
def flatten(self, input_list):
|
|
output_list = []
|
|
for i in input_list:
|
|
if isinstance(i, list):
|
|
output_list.extend(self.flatten(i))
|
|
else:
|
|
output_list.append(i)
|
|
return output_list
|
|
|
|
def get_all_childWrappers_have_id(self, wrapper):
|
|
all_childWrappers_have_id = []
|
|
element = wrapper.etree_element
|
|
if not isinstance(element, ElementTree.Element):
|
|
return
|
|
if element.get('id'):
|
|
return [wrapper]
|
|
for e in wrapper.iter_children():
|
|
all_childWrappers_have_id.append(self.get_all_childWrappers_have_id(e))
|
|
return self.flatten([e for e in all_childWrappers_have_id if e])
|
|
|
|
def update_ref_to_element(self, wrapper, style):
|
|
new_refs = {
|
|
e.etree_element.get('id', ''): (e, style)
|
|
for e in self.get_all_childWrappers_have_id(wrapper)
|
|
}
|
|
self.ref_to_element.update(new_refs)
|
|
|
|
|
|
class VMobjectFromSVGPathstring(VMobject):
|
|
CONFIG = {
|
|
"long_lines": False,
|
|
"should_subdivide_sharp_curves": False,
|
|
"should_remove_null_curves": False,
|
|
}
|
|
|
|
def __init__(self, path_string, **kwargs):
|
|
self.path_string = path_string
|
|
super().__init__(**kwargs)
|
|
|
|
def init_points(self):
|
|
# After a given svg_path has been converted into points, the result
|
|
# will be saved to a file so that future calls for the same path
|
|
# don't need to retrace the same computation.
|
|
hasher = hashlib.sha256(self.path_string.encode())
|
|
path_hash = hasher.hexdigest()[:16]
|
|
points_filepath = os.path.join(get_mobject_data_dir(), f"{path_hash}_points.npy")
|
|
tris_filepath = os.path.join(get_mobject_data_dir(), f"{path_hash}_tris.npy")
|
|
|
|
if os.path.exists(points_filepath) and os.path.exists(tris_filepath):
|
|
self.set_points(np.load(points_filepath))
|
|
self.triangulation = np.load(tris_filepath)
|
|
self.needs_new_triangulation = False
|
|
else:
|
|
self.handle_commands()
|
|
if self.should_subdivide_sharp_curves:
|
|
# For a healthy triangulation later
|
|
self.subdivide_sharp_curves()
|
|
if self.should_remove_null_curves:
|
|
# Get rid of any null curves
|
|
self.set_points(self.get_points_without_null_curves())
|
|
# SVG treats y-coordinate differently
|
|
self.stretch(-1, 1, about_point=ORIGIN)
|
|
# Save to a file for future use
|
|
np.save(points_filepath, self.get_points())
|
|
np.save(tris_filepath, self.get_triangulation())
|
|
|
|
def get_commands_and_coord_strings(self):
|
|
all_commands = list(self.get_command_to_function_map().keys())
|
|
all_commands += [c.lower() for c in all_commands]
|
|
pattern = "[{}]".format("".join(all_commands))
|
|
return zip(
|
|
re.findall(pattern, self.path_string),
|
|
re.split(pattern, self.path_string)[1:]
|
|
)
|
|
|
|
def handle_commands(self):
|
|
relative_point = ORIGIN
|
|
for command, coord_string in self.get_commands_and_coord_strings():
|
|
func, number_types_str = self.command_to_function(command)
|
|
upper_command = command.upper()
|
|
if upper_command == "Z":
|
|
func() # `close_path` takes no arguments
|
|
relative_point = self.get_last_point()
|
|
continue
|
|
|
|
number_types = np.array(list(number_types_str))
|
|
n_numbers = len(number_types_str)
|
|
number_list = _PathStringParser(coord_string, number_types_str).args
|
|
number_groups = np.array(number_list).reshape((-1, n_numbers))
|
|
|
|
for ind, numbers in enumerate(number_groups):
|
|
if command.islower():
|
|
# Treat it as a relative command
|
|
numbers[number_types == "x"] += relative_point[0]
|
|
numbers[number_types == "y"] += relative_point[1]
|
|
|
|
if upper_command == "A":
|
|
args = [*numbers[:5], np.array([*numbers[5:7], 0.0])]
|
|
elif upper_command == "H":
|
|
args = [np.array([numbers[0], relative_point[1], 0.0])]
|
|
elif upper_command == "V":
|
|
args = [np.array([relative_point[0], numbers[0], 0.0])]
|
|
else:
|
|
args = list(np.hstack((
|
|
numbers.reshape((-1, 2)), np.zeros((n_numbers // 2, 1))
|
|
)))
|
|
if upper_command == "M" and ind != 0:
|
|
# M x1 y1 x2 y2 is equal to M x1 y1 L x2 y2
|
|
func, _ = self.command_to_function("L")
|
|
func(*args)
|
|
relative_point = self.get_last_point()
|
|
|
|
def add_elliptical_arc_to(self, rx, ry, x_axis_rotation, large_arc_flag, sweep_flag, point):
|
|
def close_to_zero(a, threshold=1e-5):
|
|
return abs(a) < threshold
|
|
|
|
def solve_2d_linear_equation(a, b, c):
|
|
"""
|
|
Using Crammer's rule to solve the linear equation `[a b]x = c`
|
|
where `a`, `b` and `c` are all 2d vectors.
|
|
"""
|
|
def det(a, b):
|
|
return a[0] * b[1] - a[1] * b[0]
|
|
d = det(a, b)
|
|
if close_to_zero(d):
|
|
raise Exception("Cannot handle 0 determinant.")
|
|
return [det(c, b) / d, det(a, c) / d]
|
|
|
|
def get_arc_center_and_angles(x0, y0, rx, ry, phi, large_arc_flag, sweep_flag, x1, y1):
|
|
"""
|
|
The parameter functions of an ellipse rotated `phi` radians counterclockwise is (on `alpha`):
|
|
x = cx + rx * cos(alpha) * cos(phi) + ry * sin(alpha) * sin(phi),
|
|
y = cy + rx * cos(alpha) * sin(phi) - ry * sin(alpha) * cos(phi).
|
|
Now we have two points sitting on the ellipse: `(x0, y0)`, `(x1, y1)`, corresponding to 4 equations,
|
|
and we want to hunt for 4 variables: `cx`, `cy`, `alpha0` and `alpha_1`.
|
|
Let `d_alpha = alpha1 - alpha0`, then:
|
|
if `sweep_flag = 0` and `large_arc_flag = 1`, then `PI <= d_alpha < 2 * PI`;
|
|
if `sweep_flag = 0` and `large_arc_flag = 0`, then `0 < d_alpha <= PI`;
|
|
if `sweep_flag = 1` and `large_arc_flag = 0`, then `-PI <= d_alpha < 0`;
|
|
if `sweep_flag = 1` and `large_arc_flag = 1`, then `-2 * PI < d_alpha <= -PI`.
|
|
"""
|
|
xd = x1 - x0
|
|
yd = y1 - y0
|
|
if close_to_zero(xd) and close_to_zero(yd):
|
|
raise Exception("Cannot find arc center since the start point and the end point meet.")
|
|
# Find `p = cos(alpha1) - cos(alpha0)`, `q = sin(alpha1) - sin(alpha0)`
|
|
eq0 = [rx * np.cos(phi), ry * np.sin(phi), xd]
|
|
eq1 = [rx * np.sin(phi), -ry * np.cos(phi), yd]
|
|
p, q = solve_2d_linear_equation(*zip(eq0, eq1))
|
|
# Find `s = (alpha1 - alpha0) / 2`, `t = (alpha1 + alpha0) / 2`
|
|
# If `sin(s) = 0`, this requires `p = q = 0`,
|
|
# implying `xd = yd = 0`, which is impossible.
|
|
sin_s = (p ** 2 + q ** 2) ** 0.5 / 2
|
|
if sweep_flag:
|
|
sin_s = -sin_s
|
|
sin_s = clip(sin_s, -1, 1)
|
|
s = np.arcsin(sin_s)
|
|
if large_arc_flag:
|
|
if not sweep_flag:
|
|
s = PI - s
|
|
else:
|
|
s = -PI - s
|
|
sin_t = -p / (2 * sin_s)
|
|
cos_t = q / (2 * sin_s)
|
|
cos_t = clip(cos_t, -1, 1)
|
|
t = np.arccos(cos_t)
|
|
if sin_t <= 0:
|
|
t = -t
|
|
# We can make sure `0 < abs(s) < PI`, `-PI <= t < PI`.
|
|
alpha0 = t - s
|
|
alpha_1 = t + s
|
|
cx = x0 - rx * np.cos(alpha0) * np.cos(phi) - ry * np.sin(alpha0) * np.sin(phi)
|
|
cy = y0 - rx * np.cos(alpha0) * np.sin(phi) + ry * np.sin(alpha0) * np.cos(phi)
|
|
return cx, cy, alpha0, alpha_1
|
|
|
|
def get_point_on_ellipse(cx, cy, rx, ry, phi, angle):
|
|
return np.array([
|
|
cx + rx * np.cos(angle) * np.cos(phi) + ry * np.sin(angle) * np.sin(phi),
|
|
cy + rx * np.cos(angle) * np.sin(phi) - ry * np.sin(angle) * np.cos(phi),
|
|
0
|
|
])
|
|
|
|
def convert_elliptical_arc_to_quadratic_bezier_curve(
|
|
cx, cy, rx, ry, phi, start_angle, end_angle, n_components=8
|
|
):
|
|
theta = (end_angle - start_angle) / n_components / 2
|
|
handles = np.array([
|
|
get_point_on_ellipse(cx, cy, rx / np.cos(theta), ry / np.cos(theta), phi, a)
|
|
for a in np.linspace(
|
|
start_angle + theta,
|
|
end_angle - theta,
|
|
n_components,
|
|
)
|
|
])
|
|
anchors = np.array([
|
|
get_point_on_ellipse(cx, cy, rx, ry, phi, a)
|
|
for a in np.linspace(
|
|
start_angle + theta * 2,
|
|
end_angle,
|
|
n_components,
|
|
)
|
|
])
|
|
return handles, anchors
|
|
|
|
phi = x_axis_rotation * DEGREES
|
|
x0, y0 = self.get_last_point()[:2]
|
|
cx, cy, start_angle, end_angle = get_arc_center_and_angles(
|
|
x0, y0, rx, ry, phi, large_arc_flag, sweep_flag, point[0], point[1]
|
|
)
|
|
handles, anchors = convert_elliptical_arc_to_quadratic_bezier_curve(
|
|
cx, cy, rx, ry, phi, start_angle, end_angle
|
|
)
|
|
for handle, anchor in zip(handles, anchors):
|
|
self.add_quadratic_bezier_curve_to(handle, anchor)
|
|
|
|
def command_to_function(self, command):
|
|
return self.get_command_to_function_map()[command.upper()]
|
|
|
|
def get_command_to_function_map(self):
|
|
"""
|
|
Associates svg command to VMobject function, and
|
|
the types of arguments it takes in
|
|
"""
|
|
return {
|
|
"M": (self.start_new_path, "xy"),
|
|
"L": (self.add_line_to, "xy"),
|
|
"H": (self.add_line_to, "x"),
|
|
"V": (self.add_line_to, "y"),
|
|
"C": (self.add_cubic_bezier_curve_to, "xyxyxy"),
|
|
"S": (self.add_smooth_cubic_curve_to, "xyxy"),
|
|
"Q": (self.add_quadratic_bezier_curve_to, "xyxy"),
|
|
"T": (self.add_smooth_curve_to, "xy"),
|
|
"A": (self.add_elliptical_arc_to, "uuaffxy"),
|
|
"Z": (self.close_path, ""),
|
|
}
|
|
|
|
def get_original_path_string(self):
|
|
return self.path_string
|
|
|
|
|
|
class InvalidPathError(ValueError):
|
|
pass
|
|
|
|
|
|
class _PathStringParser:
|
|
# modified from https://github.com/regebro/svg.path/
|
|
def __init__(self, arguments, rules):
|
|
self.args = []
|
|
arguments = bytearray(arguments, "ascii")
|
|
self._strip_array(arguments)
|
|
while arguments:
|
|
for rule in rules:
|
|
self._rule_to_function_map[rule](arguments)
|
|
|
|
@property
|
|
def _rule_to_function_map(self):
|
|
return {
|
|
"x": self._get_number,
|
|
"y": self._get_number,
|
|
"a": self._get_number,
|
|
"u": self._get_unsigned_number,
|
|
"f": self._get_flag,
|
|
}
|
|
|
|
def _strip_array(self, arg_array):
|
|
# wsp: (0x9, 0x20, 0xA, 0xC, 0xD) with comma 0x2C
|
|
# https://www.w3.org/TR/SVG/paths.html#PathDataBNF
|
|
while arg_array and arg_array[0] in [0x9, 0x20, 0xA, 0xC, 0xD, 0x2C]:
|
|
arg_array[0:1] = b""
|
|
|
|
def _get_number(self, arg_array):
|
|
pattern = re.compile(rb"^[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?")
|
|
res = pattern.search(arg_array)
|
|
if not res:
|
|
raise InvalidPathError(f"Expected a number, got '{arg_array}'")
|
|
number = float(res.group())
|
|
self.args.append(number)
|
|
arg_array[res.start():res.end()] = b""
|
|
self._strip_array(arg_array)
|
|
return number
|
|
|
|
def _get_unsigned_number(self, arg_array):
|
|
number = self._get_number(arg_array)
|
|
if number < 0:
|
|
raise InvalidPathError(f"Expected an unsigned number, got '{number}'")
|
|
return number
|
|
|
|
def _get_flag(self, arg_array):
|
|
flag = arg_array[0]
|
|
if flag != 48 and flag != 49:
|
|
raise InvalidPathError(f"Expected a flag (0/1), got '{chr(flag)}'")
|
|
flag -= 48
|
|
self.args.append(flag)
|
|
arg_array[0:1] = b""
|
|
self._strip_array(arg_array)
|
|
return flag
|