3b1b-manim/manimlib/mobject/svg/svg_mobject.py
2022-12-19 16:03:16 -08:00

354 lines
12 KiB
Python

from __future__ import annotations
import os
from xml.etree import ElementTree as ET
import numpy as np
import svgelements as se
import io
from manimlib.constants import RIGHT
from manimlib.logger import log
from manimlib.mobject.geometry import Circle
from manimlib.mobject.geometry import Line
from manimlib.mobject.geometry import Polygon
from manimlib.mobject.geometry import Polyline
from manimlib.mobject.geometry import Rectangle
from manimlib.mobject.geometry import RoundedRectangle
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.directories import get_mobject_data_dir
from manimlib.utils.images import get_full_vector_image_path
from manimlib.utils.iterables import hash_obj
from manimlib.utils.simple_functions import hash_string
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.typing import ManimColor
SVG_HASH_TO_MOB_MAP: dict[int, VMobject] = {}
def _convert_point_to_3d(x: float, y: float) -> np.ndarray:
return np.array([x, y, 0.0])
class SVGMobject(VMobject):
file_name: str = ""
height: float | None = 2.0
width: float | None = None
def __init__(
self,
file_name: str = "",
should_center: bool = True,
height: float | None = None,
width: float | None = None,
# Style that overrides the original svg
color: ManimColor = None,
fill_color: ManimColor = None,
fill_opacity: float | None = None,
stroke_width: float | None = 0.0,
stroke_color: ManimColor = None,
stroke_opacity: float | None = None,
# Style that fills only when not specified
# If None, regarded as default values from svg standard
svg_default: dict = dict(
color=None,
opacity=None,
fill_color=None,
fill_opacity=None,
stroke_width=None,
stroke_color=None,
stroke_opacity=None,
),
path_string_config: dict = dict(),
**kwargs
):
self.file_name = file_name or self.file_name
self.svg_default = dict(svg_default)
self.path_string_config = dict(path_string_config)
super().__init__(**kwargs )
self.init_svg_mobject()
# Rather than passing style into super().__init__
# do it after svg has been taken in
self.set_style(
fill_color=color or fill_color,
fill_opacity=fill_opacity,
stroke_color=color or stroke_color,
stroke_width=stroke_width,
stroke_opacity=stroke_opacity,
)
# Initialize position
height = height or self.height
width = width or self.width
if should_center:
self.center()
if height is not None:
self.set_height(height)
if width is not None:
self.set_width(width)
def init_svg_mobject(self) -> None:
hash_val = hash_obj(self.hash_seed)
if hash_val in SVG_HASH_TO_MOB_MAP:
mob = SVG_HASH_TO_MOB_MAP[hash_val].copy()
self.add(*mob)
return
self.generate_mobject()
SVG_HASH_TO_MOB_MAP[hash_val] = self.copy()
@property
def hash_seed(self) -> tuple:
# Returns data which can uniquely represent the result of `init_points`.
# The hashed value of it is stored as a key in `SVG_HASH_TO_MOB_MAP`.
return (
self.__class__.__name__,
self.svg_default,
self.path_string_config,
self.file_name
)
def generate_mobject(self) -> None:
file_path = self.get_file_path()
element_tree = ET.parse(file_path)
new_tree = self.modify_xml_tree(element_tree)
# New svg based on tree contents
data_stream = io.BytesIO()
new_tree.write(data_stream)
data_stream.seek(0)
svg = se.SVG.parse(data_stream)
data_stream.close()
mobjects = self.get_mobjects_from(svg)
self.add(*mobjects)
self.flip(RIGHT) # Flip y
def get_file_path(self) -> str:
if self.file_name is None:
raise Exception("Must specify file for SVGMobject")
return get_full_vector_image_path(self.file_name)
def modify_xml_tree(self, element_tree: ET.ElementTree) -> ET.ElementTree:
config_style_attrs = self.generate_config_style_dict()
style_keys = (
"fill",
"fill-opacity",
"stroke",
"stroke-opacity",
"stroke-width",
"style"
)
root = element_tree.getroot()
style_attrs = {
k: v
for k, v in root.attrib.items()
if k in style_keys
}
# Ignore other attributes in case that svgelements cannot parse them
SVG_XMLNS = "{http://www.w3.org/2000/svg}"
new_root = ET.Element("svg")
config_style_node = ET.SubElement(new_root, f"{SVG_XMLNS}g", config_style_attrs)
root_style_node = ET.SubElement(config_style_node, f"{SVG_XMLNS}g", style_attrs)
root_style_node.extend(root)
return ET.ElementTree(new_root)
def generate_config_style_dict(self) -> dict[str, str]:
keys_converting_dict = {
"fill": ("color", "fill_color"),
"fill-opacity": ("opacity", "fill_opacity"),
"stroke": ("color", "stroke_color"),
"stroke-opacity": ("opacity", "stroke_opacity"),
"stroke-width": ("stroke_width",)
}
svg_default_dict = self.svg_default
result = {}
for svg_key, style_keys in keys_converting_dict.items():
for style_key in style_keys:
if svg_default_dict[style_key] is None:
continue
result[svg_key] = str(svg_default_dict[style_key])
return result
def get_mobjects_from(self, svg: se.SVG) -> list[VMobject]:
result = []
for shape in svg.elements():
if isinstance(shape, (se.Group, se.Use)):
continue
elif isinstance(shape, se.Path):
mob = self.path_to_mobject(shape)
elif isinstance(shape, se.SimpleLine):
mob = self.line_to_mobject(shape)
elif isinstance(shape, se.Rect):
mob = self.rect_to_mobject(shape)
elif isinstance(shape, (se.Circle, se.Ellipse)):
mob = self.ellipse_to_mobject(shape)
elif isinstance(shape, se.Polygon):
mob = self.polygon_to_mobject(shape)
elif isinstance(shape, se.Polyline):
mob = self.polyline_to_mobject(shape)
# elif isinstance(shape, se.Text):
# mob = self.text_to_mobject(shape)
elif type(shape) == se.SVGElement:
continue
else:
log.warning("Unsupported element type: %s", type(shape))
continue
if not mob.has_points():
continue
if isinstance(shape, se.GraphicObject):
self.apply_style_to_mobject(mob, shape)
if isinstance(shape, se.Transformable) and shape.apply:
self.handle_transform(mob, shape.transform)
result.append(mob)
return result
@staticmethod
def handle_transform(mob: VMobject, matrix: se.Matrix) -> VMobject:
mat = np.array([
[matrix.a, matrix.c],
[matrix.b, matrix.d]
])
vec = np.array([matrix.e, matrix.f, 0.0])
mob.apply_matrix(mat)
mob.shift(vec)
return mob
@staticmethod
def apply_style_to_mobject(
mob: VMobject,
shape: se.GraphicObject
) -> VMobject:
mob.set_style(
stroke_width=shape.stroke_width,
stroke_color=shape.stroke.hexrgb,
stroke_opacity=shape.stroke.opacity,
fill_color=shape.fill.hexrgb,
fill_opacity=shape.fill.opacity
)
return mob
def path_to_mobject(self, path: se.Path) -> VMobjectFromSVGPath:
return VMobjectFromSVGPath(path, **self.path_string_config)
def line_to_mobject(self, line: se.SimpleLine) -> Line:
return Line(
start=_convert_point_to_3d(line.x1, line.y1),
end=_convert_point_to_3d(line.x2, line.y2)
)
def rect_to_mobject(self, rect: se.Rect) -> Rectangle:
if rect.rx == 0 or rect.ry == 0:
mob = Rectangle(
width=rect.width,
height=rect.height,
)
else:
mob = RoundedRectangle(
width=rect.width,
height=rect.height * rect.rx / rect.ry,
corner_radius=rect.rx
)
mob.stretch_to_fit_height(rect.height)
mob.shift(_convert_point_to_3d(
rect.x + rect.width / 2,
rect.y + rect.height / 2
))
return mob
def ellipse_to_mobject(self, ellipse: se.Circle | se.Ellipse) -> Circle:
mob = Circle(radius=ellipse.rx)
mob.stretch_to_fit_height(2 * ellipse.ry)
mob.shift(_convert_point_to_3d(
ellipse.cx, ellipse.cy
))
return mob
def polygon_to_mobject(self, polygon: se.Polygon) -> Polygon:
points = [
_convert_point_to_3d(*point)
for point in polygon
]
return Polygon(*points)
def polyline_to_mobject(self, polyline: se.Polyline) -> Polyline:
points = [
_convert_point_to_3d(*point)
for point in polyline
]
return Polyline(*points)
def text_to_mobject(self, text: se.Text):
pass
class VMobjectFromSVGPath(VMobject):
def __init__(
self,
path_obj: se.Path,
long_lines: bool = False,
should_subdivide_sharp_curves: bool = False,
should_remove_null_curves: bool = False,
**kwargs
):
# Get rid of arcs
path_obj.approximate_arcs_with_quads()
self.path_obj = path_obj
self.long_lines = long_lines
self.should_subdivide_sharp_curves = should_subdivide_sharp_curves
self.should_remove_null_curves = should_remove_null_curves
super().__init__(**kwargs)
def init_points(self) -> None:
# After a given svg_path has been converted into points, the result
# will be saved to a file so that future calls for the same path
# don't need to retrace the same computation.
path_string = self.path_obj.d()
path_hash = hash_string(path_string)
points_filepath = os.path.join(get_mobject_data_dir(), f"{path_hash}_points.npy")
tris_filepath = os.path.join(get_mobject_data_dir(), f"{path_hash}_tris.npy")
if os.path.exists(points_filepath) and os.path.exists(tris_filepath):
self.set_points(np.load(points_filepath))
self.triangulation = np.load(tris_filepath)
self.needs_new_triangulation = False
else:
self.handle_commands()
if self.should_subdivide_sharp_curves:
# For a healthy triangulation later
self.subdivide_sharp_curves()
if self.should_remove_null_curves:
# Get rid of any null curves
self.set_points(self.get_points_without_null_curves())
# Save to a file for future use
np.save(points_filepath, self.get_points())
np.save(tris_filepath, self.get_triangulation())
def handle_commands(self) -> None:
segment_class_to_func_map = {
se.Move: (self.start_new_path, ("end",)),
se.Close: (self.close_path, ()),
se.Line: (self.add_line_to, ("end",)),
se.QuadraticBezier: (self.add_quadratic_bezier_curve_to, ("control", "end")),
se.CubicBezier: (self.add_cubic_bezier_curve_to, ("control1", "control2", "end"))
}
for segment in self.path_obj:
segment_class = segment.__class__
func, attr_names = segment_class_to_func_map[segment_class]
points = [
_convert_point_to_3d(*segment.__getattribute__(attr_name))
for attr_name in attr_names
]
func(*points)
# Get rid of the side effect of trailing "Z M" commands.
if self.has_new_path_started():
self.resize_points(self.get_num_points() - 1)