add style support to svg

This commit is contained in:
TonyCrane 2022-01-26 13:53:53 +08:00
parent 6c8dd14adc
commit 92adcd75d4
No known key found for this signature in database
GPG key ID: 2313A5058A9C637C
5 changed files with 143 additions and 56 deletions

View file

@ -608,8 +608,8 @@ class Arrow(Line):
self.insert_tip_anchor() self.insert_tip_anchor()
return self return self
def init_colors(self): def init_colors(self, override=True):
super().init_colors() super().init_colors(override)
self.create_tip_with_stroke_width() self.create_tip_with_stroke_width()
def get_arc_length(self): def get_arc_length(self):

View file

@ -109,8 +109,8 @@ class Mobject(object):
"reflectiveness": self.reflectiveness, "reflectiveness": self.reflectiveness,
} }
def init_colors(self): def init_colors(self, override=True):
self.set_color(self.color, self.opacity) self.set_color(self.color, self.opacity, override)
def init_points(self): def init_points(self):
# Typically implemented in subclass, unlpess purposefully left blank # Typically implemented in subclass, unlpess purposefully left blank

View file

@ -4,6 +4,7 @@ import string
import os import os
import hashlib import hashlib
from colour import web2hex
from xml.dom import minidom from xml.dom import minidom
from manimlib.constants import DEFAULT_STROKE_WIDTH from manimlib.constants import DEFAULT_STROKE_WIDTH
@ -25,6 +26,82 @@ from manimlib.utils.simple_functions import clip
from manimlib.logger import log from manimlib.logger import log
DEFAULT_STYLE = {
"fill": "black",
"stroke": "none",
"fill-opacity": "1",
"stroke-opacity": "1",
"stroke-width": 0,
}
def cascade_element_style(element, inherited):
style = inherited.copy()
for attr in DEFAULT_STYLE:
if element.hasAttribute(attr):
style[attr] = element.getAttribute(attr)
if element.hasAttribute("style"):
for style_spec in element.getAttribute("style").split(";"):
style_spec = style_spec.strip()
try:
key, value = style_spec.split(":")
except ValueError as e:
if not style_spec.strip():
pass
else:
raise e
else:
style[key.strip()] = value.strip()
return style
def parse_color(color):
color = color.strip()
if color[0:3] == "rgb":
splits = color[4:-1].strip().split(",")
if splits[0].strip()[-1] == "%":
parsed_rgbs = [float(i.strip()[:-1]) / 100.0 for i in splits]
else:
parsed_rgbs = [int(i) / 255.0 for i in splits]
return rgb_to_hex(parsed_rgbs)
else:
return web2hex(color)
def fill_default_values(style):
for attr in DEFAULT_STYLE:
if attr not in style:
style[attr] = DEFAULT_STYLE[attr]
def parse_style(style):
manim_style = {}
fill_default_values(style)
manim_style["fill_opacity"] = float(style["fill-opacity"])
manim_style["stroke_opacity"] = float(style["stroke-opacity"])
manim_style["stroke_width"] = float(style["stroke-width"])
if style["fill"] == "none":
manim_style["fill_opacity"] = 0
else:
manim_style["fill_color"] = parse_color(style["fill"])
if style["stroke"] == "none":
manim_style["stroke_width"] = 0
if "fill_color" in manim_style:
manim_style["stroke_color"] = manim_style["fill_color"]
else:
manim_style["stroke_color"] = parse_color(style["stroke"])
return manim_style
class SVGMobject(VMobject): class SVGMobject(VMobject):
CONFIG = { CONFIG = {
"should_center": True, "should_center": True,
@ -57,6 +134,9 @@ class SVGMobject(VMobject):
if self.width is not None: if self.width is not None:
self.set_width(self.width) self.set_width(self.width)
def init_colors(self, override=False):
super().init_colors(override=False)
def init_points(self): def init_points(self):
doc = minidom.parse(self.file_path) doc = minidom.parse(self.file_path)
self.ref_to_element = {} self.ref_to_element = {}
@ -64,64 +144,84 @@ class SVGMobject(VMobject):
for child in doc.childNodes: for child in doc.childNodes:
if not isinstance(child, minidom.Element): continue if not isinstance(child, minidom.Element): continue
if child.tagName != 'svg': continue if child.tagName != 'svg': continue
mobjects = self.get_mobjects_from(child) mobjects = self.get_mobjects_from(child, self.generate_style())
if self.unpack_groups: if self.unpack_groups:
self.add(*mobjects) self.add(*mobjects)
else: else:
self.add(*mobjects[0].submobjects) self.add(*mobjects[0].submobjects)
doc.unlink() doc.unlink()
def get_mobjects_from(self, element): def get_mobjects_from(self, element, style):
result = [] result = []
if not isinstance(element, minidom.Element): if not isinstance(element, minidom.Element):
return result return result
style = cascade_element_style(element, style)
if element.tagName == 'defs': if element.tagName == 'defs':
self.update_ref_to_element(element) self.update_ref_to_element(element, style)
elif element.tagName == 'style': elif element.tagName == 'style':
pass # TODO, handle style pass # TODO, handle style
elif element.tagName in ['g', 'svg', 'symbol']: elif element.tagName in ['g', 'svg', 'symbol']:
result += it.chain(*( result += it.chain(*(
self.get_mobjects_from(child) self.get_mobjects_from(child, style)
for child in element.childNodes for child in element.childNodes
)) ))
elif element.tagName == 'path': elif element.tagName == 'path':
result.append(self.path_string_to_mobject( result.append(self.path_string_to_mobject(
element.getAttribute('d') element.getAttribute('d'), style
)) ))
elif element.tagName == 'use': elif element.tagName == 'use':
result += self.use_to_mobjects(element) result += self.use_to_mobjects(element, style)
elif element.tagName == 'rect': elif element.tagName == 'rect':
result.append(self.rect_to_mobject(element)) result.append(self.rect_to_mobject(element, style))
elif element.tagName == 'circle': elif element.tagName == 'circle':
result.append(self.circle_to_mobject(element)) result.append(self.circle_to_mobject(element, style))
elif element.tagName == 'ellipse': elif element.tagName == 'ellipse':
result.append(self.ellipse_to_mobject(element)) result.append(self.ellipse_to_mobject(element, style))
elif element.tagName in ['polygon', 'polyline']: elif element.tagName in ['polygon', 'polyline']:
result.append(self.polygon_to_mobject(element)) result.append(self.polygon_to_mobject(element, style))
else: else:
log.warning(f"Unsupported element type: {element.tagName}") log.warning(f"Unsupported element type: {element.tagName}")
pass # TODO pass # TODO
result = [m for m in result if m is not None] result = [m.insert_n_curves(0) for m in result if m is not None]
self.handle_transforms(element, VGroup(*result)) self.handle_transforms(element, VGroup(*result))
if len(result) > 1 and not self.unpack_groups: if len(result) > 1 and not self.unpack_groups:
result = [VGroup(*result)] result = [VGroup(*result)]
return result return result
def path_string_to_mobject(self, path_string): def generate_style(self):
style = {
"fill-opacity": self.fill_opacity,
"stroke-width": self.stroke_width,
"stroke-opacity": self.stroke_opacity,
}
if self.color:
style["fill"] = style["stroke"] = self.color
if self.fill_color:
style["fill"] = self.fill_color
if self.stroke_color:
style["stroke"] = self.stroke_color
return style
def path_string_to_mobject(self, path_string, style):
return VMobjectFromSVGPathstring( return VMobjectFromSVGPathstring(
path_string, path_string,
**self.path_string_config, **self.path_string_config,
**parse_style(style),
) )
def use_to_mobjects(self, use_element): def use_to_mobjects(self, use_element, local_style):
# Remove initial "#" character # Remove initial "#" character
ref = use_element.getAttribute("xlink:href")[1:] ref = use_element.getAttribute("xlink:href")[1:]
if ref not in self.ref_to_element: if ref not in self.ref_to_element:
log.warning(f"{ref} not recognized") log.warning(f"{ref} not recognized")
return VGroup() return VGroup()
def_element, def_style = self.ref_to_element[ref]
style = local_style.copy()
style.update(def_style)
return self.get_mobjects_from( return self.get_mobjects_from(
self.ref_to_element[ref] def_element, style
) )
def attribute_to_float(self, attr): def attribute_to_float(self, attr):
@ -131,57 +231,43 @@ class SVGMobject(VMobject):
]) ])
return float(stripped_attr) return float(stripped_attr)
def polygon_to_mobject(self, polygon_element): def polygon_to_mobject(self, polygon_element, style):
path_string = polygon_element.getAttribute("points") path_string = polygon_element.getAttribute("points")
for digit in string.digits: for digit in string.digits:
path_string = path_string.replace(f" {digit}", f"L {digit}") path_string = path_string.replace(f" {digit}", f"L {digit}")
path_string = path_string.replace("L", "M", 1) path_string = path_string.replace("L", "M", 1)
return self.path_string_to_mobject(path_string) return self.path_string_to_mobject(path_string, style)
def circle_to_mobject(self, circle_element): def circle_to_mobject(self, circle_element, style):
x, y, r = [ x, y, r = (
self.attribute_to_float( self.attribute_to_float(
circle_element.getAttribute(key) circle_element.getAttribute(key)
) )
if circle_element.hasAttribute(key) if circle_element.hasAttribute(key)
else 0.0 else 0.0
for key in ("cx", "cy", "r") for key in ("cx", "cy", "r")
] )
return Circle(radius=r).shift(x * RIGHT + y * DOWN) return Circle(radius=r, **parse_style(style)).shift(x * RIGHT + y * DOWN)
def ellipse_to_mobject(self, circle_element): def ellipse_to_mobject(self, circle_element, style):
x, y, rx, ry = [ x, y, rx, ry = (
self.attribute_to_float( self.attribute_to_float(
circle_element.getAttribute(key) circle_element.getAttribute(key)
) )
if circle_element.hasAttribute(key) if circle_element.hasAttribute(key)
else 0.0 else 0.0
for key in ("cx", "cy", "rx", "ry") for key in ("cx", "cy", "rx", "ry")
] )
result = Circle() result = Circle(**parse_style(style))
result.stretch(rx, 0) result.stretch(rx, 0)
result.stretch(ry, 1) result.stretch(ry, 1)
result.shift(x * RIGHT + y * DOWN) result.shift(x * RIGHT + y * DOWN)
return result return result
def rect_to_mobject(self, rect_element): def rect_to_mobject(self, rect_element, style):
fill_color = rect_element.getAttribute("fill")
stroke_color = rect_element.getAttribute("stroke")
stroke_width = rect_element.getAttribute("stroke-width") stroke_width = rect_element.getAttribute("stroke-width")
corner_radius = rect_element.getAttribute("rx") corner_radius = rect_element.getAttribute("rx")
# input preprocessing
fill_opacity = 1
if fill_color in ["", "none", "#FFF", "#FFFFFF"] or Color(fill_color) == Color(WHITE):
fill_opacity = 0
fill_color = BLACK # shdn't be necessary but avoids error msgs
if fill_color in ["#000", "#000000"]:
fill_color = WHITE
if stroke_color in ["", "none", "#FFF", "#FFFFFF"] or Color(stroke_color) == Color(WHITE):
stroke_width = 0
stroke_color = BLACK
if stroke_color in ["#000", "#000000"]:
stroke_color = WHITE
if stroke_width in ["", "none", "0"]: if stroke_width in ["", "none", "0"]:
stroke_width = 0 stroke_width = 0
@ -190,6 +276,9 @@ class SVGMobject(VMobject):
corner_radius = float(corner_radius) corner_radius = float(corner_radius)
parsed_style = parse_style(style)
parsed_style["stroke_width"] = stroke_width
if corner_radius == 0: if corner_radius == 0:
mob = Rectangle( mob = Rectangle(
width=self.attribute_to_float( width=self.attribute_to_float(
@ -198,10 +287,7 @@ class SVGMobject(VMobject):
height=self.attribute_to_float( height=self.attribute_to_float(
rect_element.getAttribute("height") rect_element.getAttribute("height")
), ),
stroke_width=stroke_width, **parsed_style,
stroke_color=stroke_color,
fill_color=fill_color,
fill_opacity=fill_opacity
) )
else: else:
mob = RoundedRectangle( mob = RoundedRectangle(
@ -211,11 +297,8 @@ class SVGMobject(VMobject):
height=self.attribute_to_float( height=self.attribute_to_float(
rect_element.getAttribute("height") rect_element.getAttribute("height")
), ),
stroke_width=stroke_width, corner_radius=corner_radius,
stroke_color=stroke_color, **parsed_style
fill_color=fill_color,
fill_opacity=fill_opacity,
corner_radius=corner_radius
) )
mob.shift(mob.get_center() - mob.get_corner(UP + LEFT)) mob.shift(mob.get_center() - mob.get_corner(UP + LEFT))
@ -330,8 +413,8 @@ class SVGMobject(VMobject):
all_childNodes_have_id.append(self.get_all_childNodes_have_id(e)) all_childNodes_have_id.append(self.get_all_childNodes_have_id(e))
return self.flatten([e for e in all_childNodes_have_id if e]) return self.flatten([e for e in all_childNodes_have_id if e])
def update_ref_to_element(self, defs): def update_ref_to_element(self, defs, style):
new_refs = dict([(e.getAttribute('id'), e) for e in self.get_all_childNodes_have_id(defs)]) new_refs = dict([(e.getAttribute('id'), (e, style)) for e in self.get_all_childNodes_have_id(defs)])
self.ref_to_element.update(new_refs) self.ref_to_element.update(new_refs)

View file

@ -260,7 +260,7 @@ class TexturedSurface(Surface):
super().init_uniforms() super().init_uniforms()
self.uniforms["num_textures"] = self.num_textures self.uniforms["num_textures"] = self.num_textures
def init_colors(self): def init_colors(self, override=True):
self.data["opacity"] = np.array([self.uv_surface.data["rgbas"][:, 3]]) self.data["opacity"] = np.array([self.uv_surface.data["rgbas"][:, 3]])
def set_opacity(self, opacity, recurse=True): def set_opacity(self, opacity, recurse=True):

View file

@ -90,7 +90,7 @@ class VMobject(Mobject):
}) })
# Colors # Colors
def init_colors(self): def init_colors(self, override=True):
self.set_fill( self.set_fill(
color=self.fill_color or self.color, color=self.fill_color or self.color,
opacity=self.fill_opacity, opacity=self.fill_opacity,
@ -103,6 +103,9 @@ class VMobject(Mobject):
) )
self.set_gloss(self.gloss) self.set_gloss(self.gloss)
self.set_flat_stroke(self.flat_stroke) self.set_flat_stroke(self.flat_stroke)
if not override:
for submobjects in self.submobjects:
submobjects.init_colors(override=False)
return self return self
def set_rgba_array(self, rgba_array, name=None, recurse=False): def set_rgba_array(self, rgba_array, name=None, recurse=False):
@ -386,6 +389,7 @@ class VMobject(Mobject):
new_handle = self.get_points()[-1] new_handle = self.get_points()[-1]
else: else:
new_handle = self.get_reflection_of_last_handle() new_handle = self.get_reflection_of_last_handle()
print(new_handle, handle, point)
self.add_cubic_bezier_curve_to(new_handle, handle, point) self.add_cubic_bezier_curve_to(new_handle, handle, point)
def has_new_path_started(self): def has_new_path_started(self):