3b1b-manim/manimlib/mobject/svg/svg_mobject.py

323 lines
11 KiB
Python
Raw Normal View History

2020-02-18 22:31:29 -08:00
import os
import re
2020-02-18 22:31:29 -08:00
import hashlib
import itertools as it
import svgelements as se
import numpy as np
2020-02-18 22:31:29 -08:00
from manimlib.constants import RIGHT
2022-01-27 17:23:58 +08:00
from manimlib.mobject.geometry import Line
from manimlib.mobject.geometry import Circle
from manimlib.mobject.geometry import Polygon
from manimlib.mobject.geometry import Polyline
from manimlib.mobject.geometry import Rectangle
from manimlib.mobject.geometry import RoundedRectangle
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.config_ops import digest_config
from manimlib.utils.directories import get_mobject_data_dir
from manimlib.utils.images import get_full_vector_image_path
2022-01-25 14:09:05 +08:00
from manimlib.logger import log
2016-04-17 00:31:38 -07:00
def _convert_point_to_3d(x, y):
return np.array([x, y, 0.0])
2022-01-26 13:53:53 +08:00
2016-04-17 00:31:38 -07:00
class SVGMobject(VMobject):
2016-04-23 23:36:05 -07:00
CONFIG = {
"should_center": True,
"height": 2,
"width": None,
# Must be filled in a subclass, or when called
"file_name": None,
"color": None,
"opacity": None,
"fill_color": None,
"fill_opacity": None,
"stroke_width": None,
"stroke_color": None,
"stroke_opacity": None,
"path_string_config": {}
2016-04-23 23:36:05 -07:00
}
2018-09-27 17:37:25 -07:00
def __init__(self, file_name=None, **kwargs):
digest_config(self, kwargs)
self.file_name = file_name or self.file_name
2020-02-06 10:02:42 -08:00
if file_name is None:
raise Exception("Must specify file for SVGMobject")
self.file_path = get_full_vector_image_path(file_name)
super().__init__(**kwargs)
self.move_into_position()
2016-07-12 10:34:35 -07:00
2020-06-26 21:53:26 -07:00
def move_into_position(self):
if self.should_center:
self.center()
if self.height is not None:
self.set_height(self.height)
if self.width is not None:
self.set_width(self.width)
2022-01-26 08:20:45 -08:00
def init_colors(self):
# Remove fill_color, fill_opacity,
# stroke_width, stroke_color, stroke_opacity
# as each submobject may have those values specified in svg file
self.set_stroke(background=self.draw_stroke_behind_fill)
self.set_gloss(self.gloss)
self.set_flat_stroke(self.flat_stroke)
return self
2020-06-26 21:53:26 -07:00
2020-02-11 19:55:00 -08:00
def init_points(self):
with open(self.file_path, "r") as svg_file:
svg_string = svg_file.read()
# Create a temporary svg file to dump modified svg to be parsed
modified_svg_string = self.modify_svg_file(svg_string)
modified_file_path = self.file_path.replace(".svg", "_.svg")
with open(modified_file_path, "w") as modified_svg_file:
modified_svg_file.write(modified_svg_string)
2022-01-26 08:20:45 -08:00
# `color` attribute handles `currentColor` keyword
2022-01-26 13:53:53 +08:00
if self.fill_color:
color = self.fill_color
elif self.color:
color = self.color
else:
color = "black"
shapes = se.SVG.parse(
modified_file_path,
color=color
)
os.remove(modified_file_path)
mobjects = self.get_mobjects_from(shapes)
self.add(*mobjects)
self.flip(RIGHT) # Flip y
self.scale(0.75)
def modify_svg_file(self, svg_string):
# svgelements cannot handle em, ex units
# Convert them using 1em = 16px, 1ex = 0.5em = 8px
def convert_unit(match_obj):
number = float(match_obj.group(1))
unit = match_obj.group(2)
factor = 16 if unit == "em" else 8
return str(number * factor) + "px"
number_pattern = r"([-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?)(ex|em)(?![a-zA-Z])"
result = re.sub(number_pattern, convert_unit, svg_string)
# Add a group tag to set style from configuration
style_dict = self.generate_context_values_from_config()
group_tag_begin = "<g " + " ".join([
f"{k}=\"{v}\""
for k, v in style_dict.items()
]) + ">"
group_tag_end = "</g>"
begin_insert_index = re.search(r"<svg[\s\S]*?>", result).end()
end_insert_index = re.search(r"[\s\S]*(</svg\s*>)", result).start(1)
result = "".join([
result[:begin_insert_index],
group_tag_begin,
result[begin_insert_index:end_insert_index],
group_tag_end,
result[end_insert_index:]
2018-09-04 16:14:11 -07:00
])
2021-03-24 14:00:46 -07:00
return result
2018-01-28 14:55:17 +01:00
def generate_context_values_from_config(self):
result = {}
if self.stroke_width is not None:
result["stroke-width"] = self.stroke_width
if self.color is not None:
result["fill"] = result["stroke"] = self.color
if self.fill_color is not None:
result["fill"] = self.fill_color
if self.stroke_color is not None:
result["stroke"] = self.stroke_color
if self.opacity is not None:
result["fill-opacity"] = result["stroke-opacity"] = self.opacity
if self.fill_opacity is not None:
result["fill-opacity"] = self.fill_opacity
if self.stroke_opacity is not None:
result["stroke-opacity"] = self.stroke_opacity
return result
def get_mobjects_from(self, shape):
if isinstance(shape, se.Group):
return list(it.chain(*(
self.get_mobjects_from(child)
for child in shape
)))
mob = self.get_mobject_from(shape)
if mob is None:
return []
if isinstance(shape, se.Transformable) and shape.apply:
self.handle_transform(mob, shape.transform)
return [mob]
@staticmethod
def handle_transform(mob, matrix):
mat = np.array([
[matrix.a, matrix.c],
[matrix.b, matrix.d]
])
vec = np.array([matrix.e, matrix.f, 0.0])
mob.apply_matrix(mat)
mob.shift(vec)
return mob
def get_mobject_from(self, shape):
shape_class_to_func_map = {
se.Path: self.path_to_mobject,
se.SimpleLine: self.line_to_mobject,
se.Rect: self.rect_to_mobject,
se.Circle: self.circle_to_mobject,
se.Ellipse: self.ellipse_to_mobject,
se.Polygon: self.polygon_to_mobject,
se.Polyline: self.polyline_to_mobject,
# se.Text: self.text_to_mobject, # TODO
}
for shape_class, func in shape_class_to_func_map.items():
if isinstance(shape, shape_class):
mob = func(shape)
self.apply_style_to_mobject(mob, shape)
return mob
shape_class_name = shape.__class__.__name__
if shape_class_name != "SVGElement":
log.warning(f"Unsupported element type: {shape_class_name}")
return None
@staticmethod
def apply_style_to_mobject(mob, shape):
mob.set_style(
stroke_width=shape.stroke_width,
stroke_color=shape.stroke.hex,
stroke_opacity=shape.stroke.opacity,
fill_color=shape.fill.hex,
fill_opacity=shape.fill.opacity
)
return mob
def path_to_mobject(self, path):
return VMobjectFromSVGPath(path, **self.path_string_config)
def line_to_mobject(self, line):
return Line(
start=_convert_point_to_3d(line.x1, line.y1),
end=_convert_point_to_3d(line.x2, line.y2)
)
2022-01-26 13:53:53 +08:00
def rect_to_mobject(self, rect):
if rect.rx == 0 or rect.ry == 0:
mob = Rectangle(
width=rect.width,
height=rect.height,
)
else:
mob = RoundedRectangle(
width=rect.width,
height=rect.height * rect.rx / rect.ry,
corner_radius=rect.rx
)
mob.stretch_to_fit_height(rect.height)
mob.shift(_convert_point_to_3d(
rect.x + rect.width / 2,
rect.y + rect.height / 2
))
return mob
def circle_to_mobject(self, circle):
# svgelements supports `rx` & `ry` but `r`
mob = Circle(radius=circle.rx)
mob.shift(_convert_point_to_3d(
circle.cx, circle.cy
))
2016-04-17 00:31:38 -07:00
return mob
def ellipse_to_mobject(self, ellipse):
mob = Circle(radius=ellipse.rx)
mob.stretch_to_fit_height(2 * ellipse.ry)
mob.shift(_convert_point_to_3d(
ellipse.cx, ellipse.cy
))
return mob
def polygon_to_mobject(self, polygon):
points = [
_convert_point_to_3d(*point)
for point in polygon
2022-01-25 14:04:35 +08:00
]
return Polygon(*points)
def polyline_to_mobject(self, polyline):
points = [
_convert_point_to_3d(*point)
for point in polyline
]
return Polyline(*points)
def text_to_mobject(self, text):
pass
2016-04-17 00:31:38 -07:00
class VMobjectFromSVGPath(VMobject):
2020-02-20 15:51:04 -08:00
CONFIG = {
2021-10-01 12:32:04 -07:00
"long_lines": False,
"should_subdivide_sharp_curves": False,
"should_remove_null_curves": False,
2020-02-20 15:51:04 -08:00
}
def __init__(self, path_obj, **kwargs):
# Get rid of arcs
path_obj.approximate_arcs_with_quads()
self.path_obj = path_obj
super().__init__(**kwargs)
2016-04-17 00:31:38 -07:00
2020-02-11 19:55:00 -08:00
def init_points(self):
2021-01-30 17:51:14 -08:00
# After a given svg_path has been converted into points, the result
# will be saved to a file so that future calls for the same path
# don't need to retrace the same computation.
path_string = self.path_obj.d()
hasher = hashlib.sha256(path_string.encode())
2020-02-18 22:31:29 -08:00
path_hash = hasher.hexdigest()[:16]
2021-01-11 16:37:01 -10:00
points_filepath = os.path.join(get_mobject_data_dir(), f"{path_hash}_points.npy")
tris_filepath = os.path.join(get_mobject_data_dir(), f"{path_hash}_tris.npy")
2020-02-18 22:31:29 -08:00
2022-01-25 20:25:30 +08:00
if os.path.exists(points_filepath) and os.path.exists(tris_filepath):
2021-01-11 16:37:01 -10:00
self.set_points(np.load(points_filepath))
self.triangulation = np.load(tris_filepath)
self.needs_new_triangulation = False
2020-02-18 22:31:29 -08:00
else:
self.handle_commands()
if self.should_subdivide_sharp_curves:
2020-02-20 15:51:04 -08:00
# For a healthy triangulation later
self.subdivide_sharp_curves()
if self.should_remove_null_curves:
# Get rid of any null curves
2021-01-10 18:51:47 -08:00
self.set_points(self.get_points_without_null_curves())
2020-02-18 22:31:29 -08:00
# Save to a file for future use
2021-01-11 16:37:01 -10:00
np.save(points_filepath, self.get_points())
np.save(tris_filepath, self.get_triangulation())
2020-02-06 10:02:42 -08:00
def handle_commands(self):
segment_class_to_func_map = {
se.Move: (self.start_new_path, ("end",)),
se.Close: (self.close_path, ()),
se.Line: (self.add_line_to, ("end",)),
se.QuadraticBezier: (self.add_quadratic_bezier_curve_to, ("control", "end")),
se.CubicBezier: (self.add_cubic_bezier_curve_to, ("control1", "control2", "end"))
2020-02-06 10:02:42 -08:00
}
for segment in self.path_obj:
segment_class = segment.__class__
func, attr_names = segment_class_to_func_map[segment_class]
points = [
_convert_point_to_3d(*segment.__getattribute__(attr_name))
for attr_name in attr_names
]
func(*points)