2023-06-15 10:40:18 -07:00
|
|
|
from manim_imports_ext import *
|
|
|
|
from _2022.convolutions.discrete import *
|
|
|
|
from _2023.clt.main import *
|
|
|
|
|
|
|
|
|
|
|
|
SKEW_DISTRIBUTION = [0.12, 0.23, 0.31, 0.18, 0.12, 0.04]
|
|
|
|
|
|
|
|
# Helpers
|
|
|
|
|
|
|
|
|
|
|
|
def get_bar_group(
|
|
|
|
dist,
|
|
|
|
bar_colors=(BLUE_D, TEAL_D),
|
|
|
|
value_labels=None,
|
|
|
|
width_ratio=0.7,
|
|
|
|
height=2.0,
|
|
|
|
number_config=dict(),
|
|
|
|
label_buff=SMALL_BUFF,
|
|
|
|
):
|
|
|
|
bars = dist_to_bars(dist, bar_colors=bar_colors, height=height)
|
|
|
|
p_labels = VGroup(*(DecimalNumber(x, **number_config) for x in dist))
|
|
|
|
p_labels.set_max_width(width_ratio * bars[0].get_width())
|
|
|
|
for p_label, bar in zip(p_labels, bars):
|
|
|
|
p_label.next_to(bar, UP, SMALL_BUFF)
|
|
|
|
|
|
|
|
if value_labels is None:
|
|
|
|
value_labels = VectorizedPoint().replicate(len(dist))
|
|
|
|
|
|
|
|
for value_label, bar in zip(value_labels, bars):
|
|
|
|
value_label.set_width(width_ratio * bars[0].get_width())
|
|
|
|
value_label.next_to(bar, DOWN, buff=label_buff)
|
|
|
|
|
|
|
|
labeled_bars = VGroup(*(
|
|
|
|
VGroup(bar, value_label, p_label)
|
|
|
|
for bar, value_label, p_label in zip(bars, value_labels, p_labels)
|
|
|
|
))
|
|
|
|
for group in labeled_bars:
|
|
|
|
group.bar, group.die, group.value_label = group
|
|
|
|
|
|
|
|
return labeled_bars
|
|
|
|
|
|
|
|
|
|
|
|
def die_sum_labels(color1=BLUE_E, color2=RED_E, height=1.0):
|
|
|
|
die1, die2 = dice = [
|
|
|
|
DieFace(1, fill_color=color)
|
|
|
|
for color in [color1, color2]
|
|
|
|
]
|
|
|
|
for die in dice:
|
|
|
|
die.remove(die[1])
|
|
|
|
die.set_height(height / 3)
|
|
|
|
|
|
|
|
result = VGroup()
|
|
|
|
for n in range(2, 13):
|
|
|
|
sum_sym = VGroup(
|
|
|
|
die1.copy(),
|
|
|
|
Tex("+", font_size=24),
|
|
|
|
die2.copy(),
|
|
|
|
Tex("=", font_size=24).rotate(90 * DEGREES),
|
|
|
|
Tex(str(n), font_size=30),
|
|
|
|
)
|
|
|
|
sum_sym.arrange(DOWN, buff=SMALL_BUFF)
|
|
|
|
sum_sym[:2].shift(0.05 * DOWN)
|
|
|
|
sum_sym[:1].shift(0.05 * DOWN)
|
|
|
|
sum_sym.set_height(height)
|
|
|
|
result.add(sum_sym)
|
|
|
|
result.arrange(RIGHT)
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
def rotate_sum_label(sum_label):
|
|
|
|
sum_label.arrange(RIGHT, buff=SMALL_BUFF)
|
|
|
|
sum_label[-2].rotate(90 * DEGREES)
|
|
|
|
sum_label[-2:].set_height(sum_label.get_height(), about_edge=LEFT)
|
|
|
|
sum_label[-2:].shift(SMALL_BUFF * RIGHT)
|
|
|
|
sum_label[-1].shift(SMALL_BUFF * RIGHT)
|
|
|
|
|
|
|
|
|
|
|
|
def p_mob(mob, scale_factor=1.0):
|
|
|
|
used_mob = mob.copy()
|
|
|
|
aspect_ratio = mob.get_width() / mob.get_height()
|
|
|
|
Os = "O" * int(np.round(aspect_ratio))
|
|
|
|
tex = Tex(f"P({Os})")
|
|
|
|
used_mob.replace(tex[Os], dim_to_match=0)
|
|
|
|
used_mob.scale(scale_factor)
|
|
|
|
result = VGroup(*tex[:2], used_mob, tex[-1])
|
|
|
|
result.arg = used_mob
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
# Scenes
|
|
|
|
|
|
|
|
|
|
|
|
class SumAlongDiagonal(InteractiveScene):
|
|
|
|
samples = 4
|
|
|
|
|
|
|
|
dist1 = EXP_DISTRIBUTION
|
|
|
|
dist2 = SKEW_DISTRIBUTION
|
|
|
|
|
|
|
|
dist1_colors = (BLUE_D, TEAL_D)
|
|
|
|
dist2_colors = (RED_D, GOLD_E)
|
|
|
|
sum_colors = (GREEN_E, YELLOW_E)
|
|
|
|
|
|
|
|
def construct(self):
|
|
|
|
# Setup distributions
|
|
|
|
dist1 = self.dist1
|
|
|
|
dist2 = self.dist2
|
|
|
|
blue_dice = get_die_faces(fill_color=BLUE_E, dot_color=WHITE)
|
|
|
|
red_dice = get_die_faces(fill_color=RED_E, dot_color=WHITE)
|
|
|
|
bar_groups = VGroup(
|
|
|
|
get_bar_group(dist1, self.dist1_colors, blue_dice),
|
|
|
|
get_bar_group(dist2, self.dist2_colors, red_dice),
|
|
|
|
)
|
|
|
|
|
|
|
|
bar_groups.arrange(DOWN, buff=LARGE_BUFF)
|
|
|
|
bar_groups.to_edge(LEFT)
|
|
|
|
|
|
|
|
self.add(bar_groups)
|
|
|
|
|
|
|
|
# Setup the sum distribution
|
|
|
|
conv_dist = np.convolve(dist1, dist2)
|
|
|
|
sum_labels = die_sum_labels()
|
|
|
|
sum_bar_group = get_bar_group(
|
|
|
|
conv_dist, self.sum_colors, sum_labels,
|
2023-07-01 22:06:29 -07:00
|
|
|
number_config=dict(num_decimal_places=3, font_size=30),
|
|
|
|
label_buff=MED_SMALL_BUFF,
|
2023-06-15 10:40:18 -07:00
|
|
|
)
|
|
|
|
sum_bar_group.to_edge(RIGHT, buff=LARGE_BUFF)
|
|
|
|
sum_bar_group.set_y(0)
|
|
|
|
|
|
|
|
buckets = VGroup()
|
|
|
|
for bar in sum_bar_group:
|
|
|
|
base = Line(LEFT, RIGHT)
|
|
|
|
base.match_width(bar)
|
|
|
|
base.move_to(bar[0], DOWN)
|
|
|
|
v_lines = Line(DOWN, UP).replicate(2)
|
|
|
|
v_lines.set_height(6)
|
|
|
|
v_lines[0].move_to(base.get_left(), DOWN)
|
|
|
|
v_lines[1].move_to(base.get_right(), DOWN)
|
|
|
|
bucket = VGroup(base, *v_lines)
|
|
|
|
bucket.set_stroke(GREY_C, 2)
|
|
|
|
buckets.add(bucket)
|
|
|
|
|
|
|
|
self.add(sum_labels)
|
|
|
|
self.add(buckets)
|
|
|
|
|
|
|
|
# Repeatedly sample from these two (for a while)
|
|
|
|
self.show_repeated_samples(
|
|
|
|
dist1, dist2, *bar_groups, buckets, sum_labels,
|
|
|
|
n_animated_runs=1, n_total_runs=2,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Ask about sum values
|
|
|
|
rects = VGroup(*(
|
|
|
|
SurroundingRectangle(sum_label)
|
|
|
|
for sum_label in sum_labels
|
|
|
|
))
|
|
|
|
words1 = Text("What's the probability\nof this?", font_size=36)
|
|
|
|
words2 = Text("Or this?", font_size=36)
|
|
|
|
|
|
|
|
words1.next_to(rects[0], DOWN, MED_SMALL_BUFF)
|
|
|
|
words2.next_to(rects[1], DOWN, MED_SMALL_BUFF)
|
|
|
|
|
|
|
|
self.play(ShowCreation(rects[0]), FadeIn(words1))
|
|
|
|
self.wait()
|
|
|
|
self.play(
|
|
|
|
TransformMatchingStrings(words1, words2, run_time=1),
|
|
|
|
FadeOut(rects[0]),
|
|
|
|
FadeIn(rects[1]),
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
for i in range(1, len(rects) - 1):
|
|
|
|
self.play(
|
|
|
|
FadeOut(rects[i]), FadeIn(rects[i + 1]),
|
|
|
|
words2.animate.match_x(rects[i + 1]),
|
|
|
|
run_time=0.5
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
self.play(FadeOut(words2), FadeOut(rects[-1]))
|
|
|
|
self.play(FadeOut(sum_labels), FadeOut(buckets))
|
|
|
|
|
|
|
|
# Draw grid of dice values
|
|
|
|
grid = Square().get_grid(6, 6, buff=0, fill_rows_first=False, )
|
|
|
|
grid.flip(RIGHT)
|
|
|
|
grid.set_stroke(WHITE, 1)
|
|
|
|
grid.set_height(5.5)
|
|
|
|
grid.to_edge(RIGHT, buff=LARGE_BUFF)
|
|
|
|
grid.to_edge(UP)
|
|
|
|
|
|
|
|
blue_row = blue_dice.copy()
|
|
|
|
red_col = red_dice.copy()
|
|
|
|
for square, die in zip(grid[::6], blue_row):
|
|
|
|
die.set_width(0.5 * square.get_width())
|
|
|
|
die.next_to(square, DOWN)
|
|
|
|
for square, die in zip(grid, red_col):
|
|
|
|
die.set_width(0.5 * square.get_width())
|
|
|
|
die.next_to(square, LEFT)
|
|
|
|
|
|
|
|
self.play(
|
|
|
|
ShowCreation(grid, lag_ratio=0.5),
|
|
|
|
TransformFromCopy(blue_dice, blue_row),
|
|
|
|
TransformFromCopy(red_dice, red_col),
|
|
|
|
)
|
|
|
|
|
|
|
|
dice_pairs = VGroup()
|
|
|
|
anims = []
|
|
|
|
for n, square in enumerate(grid):
|
|
|
|
templates = VGroup(
|
|
|
|
blue_row[n // 6],
|
|
|
|
red_col[n % 6],
|
|
|
|
)
|
|
|
|
pair = templates.copy()
|
|
|
|
pair.arrange(RIGHT, buff=SMALL_BUFF)
|
|
|
|
pair.set_width(square.get_width() * 0.7)
|
|
|
|
pair.move_to(square)
|
|
|
|
dice_pairs.add(pair)
|
|
|
|
anims.extend([
|
|
|
|
TransformFromCopy(templates, pair),
|
|
|
|
])
|
|
|
|
|
|
|
|
self.play(LaggedStart(*anims, lag_ratio=0.1))
|
|
|
|
self.add(dice_pairs)
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
full_table = VGroup(blue_row, red_col, grid, dice_pairs)
|
|
|
|
|
|
|
|
# Highlight (4, 2) pair
|
|
|
|
pairs = [
|
|
|
|
(1, 1),
|
|
|
|
(1, 2),
|
|
|
|
(1, 3),
|
|
|
|
(1, 4),
|
|
|
|
(1, 5),
|
|
|
|
(2, 5),
|
|
|
|
(2, 4),
|
|
|
|
(2, 3),
|
|
|
|
(3, 3),
|
|
|
|
(3, 2),
|
|
|
|
(4, 2),
|
|
|
|
]
|
|
|
|
|
|
|
|
for pair in pairs:
|
|
|
|
self.isolate_pairs(bar_groups, full_table, pair)
|
|
|
|
self.wait(0.5)
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
# Show the probability of (4, 2)
|
|
|
|
def get_p_label(dice):
|
|
|
|
p_label = VGroup(
|
|
|
|
p_mob(dice), Tex("="),
|
|
|
|
p_mob(dice[0]),
|
|
|
|
p_mob(dice[1]),
|
|
|
|
)
|
|
|
|
p_label.arrange(RIGHT, buff=SMALL_BUFF)
|
|
|
|
return p_label
|
|
|
|
|
|
|
|
i0, j0 = pairs[-1]
|
|
|
|
p_label = get_p_label(dice_pairs[(i0 - 1) * 6 + (j0 - 1)])
|
|
|
|
p_label.to_edge(UP)
|
|
|
|
p_label.shift(2.5 * LEFT)
|
|
|
|
|
|
|
|
movers = VGroup(
|
|
|
|
dice_pairs[(i0 - 1) * 6 + j0 - 1],
|
|
|
|
blue_row[(i0 - 1)],
|
|
|
|
red_col[(j0 - 1)],
|
|
|
|
).copy()
|
|
|
|
self.play(
|
|
|
|
bar_groups.animate.set_width(2.5, about_edge=DL),
|
|
|
|
full_table.animate.set_width(5.5, about_edge=DR),
|
|
|
|
LaggedStart(
|
|
|
|
movers[0].animate.replace(p_label[0][2]),
|
|
|
|
movers[1].animate.replace(p_label[2][2]),
|
|
|
|
movers[2].animate.replace(p_label[3][2]),
|
|
|
|
lag_ratio=0.25,
|
|
|
|
),
|
|
|
|
)
|
|
|
|
self.play(FadeIn(p_label))
|
|
|
|
self.remove(movers)
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
# Show numerical product
|
|
|
|
prod_rhs = Tex("= (0.00)(0.00)")
|
|
|
|
num_rhs = Tex("= 0.000")
|
|
|
|
|
2024-03-21 10:03:47 -03:00
|
|
|
value1, value2 = prod_rhs.make_number_changeable("0.00", replace_all=True)
|
2023-06-15 10:40:18 -07:00
|
|
|
value1.set_value(dist1[i0 - 1]).set_color(BLUE)
|
|
|
|
value2.set_value(dist2[j0 - 1]).set_color(RED)
|
|
|
|
prod_rhs.next_to(p_label[1], DOWN, MED_LARGE_BUFF, aligned_edge=LEFT)
|
|
|
|
|
2024-03-21 10:03:47 -03:00
|
|
|
pair_prob = num_rhs.make_number_changeable("0.000")
|
2023-06-15 10:40:18 -07:00
|
|
|
pair_prob.set_value(dist1[i0 - 1] * dist2[j0 - 1])
|
|
|
|
num_rhs.next_to(prod_rhs, DOWN, MED_LARGE_BUFF, aligned_edge=LEFT)
|
|
|
|
|
|
|
|
self.play(LaggedStart(
|
|
|
|
Transform(bar_groups[0][i0 - 1][2].copy(), value1.copy(), remover=True),
|
|
|
|
Transform(bar_groups[1][j0 - 1][2].copy(), value2.copy(), remover=True),
|
|
|
|
Write(prod_rhs, lag_ratio=0.1),
|
|
|
|
lag_ratio=0.5
|
|
|
|
))
|
|
|
|
self.wait()
|
|
|
|
self.play(FadeIn(num_rhs, DOWN))
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
example = VGroup(p_label, prod_rhs, num_rhs)
|
|
|
|
|
|
|
|
# Assumption
|
|
|
|
morty = Mortimer(height=2.0).flip()
|
|
|
|
morty.next_to(example, DOWN, buff=2.0)
|
|
|
|
morty.shift(0.5 * LEFT)
|
|
|
|
|
|
|
|
self.play(
|
|
|
|
morty.says(
|
|
|
|
"Assuming rolls\nare independent!",
|
|
|
|
mode="surprised",
|
|
|
|
),
|
|
|
|
VFadeIn(morty),
|
|
|
|
)
|
|
|
|
self.play(Blink(morty))
|
|
|
|
self.wait()
|
|
|
|
self.play(LaggedStartMap(FadeOut, VGroup(morty, morty.bubble, morty.bubble.content)))
|
|
|
|
|
|
|
|
# Set up the multiplication table
|
|
|
|
full_table.generate_target()
|
|
|
|
full_table.target.set_width(6.5)
|
|
|
|
full_table.target.set_opacity(1)
|
|
|
|
full_table.target[2].set_fill(opacity=0)
|
|
|
|
full_table.target.to_edge(RIGHT, buff=1.5)
|
|
|
|
full_table.target.set_y(0)
|
|
|
|
|
|
|
|
self.play(
|
|
|
|
MoveToTarget(full_table),
|
|
|
|
example.animate.scale(0.75, about_edge=UL).shift(0.5 * LEFT)
|
|
|
|
)
|
|
|
|
|
|
|
|
marginal_labels = VGroup(*(
|
|
|
|
VGroup(*(bar[2] for bar in group)).copy()
|
|
|
|
for group in bar_groups
|
|
|
|
))
|
|
|
|
margin_dice = VGroup(*full_table[:2], full_table[3])
|
|
|
|
for margins, dice in zip(marginal_labels, margin_dice):
|
|
|
|
margins.generate_target()
|
|
|
|
dice.generate_target()
|
|
|
|
for die, prob in zip(dice.target, margins.target):
|
|
|
|
center = die.get_center()
|
|
|
|
die.scale(0.7)
|
|
|
|
die.set_opacity(0.85)
|
|
|
|
prob.next_to(die, DOWN, buff=0.125)
|
|
|
|
prob.scale(1.5)
|
|
|
|
VGroup(prob, die).move_to(center)
|
|
|
|
|
|
|
|
marginal_labels[0].target.set_fill(BLUE, 1)
|
|
|
|
marginal_labels[1].target.set_fill(RED, 1)
|
|
|
|
|
|
|
|
margin_dice[2].generate_target()
|
|
|
|
for dice in margin_dice[2].target:
|
|
|
|
dice.scale(0.75, about_edge=UP)
|
|
|
|
dice.set_opacity(0.85)
|
|
|
|
dice.set_stroke(width=1)
|
|
|
|
dice.shift(0.1 * UP)
|
|
|
|
|
|
|
|
for groups in zip(marginal_labels, margin_dice):
|
|
|
|
self.play(*map(MoveToTarget, groups), lag_ratio=0.001)
|
|
|
|
self.play(MoveToTarget(margin_dice[2], lag_ratio=0.001))
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
full_table.add(marginal_labels)
|
|
|
|
|
|
|
|
# Fill in multiplication table
|
|
|
|
grid_probs = VGroup()
|
|
|
|
self.add(grid_probs)
|
|
|
|
for n, square in enumerate(grid):
|
|
|
|
i, j = n // 6, n % 6
|
|
|
|
dice = margin_dice[2][n]
|
|
|
|
margin1 = marginal_labels[0][i]
|
|
|
|
margin2 = marginal_labels[1][j]
|
|
|
|
prob = DecimalNumber(
|
|
|
|
margin1.get_value() * margin2.get_value(),
|
|
|
|
num_decimal_places=3
|
|
|
|
)
|
|
|
|
prob.set_height(margin1.get_height())
|
|
|
|
prob.next_to(dice, DOWN, SMALL_BUFF)
|
|
|
|
|
|
|
|
rects = VGroup(*(
|
|
|
|
SurroundingRectangle(mob, buff=SMALL_BUFF)
|
|
|
|
for mob in [margin1, margin2, prob]
|
|
|
|
))
|
|
|
|
rects.set_stroke(YELLOW, 1)
|
|
|
|
|
|
|
|
grid_probs.add(prob)
|
|
|
|
|
|
|
|
value1.set_value(dist1[i])
|
|
|
|
value2.set_value(dist2[j])
|
|
|
|
pair_prob.set_value(dist1[i] * dist2[j])
|
|
|
|
new_p_label = get_p_label(dice.copy().set_opacity(1))
|
|
|
|
new_p_label.replace(p_label, dim_to_match=1)
|
|
|
|
p_label.become(new_p_label)
|
|
|
|
|
|
|
|
bar_groups.set_opacity(0.35)
|
|
|
|
bar_groups[0][i].set_opacity(1)
|
|
|
|
bar_groups[1][j].set_opacity(1)
|
|
|
|
|
|
|
|
self.add(rects)
|
|
|
|
self.wait(0.25)
|
|
|
|
self.remove(rects)
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
full_table.add(grid_probs)
|
|
|
|
|
|
|
|
# Fade out example
|
|
|
|
self.play(
|
|
|
|
FadeOut(example),
|
|
|
|
bar_groups.animate.set_opacity(1),
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
# Show it as a 3d plot
|
|
|
|
bar_groups.fix_in_frame()
|
|
|
|
sum_bar_group.fix_in_frame()
|
|
|
|
|
|
|
|
bars_3d = VGroup()
|
|
|
|
scale_factor = 30
|
|
|
|
for square, prob in zip(grid, grid_probs):
|
|
|
|
prism = VCube()
|
|
|
|
prism.set_fill(GREY_D, 0.85)
|
|
|
|
prism.set_stroke(WHITE, 1, 0.5)
|
|
|
|
prism.match_width(square)
|
|
|
|
prism.set_depth(scale_factor * prob.get_value(), stretch=True)
|
|
|
|
prism.move_to(square, IN)
|
|
|
|
prism.save_state()
|
|
|
|
prism.stretch(0.001, 2, about_edge=IN)
|
|
|
|
prism.set_opacity(0)
|
|
|
|
bars_3d.add(prism)
|
|
|
|
|
|
|
|
self.play(
|
|
|
|
LaggedStartMap(Restore, bars_3d),
|
|
|
|
self.frame.animate.reorient(12, 65, 0).move_to([-0.48, 0.37, 0.77]).set_height(9.43),
|
|
|
|
run_time=3,
|
|
|
|
)
|
|
|
|
self.add(full_table, *bars_3d)
|
|
|
|
self.play(
|
|
|
|
self.frame.animate.reorient(43, 66, 0).move_to([-0.06, 0.21, -0.29]).set_height(10.59),
|
|
|
|
run_time=7,
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
# Show sum distribution
|
|
|
|
self.play(
|
|
|
|
LaggedStartMap(FadeOut, bar_groups, shift=2 * LEFT),
|
|
|
|
FadeIn(sum_bar_group),
|
|
|
|
full_table.animate.to_edge(LEFT),
|
|
|
|
self.frame.animate.reorient(4, 65, 0).move_to([0.64, 0.8, 0.69]).set_height(10.59),
|
|
|
|
*(
|
|
|
|
MaintainPositionRelativeTo(bar, full_table)
|
|
|
|
for bar in bars_3d
|
|
|
|
),
|
|
|
|
run_time=2,
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
# Reposition
|
|
|
|
self.play(
|
|
|
|
self.frame.animate.to_default_state(),
|
|
|
|
FadeOut(bars_3d),
|
|
|
|
run_time=2,
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
# Add up along all diagonals
|
|
|
|
low_group = VGroup(blue_row, marginal_labels[0])
|
|
|
|
left_group = VGroup(red_col, marginal_labels[1])
|
|
|
|
diagonals = VGroup(*(
|
|
|
|
VGroup(*(
|
|
|
|
VGroup(dice_pairs[n], grid_probs[n])
|
|
|
|
for n in range(36)
|
|
|
|
if (n // 6) + (n % 6) == s
|
|
|
|
))
|
|
|
|
for s in range(11)
|
|
|
|
))
|
|
|
|
|
|
|
|
diagonals.save_state()
|
|
|
|
rects = VGroup()
|
|
|
|
diagonals.rotate(45 * DEGREES)
|
|
|
|
for diagonal in diagonals:
|
|
|
|
rect = SurroundingRectangle(diagonal)
|
|
|
|
rect.stretch(0.9, 1)
|
|
|
|
rect.round_corners()
|
|
|
|
rects.add(rect)
|
|
|
|
VGroup(rects, diagonals).rotate(-45 * DEGREES, about_point=diagonals.get_center())
|
|
|
|
|
|
|
|
rects.set_fill(YELLOW, 0.25)
|
|
|
|
rects.set_stroke(YELLOW, 2)
|
|
|
|
|
|
|
|
last_rect = VGroup(low_group, left_group)
|
|
|
|
for n, rect in zip(it.count(), rects):
|
|
|
|
self.add(rect, diagonals)
|
|
|
|
diagonals.generate_target()
|
|
|
|
diagonals.target.set_opacity(0.5)
|
|
|
|
diagonals.target[n].set_opacity(1)
|
|
|
|
|
|
|
|
sum_bar_group.generate_target()
|
|
|
|
sum_bar_group.target.set_opacity(0.4)
|
|
|
|
sum_bar_group.target[n].set_opacity(1)
|
|
|
|
|
|
|
|
self.play(
|
|
|
|
FadeOut(last_rect),
|
|
|
|
FadeIn(rect),
|
|
|
|
MoveToTarget(diagonals),
|
|
|
|
MoveToTarget(sum_bar_group),
|
|
|
|
)
|
|
|
|
last_rect = rect
|
|
|
|
self.wait()
|
|
|
|
self.play(
|
|
|
|
diagonals.animate.set_opacity(1),
|
|
|
|
sum_bar_group.animate.set_opacity(1),
|
|
|
|
FadeOut(last_rect),
|
|
|
|
FadeIn(low_group),
|
|
|
|
FadeIn(left_group),
|
|
|
|
)
|
|
|
|
|
|
|
|
# Show 3d grid again
|
|
|
|
self.play(
|
|
|
|
self.frame.animate.reorient(27, 66, 0).move_to([-0.16, 1.41, 0.77]).set_height(9.36),
|
|
|
|
FadeIn(bars_3d),
|
|
|
|
sum_bar_group.animate.to_edge(RIGHT),
|
|
|
|
run_time=2,
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
# Go through diagonals of the plot
|
|
|
|
sorted_bars = VGroup(*bars_3d)
|
|
|
|
camera_pos = self.frame.get_implied_camera_location()
|
|
|
|
sorted_bars.sort(lambda p: -get_norm(p - camera_pos))
|
|
|
|
self.add(*sorted_bars)
|
|
|
|
|
|
|
|
diagonal_bar_groups = VGroup().replicate(11)
|
|
|
|
|
|
|
|
for s in range(11):
|
|
|
|
sum_bar_group.generate_target()
|
|
|
|
sum_bar_group.target.set_opacity(0.2)
|
|
|
|
sum_bar_group.target[s].set_opacity(1)
|
|
|
|
|
|
|
|
for n, bar in enumerate(bars_3d):
|
|
|
|
bar.generate_target()
|
|
|
|
bar.target.set_opacity(0.1)
|
|
|
|
bar.target.set_stroke(width=0)
|
|
|
|
if (n // 6) + (n % 6) == s:
|
|
|
|
bar.target.set_opacity(1)
|
|
|
|
diagonal_bar_groups[s].add(bar)
|
|
|
|
|
|
|
|
self.play(
|
|
|
|
MoveToTarget(sum_bar_group),
|
|
|
|
*map(MoveToTarget, bars_3d),
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
self.play(
|
|
|
|
bars_3d.animate.set_opacity(0.8).set_stroke(width=0.5),
|
|
|
|
sum_bar_group.animate.set_opacity(1),
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
# Highlight bars
|
|
|
|
bars_3d.save_state()
|
|
|
|
bar_highlights = bars_3d.copy()
|
|
|
|
bar_highlights.set_fill(opacity=0)
|
|
|
|
bar_highlights.set_stroke(TEAL, 3)
|
|
|
|
self.play(ShowCreationThenFadeOut(bar_highlights, lag_ratio=0.001, run_time=2))
|
|
|
|
|
|
|
|
# Collapase diagonals
|
|
|
|
bars_3d.generate_target()
|
|
|
|
for bar in bars_3d.target:
|
|
|
|
bar.stretch(0.5, 0)
|
|
|
|
bar.stretch(0.5, 1)
|
|
|
|
bars_3d.target.set_fill(opacity=1)
|
|
|
|
bars_3d.target.set_submobject_colors_by_gradient(GREEN_D, YELLOW_D)
|
|
|
|
bars_3d.target.set_stroke(WHITE, 1)
|
|
|
|
self.play(
|
|
|
|
self.frame.animate.reorient(36, 46, 0).move_to([-0.56, 0.55, 1.22]).set_height(7.71),
|
|
|
|
MoveToTarget(bars_3d),
|
|
|
|
# FadeOut(full_table, IN),
|
|
|
|
sum_bar_group.animate.set_width(4.0).to_edge(RIGHT),
|
|
|
|
run_time=2,
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
diagonal_bar_groups.apply_depth_test()
|
|
|
|
new_diagonals = diagonal_bar_groups.copy()
|
|
|
|
for group in new_diagonals:
|
|
|
|
group.arrange(IN, buff=0)
|
|
|
|
new_diagonals.arrange(UR, buff=MED_SMALL_BUFF, aligned_edge=IN)
|
|
|
|
new_diagonals.move_to(bars_3d.get_corner(DR))
|
|
|
|
new_diagonals.shift(DR)
|
|
|
|
|
|
|
|
self.play(
|
|
|
|
self.frame.animate.reorient(40, 61, 0).move_to([1.69, 0.33, -0.73]).set_height(12.96),
|
|
|
|
ReplacementTransform(
|
|
|
|
diagonal_bar_groups, new_diagonals,
|
|
|
|
lag_ratio=0.001,
|
|
|
|
),
|
|
|
|
run_time=5,
|
|
|
|
)
|
|
|
|
self.add(full_table, new_diagonals)
|
|
|
|
self.play(
|
|
|
|
self.frame.animate.reorient(40, 85, 0).move_to([3.05, 1.93, 0.77]).set_height(14.93),
|
|
|
|
sum_bar_group.animate.set_width(5.5, about_edge=RIGHT),
|
|
|
|
FadeOut(full_table),
|
|
|
|
run_time=3
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
def show_repeated_samples(
|
|
|
|
self,
|
|
|
|
dist1,
|
|
|
|
dist2,
|
|
|
|
bar_group1,
|
|
|
|
bar_group2,
|
|
|
|
buckets,
|
|
|
|
sum_labels,
|
|
|
|
n_animated_runs=20,
|
|
|
|
n_total_runs=150,
|
|
|
|
marker_height=0.1,
|
|
|
|
marker_color=YELLOW_D,
|
|
|
|
):
|
|
|
|
marker_template = Rectangle(
|
|
|
|
height=marker_height,
|
|
|
|
width=buckets[0].get_width() * 0.8,
|
|
|
|
fill_color=marker_color,
|
|
|
|
fill_opacity=1,
|
|
|
|
stroke_color=WHITE,
|
|
|
|
stroke_width=1,
|
|
|
|
)
|
|
|
|
markers = VGroup(*(
|
|
|
|
VGroup(VectorizedPoint(bucket.get_bottom()))
|
|
|
|
for bucket in buckets
|
|
|
|
))
|
|
|
|
|
|
|
|
var1 = scipy.stats.rv_discrete(values=(range(6), dist1))
|
|
|
|
var2 = scipy.stats.rv_discrete(values=(range(6), dist2))
|
|
|
|
|
|
|
|
for n in range(n_total_runs):
|
|
|
|
x = var1.rvs()
|
|
|
|
y = var2.rvs()
|
|
|
|
|
|
|
|
animate = n < n_animated_runs
|
|
|
|
|
|
|
|
# Show dice
|
|
|
|
dice = VGroup()
|
|
|
|
for group, value in [(bar_group1, x), (bar_group2, y)]:
|
|
|
|
die = group[value][1].copy()
|
|
|
|
die.set_opacity(1)
|
|
|
|
die.scale(2)
|
|
|
|
die.next_to(group, RIGHT, LARGE_BUFF)
|
|
|
|
dice.add(die)
|
|
|
|
group.set_opacity(0.5)
|
|
|
|
group[value].set_opacity(1)
|
|
|
|
self.add(die)
|
|
|
|
self.wait(0.25 if animate else 0.0)
|
|
|
|
|
|
|
|
# Highlight sum
|
|
|
|
sum_labels.set_opacity(0.25)
|
|
|
|
sum_labels[x + y].set_opacity(1)
|
|
|
|
|
|
|
|
# Drop marker in the appropriate sum bucket
|
|
|
|
marker = marker_template.copy()
|
|
|
|
marker.move_to(markers[x + y].get_top(), DOWN)
|
|
|
|
if animate:
|
|
|
|
self.play(FadeIn(marker, DOWN, rate_func=rush_into, run_time=0.5))
|
|
|
|
self.wait(0.5)
|
|
|
|
|
|
|
|
markers[x + y].add(marker)
|
|
|
|
self.add(markers)
|
|
|
|
|
|
|
|
if animate:
|
|
|
|
self.play(LaggedStart(
|
|
|
|
FadeOut(dice[0]),
|
|
|
|
bar_group1.animate.set_opacity(0.5),
|
|
|
|
FadeOut(dice[1]),
|
|
|
|
bar_group2.animate.set_opacity(0.5),
|
|
|
|
sum_labels.animate.set_opacity(0.25),
|
|
|
|
run_time=0.5
|
|
|
|
))
|
|
|
|
else:
|
|
|
|
self.wait(0.1)
|
|
|
|
self.remove(dice)
|
|
|
|
VGroup(bar_group1, bar_group2).set_opacity(0.5)
|
|
|
|
sum_labels.set_opacity(0.25)
|
|
|
|
|
|
|
|
self.wait()
|
|
|
|
self.play(
|
|
|
|
FadeOut(markers, lag_ratio=0.01),
|
|
|
|
bar_group1.animate.set_opacity(1),
|
|
|
|
bar_group2.animate.set_opacity(1),
|
|
|
|
sum_labels.animate.set_opacity(1),
|
|
|
|
)
|
|
|
|
|
|
|
|
def isolate_pairs(self, bar_groups, full_table, *ij_tuples):
|
|
|
|
full_table.set_opacity(0.25)
|
|
|
|
full_table[2].set_fill(opacity=0)
|
|
|
|
bar_groups.set_opacity(0.35)
|
|
|
|
|
|
|
|
for i, j in ij_tuples:
|
|
|
|
im1 = i - 1
|
|
|
|
jm1 = j - 1
|
|
|
|
n = im1 * 6 + jm1
|
|
|
|
|
|
|
|
bar_groups[0][im1].set_opacity(1)
|
|
|
|
bar_groups[1][jm1].set_opacity(1)
|
|
|
|
|
|
|
|
full_table[0][im1].set_opacity(1)
|
|
|
|
full_table[1][jm1].set_opacity(1)
|
|
|
|
full_table[2][n].set_stroke(opacity=1)
|
|
|
|
full_table[3][n].set_opacity(1)
|
|
|
|
|
|
|
|
|
|
|
|
class ConvolveDiscreteDistributions(SumAlongDiagonal):
|
|
|
|
long_form = True
|
|
|
|
|
|
|
|
def construct(self):
|
|
|
|
# Set up two distributions
|
|
|
|
dist1 = self.dist1
|
|
|
|
dist2 = self.dist2
|
|
|
|
blue_dice = get_die_faces(fill_color=BLUE_E, dot_color=WHITE)
|
|
|
|
red_dice = get_die_faces(fill_color=RED_E, dot_color=WHITE)
|
|
|
|
top_bars, low_bars = bar_groups = VGroup(
|
|
|
|
get_bar_group(dist1, self.dist1_colors, blue_dice),
|
|
|
|
get_bar_group(dist2, self.dist2_colors, red_dice),
|
|
|
|
)
|
|
|
|
|
|
|
|
bar_groups.arrange(DOWN, buff=LARGE_BUFF)
|
|
|
|
bar_groups.to_edge(LEFT)
|
|
|
|
|
|
|
|
self.add(bar_groups)
|
|
|
|
|
|
|
|
# Setup the sum distribution
|
|
|
|
conv_dist = np.convolve(dist1, dist2)
|
|
|
|
sum_labels = die_sum_labels()
|
|
|
|
sum_bar_group = get_bar_group(conv_dist, self.sum_colors, sum_labels)
|
|
|
|
sum_bar_group.to_edge(RIGHT, buff=LARGE_BUFF)
|
|
|
|
sum_bar_group.set_y(0)
|
|
|
|
|
|
|
|
self.add(sum_bar_group)
|
|
|
|
|
|
|
|
# V lines
|
|
|
|
v_lines = get_bar_dividing_lines(top_bars)
|
|
|
|
for bar in top_bars:
|
|
|
|
v_line = Line(UP, DOWN).set_height(FRAME_HEIGHT)
|
|
|
|
v_line.set_stroke(GREY_C, 1, 0.75)
|
|
|
|
v_line.set_x(bar.get_left()[0])
|
|
|
|
v_line.set_y(0)
|
|
|
|
v_lines.add(v_line)
|
|
|
|
v_lines.add(v_lines[-1].copy().set_x(top_bars.get_right()[0]))
|
|
|
|
|
|
|
|
# Flip
|
|
|
|
low_bars.target = low_bars.generate_target()
|
|
|
|
low_bars.target.arrange(LEFT, aligned_edge=DOWN, buff=0).move_to(low_bars)
|
|
|
|
low_bars.target.move_to(low_bars)
|
|
|
|
|
|
|
|
rect = SurroundingRectangle(low_bars)
|
|
|
|
label = Text("Flip this")
|
|
|
|
label.next_to(rect, RIGHT)
|
|
|
|
|
|
|
|
low_arrow = Arrow(low_bars.get_right(), low_bars.get_left())
|
|
|
|
low_arrow.set_stroke(color=YELLOW)
|
|
|
|
low_arrow.next_to(low_bars, DOWN)
|
|
|
|
|
|
|
|
self.play(
|
|
|
|
ShowCreation(rect),
|
|
|
|
FadeIn(label)
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
self.play(
|
|
|
|
MoveToTarget(low_bars, path_arc=PI / 3, lag_ratio=0.005)
|
|
|
|
)
|
|
|
|
self.play(ShowCreation(low_arrow), FadeOut(rect))
|
|
|
|
self.wait()
|
|
|
|
self.play(
|
|
|
|
FadeOut(label), FadeOut(low_arrow),
|
|
|
|
ShowCreation(v_lines, lag_ratio=0.1),
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
# Show corresponding pairs
|
|
|
|
rows = VGroup(blue_dice, VGroup(*reversed(red_dice))).copy()
|
|
|
|
pairs = VGroup(*(VGroup(*pair) for pair in zip(*rows)))
|
|
|
|
self.play(rows.animate.arrange(UP, SMALL_BUFF).next_to(bar_groups, RIGHT))
|
|
|
|
|
|
|
|
rows.generate_target()
|
|
|
|
rows.target.rotate(-90 * DEGREES)
|
|
|
|
for row in rows.target:
|
|
|
|
for die in row:
|
|
|
|
die.rotate(90 * DEGREES)
|
|
|
|
rows.target.arrange(RIGHT, buff=MED_SMALL_BUFF)
|
|
|
|
rows.target.set_height(5)
|
|
|
|
rows.target.next_to(bar_groups, RIGHT, buff=1)
|
|
|
|
|
|
|
|
sum_bar_group.generate_target()
|
|
|
|
sum_bar_group.target.set_opacity(0.25)
|
|
|
|
sum_bar_group.target[7 - 2].set_opacity(1)
|
|
|
|
|
|
|
|
self.play(
|
|
|
|
MoveToTarget(rows, run_time=2),
|
|
|
|
MoveToTarget(sum_bar_group),
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
# Go through all pairs
|
|
|
|
last_rect = VMobject()
|
|
|
|
for n in [*range(6), *range(4, -1, -1)]:
|
|
|
|
pair = pairs[n]
|
|
|
|
rect = SurroundingRectangle(pair)
|
|
|
|
rect.round_corners()
|
|
|
|
bar_groups.generate_target()
|
|
|
|
bar_groups.target.set_opacity(0.35)
|
|
|
|
bar_groups.target[0][n].set_opacity(1)
|
|
|
|
bar_groups.target[1][5 - n].set_opacity(1)
|
|
|
|
self.play(
|
|
|
|
FadeIn(rect),
|
|
|
|
FadeOut(last_rect),
|
|
|
|
MoveToTarget(bar_groups),
|
|
|
|
run_time=0.5
|
|
|
|
)
|
|
|
|
self.wait(0.5)
|
|
|
|
last_rect = rect
|
|
|
|
self.play(FadeOut(last_rect), bar_groups.animate.set_opacity(1))
|
|
|
|
self.wait()
|
|
|
|
self.play(FadeOut(pairs), sum_bar_group.animate.set_opacity(1))
|
|
|
|
|
|
|
|
# March!
|
|
|
|
for bars in bar_groups:
|
|
|
|
for i, bar in zip(it.count(1), bars):
|
|
|
|
bar.index = i
|
|
|
|
|
|
|
|
for n in [7, 5, *range(2, 13)]:
|
|
|
|
sum_bar_group.generate_target()
|
|
|
|
sum_bar_group.target.set_opacity(0.25)
|
|
|
|
sum_bar_group.target[n - 2].set_opacity(1.0)
|
|
|
|
|
|
|
|
self.play(
|
|
|
|
get_row_shift(top_bars, low_bars, n),
|
|
|
|
MoveToTarget(sum_bar_group),
|
|
|
|
)
|
|
|
|
pairs = get_aligned_pairs(top_bars, low_bars, n)
|
|
|
|
|
|
|
|
label_pairs = VGroup(*(VGroup(m1.value_label, m2.value_label) for m1, m2 in pairs))
|
|
|
|
die_pairs = VGroup(*(VGroup(m1.die, m2.die) for m1, m2 in pairs))
|
|
|
|
pair_rects = VGroup(*(
|
|
|
|
SurroundingRectangle(pair, buff=0.05).set_stroke(YELLOW, 2).round_corners()
|
|
|
|
for pair in pairs
|
|
|
|
))
|
|
|
|
pair_rects.set_stroke(YELLOW, 2)
|
|
|
|
for rect in pair_rects:
|
|
|
|
rect.set_width(label_pairs[0].get_width() + 0.125, stretch=True)
|
|
|
|
|
|
|
|
fade_anims = []
|
|
|
|
|
|
|
|
# Spell out the full dot product
|
|
|
|
products = VGroup()
|
|
|
|
die_pair_targets = VGroup()
|
|
|
|
for die_pair in die_pairs:
|
|
|
|
product = VGroup(
|
|
|
|
p_mob(die_pair[0]),
|
|
|
|
p_mob(die_pair[1]),
|
|
|
|
)
|
|
|
|
product.arrange(RIGHT, buff=SMALL_BUFF)
|
|
|
|
die_pair_targets.add(VGroup(
|
|
|
|
product[0].arg,
|
|
|
|
product[1].arg,
|
|
|
|
))
|
|
|
|
products.add(product)
|
|
|
|
|
|
|
|
products.arrange(DOWN, buff=0.75)
|
|
|
|
products.move_to(midpoint(sum_bar_group.get_left(), bar_groups.get_right()))
|
|
|
|
products.shift(2 * UP).shift_onto_screen()
|
|
|
|
plusses = Tex("+", font_size=48).replicate(len(pairs))
|
|
|
|
plusses[-1].scale(0).set_opacity(0)
|
|
|
|
for plus, lp1, lp2 in zip(plusses, products, products[1:]):
|
|
|
|
plus.move_to(VGroup(lp1, lp2))
|
|
|
|
|
|
|
|
self.play(
|
|
|
|
ShowIncreasingSubsets(products),
|
|
|
|
ShowIncreasingSubsets(plusses),
|
|
|
|
ShowIncreasingSubsets(pair_rects),
|
|
|
|
run_time=0.35 * len(products)
|
|
|
|
)
|
|
|
|
self.wait(0.5)
|
|
|
|
|
|
|
|
prod_group = VGroup(*products, *plusses)
|
|
|
|
mover = prod_group.copy()
|
|
|
|
mover.sort(lambda p: -p[1])
|
|
|
|
mover.generate_target()
|
|
|
|
mover.target.set_opacity(0)
|
|
|
|
for mob in mover.target:
|
|
|
|
mob.replace(sum_bar_group[n - 2].value_label, stretch=True)
|
|
|
|
self.play(MoveToTarget(mover, remover=True, lag_ratio=0.002))
|
|
|
|
self.wait(0.5)
|
|
|
|
self.play(
|
|
|
|
FadeOut(prod_group),
|
|
|
|
FadeOut(pair_rects),
|
|
|
|
)
|
|
|
|
|
|
|
|
self.play(
|
|
|
|
get_row_shift(top_bars, low_bars, 7),
|
|
|
|
sum_bar_group.animate.set_opacity(1.0),
|
|
|
|
run_time=0.5
|
|
|
|
)
|
|
|
|
|
|
|
|
# Distribution labels
|
|
|
|
plabel_kw = dict(tex_to_color_map={"X": BLUE, "Y": RED})
|
|
|
|
PX = Tex("P_X", **plabel_kw)
|
|
|
|
PY = Tex("P_Y", **plabel_kw)
|
|
|
|
PXY = Tex("P_{X + Y}", **plabel_kw)
|
|
|
|
|
|
|
|
PX.next_to(top_bars.get_corner(UR), DR)
|
|
|
|
PY.next_to(low_bars.get_corner(UR), DR)
|
|
|
|
PXY.next_to(sum_bar_group, UP, MED_LARGE_BUFF)
|
|
|
|
|
|
|
|
# Function label
|
|
|
|
func_label = Text("Function", font_size=36)
|
|
|
|
func_label.next_to(PX, UP, LARGE_BUFF, aligned_edge=LEFT)
|
|
|
|
func_label.shift_onto_screen(buff=SMALL_BUFF)
|
|
|
|
arrow = Arrow(func_label, PX.get_top(), buff=0.2)
|
|
|
|
VGroup(func_label, arrow).set_color(YELLOW)
|
|
|
|
x_args = VGroup(*(
|
|
|
|
Tex(
|
|
|
|
f"({x}) = {np.round(dist1[x - 1], 2)}"
|
|
|
|
).next_to(PX, RIGHT, SMALL_BUFF)
|
|
|
|
for x in range(1, 7)
|
|
|
|
))
|
|
|
|
|
|
|
|
# Die rectangles
|
|
|
|
die_rects = VGroup()
|
|
|
|
value_rects = VGroup()
|
|
|
|
for index, x_arg in enumerate(x_args):
|
|
|
|
x_die = top_bars[index].die
|
|
|
|
value_label = top_bars[index].value_label
|
|
|
|
die_rect = SurroundingRectangle(x_die, buff=SMALL_BUFF)
|
|
|
|
value_rect = SurroundingRectangle(value_label, buff=SMALL_BUFF)
|
|
|
|
for rect in die_rect, value_rect:
|
|
|
|
rect.set_stroke(YELLOW, 2).round_corners()
|
|
|
|
die_rects.add(die_rect)
|
|
|
|
value_rects.add(value_rect)
|
|
|
|
|
|
|
|
index = 2
|
|
|
|
x_arg = x_args[index]
|
|
|
|
die_rect = die_rects[index]
|
|
|
|
value_rect = value_rects[index]
|
|
|
|
x_die = top_bars[index].die
|
|
|
|
value_label = top_bars[index].value_label
|
|
|
|
|
|
|
|
# Describe the distribution as a function
|
|
|
|
top_rect = SurroundingRectangle(top_bars)
|
|
|
|
top_rect.set_stroke(BLUE, 3)
|
|
|
|
top_rect.round_corners(radius=0.25)
|
|
|
|
|
|
|
|
self.play(ShowCreation(top_rect))
|
|
|
|
self.play(
|
|
|
|
Write(PX),
|
|
|
|
Write(func_label),
|
|
|
|
ShowCreation(arrow),
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
self.play(ShowCreation(die_rect), FadeOut(top_rect))
|
|
|
|
self.play(FadeTransform(x_die.copy(), x_arg[:3]))
|
|
|
|
self.play(TransformFromCopy(die_rect, value_rect))
|
|
|
|
self.play(FadeTransform(value_label.copy(), x_arg[3:]))
|
|
|
|
self.wait()
|
|
|
|
for i in range(6):
|
|
|
|
self.remove(*die_rects, *value_rects, *x_args)
|
|
|
|
self.add(die_rects[i], value_rects[i], x_args[i])
|
|
|
|
self.wait(0.5)
|
|
|
|
|
|
|
|
# Label other distribution functions
|
|
|
|
func_group = VGroup(func_label, arrow)
|
|
|
|
func_group_Y = func_group.copy().shift(PY.get_center() - PX.get_center())
|
|
|
|
func_group_XY = func_group.copy().shift(PXY.get_center() - PX.get_center())
|
|
|
|
|
|
|
|
self.play(
|
|
|
|
TransformFromCopy(func_group, func_group_Y),
|
|
|
|
Write(PY),
|
|
|
|
FadeOut(die_rects[-1]),
|
|
|
|
FadeOut(value_rects[-1])
|
|
|
|
)
|
|
|
|
self.play(
|
|
|
|
TransformFromCopy(func_group, func_group_XY),
|
|
|
|
Write(PXY),
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
self.play(LaggedStartMap(FadeOut, VGroup(
|
|
|
|
func_group, func_group_Y, func_group_XY, x_args[-1]
|
|
|
|
)))
|
|
|
|
|
|
|
|
# Label convolution
|
|
|
|
sum_bar_group.generate_target()
|
|
|
|
sum_bar_group.target.shift(DOWN)
|
|
|
|
conv_def = Tex(
|
|
|
|
R"\big[P_X * P_Y\big](s) = \sum_{x = 1}^6 P_X(x) \cdot P_Y(s - x)",
|
|
|
|
font_size=36,
|
|
|
|
isolate=["x = 1", "6"],
|
|
|
|
**plabel_kw,
|
|
|
|
)
|
|
|
|
conv_def.next_to(sum_bar_group.target, UP, buff=MED_LARGE_BUFF)
|
|
|
|
|
2023-06-23 10:56:51 -07:00
|
|
|
PXY_arg = Tex("(s)", font_size=36)
|
2023-06-15 10:40:18 -07:00
|
|
|
PXY.generate_target()
|
|
|
|
lhs = conv_def[:10]
|
|
|
|
PXY.target.next_to(lhs, UP, LARGE_BUFF).shift_onto_screen(buff=SMALL_BUFF)
|
2023-06-23 10:56:51 -07:00
|
|
|
PXY_arg.next_to(PXY.target, RIGHT, buff=SMALL_BUFF)
|
2023-06-15 10:40:18 -07:00
|
|
|
eq = Tex("=").rotate(90 * DEGREES)
|
|
|
|
eq.move_to(midpoint(PXY.target.get_bottom(), lhs.get_top()))
|
|
|
|
|
|
|
|
conv_rect = SurroundingRectangle(conv_def["P_X * P_Y"], buff=0.05)
|
|
|
|
conv_rect.set_stroke(YELLOW, 2)
|
|
|
|
conv_word = Text("Convolution")
|
|
|
|
conv_word.match_color(conv_rect)
|
|
|
|
conv_word.next_to(conv_rect, DOWN, buff=SMALL_BUFF)
|
|
|
|
|
|
|
|
self.play(LaggedStart(
|
|
|
|
MoveToTarget(sum_bar_group),
|
|
|
|
MoveToTarget(PXY),
|
2023-06-23 10:56:51 -07:00
|
|
|
FadeIn(PXY_arg, UP),
|
2023-06-15 10:40:18 -07:00
|
|
|
Write(eq),
|
|
|
|
TransformFromCopy(PX, lhs[1:3]),
|
|
|
|
TransformFromCopy(PY, lhs[4:6]),
|
|
|
|
Write(VGroup(lhs[0], lhs[3], *lhs[6:])),
|
|
|
|
))
|
|
|
|
self.wait()
|
|
|
|
self.play(
|
|
|
|
Write(conv_word),
|
|
|
|
ShowCreation(conv_rect),
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
self.play(
|
|
|
|
conv_rect.animate.become(SurroundingRectangle(conv_def["*"], buff=0.05, stroke_width=1))
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
self.play(FadeOut(conv_word), FadeOut(conv_rect))
|
|
|
|
|
|
|
|
self.add(conv_def)
|
|
|
|
conv_def[10:].set_opacity(0)
|
|
|
|
|
|
|
|
# Question right hand side
|
|
|
|
question_rhs = Text("= (What formula goes here?)", font_size=30)
|
|
|
|
question_rhs.next_to(conv_def[:10], RIGHT)
|
|
|
|
self.play(Write(question_rhs))
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
# Show example input of 4
|
|
|
|
ex_rhs = Tex(R"(4) = P_{-}(1)P_{-}(3) + P_{-}(2)P_{-}(2) + P_{-}(3)P_{-}(1)")
|
|
|
|
ex_rhs.scale(0.9)
|
|
|
|
ex_rhs.next_to(PXY, RIGHT, buff=0.1)
|
|
|
|
for n, dot in enumerate(ex_rhs["-"]):
|
|
|
|
even = n % 2 == 0
|
|
|
|
substr = Tex("X" if even else "Y", font_size=24)
|
|
|
|
substr.set_color(BLUE if even else RED)
|
|
|
|
substr.move_to(dot)
|
|
|
|
dot[0].become(substr)
|
|
|
|
ex_rhs[ex_rhs.submobjects.index(dot[0]) + 2].match_color(substr)
|
|
|
|
|
|
|
|
PXY_copy = PXY.copy()
|
|
|
|
PXY_copy.generate_target()
|
|
|
|
VGroup(PXY_copy.target, ex_rhs).to_edge(RIGHT, buff=-0.5)
|
|
|
|
|
|
|
|
eq.generate_target()
|
|
|
|
eq.target.rotate(-90 * DEGREES)
|
|
|
|
eq.target.next_to(conv_def, LEFT)
|
|
|
|
|
|
|
|
PXY.generate_target()
|
|
|
|
PXY.target.next_to(eq.target, LEFT)
|
|
|
|
VGroup(PXY.target, eq.target).align_to(PXY_copy.target, LEFT)
|
|
|
|
|
|
|
|
example_box = SurroundingRectangle(VGroup(PXY_copy.target, ex_rhs))
|
|
|
|
example_box.set_stroke(TEAL, 1)
|
|
|
|
example_words = Text("For example")
|
|
|
|
example_words.match_color(example_box)
|
|
|
|
example_words.next_to(example_box, UP)
|
|
|
|
|
|
|
|
self.play(LaggedStart(
|
|
|
|
MoveToTarget(PXY_copy),
|
|
|
|
MoveToTarget(eq),
|
|
|
|
MoveToTarget(PXY),
|
2023-06-23 10:56:51 -07:00
|
|
|
Transform(PXY_arg, ex_rhs["(4)"], remover=True),
|
2023-06-15 10:40:18 -07:00
|
|
|
conv_def.animate.next_to(eq.target, RIGHT),
|
|
|
|
MaintainPositionRelativeTo(question_rhs, conv_def),
|
|
|
|
FadeIn(ex_rhs, LEFT),
|
|
|
|
PX.animate.next_to(bar_groups[0], LEFT),
|
|
|
|
PY.animate.next_to(bar_groups[1], LEFT),
|
|
|
|
self.frame.animate.set_height(9).move_to(0.5 * UP),
|
|
|
|
sum_bar_group.animate.shift(0.5 * DOWN),
|
|
|
|
ShowCreation(example_box),
|
|
|
|
FadeIn(example_words, UP),
|
|
|
|
run_time=2
|
|
|
|
))
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
example = VGroup(PXY_copy, ex_rhs)
|
|
|
|
self.add(example)
|
|
|
|
|
|
|
|
# Cycle through cases
|
|
|
|
example_box.save_state()
|
|
|
|
for part in ex_rhs[re.compile(R"P[^+]*P[^+]*\)")]:
|
|
|
|
self.play(
|
|
|
|
example_box.animate.replace(part, stretch=True).scale(1.1).set_stroke(width=2),
|
|
|
|
run_time=0.5
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
self.play(FadeOut(example_box))
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
# Show full definition
|
|
|
|
general_words = Text("In general")
|
|
|
|
general_words.next_to(conv_def, UP)
|
|
|
|
general_words.match_x(example_words)
|
|
|
|
general_words.set_color(TEAL)
|
|
|
|
|
|
|
|
example_words.generate_target()
|
|
|
|
example_words.target.scale(0.75)
|
|
|
|
example_words.target.set_y(4.5)
|
|
|
|
example_words.target.set_color(GREY_B)
|
|
|
|
|
|
|
|
conv_def.set_opacity(1)
|
|
|
|
self.play(
|
|
|
|
Write(conv_def[10:]),
|
|
|
|
FadeOut(question_rhs, DOWN),
|
|
|
|
FadeIn(general_words, DOWN),
|
|
|
|
MoveToTarget(example_words),
|
|
|
|
example.animate.scale(0.75).next_to(example_words.target, DOWN),
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
# Talk through formula
|
|
|
|
s_arrow = Vector(0.5 * DOWN, stroke_color=YELLOW)
|
|
|
|
s_arrow.next_to(conv_def["s"][0], UP, SMALL_BUFF)
|
|
|
|
x_arrow, y_arrow = s_arrow.replicate(2)
|
|
|
|
x_arrow.next_to(conv_def["x"][1], UP, SMALL_BUFF)
|
|
|
|
y_arrow.next_to(conv_def["s - x"], UP, SMALL_BUFF)
|
|
|
|
|
|
|
|
self.play(GrowArrow(s_arrow))
|
|
|
|
self.wait()
|
|
|
|
self.play(LaggedStart(
|
|
|
|
TransformFromCopy(s_arrow, x_arrow),
|
|
|
|
TransformFromCopy(s_arrow, y_arrow),
|
|
|
|
))
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
xeq1 = conv_def["x = 1"]
|
|
|
|
y_arrow.save_state()
|
|
|
|
self.play(
|
|
|
|
s_arrow.animate.scale(0.75).rotate(90 * DEGREES).next_to(xeq1, LEFT, SMALL_BUFF),
|
|
|
|
y_arrow.animate.scale(0.75).rotate(-90 * DEGREES).next_to(xeq1, RIGHT, SMALL_BUFF),
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
self.play(
|
|
|
|
s_arrow.animate.next_to(conv_def["6"], LEFT, SMALL_BUFF),
|
|
|
|
y_arrow.animate.next_to(conv_def["6"], RIGHT, SMALL_BUFF),
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
self.play(
|
|
|
|
Restore(y_arrow),
|
|
|
|
FadeOut(s_arrow),
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
self.play(LaggedStart(FadeOut(x_arrow), FadeOut(y_arrow)))
|
|
|
|
|
|
|
|
# Show zero'd example
|
|
|
|
bar_groups.generate_target()
|
|
|
|
bar_groups.target.scale(2 / 3, about_point=top_bars.get_top() + UP)
|
|
|
|
|
|
|
|
example_words = TexText("Plugging in $s = 4$")
|
|
|
|
example_words.next_to(conv_def, DOWN, buff=1.0)
|
|
|
|
|
|
|
|
example = Tex(
|
|
|
|
R"""
|
|
|
|
[P_X * P_Y](4) =
|
|
|
|
&P_X(1) \cdot P_Y(3) \; + \\
|
|
|
|
&P_X(2) \cdot P_Y(2) \; + \\
|
|
|
|
&P_X(3) \cdot P_Y(1) \; + \\
|
|
|
|
&P_X(4) \cdot P_Y(0) \; + \\
|
|
|
|
&P_X(5) \cdot P_Y(-1) \; + \\
|
|
|
|
&P_X(6) \cdot P_Y(-2)
|
|
|
|
""",
|
|
|
|
t2c={
|
|
|
|
"X": BLUE,
|
|
|
|
"Y": RED,
|
|
|
|
},
|
|
|
|
font_size=36
|
|
|
|
)
|
|
|
|
example.next_to(example_words, DOWN, buff=0.5)
|
|
|
|
|
|
|
|
summands = VGroup(*(
|
|
|
|
example[Rf"P_X({x}) \cdot P_Y({4 - x})"]
|
|
|
|
for x in range(1, 7)
|
|
|
|
))
|
|
|
|
plusses = example["+"]
|
|
|
|
plusses.shift(SMALL_BUFF * RIGHT)
|
|
|
|
plusses.add(VectorizedPoint(summands[-1].get_center()))
|
|
|
|
|
|
|
|
self.play(
|
|
|
|
FadeIn(example[R"[P_X * P_Y](4) ="]),
|
|
|
|
FadeOut(v_lines),
|
|
|
|
MoveToTarget(bar_groups),
|
|
|
|
MaintainPositionRelativeTo(PX, bar_groups[0]),
|
|
|
|
Write(example_words),
|
|
|
|
PY.animate.next_to(bar_groups.target[1], RIGHT).match_x(PX),
|
|
|
|
sum_bar_group.animate.scale(0.5).next_to(bar_groups.target, DOWN, buff=0.75, aligned_edge=LEFT),
|
|
|
|
)
|
|
|
|
last_rect = VectorizedPoint(summands[0].get_center())
|
|
|
|
for x, summand, plus in zip(it.count(1), summands, plusses):
|
|
|
|
rect = SurroundingRectangle(VGroup(summand, plus))
|
|
|
|
rect.set_stroke(BLUE, 2)
|
|
|
|
rect.round_corners()
|
|
|
|
rect.add(Tex(Rf"x = {x}").next_to(rect, DOWN, SMALL_BUFF))
|
|
|
|
self.play(
|
|
|
|
FadeTransform(
|
|
|
|
conv_def[R"P_X(x) \cdot P_Y(s - x)"].copy(),
|
|
|
|
summand,
|
|
|
|
),
|
|
|
|
FadeTransform(last_rect, rect),
|
|
|
|
FadeIn(plus),
|
|
|
|
run_time=0.5
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
last_rect = rect
|
|
|
|
self.play(FadeOut(last_rect))
|
|
|
|
|
|
|
|
# Highlight zeros
|
|
|
|
zeroed_terms = VGroup(*(
|
|
|
|
example[Rf"P_Y({n})"]
|
|
|
|
for n in range(0, -4, -1)
|
|
|
|
))
|
|
|
|
zeroed_rect = SurroundingRectangle(zeroed_terms)
|
|
|
|
zeroed_rect.set_stroke(RED, 3)
|
|
|
|
zeroed_rect.stretch(1.3, 0, about_edge=LEFT)
|
|
|
|
zeroed_rect.round_corners()
|
|
|
|
|
|
|
|
eq_zero = Tex("= 0")
|
|
|
|
eq_zero.next_to(zeroed_rect, RIGHT)
|
|
|
|
eq_zero.set_color(RED)
|
|
|
|
|
|
|
|
self.play(ShowCreation(zeroed_rect))
|
|
|
|
self.wait()
|
|
|
|
self.play(Write(eq_zero))
|
|
|
|
self.play(
|
|
|
|
summands[3:].animate.set_opacity(0.35),
|
|
|
|
plusses[3:].animate.set_opacity(0.35),
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
def show_bars_creation(self, bars, lag_ratio=0.05, run_time=3):
|
|
|
|
anims = []
|
|
|
|
for bar in bars:
|
|
|
|
rect, num, face = bar
|
|
|
|
num.rect = rect
|
|
|
|
rect.save_state()
|
|
|
|
rect.stretch(0, 1, about_edge=DOWN)
|
|
|
|
rect.set_opacity(0)
|
|
|
|
|
|
|
|
anims.extend([
|
|
|
|
FadeIn(face),
|
|
|
|
rect.animate.restore(),
|
|
|
|
CountInFrom(num, 0),
|
|
|
|
UpdateFromAlphaFunc(num, lambda m, a: m.next_to(m.rect, UP, SMALL_BUFF).set_opacity(a)),
|
|
|
|
])
|
|
|
|
|
|
|
|
return LaggedStart(*anims, lag_ratio=lag_ratio, run_time=run_time)
|
|
|
|
|
|
|
|
|
|
|
|
class ShowConvolutionOfLists(SumAlongDiagonal):
|
|
|
|
def construct(self):
|
|
|
|
# Set up two distributions
|
|
|
|
dist1 = self.dist1
|
|
|
|
dist2 = self.dist2
|
|
|
|
conv_dist = np.convolve(dist1, dist2)
|
|
|
|
kw = dict(height=1.5)
|
|
|
|
blue_bars, red_bars, sum_bars = bar_groups = VGroup(
|
|
|
|
get_bar_group(dist1, self.dist1_colors, **kw),
|
|
|
|
get_bar_group(dist2, self.dist2_colors, **kw),
|
|
|
|
get_bar_group(conv_dist, self.sum_colors, **kw),
|
|
|
|
)
|
|
|
|
|
|
|
|
# Create equation
|
|
|
|
parens = Tex("()()")
|
|
|
|
parens.stretch(2, 1)
|
|
|
|
parens.match_height(bar_groups)
|
|
|
|
asterisk = Tex("*", font_size=96)
|
|
|
|
equation = VGroup(
|
|
|
|
parens[0], blue_bars, parens[1],
|
|
|
|
asterisk,
|
|
|
|
parens[2], red_bars, parens[3],
|
|
|
|
Tex("=", font_size=96),
|
|
|
|
sum_bars,
|
|
|
|
)
|
|
|
|
equation.arrange(RIGHT)
|
|
|
|
equation.set_width(FRAME_WIDTH - 1)
|
|
|
|
equation.to_edge(UP, buff=1.0)
|
|
|
|
|
|
|
|
self.add(equation)
|
|
|
|
self.remove(sum_bars)
|
|
|
|
self.play(
|
|
|
|
TransformFromCopy(blue_bars, sum_bars, lag_ratio=0.003),
|
|
|
|
TransformFromCopy(red_bars, sum_bars, lag_ratio=0.003),
|
|
|
|
run_time=1.5
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
# Name operation
|
|
|
|
arrow = Vector(0.5 * DOWN)
|
|
|
|
arrow.next_to(asterisk, UP)
|
|
|
|
name = Text("Convolution", font_size=60)
|
|
|
|
name.next_to(arrow, UP)
|
|
|
|
VGroup(arrow, name).set_color(YELLOW)
|
|
|
|
|
|
|
|
self.play(
|
|
|
|
Write(name),
|
|
|
|
GrowArrow(arrow)
|
|
|
|
)
|
|
|
|
self.play(FlashAround(asterisk))
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
# Lists of numbers vs functions
|
|
|
|
list_words = Text("List of numbers", font_size=36).replicate(3)
|
|
|
|
func_words = Text("Function", font_size=36).replicate(3)
|
|
|
|
crosses = VGroup()
|
|
|
|
for list_word, func_word, bar_group in zip(list_words, func_words, bar_groups):
|
|
|
|
list_word.next_to(bar_group, DOWN)
|
|
|
|
func_word.next_to(bar_group, DOWN)
|
|
|
|
crosses.add(Cross(list_word))
|
|
|
|
|
|
|
|
for list_word, bar_group in zip(list_words, bar_groups):
|
|
|
|
self.play(
|
|
|
|
FadeIn(list_word, DOWN),
|
|
|
|
LaggedStart(*(
|
|
|
|
FlashAround(bar[2], time_width=1.5, buff=0.05)
|
|
|
|
for bar in bar_group
|
|
|
|
), lag_ratio=0.03, run_time=2)
|
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
self.play(LaggedStartMap(ShowCreation, crosses, lag_ratio=0.1, run_time=1))
|
|
|
|
self.play(
|
2023-06-23 10:56:51 -07:00
|
|
|
LaggedStartMap(FadeIn, func_words, shift=0.5 * DOWN, scale=0.5, lag_ratio=0.5),
|
2023-06-15 10:40:18 -07:00
|
|
|
list_words.animate.shift(0.5 * DOWN),
|
2023-06-23 10:56:51 -07:00
|
|
|
crosses.animate.shift(0.5 * DOWN),n
|
2023-06-15 10:40:18 -07:00
|
|
|
)
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
# Cycle through appropriate pairs
|
|
|
|
for s in range(len(sum_bars)):
|
|
|
|
bar_groups.set_opacity(0.75)
|
|
|
|
sum_bars[s].set_opacity(1)
|
|
|
|
for x in range(len(blue_bars)):
|
|
|
|
bar_groups[:2].set_opacity(0.75)
|
|
|
|
y = s - x
|
|
|
|
if 0 <= y < len(red_bars):
|
|
|
|
blue_bars[x].set_opacity(1)
|
|
|
|
red_bars[y].set_opacity(1)
|
|
|
|
self.wait(0.25)
|
|
|
|
self.wait(0.5)
|
|
|
|
self.play(bar_groups.animate.set_opacity(1))
|
|
|
|
self.wait()
|
|
|
|
|
|
|
|
|
|
|
|
class ConvolveMatchingDiscreteDistributions(ConvolveDiscreteDistributions):
|
|
|
|
dist1 = EXP_DISTRIBUTION
|
|
|
|
dist2 = EXP_DISTRIBUTION
|
|
|
|
|
|
|
|
|
|
|
|
class RepeatedDiscreteConvolutions(InteractiveScene):
|
|
|
|
distribution = EXP_DISTRIBUTION
|
|
|
|
|
|
|
|
def construct(self):
|
|
|
|
# Divide up space
|
|
|
|
h_lines = Line(LEFT, RIGHT).set_width(FRAME_WIDTH).replicate(4)
|
|
|
|
h_lines.arrange(DOWN, buff=FRAME_HEIGHT / 3).center()
|
|
|
|
h_lines.set_stroke(WHITE, 1)
|
|
|
|
self.add(h_lines[1:3])
|
|
|
|
|
|
|
|
# Initial distributions
|
|
|
|
dist = self.distribution
|
|
|
|
top_bars = self.get_bar_group(dist, colors=(BLUE, TEAL))
|
|
|
|
top_bars.next_to(h_lines[1], UP, SMALL_BUFF)
|
|
|
|
|
|
|
|
low_bars = top_bars.copy()
|
|
|
|
low_bars.set_y(-top_bars.get_y())
|
|
|
|
low_bars.next_to(h_lines[2], UP, SMALL_BUFF)
|
|
|
|
|
|
|
|
VGroup(top_bars, low_bars).shift(2 * LEFT)
|
|
|
|
|
|
|
|
self.add(top_bars)
|
|
|
|
self.add(low_bars)
|
|
|
|
|
|
|
|
# Add labels
|
|
|
|
|
|
|
|
# Repeated convolution
|
|
|
|
self.flip_bar_group(low_bars)
|
|
|
|
low_bars.save_state()
|
|
|
|
for n in range(5):
|
|
|
|
new_bars = self.show_convolution(top_bars, low_bars)
|
|
|
|
self.wait()
|
|
|
|
self.play(
|
|
|
|
new_bars.animate.move_to(top_bars, DL).set_anim_args(path_arc=-120 * DEGREES),
|
|
|
|
FadeOut(top_bars, UP),
|
|
|
|
Restore(low_bars),
|
|
|
|
)
|
|
|
|
# TODO, things with labels
|
|
|
|
|
|
|
|
top_bars = new_bars
|
|
|
|
|
|
|
|
def get_bar_group(
|
|
|
|
self,
|
|
|
|
dist,
|
|
|
|
colors=(BLUE, TEAL),
|
|
|
|
y_unit=4,
|
|
|
|
bar_width=0.35,
|
|
|
|
num_decimal_places=2,
|
|
|
|
min_value=1,
|
|
|
|
):
|
|
|
|
bars = self.get_bars(dist, colors, y_unit, bar_width)
|
|
|
|
result = VGroup(
|
|
|
|
bars,
|
|
|
|
self.get_bar_value_labels(bars, min_value),
|
|
|
|
self.get_bar_prob_labels(bars, dist, num_decimal_places),
|
|
|
|
)
|
|
|
|
result.dist = dist
|
|
|
|
return result
|
|
|
|
|
|
|
|
def get_bars(self, dist, colors=(BLUE, TEAL), y_unit=4, bar_width=0.35):
|
|
|
|
axes = Axes(
|
|
|
|
(0, len(dist)), (0, 1),
|
|
|
|
height=y_unit,
|
|
|
|
width=bar_width * len(dist)
|
|
|
|
)
|
|
|
|
bars = ChartBars(axes, dist, fill_opacity=0.75)
|
|
|
|
bars.set_submobject_colors_by_gradient(*colors)
|
|
|
|
return bars
|
|
|
|
|
|
|
|
def get_bar_value_labels(self, bars, min_value=1):
|
|
|
|
values = VGroup(*(
|
|
|
|
Integer(x + min_value, font_size=16)
|
|
|
|
for x in range(len(bars))
|
|
|
|
))
|
|
|
|
for bar, value in zip(bars, values):
|
|
|
|
value.next_to(bar, DOWN, SMALL_BUFF)
|
|
|
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
def get_bar_prob_labels(self, bars, dist, num_decimal_places=2):
|
|
|
|
probs = VGroup(*(
|
|
|
|
DecimalNumber(p, font_size=16, num_decimal_places=num_decimal_places)
|
|
|
|
for p in dist
|
|
|
|
))
|
|
|
|
for bar, prob in zip(bars, probs):
|
|
|
|
prob.set_max_width(0.75 * bar.get_width())
|
|
|
|
prob.next_to(bar, UP, SMALL_BUFF)
|
|
|
|
|
|
|
|
return probs
|
|
|
|
|
|
|
|
def get_dist_label(self, indices):
|
|
|
|
index_strs = [f"X_{{{i}}}" for i in indices]
|
|
|
|
if len(indices) > 3:
|
|
|
|
index_strs = [index_strs[0], R"\cdots", index_strs[-1]]
|
|
|
|
sub_tex = "+".join(index_strs)
|
|
|
|
return Tex(f"P_{{{sub_tex}}}")
|
|
|
|
|
|
|
|
def flip_bar_group(self, bar_group):
|
|
|
|
bars = bar_group[0]
|
|
|
|
bars.target = bars.generate_target()
|
|
|
|
bars.target.arrange(LEFT, buff=0, aligned_edge=DOWN)
|
|
|
|
bars.target.align_to(bars[0], DR)
|
|
|
|
self.play(
|
|
|
|
MoveToTarget(bars, lag_ratio=0.05, path_arc=0.5),
|
|
|
|
*(
|
|
|
|
MaintainPositionRelativeTo(
|
|
|
|
VGroup(value, prob), bar
|
|
|
|
)
|
|
|
|
for bar, value, prob in zip(*bar_group)
|
|
|
|
),
|
|
|
|
)
|
|
|
|
self.add(bar_group)
|
|
|
|
|
|
|
|
def show_convolution(self, top_bars, low_bars):
|
|
|
|
# New bars
|
|
|
|
new_dist = np.convolve(top_bars.dist, low_bars.dist)
|
|
|
|
new_bars = self.get_bar_group(
|
|
|
|
new_dist,
|
|
|
|
y_unit=8,
|
|
|
|
num_decimal_places=3,
|
|
|
|
min_value=top_bars[1][0].get_value() + low_bars[1][0].get_value(),
|
|
|
|
)
|
|
|
|
new_bars.next_to(BOTTOM, UP)
|
|
|
|
new_bars.align_to(top_bars, LEFT)
|
|
|
|
|
|
|
|
# March!
|
|
|
|
for n in range(len(new_bars[0])):
|
|
|
|
x_diff = top_bars[0][0].get_x() - low_bars[0][0].get_x()
|
|
|
|
x_diff += low_bars[0][0].get_width() * n
|
|
|
|
self.play(
|
|
|
|
low_bars.animate.shift(x_diff * RIGHT),
|
|
|
|
run_time=0.5
|
|
|
|
)
|
|
|
|
index_pairs = [
|
|
|
|
(k, n - k) for k in range(n + 1)
|
|
|
|
if 0 <= n - k < len(low_bars[0])
|
|
|
|
if 0 <= k < len(top_bars[0])
|
|
|
|
]
|
|
|
|
highlights = VGroup(*(
|
|
|
|
VGroup(top_bars[0][i].copy(), low_bars[0][j].copy())
|
|
|
|
for i, j in index_pairs
|
|
|
|
))
|
|
|
|
highlights.set_color(YELLOW)
|
|
|
|
|
|
|
|
conv_rect, value_label, prob_label = (group[n] for group in new_bars)
|
|
|
|
partial_rects = VGroup()
|
|
|
|
partial_labels = VGroup()
|
|
|
|
|
|
|
|
products = [top_bars.dist[i] * low_bars.dist[j] for i, j in index_pairs]
|
|
|
|
for partial_value in np.cumsum(products):
|
|
|
|
rect = conv_rect.copy()
|
|
|
|
rect.stretch(
|
|
|
|
partial_value / new_bars.dist[n],
|
|
|
|
dim=1,
|
|
|
|
about_edge=DOWN,
|
|
|
|
)
|
|
|
|
label = prob_label.copy()
|
|
|
|
label.set_value(partial_value)
|
|
|
|
label.next_to(rect, UP, SMALL_BUFF)
|
|
|
|
partial_rects.add(rect)
|
|
|
|
partial_labels.add(label)
|
|
|
|
|
|
|
|
self.add(value_label)
|
|
|
|
self.play(
|
|
|
|
ShowSubmobjectsOneByOne(highlights, remover=True),
|
|
|
|
ShowSubmobjectsOneByOne(partial_rects, remover=True),
|
|
|
|
ShowSubmobjectsOneByOne(partial_labels, remover=True),
|
|
|
|
run_time=0.15 * len(products)
|
|
|
|
)
|
|
|
|
self.add(*(group[:n + 1] for group in new_bars))
|
|
|
|
self.wait(0.5)
|
|
|
|
|
|
|
|
return new_bars
|