3b1b-manim/manimlib/mobject/svg/svg_mobject.py

354 lines
12 KiB
Python
Raw Normal View History

from __future__ import annotations
2022-04-12 19:19:59 +08:00
import os
from xml.etree import ElementTree as ET
import numpy as np
2022-04-12 19:19:59 +08:00
import svgelements as se
2020-02-18 22:31:29 -08:00
from manimlib.constants import RIGHT
2022-04-12 19:19:59 +08:00
from manimlib.logger import log
from manimlib.mobject.geometry import Circle
2022-04-12 19:19:59 +08:00
from manimlib.mobject.geometry import Line
from manimlib.mobject.geometry import Polygon
from manimlib.mobject.geometry import Polyline
from manimlib.mobject.geometry import Rectangle
from manimlib.mobject.geometry import RoundedRectangle
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.directories import get_mobject_data_dir
from manimlib.utils.images import get_full_vector_image_path
2022-02-15 21:38:22 +08:00
from manimlib.utils.iterables import hash_obj
2022-05-22 10:29:20 +08:00
from manimlib.utils.simple_functions import hash_string
2016-04-17 00:31:38 -07:00
2022-12-15 12:48:06 -08:00
from typing import TYPE_CHECKING
if TYPE_CHECKING:
2022-12-16 20:19:18 -08:00
from manimlib.typing import ManimColor
2022-12-15 12:48:06 -08:00
SVG_HASH_TO_MOB_MAP: dict[int, VMobject] = {}
def _convert_point_to_3d(x: float, y: float) -> np.ndarray:
return np.array([x, y, 0.0])
2022-01-26 13:53:53 +08:00
2016-04-17 00:31:38 -07:00
class SVGMobject(VMobject):
2022-12-15 12:48:06 -08:00
file_name: str = ""
height: float | None = 2.0
width: float | None = None
2022-12-15 12:48:06 -08:00
def __init__(
self,
file_name: str = "",
should_center: bool = True,
height: float | None = None,
2022-12-15 12:48:06 -08:00
width: float | None = None,
# Style that overrides the original svg
2022-12-15 12:48:06 -08:00
color: ManimColor = None,
fill_color: ManimColor = None,
fill_opacity: float | None = None,
stroke_width: float | None = 0.0,
stroke_color: ManimColor = None,
stroke_opacity: float | None = None,
# Style that fills only when not specified
# If None, regarded as default values from svg standard
2022-12-15 12:48:06 -08:00
svg_default: dict = dict(
color=None,
opacity=None,
fill_color=None,
fill_opacity=None,
stroke_width=None,
stroke_color=None,
stroke_opacity=None,
),
path_string_config: dict = dict(),
**kwargs
):
self.file_name = file_name or self.file_name
self.svg_default = dict(svg_default)
self.path_string_config = dict(path_string_config)
2022-12-15 12:48:06 -08:00
super().__init__(**kwargs )
self.init_svg_mobject()
# Rather than passing style into super().__init__
# do it after svg has been taken in
self.set_style(
fill_color=color or fill_color,
2022-12-15 12:48:06 -08:00
fill_opacity=fill_opacity,
stroke_color=color or stroke_color,
2022-12-15 12:48:06 -08:00
stroke_width=stroke_width,
stroke_opacity=stroke_opacity,
)
# Initialize position
height = height or self.height
width = width or self.width
if should_center:
self.center()
if height is not None:
self.set_height(height)
if width is not None:
self.set_width(width)
2016-07-12 10:34:35 -07:00
def init_svg_mobject(self) -> None:
2022-02-15 21:38:22 +08:00
hash_val = hash_obj(self.hash_seed)
if hash_val in SVG_HASH_TO_MOB_MAP:
mob = SVG_HASH_TO_MOB_MAP[hash_val].copy()
self.add(*mob)
return
self.generate_mobject()
SVG_HASH_TO_MOB_MAP[hash_val] = self.copy()
@property
2022-03-23 13:34:30 +08:00
def hash_seed(self) -> tuple:
# Returns data which can uniquely represent the result of `init_points`.
# The hashed value of it is stored as a key in `SVG_HASH_TO_MOB_MAP`.
return (
self.__class__.__name__,
self.svg_default,
self.path_string_config,
self.file_name
)
def generate_mobject(self) -> None:
file_path = self.get_file_path()
element_tree = ET.parse(file_path)
new_tree = self.modify_xml_tree(element_tree)
# Create a temporary svg file to dump modified svg to be parsed
2022-03-22 20:46:35 +08:00
root, ext = os.path.splitext(file_path)
modified_file_path = root + "_" + ext
new_tree.write(modified_file_path)
svg = se.SVG.parse(modified_file_path)
os.remove(modified_file_path)
mobjects = self.get_mobjects_from(svg)
self.add(*mobjects)
self.flip(RIGHT) # Flip y
def get_file_path(self) -> str:
if self.file_name is None:
raise Exception("Must specify file for SVGMobject")
return get_full_vector_image_path(self.file_name)
def modify_xml_tree(self, element_tree: ET.ElementTree) -> ET.ElementTree:
2022-05-22 10:29:20 +08:00
config_style_attrs = self.generate_config_style_dict()
style_keys = (
"fill",
"fill-opacity",
"stroke",
"stroke-opacity",
"stroke-width",
"style"
)
root = element_tree.getroot()
2022-05-28 12:40:29 +08:00
style_attrs = {
k: v
for k, v in root.attrib.items()
if k in style_keys
}
# Ignore other attributes in case that svgelements cannot parse them
SVG_XMLNS = "{http://www.w3.org/2000/svg}"
2022-05-29 16:26:40 +08:00
new_root = ET.Element("svg")
config_style_node = ET.SubElement(new_root, f"{SVG_XMLNS}g", config_style_attrs)
root_style_node = ET.SubElement(config_style_node, f"{SVG_XMLNS}g", style_attrs)
root_style_node.extend(root)
return ET.ElementTree(new_root)
def generate_config_style_dict(self) -> dict[str, str]:
keys_converting_dict = {
"fill": ("color", "fill_color"),
"fill-opacity": ("opacity", "fill_opacity"),
"stroke": ("color", "stroke_color"),
"stroke-opacity": ("opacity", "stroke_opacity"),
"stroke-width": ("stroke_width",)
}
svg_default_dict = self.svg_default
result = {}
for svg_key, style_keys in keys_converting_dict.items():
for style_key in style_keys:
if svg_default_dict[style_key] is None:
continue
result[svg_key] = str(svg_default_dict[style_key])
return result
def get_mobjects_from(self, svg: se.SVG) -> list[VMobject]:
result = []
for shape in svg.elements():
if isinstance(shape, (se.Group, se.Use)):
continue
2022-03-22 20:46:35 +08:00
elif isinstance(shape, se.Path):
mob = self.path_to_mobject(shape)
elif isinstance(shape, se.SimpleLine):
mob = self.line_to_mobject(shape)
elif isinstance(shape, se.Rect):
mob = self.rect_to_mobject(shape)
elif isinstance(shape, (se.Circle, se.Ellipse)):
2022-03-22 20:46:35 +08:00
mob = self.ellipse_to_mobject(shape)
elif isinstance(shape, se.Polygon):
mob = self.polygon_to_mobject(shape)
elif isinstance(shape, se.Polyline):
mob = self.polyline_to_mobject(shape)
# elif isinstance(shape, se.Text):
# mob = self.text_to_mobject(shape)
elif type(shape) == se.SVGElement:
continue
2022-03-22 20:46:35 +08:00
else:
2022-05-30 10:19:52 +08:00
log.warning("Unsupported element type: %s", type(shape))
2022-03-22 20:46:35 +08:00
continue
2022-03-26 20:52:28 +08:00
if not mob.has_points():
continue
if isinstance(shape, se.GraphicObject):
self.apply_style_to_mobject(mob, shape)
if isinstance(shape, se.Transformable) and shape.apply:
self.handle_transform(mob, shape.transform)
result.append(mob)
return result
@staticmethod
def handle_transform(mob: VMobject, matrix: se.Matrix) -> VMobject:
mat = np.array([
[matrix.a, matrix.c],
[matrix.b, matrix.d]
])
vec = np.array([matrix.e, matrix.f, 0.0])
mob.apply_matrix(mat)
mob.shift(vec)
return mob
@staticmethod
def apply_style_to_mobject(
mob: VMobject,
2022-02-16 20:37:07 +08:00
shape: se.GraphicObject
) -> VMobject:
mob.set_style(
stroke_width=shape.stroke_width,
stroke_color=shape.stroke.hexrgb,
stroke_opacity=shape.stroke.opacity,
fill_color=shape.fill.hexrgb,
fill_opacity=shape.fill.opacity
)
return mob
def path_to_mobject(self, path: se.Path) -> VMobjectFromSVGPath:
return VMobjectFromSVGPath(path, **self.path_string_config)
2022-08-22 22:05:33 +08:00
def line_to_mobject(self, line: se.SimpleLine) -> Line:
return Line(
start=_convert_point_to_3d(line.x1, line.y1),
end=_convert_point_to_3d(line.x2, line.y2)
)
2022-01-26 13:53:53 +08:00
2022-02-16 20:37:07 +08:00
def rect_to_mobject(self, rect: se.Rect) -> Rectangle:
if rect.rx == 0 or rect.ry == 0:
mob = Rectangle(
width=rect.width,
height=rect.height,
)
else:
mob = RoundedRectangle(
width=rect.width,
height=rect.height * rect.rx / rect.ry,
corner_radius=rect.rx
)
mob.stretch_to_fit_height(rect.height)
mob.shift(_convert_point_to_3d(
rect.x + rect.width / 2,
rect.y + rect.height / 2
))
return mob
def ellipse_to_mobject(self, ellipse: se.Circle | se.Ellipse) -> Circle:
mob = Circle(radius=ellipse.rx)
mob.stretch_to_fit_height(2 * ellipse.ry)
mob.shift(_convert_point_to_3d(
ellipse.cx, ellipse.cy
))
return mob
def polygon_to_mobject(self, polygon: se.Polygon) -> Polygon:
points = [
_convert_point_to_3d(*point)
for point in polygon
2022-01-25 14:04:35 +08:00
]
return Polygon(*points)
def polyline_to_mobject(self, polyline: se.Polyline) -> Polyline:
points = [
_convert_point_to_3d(*point)
for point in polyline
]
return Polyline(*points)
def text_to_mobject(self, text: se.Text):
pass
2016-04-17 00:31:38 -07:00
class VMobjectFromSVGPath(VMobject):
2022-12-15 12:48:06 -08:00
def __init__(
self,
path_obj: se.Path,
long_lines: bool = False,
should_subdivide_sharp_curves: bool = False,
should_remove_null_curves: bool = False,
**kwargs
):
# Get rid of arcs
path_obj.approximate_arcs_with_quads()
self.path_obj = path_obj
2022-12-15 12:48:06 -08:00
self.long_lines = long_lines
self.should_subdivide_sharp_curves = should_subdivide_sharp_curves
self.should_remove_null_curves = should_remove_null_curves
super().__init__(**kwargs)
2016-04-17 00:31:38 -07:00
def init_points(self) -> None:
2021-01-30 17:51:14 -08:00
# After a given svg_path has been converted into points, the result
# will be saved to a file so that future calls for the same path
# don't need to retrace the same computation.
path_string = self.path_obj.d()
2022-05-22 10:29:20 +08:00
path_hash = hash_string(path_string)
2021-01-11 16:37:01 -10:00
points_filepath = os.path.join(get_mobject_data_dir(), f"{path_hash}_points.npy")
tris_filepath = os.path.join(get_mobject_data_dir(), f"{path_hash}_tris.npy")
2020-02-18 22:31:29 -08:00
2022-01-25 20:25:30 +08:00
if os.path.exists(points_filepath) and os.path.exists(tris_filepath):
2021-01-11 16:37:01 -10:00
self.set_points(np.load(points_filepath))
self.triangulation = np.load(tris_filepath)
self.needs_new_triangulation = False
2020-02-18 22:31:29 -08:00
else:
self.handle_commands()
if self.should_subdivide_sharp_curves:
2020-02-20 15:51:04 -08:00
# For a healthy triangulation later
self.subdivide_sharp_curves()
if self.should_remove_null_curves:
# Get rid of any null curves
2021-01-10 18:51:47 -08:00
self.set_points(self.get_points_without_null_curves())
2020-02-18 22:31:29 -08:00
# Save to a file for future use
2021-01-11 16:37:01 -10:00
np.save(points_filepath, self.get_points())
np.save(tris_filepath, self.get_triangulation())
2020-02-06 10:02:42 -08:00
def handle_commands(self) -> None:
segment_class_to_func_map = {
se.Move: (self.start_new_path, ("end",)),
se.Close: (self.close_path, ()),
se.Line: (self.add_line_to, ("end",)),
se.QuadraticBezier: (self.add_quadratic_bezier_curve_to, ("control", "end")),
se.CubicBezier: (self.add_cubic_bezier_curve_to, ("control1", "control2", "end"))
2020-02-06 10:02:42 -08:00
}
for segment in self.path_obj:
segment_class = segment.__class__
func, attr_names = segment_class_to_func_map[segment_class]
points = [
_convert_point_to_3d(*segment.__getattribute__(attr_name))
for attr_name in attr_names
]
func(*points)
2022-02-26 20:31:26 +08:00
# Get rid of the side effect of trailing "Z M" commands.
if self.has_new_path_started():
self.resize_points(self.get_num_points() - 1)