Posterior update animations in eop/bayes

This commit is contained in:
Grant Sanderson 2017-06-07 14:34:39 -07:00
parent f6340f42d7
commit a21c96991f
4 changed files with 684 additions and 317 deletions

View file

@ -43,7 +43,7 @@ class BayesOpeningQuote(OpeningQuote):
class IntroducePokerHand(PiCreatureScene, SampleSpaceScene): class IntroducePokerHand(PiCreatureScene, SampleSpaceScene):
CONFIG = { CONFIG = {
"community_cards_center" : 1.5*DOWN, "community_cards_center" : 1.5*DOWN,
"community_card_values" : ["AS", "QH", "10H", "2C", "5H"], "community_card_values" : ["10S", "QH", "AH", "2C", "5H"],
"your_hand_values" : ["JS", "KC"], "your_hand_values" : ["JS", "KC"],
} }
def construct(self): def construct(self):
@ -118,6 +118,7 @@ class IntroducePokerHand(PiCreatureScene, SampleSpaceScene):
straight_cards.submobjects.sort(card_cmp) straight_cards.submobjects.sort(card_cmp)
straight_cards.arrange_submobjects(RIGHT, buff = SMALL_BUFF) straight_cards.arrange_submobjects(RIGHT, buff = SMALL_BUFF)
straight_cards.next_to(community_cards, UP, aligned_edge = LEFT) straight_cards.next_to(community_cards, UP, aligned_edge = LEFT)
you.hand.target.shift(MED_SMALL_BUFF*UP)
self.play(LaggedStart( self.play(LaggedStart(
MoveToTarget, MoveToTarget,
@ -125,7 +126,6 @@ class IntroducePokerHand(PiCreatureScene, SampleSpaceScene):
run_time = 1.5 run_time = 1.5
)) ))
self.play(MoveToTarget(you.hand)) self.play(MoveToTarget(you.hand))
self.play(you.change, "hooray", straight_cards)
self.play(LaggedStart( self.play(LaggedStart(
ApplyMethod, ApplyMethod,
straight_cards, straight_cards,
@ -135,12 +135,14 @@ class IntroducePokerHand(PiCreatureScene, SampleSpaceScene):
lag_ratio = 0.5, lag_ratio = 0.5,
remover = True, remover = True,
)) ))
self.play(you.change, "hooray", straight_cards)
self.dither(2) self.dither(2)
self.play( self.play(
selected_community_cards.restore, selected_community_cards.restore,
you.hand.restore, you.hand.restore,
you.change_mode, "happy" you.change_mode, "happy"
) )
self.dither()
def show_flush_potential(self): def show_flush_potential(self):
you, her = self.you, self.her you, her = self.you, self.her
@ -156,6 +158,10 @@ class IntroducePokerHand(PiCreatureScene, SampleSpaceScene):
her.hand.target.next_to(heart_cards, UP) her.hand.target.next_to(heart_cards, UP)
her.hand.target.to_edge(UP) her.hand.target.to_edge(UP)
her.glasses.save_state()
her.glasses.move_to(her.hand.target)
her.glasses.set_fill(opacity = 0)
heart_qs = VGroup() heart_qs = VGroup()
hearts = VGroup() hearts = VGroup()
q_marks = VGroup() q_marks = VGroup()
@ -177,7 +183,7 @@ class IntroducePokerHand(PiCreatureScene, SampleSpaceScene):
self.play(LaggedStart(DrawBorderThenFill, heart_qs)) self.play(LaggedStart(DrawBorderThenFill, heart_qs))
self.play( self.play(
her.change, "happy", her.change, "happy",
DrawBorderThenFill(her.glasses) her.glasses.restore,
) )
self.pi_creatures.remove(her) self.pi_creatures.remove(her)
new_suit_pairs = [ new_suit_pairs = [
@ -272,19 +278,20 @@ class IntroducePokerHand(PiCreatureScene, SampleSpaceScene):
sample_space.divide_horizontally( sample_space.divide_horizontally(
p, colors = [SuitSymbol.CONFIG["red"], BLUE_E] p, colors = [SuitSymbol.CONFIG["red"], BLUE_E]
) )
top_label, bottom_label = sample_space.get_side_labels([ braces, labels = sample_space.get_side_braces_and_labels([
percentage.get_tex_string(), "95.5\\%" percentage.get_tex_string(), "95.5\\%"
]) ])
top_label, bottom_label = labels
self.play( self.play(
FadeIn(sample_space), FadeIn(sample_space),
ReplacementTransform(percentage, top_label[1]) ReplacementTransform(percentage, top_label)
) )
self.play(*map(GrowFromCenter, [ self.play(*map(GrowFromCenter, [
label[0] for label in top_label, bottom_label brace for brace in braces
])) ]))
self.dither(2) self.dither(2)
self.play(Write(bottom_label[1])) self.play(Write(bottom_label))
self.dither(2) self.dither(2)
self.sample_space = sample_space self.sample_space = sample_space
@ -297,7 +304,7 @@ class IntroducePokerHand(PiCreatureScene, SampleSpaceScene):
flush_hands, non_flush_hands = hand_lists = [ flush_hands, non_flush_hands = hand_lists = [
[self.get_hand(her, keys) for keys in key_list] [self.get_hand(her, keys) for keys in key_list]
for key_list in [ for key_list in [
[("3H", "8H"), ("4H", "AH"), ("JH", "KH")], [("3H", "8H"), ("4H", "5H"), ("JH", "KH")],
[("AC", "6D"), ("3D", "6S"), ("JH", "4C")], [("AC", "6D"), ("3D", "6S"), ("JH", "4C")],
] ]
] ]
@ -343,10 +350,7 @@ class IntroducePokerHand(PiCreatureScene, SampleSpaceScene):
self.money = money self.money = money
def change_belief(self): def change_belief(self):
numbers = VGroup(*[ numbers = self.sample_space.horizontal_parts.labels
label[1]
for label in self.sample_space.horizontal_parts.labels
])
rect = Rectangle(stroke_width = 0) rect = Rectangle(stroke_width = 0)
rect.set_fill(BLACK, 1) rect.set_fill(BLACK, 1)
rect.stretch_to_fit_width(numbers.get_width()) rect.stretch_to_fit_width(numbers.get_width())
@ -354,11 +358,12 @@ class IntroducePokerHand(PiCreatureScene, SampleSpaceScene):
rect.move_to(numbers, UP) rect.move_to(numbers, UP)
self.play(FadeIn(rect)) self.play(FadeIn(rect))
self.change_horizontal_division( anims = self.get_horizontal_division_change_animations(0.2)
0.2, anims.append(Animation(rect))
self.play(
*anims,
run_time = 3, run_time = 3,
rate_func = there_and_back, rate_func = there_and_back
added_anims = [Animation(rect)]
) )
self.play(FadeOut(rect)) self.play(FadeOut(rect))
@ -371,7 +376,7 @@ class IntroducePokerHand(PiCreatureScene, SampleSpaceScene):
cards.target.move_to(self.deck) cards.target.move_to(self.deck)
cards.target.to_edge(LEFT) cards.target.to_edge(LEFT)
self.sample_space.add(self.sample_space.horizontal_parts.labels) self.sample_space.add_braces_and_labels()
self.play( self.play(
self.deck.scale, 0.7, self.deck.scale, 0.7,
@ -389,9 +394,7 @@ class IntroducePokerHand(PiCreatureScene, SampleSpaceScene):
subtitle.scale(0.8) subtitle.scale(0.8)
subtitle.next_to(title, DOWN) subtitle.next_to(title, DOWN)
prior_word = subtitle.get_part_by_tex("prior") prior_word = subtitle.get_part_by_tex("prior")
numbers = VGroup(*[ numbers = self.sample_space.horizontal_parts.labels
label[1] for label in self.sample_space.horizontal_parts.labels
])
rect = SurroundingRectangle(numbers, color = GREEN) rect = SurroundingRectangle(numbers, color = GREEN)
arrow = Arrow(prior_word.get_bottom(), rect.get_top()) arrow = Arrow(prior_word.get_bottom(), rect.get_top())
arrow.highlight(GREEN) arrow.highlight(GREEN)
@ -431,15 +434,9 @@ class IntroducePokerHand(PiCreatureScene, SampleSpaceScene):
her.to_corner(UP+RIGHT) her.to_corner(UP+RIGHT)
you.make_eye_contact(her) you.make_eye_contact(her)
glasses = SVGMobject(file_name = "sunglasses") glasses = SunGlasses(her)
glasses.set_stroke(WHITE, width = 0)
glasses.set_fill(GREY, 1)
glasses.scale_to_fit_width(
1.1*her.eyes.get_width()
)
glasses.move_to(her.eyes, UP)
her.glasses = glasses her.glasses = glasses
self.you = you self.you = you
self.her = her self.her = her
return VGroup(you, her) return VGroup(you, her)
@ -464,7 +461,6 @@ class IntroducePokerHand(PiCreatureScene, SampleSpaceScene):
) )
return hand return hand
class HowDoesPokerWork(TeacherStudentsScene): class HowDoesPokerWork(TeacherStudentsScene):
def construct(self): def construct(self):
self.student_says( self.student_says(
@ -484,7 +480,566 @@ class YourGutKnowsBayesRule(TeacherStudentsScene):
self.change_student_modes("confused", "gracious", "guilty") self.change_student_modes("confused", "gracious", "guilty")
self.dither(3) self.dither(3)
class UpdatePokerPrior(SampleSpaceScene):
CONFIG = {
"double_heart_template" : "HH",
"cash_string" : "\\$\\$\\$",
}
def construct(self):
self.force_skipping()
self.add_sample_space()
self.add_top_conditionals()
self.react_to_top_conditionals()
self.add_bottom_conditionals()
# self.ask_where_conditionals_come_from()
# self.vary_conditionals()
self.show_restricted_space()
self.write_P_flush_given_bet()
self.reshape_rectangles()
self.compare_prior_to_posterior()
self.tweak_estimates()
self.compute_posterior()
def add_sample_space(self):
p = 1./22
sample_space = SampleSpace(fill_opacity = 0)
sample_space.shift(LEFT)
sample_space.divide_horizontally(p, colors = [
SuitSymbol.CONFIG["red"], BLUE_E
])
labels = self.get_prior_labels(p)
braces_and_labels = sample_space.get_side_braces_and_labels(labels)
self.play(
DrawBorderThenFill(sample_space),
Write(braces_and_labels)
)
self.dither()
sample_space.add(braces_and_labels)
self.sample_space = sample_space
def add_top_conditionals(self):
top_part = self.sample_space.horizontal_parts[0]
color = average_color(YELLOW, GREEN, GREEN)
p = 0.97
top_part.divide_vertically(p, colors = [color, BLUE])
label = self.get_conditional_label(p, True)
brace, _ignore = top_part.get_top_braces_and_labels([label])
explanation = TextMobject(
"Probability of", "high bet", "given", "flush"
)
explanation.highlight_by_tex("high bet", GREEN)
explanation.highlight_by_tex("flush", RED)
explanation.scale(0.6)
explanation.next_to(label, UP)
self.play(
FadeIn(top_part.vertical_parts),
Write(explanation, run_time = 3),
GrowFromCenter(brace),
)
self.play(LaggedStart(FadeIn, label, run_time = 2, lag_ratio = 0.7))
self.dither(2)
self.sample_space.add(brace, label)
self.top_explanation = explanation
self.top_conditional_rhs = label[-1]
def react_to_top_conditionals(self):
her = PiCreature(color = BLUE_B).flip()
her.next_to(self.sample_space, RIGHT)
her.to_edge(RIGHT)
glasses = SunGlasses(her)
glasses.save_state()
glasses.shift(UP)
glasses.set_fill(opacity = 0)
her.glasses = glasses
self.play(FadeIn(her))
self.play(glasses.restore)
self.play(
her.change_mode, "happy",
Animation(glasses)
)
self.dither(2)
self.her = her
def add_bottom_conditionals(self):
her = self.her
bottom_part = self.sample_space.horizontal_parts[1]
p = 0.3
bottom_part.divide_vertically(p, colors = [GREEN_E, BLUE_E])
label = self.get_conditional_label(p, False)
brace, _ignore = bottom_part.get_bottom_braces_and_labels([label])
explanation = TextMobject(
"Probability of", "high bet", "given", "no flush"
)
explanation.highlight_by_tex("high bet", GREEN)
explanation.highlight_by_tex("no flush", RED)
explanation.scale(0.6)
explanation.next_to(label, DOWN)
self.play(DrawBorderThenFill(bottom_part.vertical_parts))
self.play(GrowFromCenter(brace))
self.play(
her.change_mode, "shruggie",
MaintainPositionRelativeTo(her.glasses, her.eyes)
)
self.play(Write(explanation))
self.dither()
self.play(*[
ReplacementTransform(
VGroup(explanation[j].copy()),
VGroup(*label[i1:i2]),
run_time = 2,
rate_func = squish_rate_func(smooth, a, a+0.5)
)
for a, (i1, i2, j) in zip(np.linspace(0, 0.5, 4), [
(0, 1, 0),
(1, 2, 1),
(2, 3, 2),
(3, 6, 3),
])
])
self.play(Write(VGroup(*label[-2:])))
self.dither(2)
self.play(*map(FadeOut, [her, her.glasses]))
self.sample_space.add(brace, label)
self.bottom_explanation = explanation
self.bottom_conditional_rhs = label[-1]
def ask_where_conditionals_come_from(self):
randy = Randolph().flip()
randy.scale(0.75)
randy.to_edge(RIGHT)
randy.shift(2*DOWN)
words = TextMobject("Where do these \\\\", "numbers", "come from?")
numbers_word = words.get_part_by_tex("numbers")
numbers_word.highlight(YELLOW)
words.scale(0.7)
bubble = ThoughtBubble(height = 3, width = 4)
bubble.pin_to(randy)
bubble.shift(MED_LARGE_BUFF*RIGHT)
bubble.add_content(words)
numbers = VGroup(
self.top_conditional_rhs,
self.bottom_conditional_rhs
)
numbers.save_state()
arrows = VGroup(*[
Arrow(
numbers_word.get_left(),
num.get_right(),
buff = 2*SMALL_BUFF
)
for num in numbers
])
questions = VGroup(*map(TextMobject, [
"Does she bluff?",
"How much does she have?",
"Does she take risks?",
"What's her model of me?",
"\\vdots"
]))
questions.arrange_submobjects(DOWN, aligned_edge = LEFT)
questions[-1].next_to(questions[-2], DOWN)
questions.scale(0.7)
questions.next_to(randy, UP)
questions.shift_onto_screen()
self.play(
randy.change_mode, "confused",
ShowCreation(bubble),
Write(words, run_time = 2)
)
self.play(*map(ShowCreation, arrows))
self.play(numbers.highlight, YELLOW)
self.play(Blink(randy))
self.play(randy.change_mode, "maybe")
self.play(*map(FadeOut, [
bubble, words, arrows
]))
for question in questions:
self.play(
FadeIn(question),
randy.look_at, question
)
self.dither()
self.play(Blink(randy))
self.dither()
self.play(
randy.change_mode, "pondering",
FadeOut(questions)
)
self.randy = randy
def vary_conditionals(self):
randy = self.randy
rects = VGroup(*[
SurroundingRectangle(
VGroup(explanation),
buff = SMALL_BUFF,
)
for explanation, rhs in zip(
[self.top_explanation, self.bottom_explanation],
[self.top_conditional_rhs, self.bottom_conditional_rhs],
)
])
new_conditionals = [
(0.91, 0.4),
(0.83, 0.1),
(0.99, 0.2),
(0.97, 0.3),
]
self.play(*map(ShowCreation, rects))
self.play(FadeOut(rects))
for i, value in enumerate(it.chain(*new_conditionals)):
self.play(
randy.look_at, rects[i%2],
*self.get_conditional_change_anims(i%2, value)
)
if i%2 == 1:
self.dither()
self.play(FadeOut(randy))
def show_restricted_space(self):
high_bet_space, low_bet_space = [
VGroup(*[
self.sample_space.horizontal_parts[i].vertical_parts[j]
for i in range(2)
])
for j in range(2)
]
words = TexMobject("P(", self.cash_string, ")")
words.highlight_by_tex(self.cash_string, GREEN)
words.next_to(self.sample_space, RIGHT)
low_bet_space.generate_target()
for submob in low_bet_space.target:
submob.highlight(average_color(
submob.get_color(), *[BLACK]*4
))
arrows = VGroup(*[
Arrow(
words.get_left(),
submob.get_edge_center(vect),
color = submob.get_color()
)
for submob, vect in zip(high_bet_space, [DOWN, RIGHT])
])
self.play(MoveToTarget(low_bet_space))
self.play(
Write(words),
*map(ShowCreation, arrows)
)
self.dither()
for rect in high_bet_space:
self.play(Indicate(rect, scale_factor = 1))
self.play(*map(FadeOut, [words, arrows]))
self.high_bet_space = high_bet_space
def write_P_flush_given_bet(self):
posterior_tex = TexMobject(
"P(", self.double_heart_template,
"|", self.cash_string, ")"
)
posterior_tex.scale(0.7)
posterior_tex.highlight_by_tex(self.cash_string, GREEN)
self.insert_double_heart(posterior_tex)
rects = self.high_bet_space.copy()
rects = [rects[0].copy()] + list(rects)
for rect in rects:
rect.generate_target()
numerator = rects[0].target
plus = TexMobject("+")
denominator = VGroup(rects[1].target, plus, rects[2].target)
denominator.arrange_submobjects(RIGHT, buff = SMALL_BUFF)
frac_line = TexMobject("\\over")
frac_line.stretch_to_fit_width(denominator.get_width())
fraction = VGroup(numerator, frac_line, denominator)
fraction.arrange_submobjects(DOWN)
arrow = TexMobject("\\downarrow")
group = VGroup(posterior_tex, arrow, fraction)
group.arrange_submobjects(DOWN)
group.to_corner(UP+RIGHT)
self.play(Write(posterior_tex))
self.play(Write(arrow))
self.play(MoveToTarget(rects[0]))
self.dither()
self.play(*it.chain(
map(Write, [frac_line, plus]),
map(MoveToTarget, rects[1:])
))
self.dither(3)
self.posterior_tex = posterior_tex
self.to_fade = VGroup(arrow, frac_line, plus)
self.to_post_rects = VGroup(VGroup(*rects[:2]),rects[2])
def reshape_rectangles(self):
post_rects = self.get_posterior_rectangles()
braces, labels = self.get_posterior_rectangle_braces_and_labels(
post_rects, [self.posterior_tex.copy()]
)
height_rect = SurroundingRectangle(braces)
self.play(
FadeOut(self.to_fade),
ReplacementTransform(
self.to_post_rects, post_rects,
run_time = 2,
),
)
self.dither(2)
self.play(ReplacementTransform(self.posterior_tex, labels[0]))
self.posterior_tex = labels[0]
self.play(GrowFromCenter(braces))
self.dither()
self.play(ShowCreation(height_rect))
self.play(FadeOut(height_rect))
self.dither()
self.post_rects = post_rects
def compare_prior_to_posterior(self):
prior_tex = self.sample_space.horizontal_parts.labels[0]
post_tex = self.posterior_tex
prior_rect, post_rect = [
SurroundingRectangle(tex, stroke_width = 2)
for tex in [prior_tex, post_tex]
]
post_words = TextMobject("Posterior", "probability")
post_words.scale(0.8)
post_words.to_corner(UP+RIGHT)
post_arrow = Arrow(
post_words[0].get_bottom(), post_tex.get_top(),
color = WHITE
)
self.play(ShowCreation(prior_rect))
self.dither()
self.play(ReplacementTransform(prior_rect, post_rect))
self.dither()
self.play(FadeOut(post_rect))
self.play(Indicate(post_tex.get_part_by_tex(self.cash_string)))
self.dither()
self.play(
Write(post_words),
ShowCreation(post_arrow)
)
self.dither()
self.play(post_words[1].fade, 0.8)
self.dither(2)
self.play(*map(FadeOut, [post_words, post_arrow]))
def tweak_estimates(self):
post_rects = self.post_rects
self.revert_to_original_skipping_status()
self.preview_tweaks(post_rects)
def preview_tweaks(self, post_rects):
new_value_lists = [
(0.85, 0.1, 0.11),
(0.97, 0.3, 1./22),
]
for new_values in new_value_lists:
for i, value in zip(range(2), new_values):
self.play(*self.get_conditional_change_anims(
i, value, post_rects
))
self.play(*self.get_prior_change_anims(
new_values[-1], post_rects
))
self.dither()
def compute_posterior(self):
pass
######
def get_prior_labels(self, value):
p_str = "%0.3f"%value
q_str = "%0.3f"%(1-value)
labels = [
TexMobject(
"P(", s, self.double_heart_template, ")",
"= ", num
)
for s, num in ("", p_str), ("\\text{not }", q_str)
]
for label in labels:
label.scale(0.7)
self.insert_double_heart(label)
return labels
def get_conditional_label(self, value, given_flush = True):
label = TexMobject(
"P(", self.cash_string, "|",
"" if given_flush else "\\text{not }",
self.double_heart_template, ")",
"=", str(value)
)
self.insert_double_heart(label)
label.highlight_by_tex(self.cash_string, GREEN)
label.scale(0.7)
return label
def insert_double_heart(self, tex_mob):
double_heart = SuitSymbol("hearts")
double_heart.add(SuitSymbol("hearts"))
double_heart.arrange_submobjects(RIGHT, buff = SMALL_BUFF)
double_heart.get_tex_string = lambda : self.double_heart_template
template = tex_mob.get_part_by_tex(self.double_heart_template)
double_heart.replace(template)
tex_mob.submobjects[tex_mob.index_of_part(template)] = double_heart
return tex_mob
def get_prior_change_anims(self, value, post_rects = None):
space = self.sample_space
parts = space.horizontal_parts
anims = self.get_horizontal_division_change_animations(
value, new_label_kwargs = {
"labels" : self.get_prior_labels(value)
}
)
if post_rects is not None:
anims += self.get_posterior_rectangle_change_anims(post_rects)
return anims
def get_conditional_change_anims(
self, sub_sample_space_index, value,
post_rects = None
):
parts = self.sample_space.horizontal_parts
sub_sample_space = parts[sub_sample_space_index]
given_flush = (sub_sample_space_index == 0)
label = self.get_conditional_label(value, given_flush)
anims = self.get_division_change_animations(
sub_sample_space, sub_sample_space.vertical_parts, value,
dimension = 0,
new_label_kwargs = {"labels" : [label]},
)
if post_rects is not None:
anims += self.get_posterior_rectangle_change_anims(post_rects)
return anims
def get_top_conditional_change_anims(self, *args, **kwargs):
return self.get_conditional_change_anims(0, *args, **kwargs)
def get_bottom_conditional_change_anims(self, *args, **kwargs):
return self.get_conditional_change_anims(1, *args, **kwargs)
def get_prior_rectangles(self):
return VGroup(*[
self.sample_space.horizontal_parts[i].vertical_parts[0]
for i in range(2)
])
def get_posterior_rectangles(self):
prior_rects = self.get_prior_rectangles()
areas = [
rect.get_width()*rect.get_height()
for rect in prior_rects
]
total_area = sum(areas)
total_height = prior_rects.get_height()
post_rects = prior_rects.copy()
for rect, area in zip(post_rects, areas):
rect.stretch_to_fit_height(total_height * area/total_area)
rect.stretch_to_fit_width(
area/rect.get_height()
)
post_rects.arrange_submobjects(DOWN, buff = 0)
post_rects.next_to(
self.sample_space.full_space, RIGHT, MED_LARGE_BUFF
)
return post_rects
def get_posterior_rectangle_braces_and_labels(self, post_rects, labels):
braces = VGroup()
label_mobs = VGroup()
for label, rect in zip(labels, post_rects):
if not isinstance(label, Mobject):
label_mob = TexMobject(label)
label_mob.scale(0.7)
else:
label_mob = label
brace = Brace(
rect, RIGHT,
buff = SMALL_BUFF,
min_num_quads = 2
)
label_mob.next_to(brace, RIGHT, SMALL_BUFF)
label_mobs.add(label_mob)
braces.add(brace)
post_rects.braces = braces
post_rects.labels = label_mobs
return VGroup(braces, label_mobs)
def update_posterior_braces(self, post_rects):
braces = post_rects.braces
labels = post_rects.labels
for rect, brace, label in zip(post_rects, braces, labels):
brace.stretch_to_fit_height(rect.get_height())
brace.next_to(rect, RIGHT, SMALL_BUFF)
label.next_to(brace, RIGHT, SMALL_BUFF)
def get_posterior_rectangle_change_anims(self, post_rects):
def update_rects(rects):
new_rects = self.get_posterior_rectangles()
Transform(rects, new_rects).update(1)
if hasattr(rects, "braces"):
self.update_posterior_braces(rects)
return rects
anims = [UpdateFromFunc(post_rects, update_rects)]
if hasattr(post_rects, "braces"):
anims += map(Animation, [
post_rects.labels, post_rects.braces
])
return anims
class NextVideoWrapper(TeacherStudentsScene):
def construct(self):
title = TextMobject("Next video: Bayesian networks")
title.scale(0.8)
title.to_edge(UP, buff = SMALL_BUFF)
screen = ScreenRectangle(height = 4)
screen.next_to(title, DOWN)
title.save_state()
title.shift(DOWN)
title.set_fill(opacity = 0)
self.play(
title.restore,
self.teacher.change, "raise_right_hand"
)
self.play(ShowCreation(screen))
self.change_student_modes(*["pondering"]*3)
self.play(Animation(screen))
self.dither(5)

View file

@ -376,6 +376,8 @@ class VGroup(VMobject):
class VectorizedPoint(VMobject): class VectorizedPoint(VMobject):
CONFIG = { CONFIG = {
"color" : BLACK, "color" : BLACK,
"fill_opacity" : 0,
"stroke_width" : 0,
"artificial_width" : 0.01, "artificial_width" : 0.01,
"artificial_height" : 0.01, "artificial_height" : 0.01,
} }

View file

@ -11,253 +11,20 @@ from animation.simple_animations import Rotating
from topics.geometry import Circle, Line, Rectangle, Square, Arc, Polygon from topics.geometry import Circle, Line, Rectangle, Square, Arc, Polygon
from topics.three_dimensions import Cube from topics.three_dimensions import Cube
class DeckOfCards(VGroup): class SunGlasses(SVGMobject):
def __init__(self, **kwargs):
possible_values = map(str, range(1, 11)) + ["J", "Q", "K"]
possible_suits = ["hearts", "diamonds", "spades", "clubs"]
VGroup.__init__(self, *[
PlayingCard(value = value, suit = suit, **kwargs)
for value in possible_values
for suit in possible_suits
])
class PlayingCard(VGroup):
CONFIG = { CONFIG = {
"value" : None, "file_name" : "sunglasses",
"suit" : None, "glasses_width_to_eyes_width" : 1.1,
"key" : None, ##String like "8H" or "KS"
"height" : 2,
"height_to_width" : 3.5/2.5,
"card_height_to_symbol_height" : 7,
"card_width_to_corner_num_width" : 10,
"card_height_to_corner_num_height" : 10,
"color" : LIGHT_GREY,
"turned_over" : False,
"possible_suits" : ["hearts", "diamonds", "spades", "clubs"],
"possible_values" : map(str, range(2, 11)) + ["J", "Q", "K", "A"],
} }
def __init__(self, pi_creature, **kwargs):
def __init__(self, key = None, **kwargs): SVGMobject.__init__(self, **kwargs)
VGroup.__init__(self, key = key, **kwargs) self.set_stroke(WHITE, width = 0)
self.set_fill(GREY, 1)
def generate_points(self): self.scale_to_fit_width(
self.add(Rectangle( self.glasses_width_to_eyes_width*pi_creature.eyes.get_width()
height = self.height,
width = self.height/self.height_to_width,
stroke_color = WHITE,
stroke_width = 2,
fill_color = self.color,
fill_opacity = 1,
))
if self.turned_over:
self.set_fill(DARK_GREY)
self.set_stroke(LIGHT_GREY)
contents = VectorizedPoint(self.get_center())
else:
value = self.get_value()
symbol = self.get_symbol()
design = self.get_design(value, symbol)
corner_numbers = self.get_corner_numbers(value, symbol)
contents = VGroup(design, corner_numbers)
self.design = design
self.corner_numbers = corner_numbers
self.add(contents)
def get_value(self):
value = self.value
if value is None:
if self.key is not None:
value = self.key[:-1]
else:
value = random.choice(self.possible_values)
value = string.upper(str(value))
if value == "1":
value = "A"
if value not in self.possible_values:
raise Exception("Invalid card value")
face_card_to_value = {
"J" : 11,
"Q" : 12,
"K" : 13,
"A" : 14,
}
try:
self.numerical_value = int(value)
except:
self.numerical_value = face_card_to_value[value]
return value
def get_symbol(self):
suit = self.suit
if suit is None:
if self.key is not None:
suit = dict([
(string.upper(s[0]), s)
for s in self.possible_suits
])[string.upper(self.key[-1])]
else:
suit = random.choice(self.possible_suits)
if suit not in self.possible_suits:
raise Exception("Invalud suit value")
self.suit = suit
symbol_height = float(self.height) / self.card_height_to_symbol_height
symbol = SuitSymbol(suit, height = symbol_height)
return symbol
def get_design(self, value, symbol):
if value == "A":
return self.get_ace_design(symbol)
if value in map(str, range(2, 11)):
return self.get_number_design(value, symbol)
else:
return self.get_face_card_design(value, symbol)
def get_ace_design(self, symbol):
design = symbol.copy().scale(1.5)
design.move_to(self)
return design
def get_number_design(self, value, symbol):
num = int(value)
n_rows = {
2 : 2,
3 : 3,
4 : 2,
5 : 2,
6 : 3,
7 : 3,
8 : 3,
9 : 4,
10 : 4,
}[num]
n_cols = 1 if num in [2, 3] else 2
insertion_indices = {
5 : [0],
7 : [0],
8 : [0, 1],
9 : [1],
10 : [0, 2],
}.get(num, [])
top = self.get_top() + symbol.get_height()*DOWN
bottom = self.get_bottom() + symbol.get_height()*UP
column_points = [
interpolate(top, bottom, alpha)
for alpha in np.linspace(0, 1, n_rows)
]
design = VGroup(*[
symbol.copy().move_to(point)
for point in column_points
])
if n_cols == 2:
space = 0.2*self.get_width()
column_copy = design.copy().shift(space*RIGHT)
design.shift(space*LEFT)
design.add(*column_copy)
design.add(*[
symbol.copy().move_to(
center_of_mass(column_points[i:i+2])
)
for i in insertion_indices
])
for symbol in design:
if symbol.get_center()[1] < self.get_center()[1]:
symbol.rotate_in_place(np.pi)
return design
def get_face_card_design(self, value, symbol):
from topics.characters import PiCreature
sub_rect = Rectangle(
stroke_color = BLACK,
fill_opacity = 0,
height = 0.9*self.get_height(),
width = 0.6*self.get_width(),
) )
sub_rect.move_to(self) self.move_to(pi_creature.eyes, UP)
pi_color = average_color(symbol.get_color(), GREY)
pi_mode = {
"J" : "plain",
"Q" : "thinking",
"K" : "hooray"
}[value]
pi_creature = PiCreature(
mode = pi_mode,
color = pi_color,
)
pi_creature.scale_to_fit_width(0.8*sub_rect.get_width())
if value in ["Q", "K"]:
prefix = "king" if value == "K" else "queen"
crown = SVGMobject(file_name = prefix + "_crown")
crown.set_stroke(width = 0)
crown.set_fill(YELLOW, 1)
crown.stretch_to_fit_width(0.5*sub_rect.get_width())
crown.stretch_to_fit_height(0.17*sub_rect.get_height())
crown.move_to(pi_creature.eyes.get_center(), DOWN)
pi_creature.add_to_back(crown)
to_top_buff = 0
else:
to_top_buff = SMALL_BUFF*sub_rect.get_height()
pi_creature.next_to(sub_rect.get_top(), DOWN, to_top_buff)
# pi_creature.shift(0.05*sub_rect.get_width()*RIGHT)
pi_copy = pi_creature.copy()
pi_copy.rotate(np.pi, about_point = sub_rect.get_center())
return VGroup(sub_rect, pi_creature, pi_copy)
def get_corner_numbers(self, value, symbol):
value_mob = TextMobject(value)
width = self.get_width()/self.card_width_to_corner_num_width
height = self.get_height()/self.card_height_to_corner_num_height
value_mob.scale_to_fit_width(width)
value_mob.stretch_to_fit_height(height)
value_mob.next_to(
self.get_corner(UP+LEFT), DOWN+RIGHT,
buff = MED_LARGE_BUFF*width
)
value_mob.highlight(symbol.get_color())
corner_symbol = symbol.copy()
corner_symbol.scale_to_fit_width(width)
corner_symbol.next_to(
value_mob, DOWN,
buff = MED_SMALL_BUFF*width
)
corner_group = VGroup(value_mob, corner_symbol)
opposite_corner_group = corner_group.copy()
opposite_corner_group.rotate(
np.pi, about_point = self.get_center()
)
return VGroup(corner_group, opposite_corner_group)
class SuitSymbol(SVGMobject):
CONFIG = {
"height" : 0.5,
"fill_opacity" : 1,
"stroke_width" : 0,
"red" : "#D02028",
"black" : BLACK,
}
def __init__(self, suit_name, **kwargs):
digest_config(self, kwargs)
suits_to_colors = {
"hearts" : self.red,
"diamonds" : self.red,
"spades" : self.black,
"clubs" : self.black,
}
if suit_name not in suits_to_colors:
raise Exception("Invalid suit name")
SVGMobject.__init__(self, file_name = suit_name, **kwargs)
color = suits_to_colors[suit_name]
self.set_stroke(width = 0)
self.set_fill(color, 1)
self.scale_to_fit_height(self.height)
class Speedometer(VMobject): class Speedometer(VMobject):
CONFIG = { CONFIG = {

View file

@ -3,7 +3,7 @@ from helpers import *
from scene import Scene from scene import Scene
from animation.animation import Animation from animation.animation import Animation
from animation.transform import Transform from animation.transform import Transform, MoveToTarget
from mobject import Mobject from mobject import Mobject
from mobject.vectorized_mobject import VGroup, VMobject, VectorizedPoint from mobject.vectorized_mobject import VGroup, VMobject, VectorizedPoint
@ -22,30 +22,53 @@ class SampleSpaceScene(Scene):
def add_sample_space(self, **config): def add_sample_space(self, **config):
self.add(self.get_sample_space(**config)) self.add(self.get_sample_space(**config))
def change_horizontal_division(self, p_list, **kwargs): def get_division_change_animations(
assert(hasattr(self.sample_space, "horizontal_parts")) self, sample_space, parts, p_list,
added_anims = kwargs.pop("added_anims", []) dimension = 1,
new_division_kwargs = kwargs.pop("new_division_kwargs", {}) new_label_kwargs = None,
added_label_kwargs = kwargs.pop("label_kwargs", {}) **kwargs
):
if new_label_kwargs is None:
new_label_kwargs = {}
anims = []
p_list = sample_space.complete_p_list(p_list)
full_space = sample_space.full_space
curr_parts = self.sample_space.horizontal_parts vect = DOWN if dimension == 1 else RIGHT
new_division_kwargs["colors"] = [ parts.generate_target()
part.get_color() for part in curr_parts for part, p in zip(parts.target, p_list):
] part.replace(full_space, stretch = True)
new_parts = self.sample_space.get_horizontal_division( part.stretch(p, dimension)
p_list, **new_division_kwargs parts.target.arrange_submobjects(vect, buff = 0)
) parts.target.move_to(full_space)
anims = [Transform(curr_parts, new_parts)] anims.append(MoveToTarget(parts))
if hasattr(curr_parts, "labels"): if hasattr(parts, "labels"):
label_kwargs = curr_parts.label_kwargs label_kwargs = parts.label_kwargs
label_kwargs.update(added_label_kwargs) label_kwargs.update(new_label_kwargs)
new_labels = self.sample_space.get_subdivision_labels( new_braces, new_labels = sample_space.get_subdivision_braces_and_labels(
new_parts, **label_kwargs parts.target, **label_kwargs
) )
anims.append(Transform(curr_parts.labels, new_labels)) anims += [
anims += added_anims Transform(parts.braces, new_braces),
Transform(parts.labels, new_labels),
]
return anims
self.play(*anims, **kwargs) def get_horizontal_division_change_animations(self, p_list, **kwargs):
assert(hasattr(self.sample_space, "horizontal_parts"))
return self.get_division_change_animations(
self.sample_space, self.sample_space.horizontal_parts, p_list,
dimension = 1,
**kwargs
)
def get_vertical_division_change_animations(self, p_list, **kwargs):
assert(hasattr(self.sample_space, "vertical_parts"))
return self.get_division_change_animations(
self.sample_space, self.sample_space.vertical_parts, p_list,
dimension = 0,
**kwargs
)
class SampleSpace(VGroup): class SampleSpace(VGroup):
@ -55,7 +78,8 @@ class SampleSpace(VGroup):
"width" : 3, "width" : 3,
"fill_color" : DARK_GREY, "fill_color" : DARK_GREY,
"fill_opacity" : 0.8, "fill_opacity" : 0.8,
"stroke_width" : 0, "stroke_width" : 0.5,
"stroke_color" : LIGHT_GREY,
}, },
"default_label_scale_val" : 0.7, "default_label_scale_val" : 0.7,
} }
@ -76,12 +100,16 @@ class SampleSpace(VGroup):
def add_label(self, label): def add_label(self, label):
self.label = label self.label = label
def complete_p_list(self, p_list):
new_p_list = list(tuplify(p_list))
remainder = 1.0 - sum(new_p_list)
if abs(remainder) > EPSILON:
new_p_list.append(remainder)
return new_p_list
def get_division_along_dimension(self, p_list, dim, colors, vect): def get_division_along_dimension(self, p_list, dim, colors, vect):
p_list = list(tuplify(p_list)) p_list = self.complete_p_list(p_list)
if abs(1.0 - sum(p_list)) > EPSILON:
p_list.append(1.0 - sum(p_list))
colors = color_gradient(colors, len(p_list)) colors = color_gradient(colors, len(p_list))
perp_dim = 1-dim
last_point = self.full_space.get_edge_center(-vect) last_point = self.full_space.get_edge_center(-vect)
parts = VGroup() parts = VGroup()
@ -89,7 +117,7 @@ class SampleSpace(VGroup):
part = SampleSpace() part = SampleSpace()
part.set_fill(color, 1) part.set_fill(color, 1)
part.replace(self.full_space, stretch = True) part.replace(self.full_space, stretch = True)
part.stretch(factor, perp_dim) part.stretch(factor, dim)
part.move_to(last_point, -vect) part.move_to(last_point, -vect)
last_point = part.get_edge_center(vect) last_point = part.get_edge_center(vect)
parts.add(part) parts.add(part)
@ -100,14 +128,14 @@ class SampleSpace(VGroup):
colors = [GREEN_E, BLUE], colors = [GREEN_E, BLUE],
vect = DOWN vect = DOWN
): ):
return self.get_division_along_dimension(p_list, 0, colors, vect) return self.get_division_along_dimension(p_list, 1, colors, vect)
def get_vertical_division( def get_vertical_division(
self, p_list, self, p_list,
colors = [MAROON_B, YELLOW], colors = [MAROON_B, YELLOW],
vect = RIGHT vect = RIGHT
): ):
return self.get_division_along_dimension(p_list, 1, colors, vect) return self.get_division_along_dimension(p_list, 0, colors, vect)
def divide_horizontally(self, *args, **kwargs): def divide_horizontally(self, *args, **kwargs):
self.horizontal_parts = self.get_horizontal_division(*args, **kwargs) self.horizontal_parts = self.get_horizontal_division(*args, **kwargs)
@ -117,38 +145,53 @@ class SampleSpace(VGroup):
self.vertical_parts = self.get_vertical_division(*args, **kwargs) self.vertical_parts = self.get_vertical_division(*args, **kwargs)
self.add(self.vertical_parts) self.add(self.vertical_parts)
def get_subdivision_labels(self, parts, labels, direction, buff = SMALL_BUFF): def get_subdivision_braces_and_labels(self, parts, labels, direction, buff = SMALL_BUFF):
label_brace_groups = VGroup() label_brace_groups = VGroup()
label_mobs = VGroup()
braces = VGroup()
for label, part in zip(labels, parts): for label, part in zip(labels, parts):
brace = Brace(part, direction, min_num_quads = 1, buff = buff) brace = Brace(part, direction, min_num_quads = 1, buff = buff)
label_mob = TexMobject(label) if isinstance(label, Mobject):
label_mob.scale(self.default_label_scale_val) label_mob = label
else:
label_mob = TexMobject(label)
label_mob.scale(self.default_label_scale_val)
label_mob.next_to(brace, direction, buff) label_mob.next_to(brace, direction, buff)
full_label = VGroup(brace, label_mob)
part.add_label(full_label) braces.add(brace)
label_brace_groups.add(full_label) label_mobs.add(label_mob)
parts.labels = label_brace_groups parts.braces = braces
parts.labels = label_mobs
parts.label_kwargs = { parts.label_kwargs = {
"labels" : labels, "labels" : labels,
"direction" : direction, "direction" : direction,
"buff" : buff, "buff" : buff,
} }
return label_brace_groups return VGroup(parts.braces, parts.labels)
def get_side_labels(self, labels, direction = LEFT, **kwargs): def get_side_braces_and_labels(self, labels, direction = LEFT, **kwargs):
assert(hasattr(self, "horizontal_parts")) assert(hasattr(self, "horizontal_parts"))
parts = self.horizontal_parts parts = self.horizontal_parts
return self.get_subdivision_labels(parts, labels, direction, **kwargs) return self.get_subdivision_braces_and_labels(parts, labels, direction, **kwargs)
def get_top_labels(self, labels, **kwargs): def get_top_braces_and_labels(self, labels, **kwargs):
assert(hasattr(self, "vertical_parts")) assert(hasattr(self, "vertical_parts"))
parts = self.vertical_parts parts = self.vertical_parts
return self.get_subdivision_labels(parts, labels, UP, **kwargs) return self.get_subdivision_braces_and_labels(parts, labels, UP, **kwargs)
def get_bototm_labels(self, labels, **kwargs): def get_bottom_braces_and_labels(self, labels, **kwargs):
assert(hasattr(self, "vertical_parts")) assert(hasattr(self, "vertical_parts"))
parts = self.vertical_parts parts = self.vertical_parts
return self.get_subdivision_labels(parts, labels, DOWN, **kwargs) return self.get_subdivision_braces_and_labels(parts, labels, DOWN, **kwargs)
def add_braces_and_labels(self):
for attr in "horizontal_parts", "vertical_parts":
if not hasattr(self, attr):
continue
parts = getattr(self, attr)
for subattr in "braces", "labels":
if hasattr(parts, subattr):
self.add(getattr(parts, subattr))
### Cards ### ### Cards ###