3b1b-manim/manimlib/mobject/svg/svg_mobject.py
Varniex 5d3f730824
Cleaning up some imports + Minor Bug fixed in VectorField (#2253)
* cleaning up imports

* sample_points -> sample_coords
2024-12-05 14:42:46 -08:00

332 lines
11 KiB
Python

from __future__ import annotations
from xml.etree import ElementTree as ET
import numpy as np
import svgelements as se
import io
from manimlib.constants import RIGHT
from manimlib.logger import log
from manimlib.mobject.geometry import Circle
from manimlib.mobject.geometry import Line
from manimlib.mobject.geometry import Polygon
from manimlib.mobject.geometry import Polyline
from manimlib.mobject.geometry import Rectangle
from manimlib.mobject.geometry import RoundedRectangle
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.images import get_full_vector_image_path
from manimlib.utils.iterables import hash_obj
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.typing import ManimColor, Vect3Array
SVG_HASH_TO_MOB_MAP: dict[int, list[VMobject]] = {}
PATH_TO_POINTS: dict[str, Vect3Array] = {}
def _convert_point_to_3d(x: float, y: float) -> np.ndarray:
return np.array([x, y, 0.0])
class SVGMobject(VMobject):
file_name: str = ""
height: float | None = 2.0
width: float | None = None
def __init__(
self,
file_name: str = "",
should_center: bool = True,
height: float | None = None,
width: float | None = None,
# Style that overrides the original svg
color: ManimColor = None,
fill_color: ManimColor = None,
fill_opacity: float | None = None,
stroke_width: float | None = 0.0,
stroke_color: ManimColor = None,
stroke_opacity: float | None = None,
# Style that fills only when not specified
# If None, regarded as default values from svg standard
svg_default: dict = dict(
color=None,
opacity=None,
fill_color=None,
fill_opacity=None,
stroke_width=None,
stroke_color=None,
stroke_opacity=None,
),
path_string_config: dict = dict(),
**kwargs
):
self.file_name = file_name or self.file_name
self.svg_default = dict(svg_default)
self.path_string_config = dict(path_string_config)
super().__init__(**kwargs )
self.init_svg_mobject()
self.ensure_positive_orientation()
# Rather than passing style into super().__init__
# do it after svg has been taken in
self.set_style(
fill_color=color or fill_color,
fill_opacity=fill_opacity,
stroke_color=color or stroke_color,
stroke_width=stroke_width,
stroke_opacity=stroke_opacity,
)
# Initialize position
height = height or self.height
width = width or self.width
if should_center:
self.center()
if height is not None:
self.set_height(height)
if width is not None:
self.set_width(width)
def init_svg_mobject(self) -> None:
hash_val = hash_obj(self.hash_seed)
if hash_val in SVG_HASH_TO_MOB_MAP:
submobs = [sm.copy() for sm in SVG_HASH_TO_MOB_MAP[hash_val]]
else:
submobs = self.mobjects_from_file(self.get_file_path())
SVG_HASH_TO_MOB_MAP[hash_val] = [sm.copy() for sm in submobs]
self.add(*submobs)
self.flip(RIGHT) # Flip y
@property
def hash_seed(self) -> tuple:
# Returns data which can uniquely represent the result of `init_points`.
# The hashed value of it is stored as a key in `SVG_HASH_TO_MOB_MAP`.
return (
self.__class__.__name__,
self.svg_default,
self.path_string_config,
self.file_name
)
def mobjects_from_file(self, file_path: str) -> list[VMobject]:
element_tree = ET.parse(file_path)
new_tree = self.modify_xml_tree(element_tree)
# New svg based on tree contents
data_stream = io.BytesIO()
new_tree.write(data_stream)
data_stream.seek(0)
svg = se.SVG.parse(data_stream)
data_stream.close()
return self.mobjects_from_svg(svg)
def get_file_path(self) -> str:
if self.file_name is None:
raise Exception("Must specify file for SVGMobject")
return get_full_vector_image_path(self.file_name)
def modify_xml_tree(self, element_tree: ET.ElementTree) -> ET.ElementTree:
config_style_attrs = self.generate_config_style_dict()
style_keys = (
"fill",
"fill-opacity",
"stroke",
"stroke-opacity",
"stroke-width",
"style"
)
root = element_tree.getroot()
style_attrs = {
k: v
for k, v in root.attrib.items()
if k in style_keys
}
# Ignore other attributes in case that svgelements cannot parse them
SVG_XMLNS = "{http://www.w3.org/2000/svg}"
new_root = ET.Element("svg")
config_style_node = ET.SubElement(new_root, f"{SVG_XMLNS}g", config_style_attrs)
root_style_node = ET.SubElement(config_style_node, f"{SVG_XMLNS}g", style_attrs)
root_style_node.extend(root)
return ET.ElementTree(new_root)
def generate_config_style_dict(self) -> dict[str, str]:
keys_converting_dict = {
"fill": ("color", "fill_color"),
"fill-opacity": ("opacity", "fill_opacity"),
"stroke": ("color", "stroke_color"),
"stroke-opacity": ("opacity", "stroke_opacity"),
"stroke-width": ("stroke_width",)
}
svg_default_dict = self.svg_default
result = {}
for svg_key, style_keys in keys_converting_dict.items():
for style_key in style_keys:
if svg_default_dict[style_key] is None:
continue
result[svg_key] = str(svg_default_dict[style_key])
return result
def mobjects_from_svg(self, svg: se.SVG) -> list[VMobject]:
result = []
for shape in svg.elements():
if isinstance(shape, (se.Group, se.Use)):
continue
elif isinstance(shape, se.Path):
mob = self.path_to_mobject(shape)
elif isinstance(shape, se.SimpleLine):
mob = self.line_to_mobject(shape)
elif isinstance(shape, se.Rect):
mob = self.rect_to_mobject(shape)
elif isinstance(shape, (se.Circle, se.Ellipse)):
mob = self.ellipse_to_mobject(shape)
elif isinstance(shape, se.Polygon):
mob = self.polygon_to_mobject(shape)
elif isinstance(shape, se.Polyline):
mob = self.polyline_to_mobject(shape)
# elif isinstance(shape, se.Text):
# mob = self.text_to_mobject(shape)
elif type(shape) == se.SVGElement:
continue
else:
log.warning("Unsupported element type: %s", type(shape))
continue
if not mob.has_points():
continue
if isinstance(shape, se.GraphicObject):
self.apply_style_to_mobject(mob, shape)
if isinstance(shape, se.Transformable) and shape.apply:
self.handle_transform(mob, shape.transform)
result.append(mob)
return result
@staticmethod
def handle_transform(mob: VMobject, matrix: se.Matrix) -> VMobject:
mat = np.array([
[matrix.a, matrix.c],
[matrix.b, matrix.d]
])
vec = np.array([matrix.e, matrix.f, 0.0])
mob.apply_matrix(mat)
mob.shift(vec)
return mob
@staticmethod
def apply_style_to_mobject(
mob: VMobject,
shape: se.GraphicObject
) -> VMobject:
mob.set_style(
stroke_width=shape.stroke_width,
stroke_color=shape.stroke.hexrgb,
stroke_opacity=shape.stroke.opacity,
fill_color=shape.fill.hexrgb,
fill_opacity=shape.fill.opacity
)
return mob
def path_to_mobject(self, path: se.Path) -> VMobjectFromSVGPath:
return VMobjectFromSVGPath(path, **self.path_string_config)
def line_to_mobject(self, line: se.SimpleLine) -> Line:
return Line(
start=_convert_point_to_3d(line.x1, line.y1),
end=_convert_point_to_3d(line.x2, line.y2)
)
def rect_to_mobject(self, rect: se.Rect) -> Rectangle:
if rect.rx == 0 or rect.ry == 0:
mob = Rectangle(
width=rect.width,
height=rect.height,
)
else:
mob = RoundedRectangle(
width=rect.width,
height=rect.height * rect.rx / rect.ry,
corner_radius=rect.rx
)
mob.stretch_to_fit_height(rect.height)
mob.shift(_convert_point_to_3d(
rect.x + rect.width / 2,
rect.y + rect.height / 2
))
return mob
def ellipse_to_mobject(self, ellipse: se.Circle | se.Ellipse) -> Circle:
mob = Circle(radius=ellipse.rx)
mob.stretch_to_fit_height(2 * ellipse.ry)
mob.shift(_convert_point_to_3d(
ellipse.cx, ellipse.cy
))
return mob
def polygon_to_mobject(self, polygon: se.Polygon) -> Polygon:
points = [
_convert_point_to_3d(*point)
for point in polygon
]
return Polygon(*points)
def polyline_to_mobject(self, polyline: se.Polyline) -> Polyline:
points = [
_convert_point_to_3d(*point)
for point in polyline
]
return Polyline(*points)
def text_to_mobject(self, text: se.Text):
pass
class VMobjectFromSVGPath(VMobject):
def __init__(
self,
path_obj: se.Path,
**kwargs
):
# Get rid of arcs
path_obj.approximate_arcs_with_quads()
self.path_obj = path_obj
super().__init__(**kwargs)
def init_points(self) -> None:
# After a given svg_path has been converted into points, the result
# will be saved so that future calls for the same pathdon't need to
# retrace the same computation.
path_string = self.path_obj.d()
if path_string not in PATH_TO_POINTS:
self.handle_commands()
# Save for future use
PATH_TO_POINTS[path_string] = self.get_points().copy()
else:
points = PATH_TO_POINTS[path_string]
self.set_points(points)
def handle_commands(self) -> None:
segment_class_to_func_map = {
se.Move: (self.start_new_path, ("end",)),
se.Close: (self.close_path, ()),
se.Line: (lambda p: self.add_line_to(p, allow_null_line=False), ("end",)),
se.QuadraticBezier: (lambda c, e: self.add_quadratic_bezier_curve_to(c, e, allow_null_curve=False), ("control", "end")),
se.CubicBezier: (self.add_cubic_bezier_curve_to, ("control1", "control2", "end"))
}
for segment in self.path_obj:
segment_class = segment.__class__
func, attr_names = segment_class_to_func_map[segment_class]
points = [
_convert_point_to_3d(*segment.__getattribute__(attr_name))
for attr_name in attr_names
]
func(*points)
# Get rid of the side effect of trailing "Z M" commands.
if self.has_new_path_started():
self.resize_points(self.get_num_points() - 2)