mirror of
https://github.com/3b1b/manim.git
synced 2025-08-05 16:49:03 +00:00
Merge pull request #1712 from 3b1b/fix-svg
Improve handling of SVG transform and Some refactors
This commit is contained in:
commit
07f84e2676
3 changed files with 164 additions and 68 deletions
|
@ -4,6 +4,7 @@ import random
|
||||||
import sys
|
import sys
|
||||||
import moderngl
|
import moderngl
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from collections import Iterable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -596,7 +597,10 @@ class Mobject(object):
|
||||||
Otherwise, if about_point is given a value, scaling is done with
|
Otherwise, if about_point is given a value, scaling is done with
|
||||||
respect to that point.
|
respect to that point.
|
||||||
"""
|
"""
|
||||||
scale_factor = max(scale_factor, min_scale_factor)
|
if isinstance(scale_factor, Iterable):
|
||||||
|
scale_factor = np.array(scale_factor).clip(min=min_scale_factor)
|
||||||
|
else:
|
||||||
|
scale_factor = max(scale_factor, min_scale_factor)
|
||||||
self.apply_points_function(
|
self.apply_points_function(
|
||||||
lambda points: scale_factor * points,
|
lambda points: scale_factor * points,
|
||||||
about_point=about_point,
|
about_point=about_point,
|
||||||
|
|
|
@ -1,14 +1,13 @@
|
||||||
import itertools as it
|
import itertools as it
|
||||||
import re
|
import re
|
||||||
import string
|
import string
|
||||||
import warnings
|
|
||||||
import os
|
import os
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
from xml.dom import minidom
|
from xml.dom import minidom
|
||||||
|
|
||||||
from manimlib.constants import DEFAULT_STROKE_WIDTH
|
from manimlib.constants import DEFAULT_STROKE_WIDTH
|
||||||
from manimlib.constants import ORIGIN, UP, DOWN, LEFT, RIGHT
|
from manimlib.constants import ORIGIN, UP, DOWN, LEFT, RIGHT, IN
|
||||||
from manimlib.constants import BLACK
|
from manimlib.constants import BLACK
|
||||||
from manimlib.constants import WHITE
|
from manimlib.constants import WHITE
|
||||||
from manimlib.constants import DEGREES, PI
|
from manimlib.constants import DEGREES, PI
|
||||||
|
@ -23,6 +22,7 @@ from manimlib.utils.config_ops import digest_config
|
||||||
from manimlib.utils.directories import get_mobject_data_dir
|
from manimlib.utils.directories import get_mobject_data_dir
|
||||||
from manimlib.utils.images import get_full_vector_image_path
|
from manimlib.utils.images import get_full_vector_image_path
|
||||||
from manimlib.utils.simple_functions import clip
|
from manimlib.utils.simple_functions import clip
|
||||||
|
from manimlib.logger import log
|
||||||
|
|
||||||
|
|
||||||
def string_to_numbers(num_string):
|
def string_to_numbers(num_string):
|
||||||
|
@ -71,8 +71,10 @@ class SVGMobject(VMobject):
|
||||||
doc = minidom.parse(self.file_path)
|
doc = minidom.parse(self.file_path)
|
||||||
self.ref_to_element = {}
|
self.ref_to_element = {}
|
||||||
|
|
||||||
for svg in doc.getElementsByTagName("svg"):
|
for child in doc.childNodes:
|
||||||
mobjects = self.get_mobjects_from(svg)
|
if not isinstance(child, minidom.Element): continue
|
||||||
|
if child.tagName != 'svg': continue
|
||||||
|
mobjects = self.get_mobjects_from(child)
|
||||||
if self.unpack_groups:
|
if self.unpack_groups:
|
||||||
self.add(*mobjects)
|
self.add(*mobjects)
|
||||||
else:
|
else:
|
||||||
|
@ -107,8 +109,8 @@ class SVGMobject(VMobject):
|
||||||
elif element.tagName in ['polygon', 'polyline']:
|
elif element.tagName in ['polygon', 'polyline']:
|
||||||
result.append(self.polygon_to_mobject(element))
|
result.append(self.polygon_to_mobject(element))
|
||||||
else:
|
else:
|
||||||
|
log.warning(f"Unsupported element type: {element.tagName}")
|
||||||
pass # TODO
|
pass # TODO
|
||||||
# warnings.warn("Unknown element type: " + element.tagName)
|
|
||||||
result = [m for m in result if m is not None]
|
result = [m for m in result if m is not None]
|
||||||
self.handle_transforms(element, VGroup(*result))
|
self.handle_transforms(element, VGroup(*result))
|
||||||
if len(result) > 1 and not self.unpack_groups:
|
if len(result) > 1 and not self.unpack_groups:
|
||||||
|
@ -131,7 +133,7 @@ class SVGMobject(VMobject):
|
||||||
# Remove initial "#" character
|
# Remove initial "#" character
|
||||||
ref = use_element.getAttribute("xlink:href")[1:]
|
ref = use_element.getAttribute("xlink:href")[1:]
|
||||||
if ref not in self.ref_to_element:
|
if ref not in self.ref_to_element:
|
||||||
warnings.warn(f"{ref} not recognized")
|
log.warning(f"{ref} not recognized")
|
||||||
return VGroup()
|
return VGroup()
|
||||||
return self.get_mobjects_from(
|
return self.get_mobjects_from(
|
||||||
self.ref_to_element[ref]
|
self.ref_to_element[ref]
|
||||||
|
@ -227,7 +229,7 @@ class SVGMobject(VMobject):
|
||||||
stroke_width=stroke_width,
|
stroke_width=stroke_width,
|
||||||
stroke_color=stroke_color,
|
stroke_color=stroke_color,
|
||||||
fill_color=fill_color,
|
fill_color=fill_color,
|
||||||
fill_opacity=opacity,
|
fill_opacity=fill_opacity,
|
||||||
corner_radius=corner_radius
|
corner_radius=corner_radius
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -235,66 +237,94 @@ class SVGMobject(VMobject):
|
||||||
return mob
|
return mob
|
||||||
|
|
||||||
def handle_transforms(self, element, mobject):
|
def handle_transforms(self, element, mobject):
|
||||||
# TODO, this could use some cleaning...
|
x, y = (
|
||||||
x, y = 0, 0
|
self.attribute_to_float(element.getAttribute(key))
|
||||||
try:
|
if element.hasAttribute(key)
|
||||||
x = self.attribute_to_float(element.getAttribute('x'))
|
else 0.0
|
||||||
# Flip y
|
for key in ("x", "y")
|
||||||
y = -self.attribute_to_float(element.getAttribute('y'))
|
)
|
||||||
mobject.shift([x, y, 0])
|
mobject.shift(x * RIGHT + y * DOWN)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
transform = element.getAttribute('transform')
|
transform_names = [
|
||||||
|
"matrix",
|
||||||
|
"translate", "translateX", "translateY",
|
||||||
|
"scale", "scaleX", "scaleY",
|
||||||
|
"rotate",
|
||||||
|
"skewX", "skewY"
|
||||||
|
]
|
||||||
|
transform_pattern = re.compile("|".join([x + r"[^)]*\)" for x in transform_names]))
|
||||||
|
number_pattern = re.compile(r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?")
|
||||||
|
transforms = transform_pattern.findall(element.getAttribute('transform'))[::-1]
|
||||||
|
|
||||||
try: # transform matrix
|
for transform in transforms:
|
||||||
prefix = "matrix("
|
op_name, op_args = transform.split("(")
|
||||||
suffix = ")"
|
op_name = op_name.strip()
|
||||||
if not transform.startswith(prefix) or not transform.endswith(suffix):
|
op_args = [float(x) for x in number_pattern.findall(op_args)]
|
||||||
raise Exception()
|
|
||||||
transform = transform[len(prefix):-len(suffix)]
|
if op_name == "matrix":
|
||||||
transform = string_to_numbers(transform)
|
self._handle_matrix_transform(mobject, op_name, op_args)
|
||||||
transform = np.array(transform).reshape([3, 2])
|
elif op_name.startswith("translate"):
|
||||||
x = transform[2][0]
|
self._handle_translate_transform(mobject, op_name, op_args)
|
||||||
y = -transform[2][1]
|
elif op_name.startswith("scale"):
|
||||||
matrix = np.identity(self.dim)
|
self._handle_scale_transform(mobject, op_name, op_args)
|
||||||
matrix[:2, :2] = transform[:2, :]
|
elif op_name == "rotate":
|
||||||
matrix[1] *= -1
|
self._handle_rotate_transform(mobject, op_name, op_args)
|
||||||
matrix[:, 1] *= -1
|
elif op_name.startswith("skew"):
|
||||||
|
self._handle_skew_transform(mobject, op_name, op_args)
|
||||||
|
|
||||||
|
def _handle_matrix_transform(self, mobject, op_name, op_args):
|
||||||
|
transform = np.array(op_args).reshape([3, 2])
|
||||||
|
x = transform[2][0]
|
||||||
|
y = -transform[2][1]
|
||||||
|
matrix = np.identity(self.dim)
|
||||||
|
matrix[:2, :2] = transform[:2, :]
|
||||||
|
matrix[1] *= -1
|
||||||
|
matrix[:, 1] *= -1
|
||||||
|
for mob in mobject.family_members_with_points():
|
||||||
|
mob.apply_matrix(matrix.T)
|
||||||
|
mobject.shift(x * RIGHT + y * UP)
|
||||||
|
|
||||||
for mob in mobject.family_members_with_points():
|
def _handle_translate_transform(self, mobject, op_name, op_args):
|
||||||
mob.apply_matrix(matrix.T)
|
if op_name.endswith("X"):
|
||||||
mobject.shift(x * RIGHT + y * UP)
|
x, y = op_args[0], 0
|
||||||
except:
|
elif op_name.endswith("Y"):
|
||||||
pass
|
x, y = 0, op_args[0]
|
||||||
|
else:
|
||||||
try: # transform scale
|
x, y = op_args
|
||||||
prefix = "scale("
|
mobject.shift(x * RIGHT + y * DOWN)
|
||||||
suffix = ")"
|
|
||||||
if not transform.startswith(prefix) or not transform.endswith(suffix):
|
def _handle_scale_transform(self, mobject, op_name, op_args):
|
||||||
raise Exception()
|
if op_name.endswith("X"):
|
||||||
transform = transform[len(prefix):-len(suffix)]
|
sx, sy = op_args[0], 1
|
||||||
scale_values = string_to_numbers(transform)
|
elif op_name.endswith("Y"):
|
||||||
if len(scale_values) == 2:
|
sx, sy = 1, op_args[0]
|
||||||
scale_x, scale_y = scale_values
|
elif len(op_args) == 2:
|
||||||
mobject.scale(np.array([scale_x, scale_y, 1]), about_point=ORIGIN)
|
sx, sy = op_args
|
||||||
elif len(scale_values) == 1:
|
else:
|
||||||
scale = scale_values[0]
|
sx = sy = op_args[0]
|
||||||
mobject.scale(np.array([scale, scale, 1]), about_point=ORIGIN)
|
if sx < 0:
|
||||||
except:
|
mobject.flip(UP)
|
||||||
pass
|
sx = -sx
|
||||||
|
if sy < 0:
|
||||||
try: # transform translate
|
mobject.flip(RIGHT)
|
||||||
prefix = "translate("
|
sy = -sy
|
||||||
suffix = ")"
|
mobject.scale(np.array([sx, sy, 1]), about_point=ORIGIN)
|
||||||
if not transform.startswith(prefix) or not transform.endswith(suffix):
|
|
||||||
raise Exception()
|
def _handle_rotate_transform(self, mobject, op_name, op_args):
|
||||||
transform = transform[len(prefix):-len(suffix)]
|
if len(op_args) == 1:
|
||||||
x, y = string_to_numbers(transform)
|
mobject.rotate(op_args[0] * DEGREES, axis=IN, about_point=ORIGIN)
|
||||||
mobject.shift(x * RIGHT + y * DOWN)
|
else:
|
||||||
except:
|
deg, x, y = op_args
|
||||||
pass
|
mobject.rotate(deg * DEGREES, axis=IN, about_point=np.array([x, y, 0]))
|
||||||
# TODO, ...
|
|
||||||
|
def _handle_skew_transform(self, mobject, op_name, op_args):
|
||||||
|
rad = op_args[0] * DEGREES
|
||||||
|
if op_name == "skewX":
|
||||||
|
tana = np.tan(rad)
|
||||||
|
self._handle_matrix_transform(mobject, None, [1., 0., tana, 1., 0., 0.])
|
||||||
|
elif op_name == "skewY":
|
||||||
|
tana = np.tan(rad)
|
||||||
|
self._handle_matrix_transform(mobject, None, [1., tana, 0., 1., 0., 0.])
|
||||||
|
|
||||||
def flatten(self, input_list):
|
def flatten(self, input_list):
|
||||||
output_list = []
|
output_list = []
|
||||||
|
@ -378,7 +408,8 @@ class VMobjectFromSVGPathstring(VMobject):
|
||||||
|
|
||||||
number_types = np.array(list(number_types_str))
|
number_types = np.array(list(number_types_str))
|
||||||
n_numbers = len(number_types_str)
|
n_numbers = len(number_types_str)
|
||||||
number_groups = np.array(string_to_numbers(coord_string)).reshape((-1, n_numbers))
|
number_list = _PathStringParser(coord_string, number_types_str).args
|
||||||
|
number_groups = np.array(number_list).reshape((-1, n_numbers))
|
||||||
|
|
||||||
for numbers in number_groups:
|
for numbers in number_groups:
|
||||||
if command.islower():
|
if command.islower():
|
||||||
|
@ -520,9 +551,67 @@ class VMobjectFromSVGPathstring(VMobject):
|
||||||
"S": (self.add_smooth_cubic_curve_to, "xyxy"),
|
"S": (self.add_smooth_cubic_curve_to, "xyxy"),
|
||||||
"Q": (self.add_quadratic_bezier_curve_to, "xyxy"),
|
"Q": (self.add_quadratic_bezier_curve_to, "xyxy"),
|
||||||
"T": (self.add_smooth_curve_to, "xy"),
|
"T": (self.add_smooth_curve_to, "xy"),
|
||||||
"A": (self.add_elliptical_arc_to, "-----xy"),
|
"A": (self.add_elliptical_arc_to, "uuaffxy"),
|
||||||
"Z": (self.close_path, ""),
|
"Z": (self.close_path, ""),
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_original_path_string(self):
|
def get_original_path_string(self):
|
||||||
return self.path_string
|
return self.path_string
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidPathError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _PathStringParser:
|
||||||
|
# modified from https://github.com/regebro/svg.path/
|
||||||
|
def __init__(self, arguments, rules):
|
||||||
|
self.args = []
|
||||||
|
arguments = bytearray(arguments, "ascii")
|
||||||
|
self._strip_array(arguments)
|
||||||
|
while arguments:
|
||||||
|
for rule in rules:
|
||||||
|
self._rule_to_function_map[rule](arguments)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _rule_to_function_map(self):
|
||||||
|
return {
|
||||||
|
"x": self._get_number,
|
||||||
|
"y": self._get_number,
|
||||||
|
"a": self._get_number,
|
||||||
|
"u": self._get_unsigned_number,
|
||||||
|
"f": self._get_flag,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _strip_array(self, arg_array):
|
||||||
|
# wsp: (0x9, 0x20, 0xA, 0xC, 0xD) with comma 0x2C
|
||||||
|
# https://www.w3.org/TR/SVG/paths.html#PathDataBNF
|
||||||
|
while arg_array and arg_array[0] in [0x9, 0x20, 0xA, 0xC, 0xD, 0x2C]:
|
||||||
|
arg_array[0:1] = b""
|
||||||
|
|
||||||
|
def _get_number(self, arg_array):
|
||||||
|
pattern = re.compile(rb"^[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?")
|
||||||
|
res = pattern.search(arg_array)
|
||||||
|
if not res:
|
||||||
|
raise InvalidPathError(f"Expected a number, got '{arg_array}'")
|
||||||
|
number = float(res.group())
|
||||||
|
self.args.append(number)
|
||||||
|
arg_array[res.start():res.end()] = b""
|
||||||
|
self._strip_array(arg_array)
|
||||||
|
return number
|
||||||
|
|
||||||
|
def _get_unsigned_number(self, arg_array):
|
||||||
|
number = self._get_number(arg_array)
|
||||||
|
if number < 0:
|
||||||
|
raise InvalidPathError(f"Expected an unsigned number, got '{number}'")
|
||||||
|
return number
|
||||||
|
|
||||||
|
def _get_flag(self, arg_array):
|
||||||
|
flag = arg_array[0]
|
||||||
|
if flag != 48 and flag != 49:
|
||||||
|
raise InvalidPathError(f"Expected a flag (0/1), got '{chr(flag)}'")
|
||||||
|
flag -= 48
|
||||||
|
self.args.append(flag)
|
||||||
|
arg_array[0:1] = b""
|
||||||
|
self._strip_array(arg_array)
|
||||||
|
return flag
|
||||||
|
|
|
@ -382,7 +382,10 @@ class VMobject(Mobject):
|
||||||
|
|
||||||
def add_smooth_cubic_curve_to(self, handle, point):
|
def add_smooth_cubic_curve_to(self, handle, point):
|
||||||
self.throw_error_if_no_points()
|
self.throw_error_if_no_points()
|
||||||
new_handle = self.get_reflection_of_last_handle()
|
if self.get_num_points() == 1:
|
||||||
|
new_handle = self.get_points()[-1]
|
||||||
|
else:
|
||||||
|
new_handle = self.get_reflection_of_last_handle()
|
||||||
self.add_cubic_bezier_curve_to(new_handle, handle, point)
|
self.add_cubic_bezier_curve_to(new_handle, handle, point)
|
||||||
|
|
||||||
def has_new_path_started(self):
|
def has_new_path_started(self):
|
||||||
|
|
Loading…
Add table
Reference in a new issue