3b1b-manim/manimlib/mobject/svg/svg_mobject.py

456 lines
16 KiB
Python
Raw Normal View History

import itertools as it
import re
import string
import warnings
2020-02-18 22:31:29 -08:00
import os
import hashlib
from xml.dom import minidom
2016-04-17 00:31:38 -07:00
2020-02-18 22:31:29 -08:00
from manimlib.constants import DEFAULT_STROKE_WIDTH
from manimlib.constants import ORIGIN, UP, DOWN, LEFT, RIGHT
from manimlib.constants import BLACK
from manimlib.constants import WHITE
from manimlib.mobject.geometry import Circle
from manimlib.mobject.geometry import Rectangle
from manimlib.mobject.geometry import RoundedRectangle
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.color import *
from manimlib.utils.config_ops import digest_config
from manimlib.utils.directories import get_mobject_data_dir
from manimlib.utils.images import get_full_vector_image_path
2016-04-17 00:31:38 -07:00
def check_and_fix_percent_bug(sym):
# This is an ugly patch addressing something which should be
# addressed at a deeper level.
# The svg path for percent symbols have a known bug, so this
# checks if the symbol is (probably) a percentage sign, and
# splits it so that it's displayed properly.
if len(sym.get_points()) not in [315, 324, 372, 468, 483] or len(sym.get_subpaths()) != 4:
return
sym = sym.family_members_with_points()[0]
new_sym = VMobject()
path_lengths = [len(path) for path in sym.get_subpaths()]
2021-01-10 18:51:47 -08:00
sym_points = sym.get_points()
if len(sym_points) in [315, 324, 372]:
n = sum(path_lengths[:2])
2021-01-10 18:51:47 -08:00
p1 = sym_points[:n]
p2 = sym_points[n:]
elif len(sym_points) in [468, 483]:
p1 = np.vstack([
2021-01-10 18:51:47 -08:00
sym_points[:path_lengths[0]],
sym_points[-path_lengths[3]:]
])
2021-01-10 18:51:47 -08:00
p2 = sym_points[path_lengths[0]:sum(path_lengths[:3])]
sym.set_points(p1)
new_sym.set_points(p2)
sym.add(new_sym)
sym.refresh_triangulation()
def string_to_numbers(num_string):
num_string = num_string.replace("-", ",-")
num_string = num_string.replace("e,-", "e-")
return [
float(s)
for s in re.split("[ ,]", num_string)
if s != ""
]
2016-04-17 00:31:38 -07:00
class SVGMobject(VMobject):
2016-04-23 23:36:05 -07:00
CONFIG = {
"should_center": True,
"height": 2,
"width": None,
# Must be filled in in a subclass, or when called
"file_name": None,
"unpack_groups": True, # if False, creates a hierarchy of VGroups
2020-02-06 10:02:42 -08:00
# TODO, style components should be read in, not defaulted
"stroke_width": DEFAULT_STROKE_WIDTH,
"fill_opacity": 1.0,
"path_string_config": {}
2016-04-23 23:36:05 -07:00
}
2018-09-27 17:37:25 -07:00
def __init__(self, file_name=None, **kwargs):
digest_config(self, kwargs)
self.file_name = file_name or self.file_name
2020-02-06 10:02:42 -08:00
if file_name is None:
raise Exception("Must specify file for SVGMobject")
self.file_path = get_full_vector_image_path(file_name)
super().__init__(**kwargs)
self.move_into_position()
2016-07-12 10:34:35 -07:00
2020-06-26 21:53:26 -07:00
def move_into_position(self):
if self.should_center:
self.center()
if self.height is not None:
self.set_height(self.height)
if self.width is not None:
self.set_width(self.width)
2020-02-11 19:55:00 -08:00
def init_points(self):
2016-12-07 18:37:56 -08:00
doc = minidom.parse(self.file_path)
2016-04-17 19:29:27 -07:00
self.ref_to_element = {}
2020-02-06 10:02:42 -08:00
2016-04-17 19:29:27 -07:00
for svg in doc.getElementsByTagName("svg"):
mobjects = self.get_mobjects_from(svg)
if self.unpack_groups:
self.add(*mobjects)
else:
self.add(*mobjects[0].submobjects)
2016-04-17 00:31:38 -07:00
doc.unlink()
2016-04-17 19:29:27 -07:00
def get_mobjects_from(self, element):
result = []
if not isinstance(element, minidom.Element):
return result
if element.tagName == 'defs':
self.update_ref_to_element(element)
elif element.tagName == 'style':
pass # TODO, handle style
2019-08-05 22:53:15 +08:00
elif element.tagName in ['g', 'svg', 'symbol']:
2016-04-17 19:29:27 -07:00
result += it.chain(*[
self.get_mobjects_from(child)
for child in element.childNodes
])
elif element.tagName == 'path':
2016-04-20 19:24:54 -07:00
result.append(self.path_string_to_mobject(
element.getAttribute('d')
))
2016-04-17 19:29:27 -07:00
elif element.tagName == 'use':
result += self.use_to_mobjects(element)
elif element.tagName == 'rect':
result.append(self.rect_to_mobject(element))
elif element.tagName == 'circle':
result.append(self.circle_to_mobject(element))
2018-01-28 14:55:17 +01:00
elif element.tagName == 'ellipse':
result.append(self.ellipse_to_mobject(element))
elif element.tagName in ['polygon', 'polyline']:
2017-01-17 17:14:32 -08:00
result.append(self.polygon_to_mobject(element))
2016-04-17 19:29:27 -07:00
else:
pass # TODO
2017-06-05 12:47:03 -07:00
# warnings.warn("Unknown element type: " + element.tagName)
2018-08-09 17:56:05 -07:00
result = [m for m in result if m is not None]
2019-02-06 21:16:26 -08:00
self.handle_transforms(element, VGroup(*result))
if len(result) > 1 and not self.unpack_groups:
result = [VGroup(*result)]
2016-04-17 19:29:27 -07:00
return result
def g_to_mobjects(self, g_element):
2019-02-06 21:16:26 -08:00
mob = VGroup(*self.get_mobjects_from(g_element))
2016-04-17 19:29:27 -07:00
self.handle_transforms(g_element, mob)
return mob.submobjects
2016-04-20 19:24:54 -07:00
def path_string_to_mobject(self, path_string):
return VMobjectFromSVGPathstring(
path_string,
**self.path_string_config,
)
2016-04-17 19:29:27 -07:00
def use_to_mobjects(self, use_element):
# Remove initial "#" character
2016-04-17 00:31:38 -07:00
ref = use_element.getAttribute("xlink:href")[1:]
2016-04-20 19:24:54 -07:00
if ref not in self.ref_to_element:
2020-02-06 10:02:42 -08:00
warnings.warn(f"{ref} not recognized")
2019-02-06 21:16:26 -08:00
return VGroup()
2016-04-20 19:24:54 -07:00
return self.get_mobjects_from(
self.ref_to_element[ref]
)
2016-04-17 19:29:27 -07:00
2018-09-04 16:14:11 -07:00
def attribute_to_float(self, attr):
stripped_attr = "".join([
char for char in attr
if char in string.digits + "." + "-"
])
return float(stripped_attr)
2017-01-17 17:14:32 -08:00
def polygon_to_mobject(self, polygon_element):
path_string = polygon_element.getAttribute("points")
for digit in string.digits:
2020-08-12 13:24:16 -07:00
path_string = path_string.replace(f" {digit}", f"L {digit}")
path_string = path_string.replace("L", "M", 1)
2017-01-17 17:14:32 -08:00
return self.path_string_to_mobject(path_string)
2016-04-17 00:31:38 -07:00
def circle_to_mobject(self, circle_element):
2016-04-17 19:29:27 -07:00
x, y, r = [
2018-09-04 16:14:11 -07:00
self.attribute_to_float(
circle_element.getAttribute(key)
)
2016-04-17 19:29:27 -07:00
if circle_element.hasAttribute(key)
else 0.0
for key in ("cx", "cy", "r")
2016-04-17 19:29:27 -07:00
]
return Circle(radius=r).shift(x * RIGHT + y * DOWN)
2016-04-17 00:31:38 -07:00
2018-01-28 14:55:17 +01:00
def ellipse_to_mobject(self, circle_element):
x, y, rx, ry = [
2018-09-04 16:14:11 -07:00
self.attribute_to_float(
circle_element.getAttribute(key)
)
2018-01-28 14:55:17 +01:00
if circle_element.hasAttribute(key)
else 0.0
for key in ("cx", "cy", "rx", "ry")
2018-01-28 14:55:17 +01:00
]
return Circle().scale(rx * RIGHT + ry * UP).shift(x * RIGHT + y * DOWN)
2018-01-28 14:55:17 +01:00
2016-04-17 00:31:38 -07:00
def rect_to_mobject(self, rect_element):
fill_color = rect_element.getAttribute("fill")
stroke_color = rect_element.getAttribute("stroke")
stroke_width = rect_element.getAttribute("stroke-width")
corner_radius = rect_element.getAttribute("rx")
# input preprocessing
if fill_color in ["", "none", "#FFF", "#FFFFFF"] or Color(fill_color) == Color(WHITE):
opacity = 0
2018-04-12 21:56:28 -07:00
fill_color = BLACK # shdn't be necessary but avoids error msgs
if fill_color in ["#000", "#000000"]:
fill_color = WHITE
if stroke_color in ["", "none", "#FFF", "#FFFFFF"] or Color(stroke_color) == Color(WHITE):
stroke_width = 0
stroke_color = BLACK
if stroke_color in ["#000", "#000000"]:
stroke_color = WHITE
if stroke_width in ["", "none", "0"]:
stroke_width = 0
if corner_radius in ["", "0", "none"]:
corner_radius = 0
corner_radius = float(corner_radius)
if corner_radius == 0:
mob = Rectangle(
2018-09-04 16:14:11 -07:00
width=self.attribute_to_float(
rect_element.getAttribute("width")
),
height=self.attribute_to_float(
rect_element.getAttribute("height")
),
2018-04-12 21:56:28 -07:00
stroke_width=stroke_width,
stroke_color=stroke_color,
fill_color=fill_color,
fill_opacity=opacity
)
else:
mob = RoundedRectangle(
2018-09-04 16:14:11 -07:00
width=self.attribute_to_float(
rect_element.getAttribute("width")
),
height=self.attribute_to_float(
rect_element.getAttribute("height")
),
2018-04-12 21:56:28 -07:00
stroke_width=stroke_width,
stroke_color=stroke_color,
fill_color=fill_color,
fill_opacity=opacity,
corner_radius=corner_radius
)
mob.shift(mob.get_center() - mob.get_corner(UP + LEFT))
2016-04-17 00:31:38 -07:00
return mob
2016-04-17 19:29:27 -07:00
def handle_transforms(self, element, mobject):
2020-02-18 22:31:29 -08:00
# TODO, this could use some cleaning...
2016-04-17 00:31:38 -07:00
x, y = 0, 0
2016-04-17 19:29:27 -07:00
try:
2018-09-04 16:14:11 -07:00
x = self.attribute_to_float(element.getAttribute('x'))
# Flip y
2018-09-04 16:14:11 -07:00
y = -self.attribute_to_float(element.getAttribute('y'))
2020-02-18 22:31:29 -08:00
mobject.shift([x, y, 0])
except Exception:
2016-04-17 19:29:27 -07:00
pass
transform = element.getAttribute('transform')
try: # transform matrix
prefix = "matrix("
suffix = ")"
if not transform.startswith(prefix) or not transform.endswith(suffix):
raise Exception()
transform = transform[len(prefix):-len(suffix)]
transform = string_to_numbers(transform)
transform = np.array(transform).reshape([3, 2])
x = transform[2][0]
y = -transform[2][1]
matrix = np.identity(self.dim)
matrix[:2, :2] = transform[:2, :]
matrix[1] *= -1
matrix[:, 1] *= -1
for mob in mobject.family_members_with_points():
2021-01-11 16:37:01 -10:00
mob.apply_matrix(matrix.T)
mobject.shift(x * RIGHT + y * UP)
except:
pass
try: # transform scale
prefix = "scale("
suffix = ")"
if not transform.startswith(prefix) or not transform.endswith(suffix):
raise Exception()
transform = transform[len(prefix):-len(suffix)]
scale_values = string_to_numbers(transform)
if len(scale_values) == 2:
scale_x, scale_y = scale_values
2018-12-04 18:08:52 -08:00
mobject.scale(np.array([scale_x, scale_y, 1]), about_point=ORIGIN)
elif len(scale_values) == 1:
scale = scale_values[0]
2018-12-04 18:08:52 -08:00
mobject.scale(np.array([scale, scale, 1]), about_point=ORIGIN)
except:
pass
try: # transform translate
prefix = "translate("
suffix = ")"
if not transform.startswith(prefix) or not transform.endswith(suffix):
raise Exception()
transform = transform[len(prefix):-len(suffix)]
x, y = string_to_numbers(transform)
mobject.shift(x * RIGHT + y * DOWN)
except:
pass
# TODO, ...
2016-04-17 00:31:38 -07:00
2019-08-05 22:53:15 +08:00
def flatten(self, input_list):
output_list = []
2019-08-16 15:53:36 +08:00
for i in input_list:
if isinstance(i, list):
output_list.extend(self.flatten(i))
else:
output_list.append(i)
2019-08-05 22:53:15 +08:00
return output_list
def get_all_childNodes_have_id(self, element):
all_childNodes_have_id = []
if not isinstance(element, minidom.Element):
return
if element.hasAttribute('id'):
return [element]
2019-08-05 22:53:15 +08:00
for e in element.childNodes:
all_childNodes_have_id.append(self.get_all_childNodes_have_id(e))
return self.flatten([e for e in all_childNodes_have_id if e])
2016-04-17 19:29:27 -07:00
def update_ref_to_element(self, defs):
2019-08-05 22:53:15 +08:00
new_refs = dict([(e.getAttribute('id'), e) for e in self.get_all_childNodes_have_id(defs)])
2016-04-17 19:29:27 -07:00
self.ref_to_element.update(new_refs)
2016-04-17 00:31:38 -07:00
2016-04-17 00:31:38 -07:00
class VMobjectFromSVGPathstring(VMobject):
2020-02-20 15:51:04 -08:00
CONFIG = {
"long_lines": True,
"should_subdivide_sharp_curves": False,
"should_remove_null_curves": False,
2020-02-20 15:51:04 -08:00
}
2016-04-17 00:31:38 -07:00
def __init__(self, path_string, **kwargs):
self.path_string = path_string
super().__init__(**kwargs)
2016-04-17 00:31:38 -07:00
2020-02-11 19:55:00 -08:00
def init_points(self):
2020-02-18 22:31:29 -08:00
# TODO, move this caching operation
# higher up to Mobject somehow.
hasher = hashlib.sha256(self.path_string.encode())
2020-02-18 22:31:29 -08:00
path_hash = hasher.hexdigest()[:16]
2021-01-11 16:37:01 -10:00
points_filepath = os.path.join(get_mobject_data_dir(), f"{path_hash}_points.npy")
tris_filepath = os.path.join(get_mobject_data_dir(), f"{path_hash}_tris.npy")
2020-02-18 22:31:29 -08:00
2021-01-11 16:37:01 -10:00
if os.path.exists(points_filepath) and os.path.exists(tris_filepath):
self.set_points(np.load(points_filepath))
self.triangulation = np.load(tris_filepath)
self.needs_new_triangulation = False
2020-02-18 22:31:29 -08:00
else:
self.relative_point = np.array(ORIGIN)
for command, coord_string in self.get_commands_and_coord_strings():
new_points = self.string_to_points(command, coord_string)
self.handle_command(command, new_points)
if self.should_subdivide_sharp_curves:
2020-02-20 15:51:04 -08:00
# For a healthy triangulation later
self.subdivide_sharp_curves()
if self.should_remove_null_curves:
# Get rid of any null curves
2021-01-10 18:51:47 -08:00
self.set_points(self.get_points_without_null_curves())
2020-02-18 22:31:29 -08:00
# SVG treats y-coordinate differently
self.stretch(-1, 1, about_point=ORIGIN)
# Save to a file for future use
2021-01-11 16:37:01 -10:00
np.save(points_filepath, self.get_points())
np.save(tris_filepath, self.get_triangulation())
check_and_fix_percent_bug(self)
2020-02-06 10:02:42 -08:00
def get_commands_and_coord_strings(self):
all_commands = list(self.get_command_to_function_map().keys())
all_commands += [c.lower() for c in all_commands]
pattern = "[{}]".format("".join(all_commands))
return zip(
2016-04-17 00:31:38 -07:00
re.findall(pattern, self.path_string),
re.split(pattern, self.path_string)[1:]
2020-02-06 10:02:42 -08:00
)
2020-02-06 10:02:42 -08:00
def handle_command(self, command, new_points):
if command.islower():
# Treat it as a relative command
2020-02-06 10:02:42 -08:00
new_points += self.relative_point
2020-02-06 10:02:42 -08:00
func, n_points = self.command_to_function(command)
func(*new_points[:n_points])
leftover_points = new_points[n_points:]
2020-02-06 10:02:42 -08:00
# Recursively handle the rest of the points
if len(leftover_points) > 0:
if command.upper() == "M":
# Treat following points as relative line coordinates
command = "l"
2020-02-18 22:31:29 -08:00
if command.islower():
leftover_points -= self.relative_point
2021-01-10 18:51:47 -08:00
self.relative_point = self.get_last_point()
2020-02-06 10:02:42 -08:00
self.handle_command(command, leftover_points)
else:
# Command is over, reset for future relative commands
2021-01-10 18:51:47 -08:00
self.relative_point = self.get_last_point()
2016-04-17 00:31:38 -07:00
2020-02-06 10:02:42 -08:00
def string_to_points(self, command, coord_string):
numbers = string_to_numbers(coord_string)
2020-02-06 10:02:42 -08:00
if command.upper() in ["H", "V"]:
i = {"H": 0, "V": 1}[command.upper()]
xy = np.zeros((len(numbers), 2))
xy[:, i] = numbers
if command.isupper():
xy[:, 1 - i] = self.relative_point[1 - i]
elif command.upper() == "A":
raise Exception("Not implemented")
else:
xy = np.array(numbers).reshape((len(numbers) // 2, 2))
result = np.zeros((xy.shape[0], self.dim))
result[:, :2] = xy
2016-04-17 00:31:38 -07:00
return result
2020-02-06 10:02:42 -08:00
def command_to_function(self, command):
return self.get_command_to_function_map()[command.upper()]
def get_command_to_function_map(self):
"""
Associates svg command to VMobject function, and
the number of arguments it takes in
"""
return {
"M": (self.start_new_path, 1),
"L": (self.add_line_to, 1),
"H": (self.add_line_to, 1),
"V": (self.add_line_to, 1),
"C": (self.add_cubic_bezier_curve_to, 3),
"S": (self.add_smooth_cubic_curve_to, 2),
"Q": (self.add_quadratic_bezier_curve_to, 2),
"T": (self.add_smooth_curve_to, 1),
"A": (self.add_quadratic_bezier_curve_to, 2), # TODO
2020-02-07 09:35:57 -08:00
"Z": (self.close_path, 0),
2020-02-06 10:02:42 -08:00
}
2016-04-17 00:31:38 -07:00
def get_original_path_string(self):
return self.path_string