3b1b-manim/mobject/svg_mobject.py

336 lines
12 KiB
Python
Raw Normal View History

2016-04-17 00:31:38 -07:00
from xml.dom import minidom
import itertools as it
import re
2016-04-17 00:31:38 -07:00
import warnings
import string
2016-04-17 00:31:38 -07:00
from constants import *
from vectorized_mobject import VMobject, VGroup
2016-04-17 00:31:38 -07:00
from topics.geometry import Rectangle, Circle
from utils.bezier import is_closed
from utils.config_ops import digest_config, digest_locals
2016-04-17 00:31:38 -07:00
def string_to_numbers(num_string):
num_string = num_string.replace("-",",-")
num_string = num_string.replace("e,-","e-")
return [
float(s)
for s in re.split("[ ,]", num_string)
if s != ""
]
2016-04-17 00:31:38 -07:00
class SVGMobject(VMobject):
2016-04-23 23:36:05 -07:00
CONFIG = {
"should_center" : True,
2017-06-20 14:05:48 -07:00
"height" : 2,
"width" : None,
#Must be filled in in a subclass, or when called
"file_name" : None,
"unpack_groups" : True, # if False, creates a hierarchy of VGroups
2017-07-06 10:45:50 -07:00
"stroke_width" : 0,
"fill_opacity" : 1,
# "fill_color" : LIGHT_GREY,
2018-01-15 18:16:50 -08:00
"propagate_style_to_family" : True,
2016-04-23 23:36:05 -07:00
}
def __init__(self, **kwargs):
2016-04-17 00:31:38 -07:00
digest_config(self, kwargs, locals())
2016-07-12 10:34:35 -07:00
self.ensure_valid_file()
2016-04-17 00:31:38 -07:00
VMobject.__init__(self, **kwargs)
2016-04-17 19:29:27 -07:00
self.move_into_position()
2016-04-17 00:31:38 -07:00
2016-07-12 10:34:35 -07:00
def ensure_valid_file(self):
if self.file_name is None:
raise Exception("Must specify file for SVGMobject")
2016-07-12 10:34:35 -07:00
possible_paths = [
os.path.join(SVG_IMAGE_DIR, self.file_name),
os.path.join(SVG_IMAGE_DIR, self.file_name + ".svg"),
self.file_name,
2016-07-12 10:34:35 -07:00
]
for path in possible_paths:
if os.path.exists(path):
2016-12-07 18:37:56 -08:00
self.file_path = path
2016-07-12 10:34:35 -07:00
return
raise IOError("No file matching %s in image directory"%self.file_name)
2016-07-12 10:34:35 -07:00
2016-04-17 00:31:38 -07:00
def generate_points(self):
2016-12-07 18:37:56 -08:00
doc = minidom.parse(self.file_path)
2016-04-17 19:29:27 -07:00
self.ref_to_element = {}
for svg in doc.getElementsByTagName("svg"):
mobjects = self.get_mobjects_from(svg)
if self.unpack_groups: self.add(*mobjects)
else: self.add(*mobjects[0].submobjects)
2016-04-17 00:31:38 -07:00
doc.unlink()
2016-04-17 19:29:27 -07:00
def get_mobjects_from(self, element):
result = []
if not isinstance(element, minidom.Element):
return result
if element.tagName == 'defs':
self.update_ref_to_element(element)
elif element.tagName == 'style':
pass #TODO, handle style
elif element.tagName in ['g', 'svg']:
result += it.chain(*[
self.get_mobjects_from(child)
for child in element.childNodes
])
elif element.tagName == 'path':
2016-04-20 19:24:54 -07:00
result.append(self.path_string_to_mobject(
element.getAttribute('d')
))
2016-04-17 19:29:27 -07:00
elif element.tagName == 'use':
result += self.use_to_mobjects(element)
elif element.tagName == 'rect':
result.append(self.rect_to_mobject(element))
elif element.tagName == 'circle':
result.append(self.circle_to_mobject(element))
2018-01-28 14:55:17 +01:00
elif element.tagName == 'ellipse':
result.append(self.ellipse_to_mobject(element))
elif element.tagName in ['polygon', 'polyline']:
2017-01-17 17:14:32 -08:00
result.append(self.polygon_to_mobject(element))
2016-04-17 19:29:27 -07:00
else:
2017-06-05 12:47:03 -07:00
pass ##TODO
# warnings.warn("Unknown element type: " + element.tagName)
2016-04-17 19:29:27 -07:00
result = filter(lambda m : m is not None, result)
self.handle_transforms(element, VMobject(*result))
if len(result) > 1 and not self.unpack_groups:
result = [VGroup(*result)]
2016-04-17 19:29:27 -07:00
return result
def g_to_mobjects(self, g_element):
mob = VMobject(*self.get_mobjects_from(g_element))
self.handle_transforms(g_element, mob)
return mob.submobjects
2016-04-20 19:24:54 -07:00
def path_string_to_mobject(self, path_string):
return VMobjectFromSVGPathstring(path_string)
2016-04-17 19:29:27 -07:00
def use_to_mobjects(self, use_element):
2016-04-17 00:31:38 -07:00
#Remove initial "#" character
ref = use_element.getAttribute("xlink:href")[1:]
2016-04-20 19:24:54 -07:00
if ref not in self.ref_to_element:
2016-04-17 00:31:38 -07:00
warnings.warn("%s not recognized"%ref)
2016-04-20 19:24:54 -07:00
return VMobject()
return self.get_mobjects_from(
self.ref_to_element[ref]
)
2016-04-17 19:29:27 -07:00
2017-01-17 17:14:32 -08:00
def polygon_to_mobject(self, polygon_element):
#TODO, This seems hacky...
path_string = polygon_element.getAttribute("points")
for digit in string.digits:
path_string = path_string.replace(" " + digit, " L" + digit)
path_string = "M" + path_string
return self.path_string_to_mobject(path_string)
2016-04-17 19:29:27 -07:00
# <circle class="st1" cx="143.8" cy="268" r="22.6"/>
2016-04-17 00:31:38 -07:00
def circle_to_mobject(self, circle_element):
2016-04-17 19:29:27 -07:00
x, y, r = [
float(circle_element.getAttribute(key))
if circle_element.hasAttribute(key)
else 0.0
for key in "cx", "cy", "r"
]
return Circle(radius = r).shift(x*RIGHT+y*DOWN)
2016-04-17 00:31:38 -07:00
2018-01-28 14:55:17 +01:00
def ellipse_to_mobject(self, circle_element):
x, y, rx, ry = [
float(circle_element.getAttribute(key))
if circle_element.hasAttribute(key)
else 0.0
for key in "cx", "cy", "rx", "ry"
]
return Circle().scale(rx*RIGHT + ry*UP).shift(x*RIGHT+y*DOWN)
2016-04-17 00:31:38 -07:00
def rect_to_mobject(self, rect_element):
if rect_element.hasAttribute("fill"):
if Color(str(rect_element.getAttribute("fill"))) == Color(WHITE):
return
mob = Rectangle(
width = float(rect_element.getAttribute("width")),
height = float(rect_element.getAttribute("height")),
stroke_width = 0,
fill_color = WHITE,
fill_opacity = 1.0
)
2016-04-23 23:36:05 -07:00
mob.shift(mob.get_center()-mob.get_corner(UP+LEFT))
2016-04-17 00:31:38 -07:00
return mob
2016-04-17 19:29:27 -07:00
def handle_transforms(self, element, mobject):
2016-04-17 00:31:38 -07:00
x, y = 0, 0
2016-04-17 19:29:27 -07:00
try:
2016-04-17 00:31:38 -07:00
x = float(element.getAttribute('x'))
#Flip y
y = -float(element.getAttribute('y'))
mobject.shift(x*RIGHT+y*UP)
2016-04-17 19:29:27 -07:00
except:
pass
transform = element.getAttribute('transform')
try: # transform matrix
prefix = "matrix("
suffix = ")"
if not transform.startswith(prefix) or not transform.endswith(suffix): raise Exception()
transform = transform[len(prefix):-len(suffix)]
transform = string_to_numbers(transform)
transform = np.array(transform).reshape([3,2])
x = transform[2][0]
y = -transform[2][1]
matrix = np.identity(self.dim)
matrix[:2,:2] = transform[:2,:]
matrix[1] *= -1
matrix[:,1] *= -1
for mob in mobject.family_members_with_points():
mob.points = np.dot(mob.points, matrix)
mobject.shift(x*RIGHT+y*UP)
except:
pass
try: # transform scale
prefix = "scale("
suffix = ")"
if not transform.startswith(prefix) or not transform.endswith(suffix): raise Exception()
transform = transform[len(prefix):-len(suffix)]
scale_x, scale_y = string_to_numbers(transform)
mobject.scale(np.array([scale_x, scale_y, 1]))
except:
pass
try: # transform translate
prefix = "translate("
suffix = ")"
if not transform.startswith(prefix) or not transform.endswith(suffix): raise Exception()
transform = transform[len(prefix):-len(suffix)]
x, y = string_to_numbers(transform)
mobject.shift(x*RIGHT + y*DOWN)
except:
pass
#TODO, ...
2016-04-17 00:31:38 -07:00
2016-04-17 19:29:27 -07:00
def update_ref_to_element(self, defs):
new_refs = dict([
(element.getAttribute('id'), element)
for element in defs.childNodes
if isinstance(element, minidom.Element) and element.hasAttribute('id')
])
self.ref_to_element.update(new_refs)
2016-04-17 00:31:38 -07:00
def move_into_position(self):
2016-04-23 23:36:05 -07:00
if self.should_center:
self.center()
2017-06-20 14:05:48 -07:00
if self.height is not None:
self.scale_to_fit_height(self.height)
if self.width is not None:
self.scale_to_fit_width(self.width)
2016-04-23 23:36:05 -07:00
2016-04-17 00:31:38 -07:00
class VMobjectFromSVGPathstring(VMobject):
def __init__(self, path_string, **kwargs):
digest_locals(self)
VMobject.__init__(self, **kwargs)
def get_path_commands(self):
2016-04-17 19:29:27 -07:00
result = [
2016-04-17 00:31:38 -07:00
"M", #moveto
"L", #lineto
"H", #horizontal lineto
"V", #vertical lineto
"C", #curveto
"S", #smooth curveto
"Q", #quadratic Bezier curve
"T", #smooth quadratic Bezier curveto
"A", #elliptical Arc
"Z", #closepath
]
2016-04-17 19:29:27 -07:00
result += map(lambda s : s.lower(), result)
return result
2016-04-17 00:31:38 -07:00
def generate_points(self):
pattern = "[%s]"%("".join(self.get_path_commands()))
pairs = zip(
re.findall(pattern, self.path_string),
re.split(pattern, self.path_string)[1:]
)
#Which mobject should new points be added to
self.growing_path = self
for command, coord_string in pairs:
self.handle_command(command, coord_string)
#people treat y-coordinate differently
self.rotate(np.pi, RIGHT, about_point = ORIGIN)
2016-04-17 00:31:38 -07:00
def handle_command(self, command, coord_string):
2016-04-17 19:29:27 -07:00
isLower = command.islower()
command = command.upper()
2016-04-17 00:31:38 -07:00
#new_points are the points that will be added to the curr_points
#list. This variable may get modified in the conditionals below.
points = self.growing_path.points
new_points = self.string_to_points(coord_string)
2018-01-28 14:52:11 +01:00
2016-04-17 00:31:38 -07:00
if command == "M": #moveto
2018-01-28 14:52:11 +01:00
if isLower and len(points) > 0:
new_points[0] += points[-1]
2016-04-17 00:31:38 -07:00
if len(points) > 0:
2018-01-28 14:52:11 +01:00
self.growing_path = self.add_subpath(new_points[:1])
2016-04-17 00:31:38 -07:00
else:
2018-01-28 14:52:11 +01:00
self.growing_path.start_at(new_points[0])
if len(new_points) <= 1: return
points = self.growing_path.points
new_points = new_points[1:]
command = "L"
if isLower and len(points) > 0:
new_points += points[-1]
if command in ["L", "H", "V"]: #lineto
2016-04-17 00:31:38 -07:00
if command == "H":
new_points[0,1] = points[-1,1]
elif command == "V":
if isLower:
new_points[0,0] -= points[-1,0]
new_points[0,0] += points[-1,1]
2016-04-17 00:31:38 -07:00
new_points[0,1] = new_points[0,0]
new_points[0,0] = points[-1,0]
new_points = new_points.repeat(3, axis = 0)
2016-04-17 00:31:38 -07:00
elif command == "C": #curveto
pass #Yay! No action required
elif command in ["S", "T"]: #smooth curveto
handle1 = points[-1]+(points[-1]-points[-2])
new_points = np.append([handle1], new_points, axis = 0)
if command in ["Q", "T"]: #quadratic Bezier curve
#TODO, this is a suboptimal approximation
new_points = np.append([new_points[0]], new_points, axis = 0)
elif command == "A": #elliptical Arc
raise Exception("Not implemented")
elif command == "Z": #closepath
if not is_closed(points):
#Both handles and new anchor are the start
new_points = points[[0, 0, 0]]
2016-04-20 19:24:54 -07:00
# self.mark_paths_closed = True
#Handle situations where there's multiple relative control points
if isLower and len(new_points) > 3:
for i in range(3, len(new_points), 3):
new_points[i:i+3] -= points[-1]
new_points[i:i+3] += new_points[i-1]
2016-04-17 00:31:38 -07:00
self.growing_path.add_control_points(new_points)
def string_to_points(self, coord_string):
numbers = string_to_numbers(coord_string)
2016-04-17 00:31:38 -07:00
if len(numbers)%2 == 1:
numbers.append(0)
num_points = len(numbers)/2
result = np.zeros((num_points, self.dim))
result[:,:2] = np.array(numbers).reshape((num_points, 2))
return result
def get_original_path_string(self):
return self.path_string