Merge branch 'master' into master

This commit is contained in:
Grant Sanderson 2022-01-26 08:54:18 -08:00 committed by GitHub
commit c315300ff1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 197 additions and 99 deletions

View file

@ -608,8 +608,8 @@ class Arrow(Line):
self.insert_tip_anchor() self.insert_tip_anchor()
return self return self
def init_colors(self): def init_colors(self, override=True):
super().init_colors() super().init_colors(override)
self.create_tip_with_stroke_width() self.create_tip_with_stroke_width()
def get_arc_length(self): def get_arc_length(self):

View file

@ -109,8 +109,8 @@ class Mobject(object):
"reflectiveness": self.reflectiveness, "reflectiveness": self.reflectiveness,
} }
def init_colors(self): def init_colors(self, override=True):
self.set_color(self.color, self.opacity) self.set_color(self.color, self.opacity, override)
def init_points(self): def init_points(self):
# Typically implemented in subclass, unlpess purposefully left blank # Typically implemented in subclass, unlpess purposefully left blank

View file

@ -49,7 +49,23 @@ class _LabelledTex(_PlainTex):
if len(color_str) == 4: if len(color_str) == 4:
# "#RGB" => "#RRGGBB" # "#RGB" => "#RRGGBB"
color_str = "#" + "".join([c * 2 for c in color_str[1:]]) color_str = "#" + "".join([c * 2 for c in color_str[1:]])
return int(color_str[1:], 16)
return int(color_str[1:], 16) - 1
def get_mobjects_from(self, element, style):
result = super().get_mobjects_from(element, style)
for mob in result:
if not hasattr(mob, "glyph_label"):
mob.glyph_label = -1
try:
color_str = element.getAttribute("fill")
if color_str:
glyph_label = _LabelledTex.color_str_to_label(color_str)
for mob in result:
mob.glyph_label = glyph_label
except:
pass
return result
class _TexSpan(object): class _TexSpan(object):

View file

@ -4,12 +4,11 @@ import string
import os import os
import hashlib import hashlib
from colour import web2hex
from xml.dom import minidom from xml.dom import minidom
from manimlib.constants import DEFAULT_STROKE_WIDTH from manimlib.constants import DEFAULT_STROKE_WIDTH
from manimlib.constants import ORIGIN, UP, DOWN, LEFT, RIGHT, IN from manimlib.constants import ORIGIN, UP, DOWN, LEFT, RIGHT, IN
from manimlib.constants import BLACK
from manimlib.constants import WHITE
from manimlib.constants import DEGREES, PI from manimlib.constants import DEGREES, PI
from manimlib.mobject.geometry import Circle from manimlib.mobject.geometry import Circle
@ -25,14 +24,82 @@ from manimlib.utils.simple_functions import clip
from manimlib.logger import log from manimlib.logger import log
def string_to_numbers(num_string): DEFAULT_STYLE = {
num_string = num_string.replace("-", ",-") "fill": "black",
num_string = num_string.replace("e,-", "e-") "stroke": "none",
return [ "fill-opacity": "1",
float(s) "stroke-opacity": "1",
for s in re.split("[ ,]", num_string) "stroke-width": 0,
if s != "" }
]
def cascade_element_style(element, inherited):
style = inherited.copy()
for attr in DEFAULT_STYLE:
if element.hasAttribute(attr):
style[attr] = element.getAttribute(attr)
if element.hasAttribute("style"):
for style_spec in element.getAttribute("style").split(";"):
style_spec = style_spec.strip()
try:
key, value = style_spec.split(":")
except ValueError as e:
if not style_spec.strip():
pass
else:
raise e
else:
style[key.strip()] = value.strip()
return style
def parse_color(color):
color = color.strip()
if color[0:3] == "rgb":
splits = color[4:-1].strip().split(",")
if splits[0].strip()[-1] == "%":
parsed_rgbs = [float(i.strip()[:-1]) / 100.0 for i in splits]
else:
parsed_rgbs = [int(i) / 255.0 for i in splits]
return rgb_to_hex(parsed_rgbs)
else:
return web2hex(color)
def fill_default_values(style, default_style):
default = DEFAULT_STYLE.copy()
default.update(default_style)
for attr in default:
if attr not in style:
style[attr] = default[attr]
def parse_style(style, default_style):
manim_style = {}
fill_default_values(style, default_style)
manim_style["fill_opacity"] = float(style["fill-opacity"])
manim_style["stroke_opacity"] = float(style["stroke-opacity"])
manim_style["stroke_width"] = float(style["stroke-width"])
if style["fill"] == "none":
manim_style["fill_opacity"] = 0
else:
manim_style["fill_color"] = parse_color(style["fill"])
if style["stroke"] == "none":
manim_style["stroke_width"] = 0
if "fill_color" in manim_style:
manim_style["stroke_color"] = manim_style["fill_color"]
else:
manim_style["stroke_color"] = parse_color(style["stroke"])
return manim_style
class SVGMobject(VMobject): class SVGMobject(VMobject):
@ -43,7 +110,6 @@ class SVGMobject(VMobject):
# Must be filled in in a subclass, or when called # Must be filled in in a subclass, or when called
"file_name": None, "file_name": None,
"unpack_groups": True, # if False, creates a hierarchy of VGroups "unpack_groups": True, # if False, creates a hierarchy of VGroups
# TODO, style components should be read in, not defaulted
"stroke_width": DEFAULT_STROKE_WIDTH, "stroke_width": DEFAULT_STROKE_WIDTH,
"fill_opacity": 1.0, "fill_opacity": 1.0,
"path_string_config": {} "path_string_config": {}
@ -67,76 +133,96 @@ class SVGMobject(VMobject):
if self.width is not None: if self.width is not None:
self.set_width(self.width) self.set_width(self.width)
def init_colors(self, override=False):
super().init_colors(override=override)
def init_points(self): def init_points(self):
doc = minidom.parse(self.file_path) doc = minidom.parse(self.file_path)
self.ref_to_element = {} self.ref_to_element = {}
for child in doc.childNodes: for child in doc.childNodes:
if not isinstance(child, minidom.Element): continue if not isinstance(child, minidom.Element):
if child.tagName != 'svg': continue continue
mobjects = self.get_mobjects_from(child) if child.tagName != 'svg':
continue
mobjects = self.get_mobjects_from(child, dict())
if self.unpack_groups: if self.unpack_groups:
self.add(*mobjects) self.add(*mobjects)
else: else:
self.add(*mobjects[0].submobjects) self.add(*mobjects[0].submobjects)
doc.unlink() doc.unlink()
def get_mobjects_from(self, element): def get_mobjects_from(self, element, style):
result = [] result = []
if not isinstance(element, minidom.Element): if not isinstance(element, minidom.Element):
return result return result
style = cascade_element_style(element, style)
if element.tagName == 'defs': if element.tagName == 'defs':
self.update_ref_to_element(element) self.update_ref_to_element(element, style)
elif element.tagName == 'style': elif element.tagName == 'style':
pass # TODO, handle style pass # TODO, handle style
elif element.tagName in ['g', 'svg', 'symbol']: elif element.tagName in ['g', 'svg', 'symbol']:
result += it.chain(*( result += it.chain(*(
self.get_mobjects_from(child) self.get_mobjects_from(child, style)
for child in element.childNodes for child in element.childNodes
)) ))
elif element.tagName == 'path': elif element.tagName == 'path':
result.append(self.path_string_to_mobject( result.append(self.path_string_to_mobject(
element.getAttribute('d') element.getAttribute('d'), style
)) ))
elif element.tagName == 'use': elif element.tagName == 'use':
result += self.use_to_mobjects(element) result += self.use_to_mobjects(element, style)
elif element.tagName == 'rect': elif element.tagName == 'rect':
result.append(self.rect_to_mobject(element)) result.append(self.rect_to_mobject(element, style))
elif element.tagName == 'circle': elif element.tagName == 'circle':
result.append(self.circle_to_mobject(element)) result.append(self.circle_to_mobject(element, style))
elif element.tagName == 'ellipse': elif element.tagName == 'ellipse':
result.append(self.ellipse_to_mobject(element)) result.append(self.ellipse_to_mobject(element, style))
elif element.tagName in ['polygon', 'polyline']: elif element.tagName in ['polygon', 'polyline']:
result.append(self.polygon_to_mobject(element)) result.append(self.polygon_to_mobject(element, style))
else: else:
log.warning(f"Unsupported element type: {element.tagName}") log.warning(f"Unsupported element type: {element.tagName}")
pass # TODO pass # TODO
result = [m for m in result if m is not None] result = [m.insert_n_curves(0) for m in result if m is not None]
self.handle_transforms(element, VGroup(*result)) self.handle_transforms(element, VGroup(*result))
if len(result) > 1 and not self.unpack_groups: if len(result) > 1 and not self.unpack_groups:
result = [VGroup(*result)] result = [VGroup(*result)]
return result return result
def g_to_mobjects(self, g_element): def generate_default_style(self):
mob = VGroup(*self.get_mobjects_from(g_element)) style = {
self.handle_transforms(g_element, mob) "fill-opacity": self.fill_opacity,
return mob.submobjects "stroke-width": self.stroke_width,
"stroke-opacity": self.stroke_opacity,
}
if self.color:
style["fill"] = style["stroke"] = self.color
if self.fill_color:
style["fill"] = self.fill_color
if self.stroke_color:
style["stroke"] = self.stroke_color
return style
def path_string_to_mobject(self, path_string): def path_string_to_mobject(self, path_string, style):
return VMobjectFromSVGPathstring( return VMobjectFromSVGPathstring(
path_string, path_string,
**self.path_string_config, **self.path_string_config,
**parse_style(style, self.generate_default_style()),
) )
def use_to_mobjects(self, use_element): def use_to_mobjects(self, use_element, local_style):
# Remove initial "#" character # Remove initial "#" character
ref = use_element.getAttribute("xlink:href")[1:] ref = use_element.getAttribute("xlink:href")[1:]
if ref not in self.ref_to_element: if ref not in self.ref_to_element:
log.warning(f"{ref} not recognized") log.warning(f"{ref} not recognized")
return VGroup() return VGroup()
def_element, def_style = self.ref_to_element[ref]
style = local_style.copy()
style.update(def_style)
return self.get_mobjects_from( return self.get_mobjects_from(
self.ref_to_element[ref] def_element, style
) )
def attribute_to_float(self, attr): def attribute_to_float(self, attr):
@ -146,57 +232,46 @@ class SVGMobject(VMobject):
]) ])
return float(stripped_attr) return float(stripped_attr)
def polygon_to_mobject(self, polygon_element): def polygon_to_mobject(self, polygon_element, style):
path_string = polygon_element.getAttribute("points") path_string = polygon_element.getAttribute("points")
for digit in string.digits: for digit in string.digits:
path_string = path_string.replace(f" {digit}", f"L {digit}") path_string = path_string.replace(f" {digit}", f"L {digit}")
path_string = path_string.replace("L", "M", 1) path_string = path_string.replace("L", "M", 1)
return self.path_string_to_mobject(path_string) return self.path_string_to_mobject(path_string, style)
def circle_to_mobject(self, circle_element): def circle_to_mobject(self, circle_element, style):
x, y, r = [ x, y, r = (
self.attribute_to_float( self.attribute_to_float(
circle_element.getAttribute(key) circle_element.getAttribute(key)
) )
if circle_element.hasAttribute(key) if circle_element.hasAttribute(key)
else 0.0 else 0.0
for key in ("cx", "cy", "r") for key in ("cx", "cy", "r")
] )
return Circle(radius=r).shift(x * RIGHT + y * DOWN) return Circle(
radius=r,
**parse_style(style, self.generate_default_style())
).shift(x * RIGHT + y * DOWN)
def ellipse_to_mobject(self, circle_element): def ellipse_to_mobject(self, circle_element, style):
x, y, rx, ry = [ x, y, rx, ry = (
self.attribute_to_float( self.attribute_to_float(
circle_element.getAttribute(key) circle_element.getAttribute(key)
) )
if circle_element.hasAttribute(key) if circle_element.hasAttribute(key)
else 0.0 else 0.0
for key in ("cx", "cy", "rx", "ry") for key in ("cx", "cy", "rx", "ry")
] )
result = Circle() result = Circle(**parse_style(style, self.generate_default_style()))
result.stretch(rx, 0) result.stretch(rx, 0)
result.stretch(ry, 1) result.stretch(ry, 1)
result.shift(x * RIGHT + y * DOWN) result.shift(x * RIGHT + y * DOWN)
return result return result
def rect_to_mobject(self, rect_element): def rect_to_mobject(self, rect_element, style):
fill_color = rect_element.getAttribute("fill")
stroke_color = rect_element.getAttribute("stroke")
stroke_width = rect_element.getAttribute("stroke-width") stroke_width = rect_element.getAttribute("stroke-width")
corner_radius = rect_element.getAttribute("rx") corner_radius = rect_element.getAttribute("rx")
# input preprocessing
fill_opacity = 1
if fill_color in ["", "none", "#FFF", "#FFFFFF"] or Color(fill_color) == Color(WHITE):
fill_opacity = 0
fill_color = BLACK # shdn't be necessary but avoids error msgs
if fill_color in ["#000", "#000000"]:
fill_color = WHITE
if stroke_color in ["", "none", "#FFF", "#FFFFFF"] or Color(stroke_color) == Color(WHITE):
stroke_width = 0
stroke_color = BLACK
if stroke_color in ["#000", "#000000"]:
stroke_color = WHITE
if stroke_width in ["", "none", "0"]: if stroke_width in ["", "none", "0"]:
stroke_width = 0 stroke_width = 0
@ -205,6 +280,9 @@ class SVGMobject(VMobject):
corner_radius = float(corner_radius) corner_radius = float(corner_radius)
parsed_style = parse_style(style, self.generate_default_style())
parsed_style["stroke_width"] = stroke_width
if corner_radius == 0: if corner_radius == 0:
mob = Rectangle( mob = Rectangle(
width=self.attribute_to_float( width=self.attribute_to_float(
@ -213,10 +291,7 @@ class SVGMobject(VMobject):
height=self.attribute_to_float( height=self.attribute_to_float(
rect_element.getAttribute("height") rect_element.getAttribute("height")
), ),
stroke_width=stroke_width, **parsed_style,
stroke_color=stroke_color,
fill_color=fill_color,
fill_opacity=fill_opacity
) )
else: else:
mob = RoundedRectangle( mob = RoundedRectangle(
@ -226,11 +301,8 @@ class SVGMobject(VMobject):
height=self.attribute_to_float( height=self.attribute_to_float(
rect_element.getAttribute("height") rect_element.getAttribute("height")
), ),
stroke_width=stroke_width, corner_radius=corner_radius,
stroke_color=stroke_color, **parsed_style
fill_color=fill_color,
fill_opacity=fill_opacity,
corner_radius=corner_radius
) )
mob.shift(mob.get_center() - mob.get_corner(UP + LEFT)) mob.shift(mob.get_center() - mob.get_corner(UP + LEFT))
@ -345,8 +417,8 @@ class SVGMobject(VMobject):
all_childNodes_have_id.append(self.get_all_childNodes_have_id(e)) all_childNodes_have_id.append(self.get_all_childNodes_have_id(e))
return self.flatten([e for e in all_childNodes_have_id if e]) return self.flatten([e for e in all_childNodes_have_id if e])
def update_ref_to_element(self, defs): def update_ref_to_element(self, defs, style):
new_refs = dict([(e.getAttribute('id'), e) for e in self.get_all_childNodes_have_id(defs)]) new_refs = dict([(e.getAttribute('id'), (e, style)) for e in self.get_all_childNodes_have_id(defs)])
self.ref_to_element.update(new_refs) self.ref_to_element.update(new_refs)
@ -404,6 +476,7 @@ class VMobjectFromSVGPathstring(VMobject):
upper_command = command.upper() upper_command = command.upper()
if upper_command == "Z": if upper_command == "Z":
func() # `close_path` takes no arguments func() # `close_path` takes no arguments
relative_point = self.get_last_point()
continue continue
number_types = np.array(list(number_types_str)) number_types = np.array(list(number_types_str))
@ -411,7 +484,7 @@ class VMobjectFromSVGPathstring(VMobject):
number_list = _PathStringParser(coord_string, number_types_str).args number_list = _PathStringParser(coord_string, number_types_str).args
number_groups = np.array(number_list).reshape((-1, n_numbers)) number_groups = np.array(number_list).reshape((-1, n_numbers))
for numbers in number_groups: for ind, numbers in enumerate(number_groups):
if command.islower(): if command.islower():
# Treat it as a relative command # Treat it as a relative command
numbers[number_types == "x"] += relative_point[0] numbers[number_types == "x"] += relative_point[0]
@ -427,10 +500,12 @@ class VMobjectFromSVGPathstring(VMobject):
args = list(np.hstack(( args = list(np.hstack((
numbers.reshape((-1, 2)), np.zeros((n_numbers // 2, 1)) numbers.reshape((-1, 2)), np.zeros((n_numbers // 2, 1))
))) )))
if upper_command == "M" and ind != 0:
# M x1 y1 x2 y2 is equal to M x1 y1 L x2 y2
func, _ = self.command_to_function("L")
func(*args) func(*args)
relative_point = self.get_last_point() relative_point = self.get_last_point()
def add_elliptical_arc_to(self, rx, ry, x_axis_rotation, large_arc_flag, sweep_flag, point): def add_elliptical_arc_to(self, rx, ry, x_axis_rotation, large_arc_flag, sweep_flag, point):
def close_to_zero(a, threshold=1e-5): def close_to_zero(a, threshold=1e-5):
return abs(a) < threshold return abs(a) < threshold

View file

@ -16,7 +16,7 @@ from manimlib.utils.tex_file_writing import display_during_execution
SCALE_FACTOR_PER_FONT_POINT = 0.001 SCALE_FACTOR_PER_FONT_POINT = 0.001
tex_string_to_mob_map = {} tex_string_with_color_to_mob_map = {}
class SingleStringTex(VMobject): class SingleStringTex(VMobject):
@ -35,24 +35,26 @@ class SingleStringTex(VMobject):
super().__init__(**kwargs) super().__init__(**kwargs)
assert(isinstance(tex_string, str)) assert(isinstance(tex_string, str))
self.tex_string = tex_string self.tex_string = tex_string
if tex_string not in tex_string_to_mob_map: if tex_string not in tex_string_with_color_to_mob_map:
with display_during_execution(f" Writing \"{tex_string}\""): with display_during_execution(f" Writing \"{tex_string}\""):
full_tex = self.get_tex_file_body(tex_string) full_tex = self.get_tex_file_body(tex_string)
filename = tex_to_svg_file(full_tex) filename = tex_to_svg_file(full_tex)
svg_mob = SVGMobject( svg_mob = SVGMobject(
filename, filename,
height=None, height=None,
color=self.color,
stroke_width=self.stroke_width,
path_string_config={ path_string_config={
"should_subdivide_sharp_curves": True, "should_subdivide_sharp_curves": True,
"should_remove_null_curves": True, "should_remove_null_curves": True,
} }
) )
tex_string_to_mob_map[tex_string] = svg_mob tex_string_with_color_to_mob_map[(self.color, tex_string)] = svg_mob
self.add(*( self.add(*(
sm.copy() sm.copy()
for sm in tex_string_to_mob_map[tex_string] for sm in tex_string_with_color_to_mob_map[(self.color, tex_string)]
)) ))
self.init_colors() self.init_colors(override=False)
if self.height is None: if self.height is None:
self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size)

View file

@ -3,7 +3,6 @@ import os
import re import re
import io import io
import typing import typing
import warnings
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import functools import functools
import pygments import pygments
@ -14,6 +13,7 @@ from contextlib import contextmanager
from pathlib import Path from pathlib import Path
import manimpango import manimpango
from manimlib.logger import log
from manimlib.constants import * from manimlib.constants import *
from manimlib.mobject.geometry import Dot from manimlib.mobject.geometry import Dot
from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.svg.svg_mobject import SVGMobject
@ -54,10 +54,9 @@ class Text(SVGMobject):
self.full2short(kwargs) self.full2short(kwargs)
digest_config(self, kwargs) digest_config(self, kwargs)
if self.size: if self.size:
warnings.warn( log.warning(
"self.size has been deprecated and will " "`self.size` has been deprecated and will "
"be removed in future.", "be removed in future.",
DeprecationWarning
) )
self.font_size = self.size self.font_size = self.size
if self.lsh == -1: if self.lsh == -1:
@ -86,6 +85,9 @@ class Text(SVGMobject):
if self.height is None: if self.height is None:
self.scale(TEXT_MOB_SCALE_FACTOR) self.scale(TEXT_MOB_SCALE_FACTOR)
def init_colors(self, override=True):
super().init_colors(override=override)
def remove_empty_path(self, file_name): def remove_empty_path(self, file_name):
with open(file_name, 'r') as fpr: with open(file_name, 'r') as fpr:
content = fpr.read() content = fpr.read()

View file

@ -260,7 +260,7 @@ class TexturedSurface(Surface):
super().init_uniforms() super().init_uniforms()
self.uniforms["num_textures"] = self.num_textures self.uniforms["num_textures"] = self.num_textures
def init_colors(self): def init_colors(self, override=True):
self.data["opacity"] = np.array([self.uv_surface.data["rgbas"][:, 3]]) self.data["opacity"] = np.array([self.uv_surface.data["rgbas"][:, 3]])
def set_opacity(self, opacity, recurse=True): def set_opacity(self, opacity, recurse=True):

View file

@ -90,7 +90,7 @@ class VMobject(Mobject):
}) })
# Colors # Colors
def init_colors(self): def init_colors(self, override=True):
self.set_fill( self.set_fill(
color=self.fill_color or self.color, color=self.fill_color or self.color,
opacity=self.fill_opacity, opacity=self.fill_opacity,
@ -103,6 +103,9 @@ class VMobject(Mobject):
) )
self.set_gloss(self.gloss) self.set_gloss(self.gloss)
self.set_flat_stroke(self.flat_stroke) self.set_flat_stroke(self.flat_stroke)
if not override:
for submobjects in self.submobjects:
submobjects.init_colors(override=False)
return self return self
def set_rgba_array(self, rgba_array, name=None, recurse=False): def set_rgba_array(self, rgba_array, name=None, recurse=False):