3b1b-manim/manimlib/mobject/svg/svg_mobject.py

399 lines
14 KiB
Python
Raw Normal View History

import itertools as it
import re
import string
import warnings
from xml.dom import minidom
2016-04-17 00:31:38 -07:00
from manimlib.constants import *
from manimlib.mobject.geometry import Circle
from manimlib.mobject.geometry import Rectangle
from manimlib.mobject.geometry import RoundedRectangle
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.color import *
from manimlib.utils.config_ops import digest_config
from manimlib.utils.config_ops import digest_locals
2016-04-17 00:31:38 -07:00
def string_to_numbers(num_string):
num_string = num_string.replace("-", ",-")
num_string = num_string.replace("e,-", "e-")
return [
float(s)
for s in re.split("[ ,]", num_string)
if s != ""
]
2016-04-17 00:31:38 -07:00
class SVGMobject(VMobject):
2016-04-23 23:36:05 -07:00
CONFIG = {
"should_center": True,
"height": 2,
"width": None,
# Must be filled in in a subclass, or when called
"file_name": None,
"unpack_groups": True, # if False, creates a hierarchy of VGroups
2020-02-06 10:02:42 -08:00
# TODO, style components should be read in, not defaulted
"stroke_width": DEFAULT_STROKE_WIDTH,
"fill_opacity": 1.0,
2016-04-23 23:36:05 -07:00
}
2018-09-27 17:37:25 -07:00
def __init__(self, file_name=None, **kwargs):
digest_config(self, kwargs)
self.file_name = file_name or self.file_name
2016-07-12 10:34:35 -07:00
self.ensure_valid_file()
2016-04-17 00:31:38 -07:00
VMobject.__init__(self, **kwargs)
2016-04-17 19:29:27 -07:00
self.move_into_position()
2016-04-17 00:31:38 -07:00
2016-07-12 10:34:35 -07:00
def ensure_valid_file(self):
2020-02-06 10:02:42 -08:00
file_name = self.file_name
if file_name is None:
raise Exception("Must specify file for SVGMobject")
2016-07-12 10:34:35 -07:00
possible_paths = [
2020-02-06 10:02:42 -08:00
os.path.join(os.path.join("assets", "svg_images"), file_name),
os.path.join(os.path.join("assets", "svg_images"), file_name + ".svg"),
os.path.join(os.path.join("assets", "svg_images"), file_name + ".xdv"),
file_name,
2016-07-12 10:34:35 -07:00
]
for path in possible_paths:
if os.path.exists(path):
2016-12-07 18:37:56 -08:00
self.file_path = path
2016-07-12 10:34:35 -07:00
return
2020-02-06 10:02:42 -08:00
raise IOError(f"No file matching {file_name} in image directory")
2016-07-12 10:34:35 -07:00
2020-02-11 19:55:00 -08:00
def init_points(self):
2016-12-07 18:37:56 -08:00
doc = minidom.parse(self.file_path)
2016-04-17 19:29:27 -07:00
self.ref_to_element = {}
2020-02-06 10:02:42 -08:00
2016-04-17 19:29:27 -07:00
for svg in doc.getElementsByTagName("svg"):
mobjects = self.get_mobjects_from(svg)
if self.unpack_groups:
self.add(*mobjects)
else:
self.add(*mobjects[0].submobjects)
2016-04-17 00:31:38 -07:00
doc.unlink()
2016-04-17 19:29:27 -07:00
def get_mobjects_from(self, element):
result = []
if not isinstance(element, minidom.Element):
return result
if element.tagName == 'defs':
self.update_ref_to_element(element)
elif element.tagName == 'style':
pass # TODO, handle style
2019-08-05 22:53:15 +08:00
elif element.tagName in ['g', 'svg', 'symbol']:
2016-04-17 19:29:27 -07:00
result += it.chain(*[
self.get_mobjects_from(child)
for child in element.childNodes
])
elif element.tagName == 'path':
2016-04-20 19:24:54 -07:00
result.append(self.path_string_to_mobject(
element.getAttribute('d')
))
2016-04-17 19:29:27 -07:00
elif element.tagName == 'use':
result += self.use_to_mobjects(element)
elif element.tagName == 'rect':
result.append(self.rect_to_mobject(element))
elif element.tagName == 'circle':
result.append(self.circle_to_mobject(element))
2018-01-28 14:55:17 +01:00
elif element.tagName == 'ellipse':
result.append(self.ellipse_to_mobject(element))
elif element.tagName in ['polygon', 'polyline']:
2017-01-17 17:14:32 -08:00
result.append(self.polygon_to_mobject(element))
2016-04-17 19:29:27 -07:00
else:
pass # TODO
2017-06-05 12:47:03 -07:00
# warnings.warn("Unknown element type: " + element.tagName)
2018-08-09 17:56:05 -07:00
result = [m for m in result if m is not None]
2019-02-06 21:16:26 -08:00
self.handle_transforms(element, VGroup(*result))
if len(result) > 1 and not self.unpack_groups:
result = [VGroup(*result)]
2016-04-17 19:29:27 -07:00
return result
def g_to_mobjects(self, g_element):
2019-02-06 21:16:26 -08:00
mob = VGroup(*self.get_mobjects_from(g_element))
2016-04-17 19:29:27 -07:00
self.handle_transforms(g_element, mob)
return mob.submobjects
2016-04-20 19:24:54 -07:00
def path_string_to_mobject(self, path_string):
return VMobjectFromSVGPathstring(path_string)
2016-04-17 19:29:27 -07:00
def use_to_mobjects(self, use_element):
# Remove initial "#" character
2016-04-17 00:31:38 -07:00
ref = use_element.getAttribute("xlink:href")[1:]
2016-04-20 19:24:54 -07:00
if ref not in self.ref_to_element:
2020-02-06 10:02:42 -08:00
warnings.warn(f"{ref} not recognized")
2019-02-06 21:16:26 -08:00
return VGroup()
2016-04-20 19:24:54 -07:00
return self.get_mobjects_from(
self.ref_to_element[ref]
)
2016-04-17 19:29:27 -07:00
2018-09-04 16:14:11 -07:00
def attribute_to_float(self, attr):
stripped_attr = "".join([
char for char in attr
if char in string.digits + "." + "-"
])
return float(stripped_attr)
2017-01-17 17:14:32 -08:00
def polygon_to_mobject(self, polygon_element):
path_string = polygon_element.getAttribute("points")
for digit in string.digits:
2020-02-06 10:02:42 -08:00
path_string = path_string.replace(f" {digit}", f"{digit} L")
2017-01-17 17:14:32 -08:00
path_string = "M" + path_string
return self.path_string_to_mobject(path_string)
2016-04-17 00:31:38 -07:00
def circle_to_mobject(self, circle_element):
2016-04-17 19:29:27 -07:00
x, y, r = [
2018-09-04 16:14:11 -07:00
self.attribute_to_float(
circle_element.getAttribute(key)
)
2016-04-17 19:29:27 -07:00
if circle_element.hasAttribute(key)
else 0.0
for key in ("cx", "cy", "r")
2016-04-17 19:29:27 -07:00
]
return Circle(radius=r).shift(x * RIGHT + y * DOWN)
2016-04-17 00:31:38 -07:00
2018-01-28 14:55:17 +01:00
def ellipse_to_mobject(self, circle_element):
x, y, rx, ry = [
2018-09-04 16:14:11 -07:00
self.attribute_to_float(
circle_element.getAttribute(key)
)
2018-01-28 14:55:17 +01:00
if circle_element.hasAttribute(key)
else 0.0
for key in ("cx", "cy", "rx", "ry")
2018-01-28 14:55:17 +01:00
]
return Circle().scale(rx * RIGHT + ry * UP).shift(x * RIGHT + y * DOWN)
2018-01-28 14:55:17 +01:00
2016-04-17 00:31:38 -07:00
def rect_to_mobject(self, rect_element):
fill_color = rect_element.getAttribute("fill")
stroke_color = rect_element.getAttribute("stroke")
stroke_width = rect_element.getAttribute("stroke-width")
corner_radius = rect_element.getAttribute("rx")
# input preprocessing
if fill_color in ["", "none", "#FFF", "#FFFFFF"] or Color(fill_color) == Color(WHITE):
opacity = 0
2018-04-12 21:56:28 -07:00
fill_color = BLACK # shdn't be necessary but avoids error msgs
if fill_color in ["#000", "#000000"]:
fill_color = WHITE
if stroke_color in ["", "none", "#FFF", "#FFFFFF"] or Color(stroke_color) == Color(WHITE):
stroke_width = 0
stroke_color = BLACK
if stroke_color in ["#000", "#000000"]:
stroke_color = WHITE
if stroke_width in ["", "none", "0"]:
stroke_width = 0
if corner_radius in ["", "0", "none"]:
corner_radius = 0
corner_radius = float(corner_radius)
if corner_radius == 0:
mob = Rectangle(
2018-09-04 16:14:11 -07:00
width=self.attribute_to_float(
rect_element.getAttribute("width")
),
height=self.attribute_to_float(
rect_element.getAttribute("height")
),
2018-04-12 21:56:28 -07:00
stroke_width=stroke_width,
stroke_color=stroke_color,
fill_color=fill_color,
fill_opacity=opacity
)
else:
mob = RoundedRectangle(
2018-09-04 16:14:11 -07:00
width=self.attribute_to_float(
rect_element.getAttribute("width")
),
height=self.attribute_to_float(
rect_element.getAttribute("height")
),
2018-04-12 21:56:28 -07:00
stroke_width=stroke_width,
stroke_color=stroke_color,
fill_color=fill_color,
fill_opacity=opacity,
corner_radius=corner_radius
)
mob.shift(mob.get_center() - mob.get_corner(UP + LEFT))
2016-04-17 00:31:38 -07:00
return mob
2016-04-17 19:29:27 -07:00
def handle_transforms(self, element, mobject):
2016-04-17 00:31:38 -07:00
x, y = 0, 0
2016-04-17 19:29:27 -07:00
try:
2018-09-04 16:14:11 -07:00
x = self.attribute_to_float(element.getAttribute('x'))
# Flip y
2018-09-04 16:14:11 -07:00
y = -self.attribute_to_float(element.getAttribute('y'))
mobject.shift(x * RIGHT + y * UP)
2016-04-17 19:29:27 -07:00
except:
pass
transform = element.getAttribute('transform')
try: # transform matrix
prefix = "matrix("
suffix = ")"
if not transform.startswith(prefix) or not transform.endswith(suffix):
raise Exception()
transform = transform[len(prefix):-len(suffix)]
transform = string_to_numbers(transform)
transform = np.array(transform).reshape([3, 2])
x = transform[2][0]
y = -transform[2][1]
matrix = np.identity(self.dim)
matrix[:2, :2] = transform[:2, :]
matrix[1] *= -1
matrix[:, 1] *= -1
for mob in mobject.family_members_with_points():
mob.points = np.dot(mob.points, matrix)
mobject.shift(x * RIGHT + y * UP)
except:
pass
try: # transform scale
prefix = "scale("
suffix = ")"
if not transform.startswith(prefix) or not transform.endswith(suffix):
raise Exception()
transform = transform[len(prefix):-len(suffix)]
scale_values = string_to_numbers(transform)
if len(scale_values) == 2:
scale_x, scale_y = scale_values
2018-12-04 18:08:52 -08:00
mobject.scale(np.array([scale_x, scale_y, 1]), about_point=ORIGIN)
elif len(scale_values) == 1:
scale = scale_values[0]
2018-12-04 18:08:52 -08:00
mobject.scale(np.array([scale, scale, 1]), about_point=ORIGIN)
except:
pass
try: # transform translate
prefix = "translate("
suffix = ")"
if not transform.startswith(prefix) or not transform.endswith(suffix):
raise Exception()
transform = transform[len(prefix):-len(suffix)]
x, y = string_to_numbers(transform)
mobject.shift(x * RIGHT + y * DOWN)
except:
pass
# TODO, ...
2016-04-17 00:31:38 -07:00
2019-08-05 22:53:15 +08:00
def flatten(self, input_list):
output_list = []
2019-08-16 15:53:36 +08:00
for i in input_list:
if isinstance(i, list):
output_list.extend(self.flatten(i))
else:
output_list.append(i)
2019-08-05 22:53:15 +08:00
return output_list
def get_all_childNodes_have_id(self, element):
all_childNodes_have_id = []
if not isinstance(element, minidom.Element):
return
if element.hasAttribute('id'):
return [element]
2019-08-05 22:53:15 +08:00
for e in element.childNodes:
all_childNodes_have_id.append(self.get_all_childNodes_have_id(e))
return self.flatten([e for e in all_childNodes_have_id if e])
2016-04-17 19:29:27 -07:00
def update_ref_to_element(self, defs):
2019-08-05 22:53:15 +08:00
new_refs = dict([(e.getAttribute('id'), e) for e in self.get_all_childNodes_have_id(defs)])
2016-04-17 19:29:27 -07:00
self.ref_to_element.update(new_refs)
2016-04-17 00:31:38 -07:00
def move_into_position(self):
2016-04-23 23:36:05 -07:00
if self.should_center:
self.center()
2017-06-20 14:05:48 -07:00
if self.height is not None:
self.set_height(self.height)
2017-06-20 14:05:48 -07:00
if self.width is not None:
self.set_width(self.width)
2017-06-20 14:05:48 -07:00
2016-04-17 00:31:38 -07:00
class VMobjectFromSVGPathstring(VMobject):
def __init__(self, path_string, **kwargs):
self.path_string = path_string
2016-04-17 00:31:38 -07:00
VMobject.__init__(self, **kwargs)
2020-02-11 19:55:00 -08:00
def init_points(self):
2020-02-06 10:02:42 -08:00
self.relative_point = ORIGIN
for command, coord_string in self.get_commands_and_coord_strings():
new_points = self.string_to_points(command, coord_string)
self.handle_command(command, new_points)
# For a healthy triangulation later
self.subdivide_sharp_curves()
2020-02-06 10:02:42 -08:00
# SVG treats y-coordinate differently
self.stretch(-1, 1, about_point=ORIGIN)
def get_commands_and_coord_strings(self):
all_commands = list(self.get_command_to_function_map().keys())
all_commands += [c.lower() for c in all_commands]
pattern = "[{}]".format("".join(all_commands))
return zip(
2016-04-17 00:31:38 -07:00
re.findall(pattern, self.path_string),
re.split(pattern, self.path_string)[1:]
2020-02-06 10:02:42 -08:00
)
2020-02-06 10:02:42 -08:00
def handle_command(self, command, new_points):
if command.islower():
# Treat it as a relative command
2020-02-06 10:02:42 -08:00
new_points += self.relative_point
2020-02-06 10:02:42 -08:00
func, n_points = self.command_to_function(command)
func(*new_points[:n_points])
leftover_points = new_points[n_points:]
2020-02-06 10:02:42 -08:00
# Recursively handle the rest of the points
if len(leftover_points) > 0:
if command.upper() == "M":
# Treat following points as relative line coordinates
command = "l"
2020-02-06 10:02:42 -08:00
self.handle_command(command, leftover_points)
else:
# Command is over, reset for future relative commands
self.relative_point = self.points[-1]
2016-04-17 00:31:38 -07:00
2020-02-06 10:02:42 -08:00
def string_to_points(self, command, coord_string):
numbers = string_to_numbers(coord_string)
2020-02-06 10:02:42 -08:00
if command.upper() in ["H", "V"]:
i = {"H": 0, "V": 1}[command.upper()]
xy = np.zeros((len(numbers), 2))
xy[:, i] = numbers
if command.isupper():
xy[:, 1 - i] = self.relative_point[1 - i]
elif command.upper() == "A":
raise Exception("Not implemented")
else:
xy = np.array(numbers).reshape((len(numbers) // 2, 2))
result = np.zeros((xy.shape[0], self.dim))
result[:, :2] = xy
2016-04-17 00:31:38 -07:00
return result
2020-02-06 10:02:42 -08:00
def command_to_function(self, command):
return self.get_command_to_function_map()[command.upper()]
def get_command_to_function_map(self):
"""
Associates svg command to VMobject function, and
the number of arguments it takes in
"""
return {
"M": (self.start_new_path, 1),
"L": (self.add_line_to, 1),
"H": (self.add_line_to, 1),
"V": (self.add_line_to, 1),
"C": (self.add_cubic_bezier_curve_to, 3),
"S": (self.add_smooth_cubic_curve_to, 2),
"Q": (self.add_quadratic_bezier_curve_to, 2),
"T": (self.add_smooth_curve_to, 1),
"A": (self.add_quadratic_bezier_curve_to, 2), # TODO
2020-02-07 09:35:57 -08:00
"Z": (self.close_path, 0),
2020-02-06 10:02:42 -08:00
}
2016-04-17 00:31:38 -07:00
def get_original_path_string(self):
return self.path_string