diff --git a/manimlib/mobject/geometry.py b/manimlib/mobject/geometry.py index da0b288d..c9d820e0 100644 --- a/manimlib/mobject/geometry.py +++ b/manimlib/mobject/geometry.py @@ -608,8 +608,8 @@ class Arrow(Line): self.insert_tip_anchor() return self - def init_colors(self): - super().init_colors() + def init_colors(self, override=True): + super().init_colors(override) self.create_tip_with_stroke_width() def get_arc_length(self): diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index b1621c4f..ce0807a2 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -109,8 +109,8 @@ class Mobject(object): "reflectiveness": self.reflectiveness, } - def init_colors(self): - self.set_color(self.color, self.opacity) + def init_colors(self, override=True): + self.set_color(self.color, self.opacity, override) def init_points(self): # Typically implemented in subclass, unlpess purposefully left blank diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index f4df412f..b7854dd3 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -49,7 +49,23 @@ class _LabelledTex(_PlainTex): if len(color_str) == 4: # "#RGB" => "#RRGGBB" color_str = "#" + "".join([c * 2 for c in color_str[1:]]) - return int(color_str[1:], 16) + + return int(color_str[1:], 16) - 1 + + def get_mobjects_from(self, element, style): + result = super().get_mobjects_from(element, style) + for mob in result: + if not hasattr(mob, "glyph_label"): + mob.glyph_label = -1 + try: + color_str = element.getAttribute("fill") + if color_str: + glyph_label = _LabelledTex.color_str_to_label(color_str) + for mob in result: + mob.glyph_label = glyph_label + except: + pass + return result class _TexSpan(object): diff --git a/manimlib/mobject/svg/svg_mobject.py b/manimlib/mobject/svg/svg_mobject.py index 82afa227..c45e6153 100644 --- a/manimlib/mobject/svg/svg_mobject.py +++ b/manimlib/mobject/svg/svg_mobject.py @@ -4,12 +4,11 @@ import string import os import hashlib +from colour import web2hex from xml.dom import minidom from manimlib.constants import DEFAULT_STROKE_WIDTH from manimlib.constants import ORIGIN, UP, DOWN, LEFT, RIGHT, IN -from manimlib.constants import BLACK -from manimlib.constants import WHITE from manimlib.constants import DEGREES, PI from manimlib.mobject.geometry import Circle @@ -25,14 +24,82 @@ from manimlib.utils.simple_functions import clip from manimlib.logger import log -def string_to_numbers(num_string): - num_string = num_string.replace("-", ",-") - num_string = num_string.replace("e,-", "e-") - return [ - float(s) - for s in re.split("[ ,]", num_string) - if s != "" - ] +DEFAULT_STYLE = { + "fill": "black", + "stroke": "none", + "fill-opacity": "1", + "stroke-opacity": "1", + "stroke-width": 0, +} + + +def cascade_element_style(element, inherited): + style = inherited.copy() + + for attr in DEFAULT_STYLE: + if element.hasAttribute(attr): + style[attr] = element.getAttribute(attr) + + if element.hasAttribute("style"): + for style_spec in element.getAttribute("style").split(";"): + style_spec = style_spec.strip() + try: + key, value = style_spec.split(":") + except ValueError as e: + if not style_spec.strip(): + pass + else: + raise e + else: + style[key.strip()] = value.strip() + + return style + + +def parse_color(color): + color = color.strip() + + if color[0:3] == "rgb": + splits = color[4:-1].strip().split(",") + if splits[0].strip()[-1] == "%": + parsed_rgbs = [float(i.strip()[:-1]) / 100.0 for i in splits] + else: + parsed_rgbs = [int(i) / 255.0 for i in splits] + return rgb_to_hex(parsed_rgbs) + + else: + return web2hex(color) + + +def fill_default_values(style, default_style): + default = DEFAULT_STYLE.copy() + default.update(default_style) + for attr in default: + if attr not in style: + style[attr] = default[attr] + + +def parse_style(style, default_style): + manim_style = {} + fill_default_values(style, default_style) + + manim_style["fill_opacity"] = float(style["fill-opacity"]) + manim_style["stroke_opacity"] = float(style["stroke-opacity"]) + manim_style["stroke_width"] = float(style["stroke-width"]) + + if style["fill"] == "none": + manim_style["fill_opacity"] = 0 + else: + manim_style["fill_color"] = parse_color(style["fill"]) + + if style["stroke"] == "none": + manim_style["stroke_width"] = 0 + if "fill_color" in manim_style: + manim_style["stroke_color"] = manim_style["fill_color"] + else: + manim_style["stroke_color"] = parse_color(style["stroke"]) + + return manim_style class SVGMobject(VMobject): @@ -43,7 +110,6 @@ class SVGMobject(VMobject): # Must be filled in in a subclass, or when called "file_name": None, "unpack_groups": True, # if False, creates a hierarchy of VGroups - # TODO, style components should be read in, not defaulted "stroke_width": DEFAULT_STROKE_WIDTH, "fill_opacity": 1.0, "path_string_config": {} @@ -67,76 +133,96 @@ class SVGMobject(VMobject): if self.width is not None: self.set_width(self.width) + def init_colors(self, override=False): + super().init_colors(override=override) + def init_points(self): doc = minidom.parse(self.file_path) self.ref_to_element = {} for child in doc.childNodes: - if not isinstance(child, minidom.Element): continue - if child.tagName != 'svg': continue - mobjects = self.get_mobjects_from(child) + if not isinstance(child, minidom.Element): + continue + if child.tagName != 'svg': + continue + mobjects = self.get_mobjects_from(child, dict()) if self.unpack_groups: self.add(*mobjects) else: self.add(*mobjects[0].submobjects) doc.unlink() - def get_mobjects_from(self, element): + def get_mobjects_from(self, element, style): result = [] if not isinstance(element, minidom.Element): return result + style = cascade_element_style(element, style) + if element.tagName == 'defs': - self.update_ref_to_element(element) + self.update_ref_to_element(element, style) elif element.tagName == 'style': pass # TODO, handle style elif element.tagName in ['g', 'svg', 'symbol']: result += it.chain(*( - self.get_mobjects_from(child) + self.get_mobjects_from(child, style) for child in element.childNodes )) elif element.tagName == 'path': result.append(self.path_string_to_mobject( - element.getAttribute('d') + element.getAttribute('d'), style )) elif element.tagName == 'use': - result += self.use_to_mobjects(element) + result += self.use_to_mobjects(element, style) elif element.tagName == 'rect': - result.append(self.rect_to_mobject(element)) + result.append(self.rect_to_mobject(element, style)) elif element.tagName == 'circle': - result.append(self.circle_to_mobject(element)) + result.append(self.circle_to_mobject(element, style)) elif element.tagName == 'ellipse': - result.append(self.ellipse_to_mobject(element)) + result.append(self.ellipse_to_mobject(element, style)) elif element.tagName in ['polygon', 'polyline']: - result.append(self.polygon_to_mobject(element)) + result.append(self.polygon_to_mobject(element, style)) else: log.warning(f"Unsupported element type: {element.tagName}") pass # TODO - result = [m for m in result if m is not None] + result = [m.insert_n_curves(0) for m in result if m is not None] self.handle_transforms(element, VGroup(*result)) if len(result) > 1 and not self.unpack_groups: result = [VGroup(*result)] return result - def g_to_mobjects(self, g_element): - mob = VGroup(*self.get_mobjects_from(g_element)) - self.handle_transforms(g_element, mob) - return mob.submobjects + def generate_default_style(self): + style = { + "fill-opacity": self.fill_opacity, + "stroke-width": self.stroke_width, + "stroke-opacity": self.stroke_opacity, + } + if self.color: + style["fill"] = style["stroke"] = self.color + if self.fill_color: + style["fill"] = self.fill_color + if self.stroke_color: + style["stroke"] = self.stroke_color + return style - def path_string_to_mobject(self, path_string): + def path_string_to_mobject(self, path_string, style): return VMobjectFromSVGPathstring( path_string, **self.path_string_config, + **parse_style(style, self.generate_default_style()), ) - def use_to_mobjects(self, use_element): + def use_to_mobjects(self, use_element, local_style): # Remove initial "#" character ref = use_element.getAttribute("xlink:href")[1:] if ref not in self.ref_to_element: log.warning(f"{ref} not recognized") return VGroup() + def_element, def_style = self.ref_to_element[ref] + style = local_style.copy() + style.update(def_style) return self.get_mobjects_from( - self.ref_to_element[ref] + def_element, style ) def attribute_to_float(self, attr): @@ -146,57 +232,46 @@ class SVGMobject(VMobject): ]) return float(stripped_attr) - def polygon_to_mobject(self, polygon_element): + def polygon_to_mobject(self, polygon_element, style): path_string = polygon_element.getAttribute("points") for digit in string.digits: path_string = path_string.replace(f" {digit}", f"L {digit}") path_string = path_string.replace("L", "M", 1) - return self.path_string_to_mobject(path_string) + return self.path_string_to_mobject(path_string, style) - def circle_to_mobject(self, circle_element): - x, y, r = [ + def circle_to_mobject(self, circle_element, style): + x, y, r = ( self.attribute_to_float( circle_element.getAttribute(key) ) if circle_element.hasAttribute(key) else 0.0 for key in ("cx", "cy", "r") - ] - return Circle(radius=r).shift(x * RIGHT + y * DOWN) + ) + return Circle( + radius=r, + **parse_style(style, self.generate_default_style()) + ).shift(x * RIGHT + y * DOWN) - def ellipse_to_mobject(self, circle_element): - x, y, rx, ry = [ + def ellipse_to_mobject(self, circle_element, style): + x, y, rx, ry = ( self.attribute_to_float( circle_element.getAttribute(key) ) if circle_element.hasAttribute(key) else 0.0 for key in ("cx", "cy", "rx", "ry") - ] - result = Circle() + ) + result = Circle(**parse_style(style, self.generate_default_style())) result.stretch(rx, 0) result.stretch(ry, 1) result.shift(x * RIGHT + y * DOWN) return result - def rect_to_mobject(self, rect_element): - fill_color = rect_element.getAttribute("fill") - stroke_color = rect_element.getAttribute("stroke") + def rect_to_mobject(self, rect_element, style): stroke_width = rect_element.getAttribute("stroke-width") corner_radius = rect_element.getAttribute("rx") - # input preprocessing - fill_opacity = 1 - if fill_color in ["", "none", "#FFF", "#FFFFFF"] or Color(fill_color) == Color(WHITE): - fill_opacity = 0 - fill_color = BLACK # shdn't be necessary but avoids error msgs - if fill_color in ["#000", "#000000"]: - fill_color = WHITE - if stroke_color in ["", "none", "#FFF", "#FFFFFF"] or Color(stroke_color) == Color(WHITE): - stroke_width = 0 - stroke_color = BLACK - if stroke_color in ["#000", "#000000"]: - stroke_color = WHITE if stroke_width in ["", "none", "0"]: stroke_width = 0 @@ -205,6 +280,9 @@ class SVGMobject(VMobject): corner_radius = float(corner_radius) + parsed_style = parse_style(style, self.generate_default_style()) + parsed_style["stroke_width"] = stroke_width + if corner_radius == 0: mob = Rectangle( width=self.attribute_to_float( @@ -213,10 +291,7 @@ class SVGMobject(VMobject): height=self.attribute_to_float( rect_element.getAttribute("height") ), - stroke_width=stroke_width, - stroke_color=stroke_color, - fill_color=fill_color, - fill_opacity=fill_opacity + **parsed_style, ) else: mob = RoundedRectangle( @@ -226,11 +301,8 @@ class SVGMobject(VMobject): height=self.attribute_to_float( rect_element.getAttribute("height") ), - stroke_width=stroke_width, - stroke_color=stroke_color, - fill_color=fill_color, - fill_opacity=fill_opacity, - corner_radius=corner_radius + corner_radius=corner_radius, + **parsed_style ) mob.shift(mob.get_center() - mob.get_corner(UP + LEFT)) @@ -246,10 +318,10 @@ class SVGMobject(VMobject): mobject.shift(x * RIGHT + y * DOWN) transform_names = [ - "matrix", - "translate", "translateX", "translateY", - "scale", "scaleX", "scaleY", - "rotate", + "matrix", + "translate", "translateX", "translateY", + "scale", "scaleX", "scaleY", + "rotate", "skewX", "skewY" ] transform_pattern = re.compile("|".join([x + r"[^)]*\)" for x in transform_names])) @@ -260,7 +332,7 @@ class SVGMobject(VMobject): op_name, op_args = transform.split("(") op_name = op_name.strip() op_args = [float(x) for x in number_pattern.findall(op_args)] - + if op_name == "matrix": self._handle_matrix_transform(mobject, op_name, op_args) elif op_name.startswith("translate"): @@ -271,7 +343,7 @@ class SVGMobject(VMobject): self._handle_rotate_transform(mobject, op_name, op_args) elif op_name.startswith("skew"): self._handle_skew_transform(mobject, op_name, op_args) - + def _handle_matrix_transform(self, mobject, op_name, op_args): transform = np.array(op_args).reshape([3, 2]) x = transform[2][0] @@ -292,31 +364,31 @@ class SVGMobject(VMobject): else: x, y = op_args mobject.shift(x * RIGHT + y * DOWN) - + def _handle_scale_transform(self, mobject, op_name, op_args): if op_name.endswith("X"): sx, sy = op_args[0], 1 elif op_name.endswith("Y"): sx, sy = 1, op_args[0] elif len(op_args) == 2: - sx, sy = op_args + sx, sy = op_args else: sx = sy = op_args[0] if sx < 0: mobject.flip(UP) - sx = -sx + sx = -sx if sy < 0: mobject.flip(RIGHT) sy = -sy mobject.scale(np.array([sx, sy, 1]), about_point=ORIGIN) - + def _handle_rotate_transform(self, mobject, op_name, op_args): if len(op_args) == 1: mobject.rotate(op_args[0] * DEGREES, axis=IN, about_point=ORIGIN) else: - deg, x, y = op_args + deg, x, y = op_args mobject.rotate(deg * DEGREES, axis=IN, about_point=np.array([x, y, 0])) - + def _handle_skew_transform(self, mobject, op_name, op_args): rad = op_args[0] * DEGREES if op_name == "skewX": @@ -345,8 +417,8 @@ class SVGMobject(VMobject): all_childNodes_have_id.append(self.get_all_childNodes_have_id(e)) return self.flatten([e for e in all_childNodes_have_id if e]) - def update_ref_to_element(self, defs): - new_refs = dict([(e.getAttribute('id'), e) for e in self.get_all_childNodes_have_id(defs)]) + def update_ref_to_element(self, defs, style): + new_refs = dict([(e.getAttribute('id'), (e, style)) for e in self.get_all_childNodes_have_id(defs)]) self.ref_to_element.update(new_refs) @@ -404,6 +476,7 @@ class VMobjectFromSVGPathstring(VMobject): upper_command = command.upper() if upper_command == "Z": func() # `close_path` takes no arguments + relative_point = self.get_last_point() continue number_types = np.array(list(number_types_str)) @@ -411,7 +484,7 @@ class VMobjectFromSVGPathstring(VMobject): number_list = _PathStringParser(coord_string, number_types_str).args number_groups = np.array(number_list).reshape((-1, n_numbers)) - for numbers in number_groups: + for ind, numbers in enumerate(number_groups): if command.islower(): # Treat it as a relative command numbers[number_types == "x"] += relative_point[0] @@ -427,10 +500,12 @@ class VMobjectFromSVGPathstring(VMobject): args = list(np.hstack(( numbers.reshape((-1, 2)), np.zeros((n_numbers // 2, 1)) ))) + if upper_command == "M" and ind != 0: + # M x1 y1 x2 y2 is equal to M x1 y1 L x2 y2 + func, _ = self.command_to_function("L") func(*args) relative_point = self.get_last_point() - def add_elliptical_arc_to(self, rx, ry, x_axis_rotation, large_arc_flag, sweep_flag, point): def close_to_zero(a, threshold=1e-5): return abs(a) < threshold @@ -572,7 +647,7 @@ class _PathStringParser: while arguments: for rule in rules: self._rule_to_function_map[rule](arguments) - + @property def _rule_to_function_map(self): return { @@ -605,7 +680,7 @@ class _PathStringParser: if number < 0: raise InvalidPathError(f"Expected an unsigned number, got '{number}'") return number - + def _get_flag(self, arg_array): flag = arg_array[0] if flag != 48 and flag != 49: diff --git a/manimlib/mobject/svg/tex_mobject.py b/manimlib/mobject/svg/tex_mobject.py index 6f1054ca..c0774f63 100644 --- a/manimlib/mobject/svg/tex_mobject.py +++ b/manimlib/mobject/svg/tex_mobject.py @@ -16,7 +16,7 @@ from manimlib.utils.tex_file_writing import display_during_execution SCALE_FACTOR_PER_FONT_POINT = 0.001 -tex_string_to_mob_map = {} +tex_string_with_color_to_mob_map = {} class SingleStringTex(VMobject): @@ -35,24 +35,26 @@ class SingleStringTex(VMobject): super().__init__(**kwargs) assert(isinstance(tex_string, str)) self.tex_string = tex_string - if tex_string not in tex_string_to_mob_map: + if tex_string not in tex_string_with_color_to_mob_map: with display_during_execution(f" Writing \"{tex_string}\""): full_tex = self.get_tex_file_body(tex_string) filename = tex_to_svg_file(full_tex) svg_mob = SVGMobject( filename, height=None, + color=self.color, + stroke_width=self.stroke_width, path_string_config={ "should_subdivide_sharp_curves": True, "should_remove_null_curves": True, } ) - tex_string_to_mob_map[tex_string] = svg_mob + tex_string_with_color_to_mob_map[(self.color, tex_string)] = svg_mob self.add(*( sm.copy() - for sm in tex_string_to_mob_map[tex_string] + for sm in tex_string_with_color_to_mob_map[(self.color, tex_string)] )) - self.init_colors() + self.init_colors(override=False) if self.height is None: self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index cf78d111..ae0d69a9 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -3,7 +3,6 @@ import os import re import io import typing -import warnings import xml.etree.ElementTree as ET import functools import pygments @@ -14,6 +13,7 @@ from contextlib import contextmanager from pathlib import Path import manimpango +from manimlib.logger import log from manimlib.constants import * from manimlib.mobject.geometry import Dot from manimlib.mobject.svg.svg_mobject import SVGMobject @@ -54,10 +54,9 @@ class Text(SVGMobject): self.full2short(kwargs) digest_config(self, kwargs) if self.size: - warnings.warn( - "self.size has been deprecated and will " + log.warning( + "`self.size` has been deprecated and will " "be removed in future.", - DeprecationWarning ) self.font_size = self.size if self.lsh == -1: @@ -86,6 +85,9 @@ class Text(SVGMobject): if self.height is None: self.scale(TEXT_MOB_SCALE_FACTOR) + def init_colors(self, override=True): + super().init_colors(override=override) + def remove_empty_path(self, file_name): with open(file_name, 'r') as fpr: content = fpr.read() diff --git a/manimlib/mobject/types/surface.py b/manimlib/mobject/types/surface.py index 1160c1ae..c3fb6b77 100644 --- a/manimlib/mobject/types/surface.py +++ b/manimlib/mobject/types/surface.py @@ -260,7 +260,7 @@ class TexturedSurface(Surface): super().init_uniforms() self.uniforms["num_textures"] = self.num_textures - def init_colors(self): + def init_colors(self, override=True): self.data["opacity"] = np.array([self.uv_surface.data["rgbas"][:, 3]]) def set_opacity(self, opacity, recurse=True): diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 3c7a4326..ea8d650f 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -90,7 +90,7 @@ class VMobject(Mobject): }) # Colors - def init_colors(self): + def init_colors(self, override=True): self.set_fill( color=self.fill_color or self.color, opacity=self.fill_opacity, @@ -103,6 +103,9 @@ class VMobject(Mobject): ) self.set_gloss(self.gloss) self.set_flat_stroke(self.flat_stroke) + if not override: + for submobjects in self.submobjects: + submobjects.init_colors(override=False) return self def set_rgba_array(self, rgba_array, name=None, recurse=False):