Merge pull request #1464 from 3b1b/matrix-exp-development

Matrix exp development
This commit is contained in:
Grant Sanderson 2021-04-08 14:18:25 -07:00 committed by GitHub
commit 7a11e3d20f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 178 additions and 76 deletions

View file

@ -25,6 +25,10 @@ class ShowPartial(Animation):
if not self.should_match_start:
self.mobject.lock_matching_data(self.mobject, self.starting_mobject)
def finish(self):
super().finish()
self.mobject.unlock_data()
def interpolate_submobject(self, submob, start_submob, alpha):
submob.pointwise_become_partial(
start_submob, *self.get_bounds(alpha)

View file

@ -14,6 +14,7 @@ from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.mobject.geometry import Circle
from manimlib.mobject.geometry import Dot
from manimlib.mobject.shape_matchers import SurroundingRectangle
from manimlib.mobject.shape_matchers import Underline
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.geometry import Line
from manimlib.utils.bezier import interpolate
@ -156,45 +157,50 @@ class ShowPassingFlash(ShowPartial):
class VShowPassingFlash(Animation):
CONFIG = {
"time_width": 0.3,
"taper_width": 0.1,
"taper_width": 0.02,
"remover": True,
}
def begin(self):
self.mobject.align_stroke_width_data_to_points()
# Compute an array of stroke widths for each submobject
# which tapers out at either end
self.submob_to_anchor_widths = dict()
for sm in self.mobject.get_family():
original_widths = sm.get_stroke_widths()
anchor_widths = np.array([*original_widths[0::3], original_widths[-1]])
def taper_kernel(x):
if x < self.taper_width:
return x
elif x > 1 - self.taper_width:
return 1.0 - x
return 1.0
taper_array = list(map(taper_kernel, np.linspace(0, 1, len(anchor_widths))))
self.submob_to_anchor_widths[hash(sm)] = anchor_widths * taper_array
super().begin()
def interpolate_submobject(self, submobject, starting_sumobject, alpha):
original_widths = starting_sumobject.get_stroke_widths()
# anchor_widths = np.array([*original_widths[0::3, 0], original_widths[-1, 0]])
anchor_widths = np.array([0, *original_widths[3::3, 0], 0])
n_anchors = len(anchor_widths)
time_width = self.time_width
# taper_width = self.taper_width
anchor_widths = self.submob_to_anchor_widths[hash(submobject)]
# Create a gaussian such that 3 sigmas out on either side
# will equals time_width * (number of points)
sigma = time_width / 6
mu = interpolate(-time_width / 2, 1 + time_width / 2, alpha)
offset = math.exp(-4.5) # 3 sigmas out
# will equals time_width
tw = self.time_width
sigma = tw / 6
mu = interpolate(-tw / 2, 1 + tw / 2, alpha)
def kernel_func(x):
result = math.exp(-0.5 * ((x - mu) / sigma)**2) - offset
result = max(result, 0)
# if x < taper_width:
# result *= x / taper_width
# elif x > 1 - taper_width:
# result *= (1 - x) / taper_width
return result
def gauss_kernel(x):
if abs(x - mu) > 3 * sigma:
return 0
z = (x - mu) / sigma
return math.exp(-0.5 * z * z)
kernel_array = np.array([
kernel_func(n / (n_anchors - 1))
for n in range(n_anchors)
])
kernel_array = list(map(gauss_kernel, np.linspace(0, 1, len(anchor_widths))))
scaled_widths = anchor_widths * kernel_array
new_widths = np.zeros(submobject.get_num_points())
new_widths[0::3] = scaled_widths[:-1]
new_widths[1::3] = (scaled_widths[:-1] + scaled_widths[1:]) / 2
new_widths[2::3] = scaled_widths[1:]
new_widths[1::3] = (new_widths[0::3] + new_widths[2::3]) / 2
submobject.set_stroke(width=new_widths)
def finish(self):
@ -203,6 +209,32 @@ class VShowPassingFlash(Animation):
submob.match_style(start)
class FlashAround(VShowPassingFlash):
CONFIG = {
"stroke_width": 4.0,
"color": YELLOW,
"buff": SMALL_BUFF,
"time_width": 1.0,
"n_inserted_curves": 20,
}
def __init__(self, mobject, **kwargs):
digest_config(self, kwargs)
path = self.get_path(mobject)
path.insert_n_curves(self.n_inserted_curves)
path.set_points(path.get_points_without_null_curves())
path.set_stroke(self.color, self.stroke_width)
super().__init__(path, **kwargs)
def get_path(self, mobject):
return SurroundingRectangle(mobject, buff=self.buff)
class FlashUnder(FlashAround):
def get_path(self, mobject):
return Underline(mobject, buff=self.buff)
class ShowCreationThenDestruction(ShowPassingFlash):
CONFIG = {
"time_width": 2.0,

View file

@ -10,7 +10,6 @@ from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Group
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.utils.config_ops import digest_config
@ -129,7 +128,7 @@ class TransformMatchingShapes(TransformMatchingParts):
class TransformMatchingTex(TransformMatchingParts):
CONFIG = {
"mobject_type": Tex,
"mobject_type": VMobject,
"group_type": VGroup,
}

View file

@ -77,16 +77,24 @@ class CameraFrame(Mobject):
self.set_euler_angles(theta, phi, gamma)
return self
def set_euler_angles(self, theta=None, phi=None, gamma=None):
def set_euler_angles(self, theta=None, phi=None, gamma=None, units=RADIANS):
if theta is not None:
self.data["euler_angles"][0] = theta
self.data["euler_angles"][0] = theta * units
if phi is not None:
self.data["euler_angles"][1] = phi
self.data["euler_angles"][1] = phi * units
if gamma is not None:
self.data["euler_angles"][2] = gamma
self.data["euler_angles"][2] = gamma * units
self.refresh_rotation_matrix()
return self
def reorient(self, theta_degrees=None, phi_degrees=None, gamma_degrees=None):
"""
Shortcut for set_euler_angles, defaulting to taking
in angles in degrees
"""
self.set_euler_angles(theta_degrees, phi_degrees, gamma_degrees, units=DEGREES)
return self
def set_theta(self, theta):
return self.set_euler_angles(theta=theta)

View file

@ -234,7 +234,9 @@ def get_configuration(args):
# Default to making window half the screen size
# but make it full screen if -f is passed in
monitor = get_monitors()[custom_config["window_monitor"]]
monitors = get_monitors()
mon_index = custom_config["window_monitor"]
monitor = monitors[min(mon_index, len(monitors) - 1)]
window_width = monitor.width
if not args.full_screen:
window_width //= 2

View file

@ -50,6 +50,9 @@ RIGHT_SIDE = FRAME_X_RADIUS * RIGHT
PI = np.pi
TAU = 2 * PI
DEGREES = TAU / 360
# Nice to have a constant for readability
# when juxtaposed with expressions like 30 * DEGREES
RADIANS = 1
FFMPEG_BIN = "ffmpeg"

View file

@ -25,8 +25,8 @@ class CoordinateSystem():
"""
CONFIG = {
"dimension": 2,
"x_range": np.array([-8, 8, 1.0]),
"y_range": np.array([-4, 4, 1.0]),
"x_range": np.array([-8.0, 8.0, 1.0]),
"y_range": np.array([-4.0, 4.0, 1.0]),
"width": None,
"height": None,
"num_sampled_graph_points_per_tick": 5,
@ -343,11 +343,13 @@ class Axes(VGroup, CoordinateSystem):
class ThreeDAxes(Axes):
CONFIG = {
"dimension": 3,
"x_range": np.array([-6, 6, 1]),
"y_range": np.array([-5, 5, 1]),
"z_range": np.array([-4, 4, 1]),
"x_range": np.array([-6.0, 6.0, 1.0]),
"y_range": np.array([-5.0, 5.0, 1.0]),
"z_range": np.array([-4.0, 4.0, 1.0]),
"z_axis_config": {},
"z_normal": DOWN,
"height": None,
"width": None,
"depth": None,
"num_axis_pieces": 20,
"gloss": 0.5,

View file

@ -20,6 +20,9 @@ class ScreenRectangle(Rectangle):
class FullScreenRectangle(ScreenRectangle):
CONFIG = {
"height": FRAME_HEIGHT,
"fill_color": GREY_E,
"fill_opacity": 1,
"stroke_width": 0,
}

View file

@ -57,13 +57,13 @@ class Matrix(VMobject):
CONFIG = {
"v_buff": 0.8,
"h_buff": 1.3,
"bracket_h_buff": MED_SMALL_BUFF,
"bracket_v_buff": MED_SMALL_BUFF,
"bracket_h_buff": 0.2,
"bracket_v_buff": 0.25,
"add_background_rectangles_to_entries": False,
"include_background_rectangle": False,
"element_to_mobject": Tex,
"element_to_mobject_config": {},
"element_alignment_corner": DR,
"element_alignment_corner": DOWN,
}
def __init__(self, matrix, **kwargs):
@ -132,6 +132,12 @@ class Matrix(VMobject):
for i in range(len(self.mob_matrix[0]))
])
def get_rows(self):
return VGroup(*[
VGroup(*row)
for row in self.mob_matrix
])
def set_column_colors(self, *colors):
columns = self.get_columns()
for color, column in zip(colors, columns):
@ -163,6 +169,7 @@ class DecimalMatrix(Matrix):
class IntegerMatrix(Matrix):
CONFIG = {
"element_to_mobject": Integer,
"element_alignment_corner": UP,
}

View file

@ -176,6 +176,7 @@ class Mobject(object):
def match_points(self, mobject):
self.set_points(mobject.get_points())
return self
def get_points(self):
return self.data["points"]
@ -561,7 +562,7 @@ class Mobject(object):
)
return self
def scale(self, scale_factor, **kwargs):
def scale(self, scale_factor, min_scale_factor=1e-8, **kwargs):
"""
Default behavior is to scale about the center of the mobject.
The argument about_edge can be a vector, indicating which side of
@ -571,6 +572,7 @@ class Mobject(object):
Otherwise, if about_point is given a value, scaling is done with
respect to that point.
"""
scale_factor = max(scale_factor, min_scale_factor)
self.apply_points_function(
lambda points: scale_factor * points,
works_on_bounding_box=True,

View file

@ -5,7 +5,6 @@ from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.bezier import interpolate
from manimlib.utils.config_ops import digest_config
from manimlib.utils.config_ops import merge_dicts_recursively
from manimlib.utils.iterables import list_difference_update
from manimlib.utils.simple_functions import fdiv
from manimlib.utils.space_ops import normalize
@ -144,7 +143,7 @@ class NumberLine(Line):
direction=direction,
buff=buff
)
if x < 0 and self.line_to_number_direction[0] == 0:
if x < 0 and direction[0] == 0:
# Align without the minus sign
num_mob.shift(num_mob[0].get_width() * LEFT / 2)
return num_mob
@ -155,10 +154,11 @@ class NumberLine(Line):
kwargs["font_size"] = font_size
if excluding is None:
excluding = self.numbers_to_exclude
numbers = VGroup()
for x in x_values:
if x in self.numbers_to_exclude:
continue
if excluding is not None and x in excluding:
continue
numbers.add(self.get_number_mobject(x, **kwargs))

View file

@ -128,10 +128,11 @@ class DecimalNumber(VMobject):
def set_value(self, number):
move_to_point = self.get_edge_center(self.edge_to_fix)
style = self.get_style()
old_submobjects = self.submobjects
self.set_submobjects_from_number(number)
self.move_to(move_to_point, self.edge_to_fix)
self.set_style(**style)
for sm1, sm2 in zip(self.submobjects, old_submobjects):
sm1.match_style(sm2)
return self
def scale(self, scale_factor, **kwargs):

View file

@ -64,16 +64,17 @@ class BackgroundRectangle(SurroundingRectangle):
class Cross(VGroup):
CONFIG = {
"stroke_color": RED,
"stroke_width": 6,
"stroke_width": [0, 6, 0],
}
def __init__(self, mobject, **kwargs):
VGroup.__init__(self,
Line(UP + LEFT, DOWN + RIGHT),
Line(UP + RIGHT, DOWN + LEFT),
)
super().__init__(
Line(UL, DR),
Line(UR, DL),
)
self.insert_n_curves(2)
self.replace(mobject, stretch=True)
self.set_stroke(self.stroke_color, self.stroke_width)
self.set_stroke(self.stroke_color, width=self.stroke_width)
class Underline(Line):

View file

@ -8,6 +8,7 @@ from manimlib.animation.growing import GrowFromCenter
from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.mobject.svg.tex_mobject import SingleStringTex
from manimlib.mobject.svg.tex_mobject import TexText
from manimlib.mobject.svg.text_mobject import Text
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.config_ops import digest_config
from manimlib.utils.space_ops import get_norm
@ -61,9 +62,10 @@ class Brace(SingleStringTex):
mob.shift(self.get_direction() * shift_distance)
return self
def get_text(self, *text, **kwargs):
text_mob = TexText(*text)
self.put_at_tip(text_mob, **kwargs)
def get_text(self, text, **kwargs):
buff = kwargs.pop("buff", SMALL_BUFF)
text_mob = Text(text, **kwargs)
self.put_at_tip(text_mob, buff=buff)
return text_mob
def get_tex(self, *tex, **kwargs):

View file

@ -169,7 +169,11 @@ class SVGMobject(VMobject):
else 0.0
for key in ("cx", "cy", "rx", "ry")
]
return Circle().scale(rx * RIGHT + ry * UP).shift(x * RIGHT + y * DOWN)
result = Circle()
result.stretch(rx, 0)
result.stretch(ry, 1)
result.shift(x * RIGHT + y * DOWN)
return result
def rect_to_mobject(self, rect_element):
fill_color = rect_element.getAttribute("fill")

View file

@ -152,7 +152,7 @@ class SingleStringTex(VMobject):
class Tex(SingleStringTex):
CONFIG = {
"arg_separator": " ",
"arg_separator": "",
# Note, use of isolate is largely rendered
# moot by the fact that you can surround such strings in
# {{ and }} as needed.

View file

@ -10,6 +10,7 @@ import manimpango
from manimlib.constants import *
from manimlib.mobject.geometry import Dot
from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.config_ops import digest_config
from manimlib.utils.customization import get_customization
from manimlib.utils.directories import get_downloads_dir, get_text_dir
@ -100,6 +101,19 @@ class Text(SVGMobject):
index = self.text.find(word, index + len(word))
return indexes
def get_parts_by_text(self, word):
return VGroup(*(
self[i:j]
for i, j in self.find_indexes(word)
))
def get_part_by_text(self, word):
parts = self.get_parts_by_text(word)
if len(parts) > 0:
return parts[0]
else:
return None
def full2short(self, config):
for kwargs in [config, self.CONFIG]:
if kwargs.__contains__('line_spacing_height'):
@ -212,6 +226,7 @@ class Text(SVGMobject):
self.text,
)
@contextmanager
def register_font(font_file: typing.Union[str, Path]):
"""Temporarily add a font file to Pango's search path.

View file

@ -244,7 +244,7 @@ class VMobject(Mobject):
return self.data['stroke_rgba'][:, 3]
def get_stroke_widths(self):
return self.data['stroke_width']
return self.data['stroke_width'][:, 0]
# TODO, it's weird for these to return the first of various lists
# rather than the full information
@ -848,8 +848,8 @@ class VMobject(Mobject):
old_points = self.get_points()
func(self, *args, **kwargs)
if not np.all(self.get_points() == old_points):
self.refresh_triangulation()
self.refresh_unit_normal()
self.refresh_triangulation()
return wrapper
@triggers_refreshed_triangulation
@ -870,9 +870,10 @@ class VMobject(Mobject):
self.make_approximately_smooth()
return self
@triggers_refreshed_triangulation
def flip(self, *args, **kwargs):
super().flip(*args, **kwargs)
self.refresh_unit_normal()
self.refresh_triangulation()
return self
# For shaders

View file

@ -62,10 +62,13 @@ def move_submobjects_along_vector_field(mobject, func):
return mobject
def move_points_along_vector_field(mobject, func):
def move_points_along_vector_field(mobject, func, coordinate_system):
cs = coordinate_system
origin = cs.get_origin()
def apply_nudge(self, dt):
self.mobject.apply_function(
lambda p: p + func(p) * dt
mobject.apply_function(
lambda p: p + (cs.c2p(*func(*cs.p2c(p))) - origin) * dt
)
mobject.add_updater(apply_nudge)
return mobject
@ -128,7 +131,7 @@ class VectorField(VGroup):
origin, _output, buff=0,
**vector_config
)
vect.shift(_input)
vect.shift(_input - origin)
vect.set_rgba_array([[*self.value_to_rgb(norm), self.opacity]])
return vect
@ -162,19 +165,21 @@ class StreamLines(VGroup):
self.init_style()
def point_func(self, point):
return self.coordinate_system.c2p(
*self.func(*self.coordinate_system.p2c(point))
)
in_coords = self.coordinate_system.p2c(point)
out_coords = self.func(*in_coords)
return self.coordinate_system.c2p(*out_coords)
def draw_lines(self):
lines = []
origin = self.coordinate_system.get_origin()
for point in self.get_start_points():
points = [point]
total_arc_len = 0
# for t in np.arange(0, self.virtual_time, self.dt):
time = 0
for x in range(self.max_time_steps):
time += self.dt
last_point = points[-1]
new_point = last_point + self.dt * self.point_func(last_point)
new_point = last_point + self.dt * (self.point_func(last_point) - origin)
points.append(new_point)
total_arc_len += get_norm(new_point - last_point)
if get_norm(last_point) > self.cutoff_norm:
@ -182,8 +187,10 @@ class StreamLines(VGroup):
if total_arc_len > self.arc_len:
break
line = VMobject()
line.virtual_time = time
step = max(1, int(len(points) / self.n_samples_per_line))
line.set_points_smoothly(points[::step])
line.set_points_as_corners(points[::step])
line.make_approximately_smooth()
lines.append(line)
self.set_submobjects(lines)
@ -220,7 +227,7 @@ class StreamLines(VGroup):
rgbas[:, 3] = self.stroke_opacity
line.set_rgba_array(rgbas, "stroke_rgba")
else:
self.set_stroke(self.stroke_color)
self.set_stroke(self.stroke_color, opacity=self.stroke_opacity)
if self.taper_stroke_width:
width = [0, self.stroke_width, 0]
@ -234,7 +241,7 @@ class AnimatedStreamLines(VGroup):
"lag_range": 4,
"line_anim_class": VShowPassingFlash,
"line_anim_config": {
"run_time": 4,
# "run_time": 4,
"rate_func": linear,
"time_width": 0.5,
},
@ -244,7 +251,11 @@ class AnimatedStreamLines(VGroup):
super().__init__(**kwargs)
self.stream_lines = stream_lines
for line in stream_lines:
line.anim = self.line_anim_class(line, **self.line_anim_config)
line.anim = self.line_anim_class(
line,
run_time=line.virtual_time,
**self.line_anim_config,
)
line.anim.begin()
line.time = -self.lag_range * random.random()
self.add(line.anim.mobject)

View file

@ -74,12 +74,14 @@ class SwitchOff(LaggedStartMap):
class Lighthouse(SVGMobject):
CONFIG = {
"file_name": "lighthouse",
"height": LIGHTHOUSE_HEIGHT,
"fill_color": WHITE,
"fill_opacity": 1.0,
}
def __init__(self, **kwargs):
super().__init__("lighthouse", **kwargs)
def move_to(self, point):
self.next_to(point, DOWN, buff=0)

View file

@ -23,6 +23,7 @@ def get_text_dir():
def get_mobject_data_dir():
return guarantee_existence(os.path.join(get_temp_dir(), "mobject_data"))
def get_downloads_dir():
return guarantee_existence(os.path.join(get_temp_dir(), "manim_downloads"))

View file

@ -10,7 +10,7 @@ def get_full_raster_image_path(image_file_name):
return find_file(
image_file_name,
directories=[get_raster_image_dir()],
extensions=[".jpg", ".png", ".gif", ""]
extensions=[".jpg", ".jpeg", ".png", ".gif", ""]
)

View file

@ -38,7 +38,9 @@ class Window(PygletWindow):
def find_initial_position(self, size):
custom_position = get_customization()["window_position"]
monitor = get_monitors()[get_customization()["window_monitor"]]
monitors = get_monitors()
mon_index = get_customization()["window_monitor"]
monitor = monitors[min(mon_index, len(monitors) - 1)]
window_width, window_height = size
# Position might be specified with a string of the form
# x,y for integers x and y