diff --git a/from_3b1b/active/bayes/beta1.py b/from_3b1b/active/bayes/beta1.py index 3b8070ea..2837cd2f 100644 --- a/from_3b1b/active/bayes/beta1.py +++ b/from_3b1b/active/bayes/beta1.py @@ -40,6 +40,36 @@ class Thumbnail1(Scene): self.add(text) +class AltThumbnail1(Scene): + def construct(self): + N = 20 + n_trials = 10000 + p = 0.7 + outcomes = (np.random.random((N, n_trials)) < p).sum(0) + counts = [] + for k in range(N + 1): + counts.append((outcomes == k).sum()) + + hist = Histogram( + counts, + y_max=0.3, + y_tick_freq=0.05, + y_axis_numbers_to_show=[10, 20, 30], + x_label_freq=10, + ) + hist.set_width(FRAME_WIDTH - 1) + hist.bars.set_submobject_colors_by_gradient(YELLOW, YELLOW, GREEN, BLUE) + hist.bars.set_stroke(WHITE, 2) + + title = TextMobject("Binomial distribution") + title.set_width(12) + title.to_corner(UR, buff=0.8) + title.add_background_rectangle() + + self.add(hist) + self.add(title) + + class Thumbnail2(Scene): def construct(self): axes = self.get_axes() @@ -1827,8 +1857,8 @@ class AskAboutUnknownProbabilities(Scene): def show_many_coins(self, n_rows, n_cols): coin_choices = VGroup( - get_coin(BLUE_E, "H"), - get_coin(RED_E, "T"), + get_coin("H"), + get_coin("T"), ) coin_choices.set_stroke(width=0) coins = VGroup(*[ @@ -1873,10 +1903,10 @@ class AskProbabilityOfCoins(Scene): condition = VGroup( TextMobject("If you've seen"), Integer(80, color=BLUE_C), - get_coin(BLUE_E, "H").set_height(0.5), + get_coin("H").set_height(0.5), TextMobject("and"), Integer(20, color=RED_C), - get_coin(RED_E, "T").set_height(0.5), + get_coin("T").set_height(0.5), ) condition.arrange(RIGHT) condition.to_edge(UP) @@ -1886,7 +1916,7 @@ class AskProbabilityOfCoins(Scene): "\\text{What is }", "P(", "00", ")", "?" ) - coin = get_coin(BLUE_E, "H") + coin = get_coin("H") coin.replace(question.get_part_by_tex("00")) question.replace_submobject( question.index_of_part_by_tex("00"), @@ -1899,10 +1929,7 @@ class AskProbabilityOfCoins(Scene): random.shuffle(values) coins = VGroup(*[ - get_coin( - BLUE_E if symbol == "H" else RED_E, - symbol - ) + get_coin(symbol) for symbol in values ]) coins.arrange_in_grid(10, 10, buff=MED_SMALL_BUFF) @@ -3265,8 +3292,8 @@ class StateIndependence(Scene): class IllustrateBinomialSetupWithCoins(Scene): def construct(self): coins = [ - get_coin(BLUE_E, "H"), - get_coin(RED_E, "T"), + get_coin("H"), + get_coin("T"), ] coin_row = VGroup() @@ -3289,7 +3316,7 @@ class IllustrateBinomialSetupWithCoins(Scene): "k": GREEN, } ) - heads = get_coin(BLUE_E, "H") + heads = get_coin("H") template = prob_label.get_part_by_tex("00") heads.replace(template) prob_label.replace_submobject( diff --git a/from_3b1b/active/bayes/beta2.py b/from_3b1b/active/bayes/beta2.py index 3349e26f..d7cee960 100644 --- a/from_3b1b/active/bayes/beta2.py +++ b/from_3b1b/active/bayes/beta2.py @@ -1,519 +1,344 @@ from manimlib.imports import * from from_3b1b.active.bayes.beta_helpers import * from from_3b1b.active.bayes.beta1 import * +from from_3b1b.old.hyperdarts import Dartboard import scipy.stats OUTPUT_DIRECTORY = "bayes/beta2" -class PartTwoReady(Scene): +class WeightedCoin(Scene): def construct(self): - br = FullScreenFadeRectangle() - br.set_fill(GREY_E, 1) - self.add(br) - text = TextMobject( - "Part 2\\\\", - "Early view\\\\for supporters" + # Coin grid + bools = 50 * [True] + 50 * [False] + random.shuffle(bools) + grid = get_coin_grid(bools) + + sorted_grid = VGroup(*grid) + sorted_grid.submobjects.sort(key=lambda m: m.symbol) + + # Prob label + p_label = get_prob_coin_label() + p_label.set_height(0.7) + p_label.to_edge(UP) + + rhs = p_label[-1] + rhs_box = SurroundingRectangle(rhs, color=RED) + rhs_label = TextMobject("Not necessarily") + rhs_label.next_to(rhs_box, DOWN, LARGE_BUFF) + rhs_label.to_edge(RIGHT) + rhs_label.match_color(rhs_box) + + rhs_arrow = Arrow( + rhs_label.get_top(), + rhs_box.get_right(), + buff=SMALL_BUFF, + path_arc=60 * DEGREES, + color=rhs_box.get_color() ) - text.scale(1.5) - text[0].match_width(text[1], about_edge=DOWN) - text[0].shift(MED_SMALL_BUFF * UP) - text[1].set_color("#f96854") - self.add(text) + # Introduce coin + self.play(FadeIn( + grid, + run_time=2, + rate_func=linear, + lag_ratio=3 / len(grid), + )) + self.wait() + self.play( + grid.set_height, 5, + grid.to_edge, DOWN, + FadeInFromDown(p_label) + ) -class ShowBayesianUpdating(Scene): - CONFIG = { - "true_p": 0.72, - "random_seed": 4, - "initial_axis_scale_factor": 3.5 - } - - def construct(self): - # Axes - axes = self.get_axes() - self.add(axes) - - # Graph - n_heads = 0 - n_tails = 0 - graph = get_beta_graph(axes, n_heads, n_tails) - self.add(graph) - - # Get coins - true_p = self.true_p - bool_values = np.random.random(100) < true_p - bool_values[1] = True - coins = self.get_coins(bool_values) - coins.next_to(axes.y_axis, RIGHT, MED_LARGE_BUFF) - coins.to_edge(UP, LARGE_BUFF) - - # Probability label - p_label, prob, prob_box = self.get_probability_label() - self.add(p_label) - self.add(prob_box) - - # Slow animations - def head_likelihood(x): - return x - - def tail_likelihood(x): - return 1 - x - - n_previews = 10 - n_slow_previews = 5 - for x in range(n_previews): - coin = coins[x] - is_heads = bool_values[x] - - new_data_label = TextMobject("New data") - new_data_label.set_height(0.3) - arrow = Vector(0.5 * UP) - arrow.next_to(coin, DOWN, SMALL_BUFF) - new_data_label.next_to(arrow, DOWN, SMALL_BUFF) - new_data_label.shift(MED_SMALL_BUFF * RIGHT) - - if is_heads: - line = axes.get_graph(lambda x: x) - label = TexMobject("\\text{Scale by } x") - likelihood = head_likelihood - n_heads += 1 - else: - line = axes.get_graph(lambda x: 1 - x) - label = TexMobject("\\text{Scale by } (1 - x)") - likelihood = tail_likelihood - n_tails += 1 - label.next_to(graph, UP) - label.set_stroke(BLACK, 3, background=True) - line.set_stroke(YELLOW, 3) - - graph_copy = graph.copy() - graph_copy.unlock_triangulation() - scaled_graph = graph.copy() - scaled_graph.apply_function( - lambda p: axes.c2p( - axes.x_axis.p2n(p), - axes.y_axis.p2n(p) * likelihood(axes.x_axis.p2n(p)) - ) - ) - scaled_graph.set_color(GREEN) - - renorm_label = TextMobject("Renormalize") - renorm_label.move_to(label) - - new_graph = get_beta_graph(axes, n_heads, n_tails) - - renormalized_graph = scaled_graph.copy() - renormalized_graph.match_style(graph) - renormalized_graph.match_height(new_graph, stretch=True, about_edge=DOWN) - - if x < n_slow_previews: - self.play( - FadeInFromDown(coin), - FadeIn(new_data_label), - GrowArrow(arrow), - ) - self.play( - FadeOut(new_data_label), - FadeOut(arrow), - ShowCreation(line), - FadeIn(label), - ) - self.add(graph_copy, line, label) - self.play(Transform(graph_copy, scaled_graph)) - self.play( - FadeOut(line), - FadeOut(label), - FadeIn(renorm_label), - ) - self.play( - Transform(graph_copy, renormalized_graph), - FadeOut(graph), - ) - self.play(FadeOut(renorm_label)) - else: - self.add(coin) - graph_copy.become(scaled_graph) - self.add(graph_copy) - self.play( - Transform(graph_copy, renormalized_graph), - FadeOut(graph), - ) - graph = new_graph - self.remove(graph_copy) - self.add(new_graph) - - # Rescale y axis - axes.save_state() - sf = self.initial_axis_scale_factor - axes.y_axis.stretch(1 / sf, 1, about_point=axes.c2p(0, 0)) - for number in axes.y_axis.numbers: - number.stretch(sf, 1) - axes.y_axis.numbers[:4].set_opacity(0) + for coin in grid: + coin.generate_target() + sorted_coins = list(grid) + sorted_coins.sort(key=lambda m: m.symbol) + for c1, c2 in zip(sorted_coins, grid): + c1.target.move_to(c2) self.play( - Restore(axes, rate_func=lambda t: smooth(1 - t)), - graph.stretch, 1 / sf, 1, {"about_edge": DOWN}, - run_time=2, + FadeIn(rhs_label, lag_ratio=0.1), + ShowCreation(rhs_arrow), + ShowCreation(rhs_box), + LaggedStartMap( + MoveToTarget, grid, + path_arc=30 * DEGREES, + lag_ratio=0.01, + ), ) - # Fast animations - for x in range(n_previews, len(coins)): - coin = coins[x] - is_heads = bool_values[x] + # Alternate weightings + old_grid = VGroup(*sorted_coins) + rhs_junk_on_screen = True + for value in [0.2, 0.9, 0.0, 0.31]: + n = int(100 * value) + new_grid = get_coin_grid([True] * n + [False] * (100 - n)) + new_grid.replace(grid) - if is_heads: - n_heads += 1 - else: - n_tails += 1 - new_graph = get_beta_graph(axes, n_heads, n_tails) + anims = [] + if rhs_junk_on_screen: + anims += [ + FadeOut(rhs_box), + FadeOut(rhs_label), + FadeOut(rhs_arrow), + ] + rhs_junk_on_screen = False - self.add(coins[:x + 1]) + self.wait() self.play( - FadeIn(new_graph), - run_time=0.25, + FadeOutAndShift( + old_grid, + 0.1 * DOWN, + lag_ratio=0.01, + run_time=1.5 + ), + FadeIn(new_grid, lag_ratio=0.01, run_time=1.5), + ChangeDecimalToValue(rhs, value), + *anims, ) - self.play( - FadeOut(graph), - run_time=0.25, - ) - graph = new_graph + old_grid = new_grid - # Show confidence interval + long_rhs = DecimalNumber( + 0.31415926, + num_decimal_places=8, + show_ellipsis=True, + ) + long_rhs.match_height(rhs) + long_rhs.move_to(rhs, DL) + + self.play(ShowIncreasingSubsets(long_rhs, rate_func=linear)) + self.wait() + + # You just don't know + box = get_q_box(rhs) + + self.remove(rhs) + self.play( + FadeOut(old_grid, lag_ratio=0.1), + FadeOutAndShift(long_rhs, 0.1 * RIGHT, lag_ratio=0.1), + Write(box), + ) + p_label.add(box) + self.wait() + + # 7/10 heads + bools = [True] * 7 + [False] * 3 + random.shuffle(bools) + coins = VGroup(*[ + get_coin("H" if heads else "T") + for heads in bools + ]) + coins.arrange(RIGHT) + coins.set_height(0.7) + coins.next_to(p_label, DOWN, buff=LARGE_BUFF) + + heads_arrows = VGroup(*[ + Vector( + 0.5 * UP, + max_stroke_width_to_length_ratio=15, + max_tip_length_to_length_ratio=0.4, + ).next_to(coin, DOWN) + for coin in coins + if coin.symbol == "H" + ]) + numbers = VGroup(*[ + Integer(i + 1).next_to(arrow, DOWN, SMALL_BUFF) + for i, arrow in enumerate(heads_arrows) + ]) + + for coin in coins: + coin.save_state() + coin.stretch(0, 0) + coin.set_opacity(0) + + self.play(LaggedStartMap(Restore, coins), run_time=1) + self.play( + ShowIncreasingSubsets(heads_arrows), + ShowIncreasingSubsets(numbers), + rate_func=linear, + ) + self.wait() + + # Plot + axes = scaled_pdf_axes() + axes.to_edge(DOWN, buff=MED_SMALL_BUFF) + axes.y_axis.numbers.set_opacity(0) + axes.y_axis_label.set_opacity(0) + + x_axis_label = p_label[:4].copy() + x_axis_label.set_height(0.4) + x_axis_label.next_to(axes.c2p(1, 0), UR, buff=SMALL_BUFF) + axes.x_axis.add(x_axis_label) + + n_heads = 7 + n_tails = 3 + graph = get_beta_graph(axes, n_heads, n_tails) dist = scipy.stats.beta(n_heads + 1, n_tails + 1) - v_lines = VGroup() - labels = VGroup() - x_bounds = dist.interval(0.95) - for x in x_bounds: - line = DashedLine( - axes.c2p(x, 0), - axes.c2p(x, 12), - ) - line.set_color(YELLOW) - v_lines.add(line) - label = DecimalNumber(x) - label.set_height(0.25) - label.next_to(line, UP) - label.match_color(line) - labels.add(label) - true_graph = axes.get_graph(dist.pdf) - region = get_region_under_curve(axes, true_graph, *x_bounds) - region.set_fill(GREY_BROWN, 0.85) + + v_line = Line( + axes.c2p(0.7, 0), + axes.input_to_graph_point(0.7, true_graph), + ) + v_line.set_stroke(YELLOW, 4) + + region = get_region_under_curve(axes, true_graph, 0.6, 0.8) + region.set_fill(GREY, 0.85) region.set_stroke(YELLOW, 1) - label95 = TexMobject("95\\%") - fix_percent(label95.family_members_with_points()[-1]) - label95.move_to(region, DOWN) - label95.shift(0.5 * UP) + eq_label = VGroup( + p_label[:4].copy(), + TexMobject("= 0.7"), + ) + for mob in eq_label: + mob.set_height(0.4) + eq_label.arrange(RIGHT, buff=SMALL_BUFF) + pp_label = VGroup( + TexMobject("P("), + eq_label, + TexMobject(")"), + ) + for mob in pp_label[::2]: + mob.set_height(0.7) + mob.set_color(YELLOW) + pp_label.arrange(RIGHT, buff=SMALL_BUFF) + pp_label.move_to(axes.c2p(0.3, 3)) - self.play(*map(ShowCreation, v_lines)) self.play( - FadeIn(region), - Write(label95) + FadeOut(heads_arrows), + FadeOut(numbers), + Write(axes), + DrawBorderThenFill(graph), ) - self.wait() - for label in labels: - self.play(FadeInFromDown(label)) - self.wait() - - # Show true value - self.wait() - self.play(FadeOut(prob_box)) - self.play(ShowCreationThenFadeAround(prob)) - self.wait() - - # Much more data - many_bools = np.hstack([ - bool_values, - (np.random.random(1000) < true_p) - ]) - N_tracker = ValueTracker(100) - graph.N_tracker = N_tracker - graph.bools = many_bools - graph.axes = axes - graph.v_lines = v_lines - graph.labels = labels - graph.region = region - graph.label95 = label95 - - label95.width_ratio = label95.get_width() / region.get_width() - - def update_graph(graph): - N = int(graph.N_tracker.get_value()) - nh = sum(graph.bools[:N]) - nt = len(graph.bools[:N]) - nh - new_graph = get_beta_graph(graph.axes, nh, nt, step_size=0.05) - graph.become(new_graph) - - dist = scipy.stats.beta(nh + 1, nt + 1) - x_bounds = dist.interval(0.95) - for x, line, label in zip(x_bounds, graph.v_lines, graph.labels): - line.set_x(graph.axes.c2p(x, 0)[0]) - label.set_x(graph.axes.c2p(x, 0)[0]) - label.set_value(x) - - graph.labels[0].shift(MED_SMALL_BUFF * LEFT) - graph.labels[1].shift(MED_SMALL_BUFF * RIGHT) - - new_simple_graph = graph.axes.get_graph(dist.pdf) - new_region = get_region_under_curve(graph.axes, new_simple_graph, *x_bounds) - new_region.match_style(graph.region) - graph.region.become(new_region) - - graph.label95.set_width(graph.label95.width_ratio * graph.region.get_width()) - graph.label95.match_x(graph.region) - - self.add(graph, region, label95, p_label) self.play( - N_tracker.set_value, 1000, - UpdateFromFunc(graph, update_graph), - Animation(v_lines), - Animation(labels), - Animation(graph.region), - Animation(graph.label95), - run_time=5, + FadeIn(pp_label[::2]), + ShowCreation(v_line), + ) + self.wait() + self.play(TransformFromCopy(p_label[:4], eq_label[0])) + self.play( + GrowFromPoint(eq_label[1], v_line.get_center()) ) self.wait() - # - def get_axes(self): - axes = get_beta_dist_axes( - label_y=True, - y_unit=1, + # Look confused + randy = Randolph() + randy.set_height(1.5) + randy.next_to(axes.c2p(0, 0), UR, MED_LARGE_BUFF) + + self.play(FadeIn(randy)) + self.play(randy.change, "confused", pp_label.get_top()) + self.play(Blink(randy)) + self.wait() + self.play(FadeOut(randy)) + + # Remind what the title is + title = TextMobject( + "Probabilities", "of", "Probabilities" ) - axes.y_axis.numbers.set_submobjects([ - *axes.y_axis.numbers[:5], - *axes.y_axis.numbers[4::5] - ]) - sf = self.initial_axis_scale_factor - axes.y_axis.stretch(sf, 1, about_point=axes.c2p(0, 0)) - for number in axes.y_axis.numbers: - number.stretch(1 / sf, 1) - axes.y_axis_label.to_edge(LEFT) - axes.y_axis_label.add_background_rectangle(opacity=1) - axes.set_stroke(background=True) - return axes + title.arrange(DOWN, aligned_edge=LEFT) + title.next_to(axes.c2p(0, 0), UR, buff=MED_LARGE_BUFF) + title.align_to(pp_label, LEFT) - def get_coins(self, bool_values): - coins = VGroup(*[ - get_coin(BLUE_E, "H") - if heads else - get_coin(RED_E, "T") - for heads in bool_values - ]) - coins.arrange_in_grid(n_rows=10, buff=MED_LARGE_BUFF) - coins.set_height(5) - return coins + self.play(ShowIncreasingSubsets(title, rate_func=linear)) + self.wait() + self.play(FadeOut(title)) - def get_probability_label(self): - head = get_coin(BLUE_E, "H") - p_label = TexMobject( - "P(00) = ", - tex_to_color_map={"00": WHITE} + # Continuous values + v_line.tracker = ValueTracker(0.7) + v_line.axes = axes + v_line.graph = true_graph + v_line.add_updater( + lambda m: m.put_start_and_end_on( + m.axes.c2p(m.tracker.get_value(), 0), + m.axes.input_to_graph_point(m.tracker.get_value(), m.graph), + ) ) - template = p_label.get_part_by_tex("00") - head.replace(template) - p_label.replace_submobject( - p_label.index_of_part(template), - head, + + for value in [0.4, 0.9, 0.7]: + self.play( + v_line.tracker.set_value, value, + run_time=3, + ) + + # Label h + brace = Brace(rhs_box, DOWN, buff=SMALL_BUFF) + h_label = TexMobject("h", buff=SMALL_BUFF) + h_label.set_color(YELLOW) + h_label.next_to(brace, DOWN) + + self.play( + LaggedStartMap(FadeOutAndShift, coins, lambda m: (m, DOWN)), + GrowFromCenter(brace), + Write(h_label), ) - prob = DecimalNumber(self.true_p) - prob.next_to(p_label, RIGHT) - p_label.add(prob) - p_label.set_height(0.75) - p_label.to_corner(UR) + self.wait() - prob_box = SurroundingRectangle(prob, buff=SMALL_BUFF) - prob_box.set_fill(GREY_D, 1) - prob_box.set_stroke(WHITE, 2) - - q_marks = TexMobject("???") - q_marks.move_to(prob_box) - prob_box.add(q_marks) - - return p_label, prob, prob_box + # End + self.embed() -class HighlightReviewPartsReversed(HighlightReviewParts): - CONFIG = { - "reverse_order": True, - } - - -class LastTimeWrapper(Scene): +class Eq70(Scene): def construct(self): - fs_rect = FullScreenFadeRectangle(fill_opacity=1, fill_color=GREY_E) - self.add(fs_rect) - - title = TextMobject("Last Time") - title.scale(1.5) - title.to_edge(UP) - - rect = ScreenRectangle() - rect.set_height(6) - rect.set_fill(BLACK, 1) - rect.next_to(title, DOWN) - - self.play( - DrawBorderThenFill(rect), - FadeInFromDown(title), - ) + label = TexMobject("=", "70", "\\%", "?") + fix_percent(label.get_part_by_tex("\\%")[0]) + self.play(FadeIn(label)) self.wait() -class Grey(Scene): +class ShowInfiniteContinuum(Scene): def construct(self): - self.add(FullScreenFadeRectangle(fill_color=GREY_D, fill_opacity=1)) + # Axes + axes = scaled_pdf_axes() + axes.to_edge(DOWN, buff=MED_SMALL_BUFF) + axes.y_axis.numbers.set_opacity(0) + axes.y_axis_label.set_opacity(0) + self.add(axes) + # Label + p_label = get_prob_coin_label() + p_label.set_height(0.7) + p_label.to_edge(UP) + box = get_q_box(p_label[-1]) + p_label.add(box) -class ShowBayesRule(Scene): - def construct(self): - hyp = "\\text{Hypothesis}" - data = "\\text{Data}" - bayes = TexMobject( - f"P({hyp} \\,|\\, {data})", "=", "{", - f"P({data} \\,|\\, {hyp})", f"P({hyp})", - "\\over", f"P({data})", - tex_to_color_map={ - hyp: YELLOW, - data: GREEN, - } - ) + brace = Brace(box, DOWN, buff=SMALL_BUFF) + h_label = TexMobject("h") + h_label.next_to(brace, DOWN) + h_label.set_color(YELLOW) + eq = TexMobject("=") + eq.next_to(h_label, RIGHT) + value = DecimalNumber(0, num_decimal_places=4) + value.match_height(h_label) + value.next_to(eq, RIGHT) + value.set_color(YELLOW) - title = TextMobject("Bayes' rule") - title.scale(2) - title.to_edge(UP) + self.add(p_label) + self.add(brace) + self.add(h_label) - self.add(title) - self.add(*bayes[:5]) - self.wait() - self.play( - *[ - TransformFromCopy(bayes[i], bayes[j], path_arc=30 * DEGREES) - for i, j in [ - (0, 7), - (1, 10), - (2, 9), - (3, 8), - (4, 11), - ] - ], - FadeIn(bayes[5]), - run_time=1.5 - ) - self.wait() - self.play( - *[ - TransformFromCopy(bayes[i], bayes[j], path_arc=30 * DEGREES) - for i, j in [ - (0, 12), - (1, 13), - (4, 14), - (0, 16), - (3, 17), - (4, 18), - ] - ], - FadeIn(bayes[15]), - run_time=1.5 - ) - self.add(bayes) - self.wait() - - hyp_word = bayes.get_part_by_tex(hyp) - example_hyp = TextMobject( - "For example,\\\\", - "$0.9 < s < 0.99$", - ) - example_hyp[1].set_color(YELLOW) - example_hyp.next_to(hyp_word, DOWN, buff=1.5) - - data_word = bayes.get_part_by_tex(data) - example_data = TexMobject( - "48\\,", CMARK_TEX, - "\\,2\\,", XMARK_TEX, - ) - example_data.set_color_by_tex(CMARK_TEX, GREEN) - example_data.set_color_by_tex(XMARK_TEX, RED) - example_data.scale(1.5) - example_data.next_to(example_hyp, RIGHT, buff=1.5) - - hyp_arrow = Arrow( - hyp_word.get_bottom(), - example_hyp.get_top(), - ) - data_arrow = Arrow( - data_word.get_bottom(), - example_data.get_top(), - ) - - self.play( - GrowArrow(hyp_arrow), - FadeInFromPoint(example_hyp, hyp_word.get_center()), - ) - self.wait() - self.play( - GrowArrow(data_arrow), - FadeInFromPoint(example_data, data_word.get_center()), - ) - self.wait() - - -class VisualizeBayesRule(Scene): - def construct(self): - self.show_continuum() - self.show_arrows() - self.show_discrete_probabilities() - self.show_bayes_formula() - self.parallel_universes() - self.update_from_data() - - def show_continuum(self): - axes = get_beta_dist_axes(y_max=1, y_unit=0.1) - axes.y_axis.add_numbers( - *np.arange(0.2, 1.2, 0.2), - number_config={ - "num_decimal_places": 1, - } - ) - - p_label = TexMobject( - "P(s \\,|\\, \\text{data})", - tex_to_color_map={ - "s": YELLOW, - "\\text{data}": GREEN, - } - ) - p_label.scale(1.5) - p_label.to_edge(UP, LARGE_BUFF) - - s_part = p_label.get_part_by_tex("s").copy() + # Moving h + h_part = h_label.copy() x_line = Line(axes.c2p(0, 0), axes.c2p(1, 0)) x_line.set_stroke(YELLOW, 3) - arrow = Vector(DOWN) - arrow.next_to(s_part, DOWN, SMALL_BUFF) - value = DecimalNumber(0, num_decimal_places=4) - value.set_color(YELLOW) - value.next_to(arrow, DOWN) - - self.add(axes) - self.add(p_label) self.play( - s_part.next_to, x_line.get_start(), UR, SMALL_BUFF, - GrowArrow(arrow), - FadeInFromPoint(value, s_part.get_center()), + h_part.next_to, x_line.get_start(), UR, SMALL_BUFF, + Write(eq), + FadeInFromPoint(value, h_part.get_center()), ) - s_part.tracked = x_line + # Scan continuum + h_part.tracked = x_line value.tracked = x_line value.x_axis = axes.x_axis self.play( ShowCreation(x_line), UpdateFromFunc( - s_part, + h_part, lambda m: m.next_to(m.tracked.get_end(), UR, SMALL_BUFF) ), UpdateFromFunc( @@ -526,24 +351,16 @@ class VisualizeBayesRule(Scene): ) self.wait() self.play( - FadeOut(arrow), + FadeOut(eq), FadeOut(value), ) - self.p_label = p_label - self.s_part = s_part - self.value = value - self.x_line = x_line - self.axes = axes - - def show_arrows(self): - axes = self.axes - + # Arrows arrows = VGroup() arrow_template = Vector(DOWN) arrow_template.lock_triangulation() - def get_arrow(s, denom): + def get_arrow(s, denom, arrow_template=arrow_template, axes=axes): arrow = arrow_template.copy() arrow.set_height(4 / denom) arrow.move_to(axes.c2p(s, 0), DOWN) @@ -563,9 +380,9 @@ class VisualizeBayesRule(Scene): arrows.add(get_arrow(1 - 1 / k, k)) kw = { - "lag_ratio": 0.5, + "lag_ratio": 0.05, "run_time": 5, - "rate_func": lambda t: t**4, + "rate_func": lambda t: t**5, } arrows.save_state() for arrow in arrows: @@ -576,624 +393,88 @@ class VisualizeBayesRule(Scene): self.play(LaggedStartMap( ApplyMethod, arrows, lambda m: (m.scale, 0, {"about_edge": DOWN}), - **kw + lag_ratio=10 / len(arrows), + rate_func=smooth, + run_time=3, )) self.remove(arrows) self.wait() - def show_discrete_probabilities(self): - axes = self.axes - x_lines = VGroup() - dx = 0.01 - for x in np.arange(0, 1, dx): - line = Line( - axes.c2p(x, 0), - axes.c2p(x + dx, 0), - ) - line.set_stroke(BLUE, 3) - line.generate_target() - line.target.rotate( - 90 * DEGREES, - about_point=line.get_start() - ) - x_lines.add(line) +class TitleCard(Scene): + def construct(self): + text = TextMobject("A beginner's guide to\\\\probability density") + text.scale(2) + text.to_edge(UP, buff=1.5) - self.add(x_lines) - self.play( - FadeOut(self.x_line), - LaggedStartMap( - MoveToTarget, x_lines, - ) - ) + subtext = TextMobject("Probabilities of probabilities, ", "part 2") + subtext.set_width(FRAME_WIDTH - 3) + subtext[0].set_color(BLUE) + subtext.next_to(text, DOWN, LARGE_BUFF) - label = Integer(0) - label.set_height(0.5) - label.next_to(self.p_label[1], DOWN, LARGE_BUFF) - unit = TexMobject("\\%") - unit.match_height(label) - fix_percent(unit.family_members_with_points()[0]) - always(unit.next_to, label, RIGHT, SMALL_BUFF) + self.add(text) + self.play(FadeIn(subtext, lag_ratio=0.1, run_time=2)) + self.wait(2) - arrow = Arrow() - arrow.max_stroke_width_to_length_ratio = 1 - arrow.axes = axes - arrow.label = label - arrow.add_updater(lambda m: m.put_start_and_end_on( - m.label.get_bottom() + MED_SMALL_BUFF * DOWN, - m.axes.c2p(0.01 * m.label.get_value(), 0.03), - )) - self.add(label, unit, arrow) - self.play( - ChangeDecimalToValue(label, 99), - run_time=5, - ) - self.wait() - self.play(*map(FadeOut, [label, unit, arrow])) - - # Show prior label - p_label = self.p_label - given_data = p_label[2:4] - prior_label = TexMobject("P(s)", tex_to_color_map={"s": YELLOW}) - prior_label.match_height(p_label) - prior_label.move_to(p_label, DOWN, LARGE_BUFF) - - p_label.save_state() - self.play( - given_data.scale, 0.5, - given_data.set_opacity, 0.5, - given_data.to_corner, UR, - Transform(p_label[:2], prior_label[:2]), - Transform(p_label[-1], prior_label[-1]), - ) +class NamePdfs(Scene): + def construct(self): + label = TextMobject("Probability density\\\\function") + self.play(Write(label)) self.wait() - # Zoom in on the y-values - new_ticks = VGroup() - new_labels = VGroup() - dy = 0.01 - for y in np.arange(dy, 5 * dy, dy): - height = get_norm(axes.c2p(0, dy) - axes.c2p(0, 0)) - tick = axes.y_axis.get_tick(y, SMALL_BUFF) - label = DecimalNumber(y) - label.match_height(axes.y_axis.numbers[0]) - always(label.next_to, tick, LEFT, SMALL_BUFF) - - new_ticks.add(tick) - new_labels.add(label) - - for num in axes.y_axis.numbers: - height = num.get_height() - always(num.set_height, height, stretch=True) - - bars = VGroup() - dx = 0.01 - origin = axes.c2p(0, 0) - for x in np.arange(0, 1, dx): - rect = Rectangle( - width=get_norm(axes.c2p(dx, 0) - origin), - height=get_norm(axes.c2p(0, dy) - origin), - ) - rect.x = x - rect.set_stroke(BLUE, 1) - rect.set_fill(BLUE, 0.5) - rect.move_to(axes.c2p(x, 0), DL) - bars.add(rect) - - stretch_group = VGroup( - axes.y_axis, - bars, - new_ticks, - x_lines, - ) - x_lines.set_height( - bars.get_height(), - about_edge=DOWN, - stretch=True, - ) - - self.play( - stretch_group.stretch, 25, 1, {"about_point": axes.c2p(0, 0)}, - VFadeIn(bars), - VFadeIn(new_ticks), - VFadeIn(new_labels), - VFadeOut(x_lines), - run_time=4, - ) - - highlighted_bars = bars.copy() - highlighted_bars.set_color(YELLOW) - self.play( - LaggedStartMap( - FadeIn, highlighted_bars, - lag_ratio=0.5, - rate_func=there_and_back, - ), - ShowCreationThenFadeAround(new_labels[0]), - run_time=3, - ) - self.remove(highlighted_bars) - - # Nmae as prior - prior_name = TextMobject("Prior", " distribution") - prior_name.set_height(0.6) - prior_name.next_to(prior_label, DOWN, LARGE_BUFF) - - self.play(FadeInFromDown(prior_name)) - self.wait() - - # Show alternate distribution - bars.save_state() - for a, b in [(5, 2), (1, 6)]: - dist = scipy.stats.beta(a, b) - for bar, saved in zip(bars, bars.saved_state): - bar.target = saved.copy() - height = get_norm(axes.c2p(0.1 * dist.pdf(bar.x)) - axes.c2p(0, 0)) - bar.target.set_height(height, about_edge=DOWN, stretch=True) - - self.play(LaggedStartMap(MoveToTarget, bars, lag_ratio=0.00)) - self.wait() - self.play(Restore(bars)) - self.wait() - - uniform_name = TextMobject("Uniform") - uniform_name.match_height(prior_name) - uniform_name.move_to(prior_name, DL) - uniform_name.shift(RIGHT) - uniform_name.set_y(bars.get_top()[1] + MED_SMALL_BUFF, DOWN) - self.play( - prior_name[0].next_to, uniform_name, RIGHT, MED_SMALL_BUFF, DOWN, - FadeOutAndShift(prior_name[1], RIGHT), - FadeInFrom(uniform_name, LEFT) - ) - self.wait() - - self.bars = bars - self.uniform_label = VGroup(uniform_name, prior_name[0]) - - def show_bayes_formula(self): - uniform_label = self.uniform_label - p_label = self.p_label - bars = self.bars - - prior_label = VGroup( - p_label[0].deepcopy(), - p_label[1].deepcopy(), - p_label[4].deepcopy(), - ) - eq = TexMobject("=") - likelihood_label = TexMobject( - "P(", "\\text{data}", "|", "s", ")", - ) - likelihood_label.set_color_by_tex("data", GREEN) - likelihood_label.set_color_by_tex("s", YELLOW) - over = Line(LEFT, RIGHT) - p_data_label = TextMobject("P(", "\\text{data}", ")") - p_data_label.set_color_by_tex("data", GREEN) - - for mob in [eq, likelihood_label, over, p_data_label]: - mob.scale(1.5) - mob.set_opacity(0.1) - - eq.move_to(prior_label, LEFT) - over.set_width( - prior_label.get_width() + - likelihood_label.get_width() + - MED_SMALL_BUFF - ) - over.next_to(eq, RIGHT, MED_SMALL_BUFF) - p_data_label.next_to(over, DOWN, MED_SMALL_BUFF) - likelihood_label.next_to(over, UP, MED_SMALL_BUFF, RIGHT) - - self.play( - p_label.restore, - p_label.next_to, eq, LEFT, MED_SMALL_BUFF, - prior_label.next_to, over, UP, MED_SMALL_BUFF, LEFT, - FadeIn(eq), - FadeIn(likelihood_label), - FadeIn(over), - FadeIn(p_data_label), - FadeOut(uniform_label), - ) - - # Show new distribution - post_bars = bars.copy() - total_prob = 0 - for bar, p in zip(post_bars, np.arange(0, 1, 0.01)): - prob = scipy.stats.binom(50, p).pmf(48) - bar.stretch(prob, 1, about_edge=DOWN) - total_prob += 0.01 * prob - post_bars.stretch(1 / total_prob, 1, about_edge=DOWN) - post_bars.stretch(0.25, 1, about_edge=DOWN) # Lie to fit on screen... - post_bars.set_color(MAROON_D) - post_bars.set_fill(opacity=0.8) +class LabelH(Scene): + def construct(self): + p_label = get_prob_coin_label() + p_label.scale(1.5) brace = Brace(p_label, DOWN) - post_word = brace.get_text("Posterior") - post_word.scale(1.25, about_edge=UP) - post_word.set_color(MAROON_D) + h = TexMobject("h") + h.scale(2) + h.next_to(brace, DOWN) + self.add(p_label) + self.play(ShowCreationThenFadeAround(p_label)) self.play( - ReplacementTransform( - bars.copy().set_opacity(0), - post_bars, - ), GrowFromCenter(brace), - FadeInFrom(post_word, 0.25 * UP) - ) - self.wait() - self.play( - eq.set_opacity, 1, - likelihood_label.set_opacity, 1, - ) - self.wait() - - data = get_check_count_label(48, 2) - data.scale(1.5) - data.next_to(likelihood_label, DOWN, buff=2, aligned_edge=LEFT) - data_arrow = Arrow( - likelihood_label[1].get_bottom(), - data.get_top() - ) - data_arrow.set_color(GREEN) - - self.play( - GrowArrow(data_arrow), - GrowFromPoint(data, data_arrow.get_start()), - ) - self.wait() - self.play(FadeOut(data_arrow)) - self.play( - over.set_opacity, 1, - p_data_label.set_opacity, 1, - ) - self.wait() - - self.play( - FadeOut(brace), - FadeOut(post_word), - FadeOut(post_bars), - FadeOut(data), - p_label.set_opacity, 0.1, - eq.set_opacity, 0.1, - likelihood_label.set_opacity, 0.1, - over.set_opacity, 0.1, - p_data_label.set_opacity, 0.1, - ) - - self.bayes = VGroup( - p_label, eq, - prior_label, likelihood_label, - over, p_data_label - ) - self.data = data - - def parallel_universes(self): - bars = self.bars - - cols = VGroup() - squares = VGroup() - sample_colors = color_gradient( - [GREEN_C, GREEN_D, GREEN_E], - 100 - ) - for bar in bars: - n_rows = 12 - col = VGroup() - for x in range(n_rows): - square = Rectangle( - width=bar.get_width(), - height=bar.get_height() / n_rows, - ) - square.set_stroke(width=0) - square.set_fill(opacity=1) - square.set_color(random.choice(sample_colors)) - col.add(square) - squares.add(square) - col.arrange(DOWN, buff=0) - col.move_to(bar) - cols.add(col) - squares.shuffle() - - self.play( - LaggedStartMap( - VFadeInThenOut, squares, - lag_ratio=0.005, - run_time=3 - ) - ) - self.remove(squares) - squares.set_opacity(1) - self.wait() - - example_col = cols[95] - - self.play( - bars.set_opacity, 0.25, - FadeIn(example_col, lag_ratio=0.1), - ) - self.wait() - - dist = scipy.stats.binom(50, 0.95) - for x in range(12): - square = random.choice(example_col).copy() - square.set_fill(opacity=0) - square.set_stroke(YELLOW, 2) - self.add(square) - nc = dist.ppf(random.random()) - data = get_check_count_label(nc, 50 - nc) - data.next_to(example_col, UP) - - self.add(square, data) - self.wait(0.5) - self.remove(square, data) - self.wait() - - self.data.set_opacity(1) - self.play( - FadeIn(self.data), - FadeOut(example_col), - self.bayes[3].set_opacity, 1, - ) - self.wait() - - def update_from_data(self): - bars = self.bars - data = self.data - bayes = self.bayes - - new_bars = bars.copy() - new_bars.set_stroke(opacity=1) - new_bars.set_fill(opacity=0.8) - for bar, p in zip(new_bars, np.arange(0, 1, 0.01)): - dist = scipy.stats.binom(50, p) - scalar = dist.pmf(48) - bar.stretch(scalar, 1, about_edge=DOWN) - - self.play( - ReplacementTransform( - bars.copy().set_opacity(0), - new_bars - ), - bars.set_fill, {"opacity": 0.1}, - bars.set_stroke, {"opacity": 0.1}, - run_time=2, - ) - - # Show example bar - bar95 = VGroup( - bars[95].copy(), - new_bars[95].copy() - ) - bar95.save_state() - bar95.generate_target() - bar95.target.scale(2) - bar95.target.next_to(bar95, UP, LARGE_BUFF) - bar95.target.set_stroke(BLUE, 3) - - ex_label = TexMobject("s", "=", "0.95") - ex_label.set_color(YELLOW) - ex_label.next_to(bar95.target, DOWN, submobject_to_align=ex_label[-1]) - - highlight = SurroundingRectangle(bar95, buff=0) - highlight.set_stroke(YELLOW, 2) - - self.play(FadeIn(highlight)) - self.play( - MoveToTarget(bar95), - FadeInFromDown(ex_label), - data.shift, LEFT, - ) - self.wait() - - side_brace = Brace(bar95[1], RIGHT, buff=SMALL_BUFF) - side_label = side_brace.get_text("0.26", buff=SMALL_BUFF) - self.play( - GrowFromCenter(side_brace), - FadeIn(side_label) - ) - self.wait() - self.play( - FadeOut(side_brace), - FadeOut(side_label), - FadeOut(ex_label), - ) - self.play( - bar95.restore, - bar95.set_opacity, 0, - ) - - for bar in bars[94:80:-1]: - highlight.move_to(bar) - self.wait(0.5) - self.play(FadeOut(highlight)) - self.wait() - - # Emphasize formula terms - tops = VGroup() - for bar, new_bar in zip(bars, new_bars): - top = Line(bar.get_corner(UL), bar.get_corner(UR)) - top.set_stroke(YELLOW, 2) - top.generate_target() - top.target.move_to(new_bar, UP) - tops.add(top) - - rect = SurroundingRectangle(bayes[2]) - rect.set_stroke(YELLOW, 1) - rect.target = SurroundingRectangle(bayes[3]) - rect.target.match_style(rect) - self.play( - ShowCreation(rect), - ShowCreation(tops), - ) - self.wait() - self.play( - LaggedStartMap( - MoveToTarget, tops, - run_time=2, - lag_ratio=0.02, - ), - MoveToTarget(rect), - ) - self.play(FadeOut(tops)) - self.wait() - - # Show alternate priors - axes = self.axes - bar_groups = VGroup() - for bar, new_bar in zip(bars, new_bars): - bar_groups.add(VGroup(bar, new_bar)) - - bar_groups.save_state() - for a, b in [(5, 2), (7, 1)]: - dist = scipy.stats.beta(a, b) - for bar, saved in zip(bar_groups, bar_groups.saved_state): - bar.target = saved.copy() - height = get_norm(axes.c2p(0.1 * dist.pdf(bar[0].x)) - axes.c2p(0, 0)) - height = max(height, 1e-6) - bar.target.set_height(height, about_edge=DOWN, stretch=True) - - self.play(LaggedStartMap(MoveToTarget, bar_groups, lag_ratio=0)) - self.wait() - self.play(Restore(bar_groups)) - self.wait() - - # Rescale - ex_p_label = TexMobject( - "P(s = 0.95 | 00000000) = ", - tex_to_color_map={ - "s = 0.95": YELLOW, - "00000000": WHITE, - } - ) - ex_p_label.scale(1.5) - ex_p_label.next_to(bars, UP, LARGE_BUFF) - ex_p_label.align_to(bayes, LEFT) - template = ex_p_label.get_part_by_tex("00000000") - template.set_opacity(0) - - highlight = SurroundingRectangle(new_bars[95], buff=0) - highlight.set_stroke(YELLOW, 1) - - self.remove(data) - self.play( - FadeIn(ex_p_label), - VFadeOut(data[0]), - data[1:].move_to, template, - FadeIn(highlight) - ) - self.wait() - - numer = new_bars[95].copy() - numer.set_stroke(YELLOW, 1) - denom = new_bars[80:].copy() - h_line = Line(LEFT, RIGHT) - h_line.set_width(3) - h_line.set_stroke(width=2) - h_line.next_to(ex_p_label, RIGHT) - - self.play( - numer.next_to, h_line, UP, - denom.next_to, h_line, DOWN, - ShowCreation(h_line), - ) - self.wait() - self.play( - denom.space_out_submobjects, - rate_func=there_and_back - ) - self.play( - bayes[4].set_opacity, 1, - bayes[5].set_opacity, 1, - FadeOut(rect), - ) - self.wait() - - # Rescale - self.play( - FadeOut(highlight), - FadeOut(ex_p_label), - FadeOut(data), - FadeOut(h_line), - FadeOut(numer), - FadeOut(denom), - bayes.set_opacity, 1, - ) - - new_bars.unlock_shader_data() - self.remove(new_bars, *new_bars) - self.play( - new_bars.set_height, 5, {"about_edge": DOWN, "stretch": True}, - new_bars.set_color, MAROON_D, + FadeInFrom(h, UP), ) self.wait() -class UniverseOf95Percent(WhatsTheModel): - CONFIG = {"s": 0.95} - +class DrawUnderline(Scene): def construct(self): - self.introduce_buyer_and_seller() - for m, v in [(self.seller, RIGHT), (self.buyer, LEFT)]: - m.shift(v) - m.label.shift(v) - - pis = VGroup(self.seller, self.buyer) - label = get_prob_positive_experience_label(True, True) - label[-1].set_value(self.s) - label.set_height(1) - label.next_to(pis, UP, LARGE_BUFF) - self.add(label) - - for x in range(4): - self.play(*self.experience_animations( - self.seller, self.buyer, arc=30 * DEGREES, p=self.s - )) - - self.embed() - - -class UniverseOf50Percent(UniverseOf95Percent): - CONFIG = {"s": 0.5} - - -class OpenAndCloseAsideOnPdfs(Scene): - def construct(self): - labels = VGroup( - TextMobject("$\\langle$", "Aside on", " pdfs", "$\\rangle$"), - TextMobject("$\\langle$/", "Aside on", " pdfs", "$\\rangle$"), - ) - labels.set_width(FRAME_WIDTH / 2) - for label in labels: - label.set_color_by_tex("pdfs", YELLOW) - - self.play(FadeInFromDown(labels[0])) - self.wait() - self.play(Transform(*labels)) + line = Line(2 * LEFT, 2 * RIGHT) + line.set_stroke(PINK, 5) + self.play(ShowCreation(line)) self.wait() + line.reverse_points() + self.play(Uncreate(line)) class TryAssigningProbabilitiesToSpecificValues(Scene): def construct(self): - # To get "P(s = 95.9999%) ="" type labels + # To get "P(s = .7000001) = ???" type labels def get_p_label(value): result = TexMobject( - "P(", "{s}", "=", value, "\\%", ")", + # "P(", "{s}", "=", value, "\\%", ")", + "P(", "{h}", "=", value, ")", ) - fix_percent(result.get_part_by_tex("\\%")[0]) - result.set_color_by_tex("{s}", YELLOW) + # fix_percent(result.get_part_by_tex("\\%")[0]) + result.set_color_by_tex("{h}", YELLOW) return result labels = VGroup( - get_p_label("95.0000000"), - get_p_label("94.9999999"), - get_p_label("94.9314159"), - get_p_label("94.9271828"), - get_p_label("94.9466920"), - get_p_label("94.9161803"), + get_p_label("0.70000000"), + get_p_label("0.70000001"), + get_p_label("0.70314159"), + get_p_label("0.70271828"), + get_p_label("0.70466920"), + get_p_label("0.70161803"), ) labels.arrange(DOWN, buff=0.35, aligned_edge=LEFT) + labels.set_height(4.5) + labels.to_edge(DOWN, buff=LARGE_BUFF) q_marks = VGroup() gt_zero = VGroup() @@ -1224,7 +505,7 @@ class TryAssigningProbabilitiesToSpecificValues(Scene): for m1, m2 in [ (q_marks[0], q_marks[1]), (labels[0][:3], labels[1][:3]), - (labels[0][5], labels[1][5]), + (labels[0][-1], labels[1][-1]), ] ]) self.play(ShowIncreasingSubsets( @@ -1259,19 +540,20 @@ class TryAssigningProbabilitiesToSpecificValues(Scene): # Show sum group = VGroup(labels, gt_zero, v_dots) sum_label = TexMobject( - "\\sum_{s}", "P(", "{s}", ")", "=", - tex_to_color_map={"{s}": YELLOW}, + "\\sum_{0 \\le {h} \\le 1}", "P(", "{h}", ")", "=", + tex_to_color_map={"{h}": YELLOW}, ) # sum_label.set_color_by_tex("{s}", YELLOW) sum_label[0].set_color(WHITE) sum_label.scale(1.75) sum_label.next_to(ORIGIN, RIGHT, buff=1) + sum_label.shift(LEFT) morty = Mortimer() morty.set_height(2) morty.to_corner(DR) - self.play(group.next_to, ORIGIN, LEFT) + self.play(group.to_corner, DL) self.play( Write(sum_label), VFadeIn(morty), @@ -1316,11 +598,127 @@ class TryAssigningProbabilitiesToSpecificValues(Scene): self.wait() +class WanderingArrow(Scene): + def construct(self): + arrow = Vector(0.8 * DOWN) + arrow.move_to(4 * LEFT, DOWN) + for u in [1, -1, 1, -1, 1]: + self.play( + arrow.shift, u * 8 * RIGHT, + run_time=3 + ) + + +class ProbabilityToContinuousValuesSupplement(Scene): + def construct(self): + nl = UnitInterval() + nl.set_width(10) + nl.add_numbers( + *np.arange(0, 1.1, 0.1), + buff=0.3, + ) + nl.to_edge(LEFT) + self.add(nl) + + def f(x): + return -100 * (x - 0.6) * (x - 0.8) + + values = np.linspace(0.65, 0.75, 100) + lines = VGroup() + for x, color in zip(values, it.cycle([BLUE_E, BLUE_C])): + line = Line(ORIGIN, UP) + line.set_height(f(x)) + line.set_stroke(color, 1) + line.move_to(nl.n2p(x), DOWN) + lines.add(line) + + self.play(ShowCreation(lines, lag_ratio=0.9, run_time=5)) + + lines_row = lines.copy() + lines_row.generate_target() + for lt in lines_row.target: + lt.rotate(90 * DEGREES) + lines_row.target.arrange(RIGHT, buff=0) + lines_row.target.set_stroke(width=4) + lines_row.target.next_to(nl, UP, LARGE_BUFF) + lines_row.target.align_to(nl.n2p(0), LEFT) + + self.play( + MoveToTarget( + lines_row, + lag_ratio=0.1, + rate_func=rush_into, + run_time=4, + ) + ) + self.wait() + self.play( + lines.set_height, 0.01, {"about_edge": DOWN, "stretch": True}, + ApplyMethod( + lines_row.set_width, 0.01, {"about_edge": LEFT}, + rate_func=rush_into, + ), + run_time=6, + ) + self.wait() + + +class CarFactoryNumbers(Scene): + def construct(self): + # Test words + denom_words = TextMobject( + "in a test of 100 cars", + tex_to_color_map={"100": BLUE}, + ) + denom_words.to_corner(UR) + + numer_words = TextMobject( + "2 defects found", + tex_to_color_map={"2": RED} + ) + numer_words.move_to(denom_words, LEFT) + + self.play(Write(denom_words, run_time=1)) + self.wait() + self.play( + denom_words.next_to, numer_words, DOWN, {"aligned_edge": LEFT}, + FadeIn(numer_words), + ) + self.wait() + + # Question words + question = VGroup( + TextMobject("What can you say"), + TexMobject( + "\\text{about } P(\\text{defect})?", + tex_to_color_map={"\\text{defect}": RED} + ) + ) + + question.arrange(DOWN, aligned_edge=LEFT) + question.next_to(denom_words, DOWN, buff=1.5, aligned_edge=LEFT) + + self.play(FadeIn(question)) + self.wait() + + +class TeacherHoldingValue(TeacherStudentsScene): + def construct(self): + self.play(self.teacher.change, "raise_right_hand", self.screen) + self.change_all_student_modes( + "pondering", + look_at_arg=self.screen, + ) + self.wait(8) + + class ShowLimitToPdf(Scene): def construct(self): # Init axes = self.get_axes() - dist = scipy.stats.beta(4, 2) + alpha = 4 + beta = 2 + dist = scipy.stats.beta(alpha, beta) bars = self.get_bars(axes, dist, 0.05) axis_prob_label = TextMobject("Probability") @@ -1331,9 +729,9 @@ class ShowLimitToPdf(Scene): self.add(axis_prob_label) # From individual to ranges - kw = {"tex_to_color_map": {"s": YELLOW}} - eq_label = TexMobject("P(s = 0.8)", **kw) - ineq_label = TexMobject("P(0.8 < s < 0.85)", **kw) + kw = {"tex_to_color_map": {"h": YELLOW}} + eq_label = TexMobject("P(h = 0.8)", **kw) + ineq_label = TexMobject("P(0.8 < h < 0.85)", **kw) arrows = VGroup(Vector(DOWN), Vector(DOWN)) for arrow, x in zip(arrows, [0.8, 0.85]): @@ -1359,10 +757,12 @@ class ShowLimitToPdf(Scene): ) self.wait() + # Bars arrow = arrows[0] arrow.generate_target() arrow.target.next_to(bars[16], UP, SMALL_BUFF) - bars[16].set_color(GREEN) + highlighted_bar_color = RED_E + bars[16].set_color(highlighted_bar_color) for bar in bars: bar.save_state() @@ -1448,6 +848,42 @@ class ShowLimitToPdf(Scene): prob_label = VGroup(area_word, *prob_label[1:]) self.add(prob_label) + # Ask about where values come from + randy = Randolph(height=1) + randy.next_to(prob_label, UP, aligned_edge=LEFT) + + bubble = SpeechBubble( + height=2, + width=4, + ) + bubble.move_to(randy.get_corner(UR), DL) + bubble.write("Where do these\\\\probabilities come from?") + + self.play( + FadeIn(randy), + ShowCreation(bubble), + ) + self.play( + randy.change, "confused", + FadeIn(bubble.content, lag_ratio=0.1) + ) + self.play(Blink(randy)) + + bars.generate_target() + bars.save_state() + bars.target.arrange(RIGHT, buff=SMALL_BUFF, aligned_edge=DOWN) + bars.target.next_to(bars.get_bottom(), UP) + + self.play(MoveToTarget(bars)) + self.play(LaggedStartMap(Indicate, bars, scale_factor=1.05), run_time=1) + self.play(Restore(bars)) + self.play(Blink(randy)) + self.play( + FadeOut(randy), + FadeOut(bubble), + FadeOut(bubble.content), + ) + # Refine last_ineq_label = ineq_label last_bars = bars @@ -1455,8 +891,8 @@ class ShowLimitToPdf(Scene): for step_size in [0.025, 0.01, 0.005, 0.001]: new_bars = self.get_bars(axes, dist, step_size) new_ineq_label = TexMobject( - "P(0.8 < s < {:.3})".format(0.8 + step_size), - tex_to_color_map={"s": YELLOW}, + "P(0.8 < h < {:.3})".format(0.8 + step_size), + tex_to_color_map={"h": YELLOW}, ) if step_size <= 0.005: @@ -1464,7 +900,7 @@ class ShowLimitToPdf(Scene): arrow.generate_target() bar = new_bars[int(0.8 * len(new_bars))] - bar.set_color(GREEN) + bar.set_color(highlighted_bar_color) arrow.target.next_to(bar, UP, SMALL_BUFF) new_ineq_label.next_to(arrow.target, UP) @@ -1486,7 +922,7 @@ class ShowLimitToPdf(Scene): all_ineq_labels.add(new_ineq_label) # Show continuous graph - graph = get_beta_graph(axes, 3, 1) + graph = get_beta_graph(axes, alpha - 1, beta - 1) graph_curve = axes.get_graph(dist.pdf) graph_curve.set_stroke([YELLOW, GREEN]) @@ -1568,12 +1004,74 @@ class ShowLimitToPdf(Scene): ) ) self.wait(2) + + # What if it was heights + bars.restore() + height_word.move_to(area_word, RIGHT) + height_word.set_color(PINK) + step = 0.05 + new_y_numbers = VGroup(*[ + DecimalNumber(x) for x in np.arange(step, 5 * step, step) + ]) + for n1, n2 in zip(axes.y_axis.numbers, new_y_numbers): + n2.match_height(n1) + n2.add_background_rectangle( + opacity=1, + buff=SMALL_BUFF, + ) + n2.move_to(n1, RIGHT) + + self.play( + FadeOut(limit_words), + FadeOut(graph), + FadeIn(bars), + FadeOutAndShift(area_word, UP), + FadeInFrom(height_word, DOWN), + FadeInFrom(new_y_numbers, 0.5 * RIGHT), + ) + + # Height refine + rect = SurroundingRectangle(rhss[0][1]) + rect.set_stroke(RED, 3) + self.play(FadeIn(rect)) + + last_bars = bars + for step_size, rhs in zip(step_sizes[1:], rhss[1:]): + new_bars = self.get_bars(axes, dist, step_size) + bar = new_bars[int(0.8 * len(new_bars))] + bar.set_color(highlighted_bar_color) + new_bars.stretch( + step_size / 0.05, 1, + about_edge=DOWN, + ) + if step_size <= 0.05: + new_bars.set_stroke(width=0) + self.remove(last_bars) + self.play( + TransformFromCopy(last_bars, new_bars, lag_ratio=step_size), + rect.move_to, rhs[1], + ) + last_bars = new_bars + self.play( + FadeOut(last_bars), + FadeOutAndShiftDown(rect), + ) + self.wait() + + # Back to area + self.play( + FadeIn(graph), + FadeInFrom(area_word, 0.5 * DOWN), + FadeOutAndShift(height_word, 0.5 * UP), + FadeOut(new_y_numbers, lag_ratio=0.2), + ) self.play( arrow.scale, 0, {"about_edge": DOWN}, FadeOutAndShift(to_zero_words, DOWN), LaggedStartMap(FadeOutAndShiftDown, all_ineq_labels), LaggedStartMap(FadeOutAndShiftDown, rhss), ) + self.wait() # Ask about y_axis units arrow = Arrow( @@ -1595,7 +1093,6 @@ class ShowLimitToPdf(Scene): self.play( FadeOut(graph), FadeIn(bars), - FadeOut(limit_words) ) bars.generate_target() bars.save_state() @@ -1669,9 +1166,13 @@ class ShowLimitToPdf(Scene): new_bars = self.get_bars(axes, dist, step_size) if step_size <= 0.05: new_bars.set_stroke(width=0) - self.play(ReplacementTransform( - bars, new_bars, lag_ratio=step_size - )) + self.play( + ReplacementTransform( + bars, new_bars, lag_ratio=step_size + ), + run_time=3, + ) + self.wait() bars = new_bars self.add(graph, total_label) self.play( @@ -1738,7 +1239,7 @@ class ShowLimitToPdf(Scene): TexMobject("P("), DecimalNumber(min_x), TexMobject("\\le"), - TexMobject("s", color=YELLOW), + TexMobject("h", color=YELLOW), TexMobject("\\le"), DecimalNumber(max_x), TexMobject(")") @@ -1813,6 +1314,7 @@ class ShowLimitToPdf(Scene): run_time=2, ) self.wait() + # Stretch to area 1 self.play( ChangeDecimalToValue(p_label[1], 0), @@ -1835,7 +1337,7 @@ class ShowLimitToPdf(Scene): interpolate(m.new_x, 1, a), )) ), - run_time=3, + run_time=5, ) self.wait() @@ -1858,11 +1360,11 @@ class ShowLimitToPdf(Scene): ) axes.center() - s_label = TexMobject("s") - s_label.set_color(YELLOW) - s_label.next_to(axes.x_axis, RIGHT) - axes.x_axis.add(s_label) - axes.x_axis.s_label = s_label + h_label = TexMobject("h") + h_label.set_color(YELLOW) + h_label.next_to(axes.x_axis.n2p(1), UR, buff=0.2) + axes.x_axis.add(h_label) + axes.x_axis.label = h_label axes.x_axis.add_numbers( *np.arange(0.2, 1.2, 0.2), @@ -1892,594 +1394,945 @@ class ShowLimitToPdf(Scene): return bars -class BayesRuleWithPdf(ShowLimitToPdf): +class FiniteVsContinuum(Scene): def construct(self): - # Axes - axes = self.get_axes() - sf = 1.5 - axes.y_axis.stretch(sf, 1, about_point=axes.c2p(0, 0)) - for number in axes.y_axis.numbers: - number.stretch(1 / sf, 1) - self.add(axes) + # Title + f_title = TextMobject("Discrete context") + f_title.set_height(0.5) + f_title.to_edge(UP) + f_underline = Underline(f_title) + f_underline.scale(1.3) + f_title.add(f_underline) + self.add(f_title) - # Formula - bayes = self.get_formula() + # Equations + dice = get_die_faces()[::2] + cards = [PlayingCard(letter + "H") for letter in "A35"] - post = bayes[:5] - eq = bayes[5] - prior = bayes[6:9] - likelihood = bayes[9:14] - over = bayes[14] - p_data = bayes[15:] - - self.play(FadeInFromDown(bayes)) - self.wait() - - # Prior - prior_graph = get_beta_graph(axes, 0, 0) - prior_graph_top = Line( - prior_graph.get_corner(UL), - prior_graph.get_corner(UR), + eqs = VGroup( + self.get_union_equation(dice), + self.get_union_equation(cards), ) - prior_graph_top.set_stroke(YELLOW, 3) + for eq in eqs: + eq.set_width(FRAME_WIDTH - 1) + eqs.arrange(DOWN, buff=LARGE_BUFF) + eqs.next_to(f_underline, DOWN, LARGE_BUFF) - bayes.save_state() - bayes.set_opacity(0.2) - prior.set_opacity(1) + anims = [] + for eq in eqs: + movers = eq.mob_copies1.copy() + for m1, m2 in zip(movers, eq.mob_copies2): + m1.generate_target() + m1.target.replace(m2) + eq.mob_copies2.set_opacity(0) + eq.add(movers) + + self.play(FadeIn(eq[0])) + + anims.append(FadeIn(eq[1:])) + anims.append(LaggedStartMap( + MoveToTarget, movers, + path_arc=30 * DEGREES, + lag_ratio=0.1, + )) + self.wait() + for anim in anims: + self.play(anim) + + # Continuum label + c_title = TextMobject("Continuous context") + c_title.match_height(f_title) + c_underline = Underline(c_title) + c_underline.scale(1.25) self.play( - Restore(bayes, rate_func=reverse_smooth), - FadeIn(prior_graph), - ShowCreation(prior_graph_top), + Write(c_title, run_time=1), + ShowCreation(c_underline), + eqs[0].shift, 0.5 * UP, + eqs[1].shift, UP, ) - self.play(FadeOut(prior_graph_top)) + + # Range sum + c_eq = TexMobject( + "P\\big(", "x \\in [0.65, 0.75]", "\\big)", + "=", + "\\sum_{x \\in [0.65, 0.75]}", + "P(", "x", ")", + ) + c_eq.set_color_by_tex("P", YELLOW) + c_eq.set_color_by_tex(")", YELLOW) + c_eq.next_to(c_underline, DOWN, LARGE_BUFF) + c_eq.to_edge(LEFT) + + equals = c_eq.get_part_by_tex("=") + equals.shift(SMALL_BUFF * RIGHT) + e_cross = Line(DL, UR) + e_cross.replace(equals, dim_to_match=0) + e_cross.set_stroke(RED, 5) + + self.play(FadeIn(c_eq)) + self.wait(2) + self.play(ShowCreation(e_cross)) self.wait() - # Scale Down - nh = 1 - nt = 2 + def get_union_equation(self, mobs): + mob_copies1 = VGroup() + mob_copies2 = VGroup() + p_color = YELLOW - scaled_graph = axes.get_graph( - lambda x: scipy.stats.binom(3, x).pmf(1) + 1e-6 + # Create mob_set + brackets = TexMobject("\\big\\{\\big\\}")[0] + mob_set = VGroup(brackets[0]) + commas = VGroup() + for mob in mobs: + mc = mob.copy() + mc.match_height(mob_set[0]) + mob_copies1.add(mc) + comma = TexMobject(",") + commas.add(comma) + mob_set.add(mc) + mob_set.add(comma) + + mob_set.remove(commas[-1]) + commas.remove(commas[-1]) + mob_set.add(brackets[1]) + mob_set.arrange(RIGHT, buff=0.15) + commas.set_y(mob_set[1].get_bottom()[1]) + + mob_set.scale(0.8) + + # Create individual probabilities + probs = VGroup() + for mob in mobs: + prob = TexMobject("P(", "x = ", "00", ")") + index = prob.index_of_part_by_tex("00") + mc = mob.copy() + mc.replace(prob[index]) + mc.scale(0.8, about_edge=LEFT) + mc.match_y(prob[-1]) + mob_copies2.add(mc) + prob.replace_submobject(index, mc) + prob[0].set_color(p_color) + prob[1].match_y(mc) + prob[-1].set_color(p_color) + probs.add(prob) + + # Result + lhs = VGroup( + TexMobject("P\\big(", color=p_color), + TexMobject("x \\in"), + mob_set, + TexMobject("\\big)", color=p_color), ) - scaled_graph.set_stroke(GREEN) - scaled_region = get_region_under_curve(axes, scaled_graph, 0, 1) + lhs.arrange(RIGHT, buff=SMALL_BUFF) + group = VGroup(lhs, TexMobject("=")) + for prob in probs: + group.add(prob) + group.add(TexMobject("+")) + group.remove(group[-1]) - def to_uniform(p, axes=axes): - return axes.c2p( - axes.x_axis.p2n(p), - int(axes.y_axis.p2n(p) != 0), + group.arrange(RIGHT, buff=0.2) + group.mob_copies1 = mob_copies1 + group.mob_copies2 = mob_copies2 + + return group + + +class ComplainAboutRuleChange(TeacherStudentsScene): + def construct(self): + self.student_says( + "Wait, the rules\\\\changed?", + target_mode="sassy", + added_anims=[self.teacher.change, "tease"] + ) + self.change_student_modes("erm", "confused") + self.wait(4) + self.teacher_says("You may enjoy\\\\``Measure theory''") + self.change_all_student_modes( + "pondering", + look_at_arg=self.teacher.bubble + ) + self.wait(8) + + +class HalfFiniteHalfContinuous(Scene): + def construct(self): + # Basic symbols + box = Rectangle(width=3, height=1.2) + box.set_stroke(WHITE, 2) + box.set_fill(GREY_E, 1) + box.move_to(2.5 * LEFT, RIGHT) + + arrows = VGroup() + arrow_labels = VGroup() + for vect in [UP, DOWN]: + arrow = Arrow( + box.get_corner(vect + RIGHT), + box.get_corner(vect + RIGHT) + 3 * RIGHT + 1.5 * vect, + buff=MED_SMALL_BUFF, + ) + label = TexMobject("50\\%") + fix_percent(label[0][-1]) + label.set_color(YELLOW) + label.next_to( + arrow.get_center(), + vect + LEFT, + buff=SMALL_BUFF, ) - scaled_region.set_fill(opacity=0.75) - scaled_region.save_state() - scaled_region.apply_function(to_uniform) + arrow_labels.add(label) + arrows.add(arrow) - self.play( - Restore(scaled_region), - UpdateFromAlphaFunc( - scaled_region, - lambda m, a: m.set_opacity(a * 0.75), - ), - likelihood.set_opacity, 1, - ) - self.wait() + zero = Integer(0) + zero.set_height(0.5) + zero.next_to(arrows[0].get_end(), RIGHT) - # Rescale - new_graph = get_beta_graph(axes, nh, nt) - self.play( - ApplyMethod( - scaled_region.set_height, new_graph.get_height(), - {"about_edge": DOWN, "stretch": True}, - run_time=2, - ), - over.set_opacity, 1, - p_data.set_opacity, 1, - ) - self.wait() - self.play( - post.set_opacity, 1, - eq.set_opacity, 1, - ) - self.wait() - - # Use lower case - new_bayes = self.get_formula(lowercase=True) - new_bayes.replace(bayes, dim_to_match=0) - rects = VGroup( - SurroundingRectangle(new_bayes[0][0]), - SurroundingRectangle(new_bayes[6][0]), - ) - rects.set_stroke(YELLOW, 3) - - self.remove(bayes) - bayes = self.get_formula() - bayes.unlock_triangulation() - self.add(bayes) - self.play(Transform(bayes, new_bayes)) - self.play(ShowCreationThenFadeOut(rects)) - - def get_formula(self, lowercase=False): - p_sym = "p" if lowercase else "P" - bayes = TexMobject( - p_sym + "({s} \\,|\\, \\text{data})", "=", - "{" + p_sym + "({s})", - "P(\\text{data} \\,|\\, {s})", - "\\over", - "P(\\text{data})", - tex_to_color_map={ - "{s}": YELLOW, - "\\text{data}": GREEN, + # Half Gaussian + axes = Axes( + x_min=0, + x_max=6.5, + y_min=0, + y_max=0.25, + y_axis_config={ + "tick_frequency": 1 / 16, + "unit_size": 10, + "include_tip": False, } ) - bayes.set_height(1.5) - bayes.to_edge(UP) - return bayes + axes.next_to(arrows[1].get_end(), RIGHT) + dist = scipy.stats.norm(0, 2) + graph = axes.get_graph(dist.pdf) + graph_fill = graph.copy() + close_off_graph(axes, graph_fill) + graph.set_stroke(BLUE, 3) + graph_fill.set_fill(BLUE_E, 1) + graph_fill.set_stroke(BLUE_E, 0) -class TalkThroughCoinExample(ShowBayesianUpdating): - def construct(self): - # Setup - axes = self.get_axes() - x_label = TexMobject("x") - x_label.next_to(axes.x_axis.get_end(), UR, MED_SMALL_BUFF) - axes.add(x_label) - - p_label, prob, prob_box = self.get_probability_label() - prob_box_x = x_label.copy().move_to(prob_box) - - self.add(axes) - self.add(p_label) - self.add(prob_box) - - self.wait() - q_marks = prob_box[1] - prob_box.remove(q_marks) - self.play( - FadeOut(q_marks), - TransformFromCopy(x_label, prob_box_x) - ) - prob_box.add(prob_box_x) - - # Setup coins - bool_values = (np.random.random(100) < self.true_p) - bool_values[:5] = [True, False, True, True, False] - coins = self.get_coins(bool_values) - coins.next_to(axes.y_axis, RIGHT, MED_LARGE_BUFF) - coins.to_edge(UP) - - # Random coin - rows = VGroup() - for x in range(5): - row = self.get_coins(np.random.random(10) < self.true_p) - row.arrange(RIGHT, buff=MED_LARGE_BUFF) - row.set_width(6) - row.move_to(UP) - rows.add(row) - - last_row = VMobject() - for row in rows: - self.play( - FadeOutAndShift(last_row, DOWN), - FadeIn(row, lag_ratio=0.1) - ) - last_row = row - self.play(FadeOutAndShift(last_row, DOWN)) - - # Uniform pdf - region = get_beta_graph(axes, 0, 0) - graph = Line( - region.get_corner(UL), - region.get_corner(UR), - ) - func_label = TexMobject("f(x) =", "1") - func_label.next_to(graph, UP) - - self.play( - FadeIn(func_label, lag_ratio=0.1), - ShowCreation(graph), - ) - self.add(region, graph) - self.play(FadeIn(region)) - self.wait() - - # First flip - coin = coins[0] - arrow = Vector(0.5 * UP) - arrow.next_to(coin, DOWN, SMALL_BUFF) - data_label = TextMobject("New data") - data_label.set_height(0.25) - data_label.next_to(arrow, DOWN) - data_label.shift(0.5 * RIGHT) - - self.play( - FadeInFrom(coin, DOWN), - GrowArrow(arrow), - Write(data_label, run_time=1) - ) - self.wait() - - # Show Bayes rule - bayes = TexMobject( - "p({x} | \\text{data})", "=", - "p({x})", - "{P(\\text{data} | {x})", - "\\over", - "P(\\text{data})", - tex_to_color_map={ - "{x}": WHITE, - "\\text{data}": GREEN, - } - ) - bayes.next_to(func_label, UP, LARGE_BUFF, LEFT) - - likelihood = bayes[9:14] - p_data = bayes[15:] - likelihood_rect = SurroundingRectangle(likelihood, buff=0.05) - likelihood_rect.save_state() - p_data_rect = SurroundingRectangle(p_data, buff=0.05) - - likelihood_x_label = TexMobject("x") - likelihood_x_label.next_to(likelihood_rect, UP) - - self.play(FadeInFromDown(bayes)) - self.wait() - self.play(ShowCreation(likelihood_rect)) - self.wait() - - self.play(TransformFromCopy(likelihood[-2], likelihood_x_label)) - self.wait() - - # Scale by x - times_x = TexMobject("\\cdot \\, x") - times_x.next_to(func_label, RIGHT, buff=0.2) - - new_graph = axes.get_graph(lambda x: x) - sub_region = get_region_under_curve(axes, new_graph, 0, 1) - - self.play( - Write(times_x), - Transform(graph, new_graph), - ) - self.play( - region.set_opacity, 0.5, - FadeIn(sub_region), - ) - self.wait() - - # Show example scalings - low_x = 0.1 - high_x = 0.9 - lines = VGroup() - for x in [low_x, high_x]: - lines.add(Line(axes.c2p(x, 0), axes.c2p(x, 1))) - - lines.set_stroke(YELLOW, 3) - - for x, line in zip([low_x, high_x], lines): - self.play(FadeIn(line)) - self.play(line.scale, x, {"about_edge": DOWN}) - self.wait() - self.play(FadeOut(lines)) - - # Renormalize - self.play( - FadeOut(likelihood_x_label), - ReplacementTransform(likelihood_rect, p_data_rect), - ) - self.wait() - - one = func_label[1] - two = TexMobject("2") - two.move_to(one, LEFT) - - self.play( - FadeOut(region), - sub_region.stretch, 2, 1, {"about_edge": DOWN}, - sub_region.set_color, BLUE, - graph.stretch, 2, 1, {"about_edge": DOWN}, - FadeInFromDown(two), - FadeOutAndShift(one, UP), - ) - region = sub_region - func_label = VGroup(func_label[0], two, times_x) - self.add(func_label) - - self.play(func_label.shift, 0.5 * UP) - self.wait() - - const = TexMobject("C") - const.scale(0.9) - const.move_to(two, DR) - const.shift(0.07 * RIGHT) - self.play( - FadeOutAndShift(two, UP), - FadeInFrom(const, DOWN) - ) - self.remove(func_label) - func_label = VGroup(func_label[0], const, times_x) - self.add(func_label) - self.play(FadeOut(p_data_rect)) - self.wait() - - # Show tails - coin = coins[1] - self.play( - arrow.next_to, coin, DOWN, SMALL_BUFF, - MaintainPositionRelativeTo(data_label, arrow), - FadeInFromDown(coin), - ) - self.wait() - - to_prior_arrow = Arrow( - func_label[0][3], - bayes[6], - max_tip_length_to_length_ratio=0.15, - stroke_width=3, - ) - to_prior_arrow.set_color(RED) - - self.play(Indicate(func_label, scale_factor=1.2, color=RED)) - self.play(ShowCreation(to_prior_arrow)) - self.wait() - self.play(FadeOut(to_prior_arrow)) - - # Scale by (1 - x) - eq_1mx = TexMobject("(1 - x)") - dot = TexMobject("\\cdot") - rhs_part = VGroup(dot, eq_1mx) - rhs_part.arrange(RIGHT, buff=0.2) - rhs_part.move_to(func_label, RIGHT) - - l_1mx = eq_1mx.copy() - likelihood_rect.restore() - l_1mx.next_to(likelihood_rect, UP, SMALL_BUFF) - - self.play( - ShowCreation(likelihood_rect), - FadeInFrom(l_1mx, 0.5 * DOWN), - ) - self.wait() - self.play(ShowCreationThenFadeOut(Underline(p_label))) - self.play(Indicate(coins[1])) - self.wait() - self.play( - TransformFromCopy(l_1mx, eq_1mx), - FadeInFrom(dot, RIGHT), - func_label.next_to, dot, LEFT, 0.2, + half_gauss = Group( + graph, graph_fill, axes, ) - scaled_graph = axes.get_graph(lambda x: 2 * x * (1 - x)) - scaled_region = get_region_under_curve(axes, scaled_graph, 0, 1) + # Random Decimal + number = DecimalNumber(num_decimal_places=4) + number.set_height(0.6) + number.move_to(box) - self.play(Transform(graph, scaled_graph)) - self.play(FadeIn(scaled_region)) - self.wait() + number.time = 0 + number.last_change = 0 + number.change_freq = 0.2 - # Renormalize - self.remove(likelihood_rect) - self.play( - TransformFromCopy(likelihood_rect, p_data_rect), - FadeOut(l_1mx) - ) - new_graph = get_beta_graph(axes, 1, 1) - group = VGroup(graph, scaled_region) - self.play( - group.set_height, - new_graph.get_height(), {"about_edge": DOWN, "stretch": True}, - group.set_color, BLUE, - FadeOut(region), - ) - region = scaled_region - self.play(FadeOut(p_data_rect)) - self.wait() - self.play(ShowCreationThenFadeAround(const)) + def update_number(number, dt, dist=dist): + number.time += dt - # Repeat - exp1 = Integer(1) - exp1.set_height(0.2) - exp1.move_to(func_label[2].get_corner(UR), DL) - exp1.shift(0.02 * DOWN + 0.07 * RIGHT) + if (number.time - number.last_change) < number.change_freq: + return - exp2 = exp1.copy() - exp2.move_to(eq_1mx.get_corner(UR), DL) - exp2.shift(0.1 * RIGHT) - exp2.align_to(exp1, DOWN) - - shift_vect = UP + 0.5 * LEFT - VGroup(exp1, exp2).shift(shift_vect) - - self.play( - FadeInFrom(exp1, DOWN), - FadeInFrom(exp2, DOWN), - VGroup(func_label, dot, eq_1mx).shift, shift_vect, - bayes.scale, 0.5, - bayes.next_to, p_label, DOWN, LARGE_BUFF, {"aligned_edge": RIGHT}, - ) - nh = 1 - nt = 1 - for coin, is_heads in zip(coins[2:10], bool_values[2:10]): - self.play( - arrow.next_to, coin, DOWN, SMALL_BUFF, - MaintainPositionRelativeTo(data_label, arrow), - FadeInFrom(coin, DOWN), - ) - if is_heads: - nh += 1 - old_exp = exp1 + number.last_change = number.time + rand_val = random.random() + if rand_val < 0.5: + number.set_value(0) else: - nt += 1 - old_exp = exp2 + number.set_value(dist.ppf(rand_val)) - new_exp = old_exp.copy() - new_exp.increment_value(1) + number.add_updater(update_number) - dist = scipy.stats.beta(nh + 1, nt + 1) - new_graph = axes.get_graph(dist.pdf) - new_region = get_region_under_curve(axes, new_graph, 0, 1) - new_region.match_style(region) + v_line = SurroundingRectangle(zero) + v_line.save_state() + v_line.set_stroke(YELLOW, 3) - self.play( - FadeOut(graph), - FadeOut(region), - FadeIn(new_graph), - FadeIn(new_region), - FadeOutAndShift(old_exp, MED_SMALL_BUFF * UP), - FadeInFrom(new_exp, MED_SMALL_BUFF * DOWN), - ) - graph = new_graph - region = new_region - self.remove(new_exp) - self.add(old_exp) - old_exp.increment_value() - self.wait() + def update_v_line(v_line, number=number, axes=axes, graph=graph): + x = number.get_value() + if x < 0.5: + v_line.restore() + else: + v_line.set_width(1e-6) + p1 = axes.c2p(x, 0) + p2 = axes.input_to_graph_point(x, graph) + v_line.set_height(get_norm(p2 - p1), stretch=True) + v_line.move_to(p1, DOWN) - if coin is coins[4]: - area_label = TextMobject("Area = 1") - area_label.move_to(axes.c2p(0.6, 0.8)) - self.play(GrowFromPoint( - area_label, const.get_center() - )) + v_line.add_updater(update_v_line) + + # Add everything + self.add(box) + self.add(number) + self.wait(4) + self.play( + GrowArrow(arrows[0]), + FadeIn(arrow_labels[0]), + GrowFromPoint(zero, box.get_corner(UR)) + ) + self.wait(2) + self.play( + GrowArrow(arrows[1]), + FadeIn(arrow_labels[1]), + FadeIn(half_gauss), + ) + self.add(v_line) + + self.wait(30) -class PDefectEqualsQmark(Scene): +class SumToIntegral(Scene): def construct(self): - label = TexMobject( - "P(\\text{Defect}) = ???", - tex_to_color_map={ - "\\text{Defect}": RED, + # Titles + titles = VGroup( + TextMobject("Discrete context"), + TextMobject("Continuous context"), + ) + titles.set_height(0.5) + for title, vect in zip(titles, [LEFT, RIGHT]): + title.move_to(vect * FRAME_WIDTH / 4) + title.to_edge(UP, buff=MED_SMALL_BUFF) + + v_line = Line(UP, DOWN).set_height(FRAME_HEIGHT) + h_line = Line(LEFT, RIGHT).set_width(FRAME_WIDTH) + h_line.next_to(titles, DOWN) + h_line.set_x(0) + v_line.center() + + self.play( + ShowCreation(VGroup(h_line, v_line)), + LaggedStartMap( + FadeInFrom, titles, + lambda m: (m, -0.2 * m.get_center()[0] * RIGHT), + run_time=1, + lag_ratio=0.1, + ), + ) + self.wait() + + # Sum and int + kw = {"tex_to_color_map": {"S": BLUE}} + s_sym = TexMobject("\\sum", "_{x \\in S} P(x)", **kw) + i_sym = TexMobject("\\int_{S} p(x)", "\\text{d}x", **kw) + syms = VGroup(s_sym, i_sym) + syms.scale(2) + for sym, title in zip(syms, titles): + sym.shift(-sym[-1].get_center()) + sym.match_x(title) + + arrow = Arrow( + s_sym[0].get_corner(UP), + i_sym[0].get_corner(UP), + path_arc=-90 * DEGREES, + ) + arrow.set_color(YELLOW) + + self.play(Write(s_sym, run_time=1)) + anims = [ShowCreation(arrow)] + for i, j in [(0, 0), (2, 1), (3, 2)]: + source = s_sym[i].deepcopy() + target = i_sym[j] + target.save_state() + source.generate_target() + target.replace(source, stretch=True) + source.target.replace(target, stretch=True) + target.set_opacity(0) + source.target.set_opacity(0) + anims += [ + Restore(target, path_arc=-60 * DEGREES), + MoveToTarget(source, path_arc=-60 * DEGREES), + ] + self.play(LaggedStart(*anims)) + self.play(FadeInFromDown(i_sym[3])) + self.add(i_sym) + self.wait() + self.play( + FadeOutAndShift(arrow, UP), + syms.next_to, h_line, DOWN, {"buff": MED_LARGE_BUFF}, + syms.match_x, syms, + ) + + # Add curve area in editing + # Add bar chart + axes = Axes( + x_min=0, + x_max=10, + y_min=0, + y_max=7, + y_axis_config={ + "unit_size": 0.75, } ) - self.play(FadeInFrom(label, DOWN)) - self.wait() + axes.set_width(0.5 * FRAME_WIDTH - 1) + axes.next_to(s_sym, DOWN) + axes.y_axis.add_numbers(2, 4, 6) + bars = VGroup() + for x, y in [(1, 1), (4, 3), (7, 2)]: + bar = Rectangle() + bar.set_stroke(WHITE, 1) + bar.set_fill(BLUE_D, 1) + line = Line(axes.c2p(x, 0), axes.c2p(x + 2, y)) + bar.replace(line, stretch=True) + bars.add(bar) -class UpdateOnceWithBinomial(TalkThroughCoinExample): - def construct(self): - # Fair bit of copy-pasting from above. If there's - # time, refactor this properly - # Setup - axes = self.get_axes() - x_label = TexMobject("x") - x_label.next_to(axes.x_axis.get_end(), UR, MED_SMALL_BUFF) - axes.add(x_label) + addition_formula = TexMobject(*"1+3+2") + addition_formula.space_out_submobjects(2.1) + addition_formula.next_to(bars, UP) - p_label, prob, prob_box = self.get_probability_label() - prob_box_x = x_label.copy().move_to(prob_box) - - q_marks = prob_box[1] - prob_box.remove(q_marks) - prob_box.add(prob_box_x) - - self.add(axes) - self.add(p_label) - self.add(prob_box) - - # Coins - bool_values = (np.random.random(100) < self.true_p) - bool_values[:5] = [True, False, True, True, False] - coins = self.get_coins(bool_values) - coins.next_to(axes.y_axis, RIGHT, MED_LARGE_BUFF) - coins.to_edge(UP) - self.add(coins[:10]) - - # Uniform pdf - region = get_beta_graph(axes, 0, 0) - graph = axes.get_graph( - lambda x: 1, - min_samples=30, - ) - self.add(region, graph) - - # Show Bayes rule - bayes = TexMobject( - "p({x} | \\text{data})", "=", - "p({x})", - "{P(\\text{data} | {x})", - "\\over", - "P(\\text{data})", - tex_to_color_map={ - "{x}": WHITE, - "\\text{data}": GREEN, - } - ) - bayes.move_to(axes.c2p(0, 2.5)) - bayes.align_to(coins, LEFT) - - likelihood = bayes[9:14] - # likelihood_rect = SurroundingRectangle(likelihood, buff=0.05) - - self.add(bayes) - - # All data at once - brace = Brace(coins[:10], DOWN) - all_data_label = brace.get_text("One update from all data") - - self.wait() - self.play( - GrowFromCenter(brace), - FadeInFrom(all_data_label, 0.2 * UP), - ) - self.wait() - - # Binomial formula - nh = sum(bool_values[:10]) - nt = sum(~bool_values[:10]) - - likelihood_brace = Brace(likelihood, UP) - t2c = { - str(nh): BLUE, - str(nt): RED, - } - binom_formula = TexMobject( - "{10 \\choose ", str(nh), "}", - "x^{", str(nh), "}", - "(1-x)^{" + str(nt) + "}", - tex_to_color_map=t2c, - ) - binom_formula[0][-1].set_color(BLUE) - binom_formula[1].set_color(WHITE) - binom_formula.set_width(likelihood_brace.get_width() + 0.5) - binom_formula.next_to(likelihood_brace, UP) + for bar in bars: + bar.save_state() + bar.stretch(0, 1, about_edge=DOWN) self.play( - TransformFromCopy(brace, likelihood_brace), - FadeOut(all_data_label), - FadeIn(binom_formula) + Write(axes), + LaggedStartMap(Restore, bars), + LaggedStartMap(FadeInFromDown, addition_formula), ) self.wait() - # New plot - rhs = TexMobject( - "C \\cdot", - "x^{", str(nh), "}", - "(1-x)^{", str(nt), "}", - tex_to_color_map=t2c - ) - rhs.next_to(bayes[:5], DOWN, LARGE_BUFF, aligned_edge=LEFT) - eq = TexMobject("=") - eq.rotate(90 * DEGREES) - eq.next_to(bayes[:5], DOWN, buff=0.35) - - dist = scipy.stats.beta(nh + 1, nt + 1) - new_graph = axes.get_graph(dist.pdf) - new_graph.shift(1e-6 * UP) - new_graph.set_stroke(WHITE, 1, opacity=0.5) - new_region = get_region_under_curve(axes, new_graph, 0, 1) - new_region.match_style(region) - new_region.set_opacity(0.75) - - self.add(new_region, new_graph, bayes) - region.unlock_triangulation() + # Confusion + morty = Mortimer() + morty.to_corner(DR) + morty.look_at(i_sym) self.play( - FadeOut(graph), - FadeOut(region), - FadeIn(new_graph), - FadeIn(new_region), + *map(FadeOut, [axes, bars, addition_formula]), + FadeIn(morty) + ) + self.play(morty.change, "maybe") + self.play(Blink(morty)) + self.play(morty.change, "confused", i_sym.get_right()) + self.play(Blink(morty)) + self.wait() + + # Focus on integral + self.play( + Uncreate(VGroup(v_line, h_line)), + FadeOutAndShift(titles, UP), + FadeOutAndShift(morty, RIGHT), + FadeOutAndShift(s_sym, LEFT), + i_sym.center, + i_sym.to_edge, LEFT + ) + + arrows = VGroup() + for vect in [UP, DOWN]: + corner = i_sym[-1].get_corner(RIGHT + vect) + arrows.add(Arrow( + corner, + corner + 2 * RIGHT + 2 * vect, + path_arc=-np.sign(vect[1]) * 60 * DEGREES, + )) + + self.play(*map(ShowCreation, arrows)) + + # Types of integration + dist = scipy.stats.beta(7 + 1, 3 + 1) + axes_pair = VGroup() + graph_pair = VGroup() + for arrow in arrows: + axes = get_beta_dist_axes(y_max=5, y_unit=1) + axes.set_width(4) + axes.next_to(arrow.get_end(), RIGHT) + graph = axes.get_graph(dist.pdf) + graph.set_stroke(BLUE, 2) + graph.set_fill(BLUE_E, 0) + graph.make_smooth() + axes_pair.add(axes) + graph_pair.add(graph) + + r_axes, l_axes = axes_pair + r_graph, l_graph = graph_pair + r_name = TextMobject("Riemann\\\\Integration") + r_name.next_to(r_axes, RIGHT) + l_name = TextMobject("Lebesgue\\\\Integration$^*$") + l_name.next_to(l_axes, RIGHT) + footnote = TextMobject("*a bit more complicated than\\\\these bars make it look") + footnote.match_width(l_name) + footnote.next_to(l_name, DOWN) + + self.play(LaggedStart( + FadeIn(r_axes), + FadeIn(r_graph), + FadeIn(r_name), + FadeIn(l_axes), + FadeIn(l_graph), + FadeIn(l_name), run_time=1, - ) + )) + + # Approximation bars + def get_riemann_rects(dx, axes=r_axes, func=dist.pdf): + bars = VGroup() + for x in np.arange(0, 1, dx): + bar = Rectangle() + line = Line( + axes.c2p(x, 0), + axes.c2p(x + dx, func(x)), + ) + bar.replace(line, stretch=True) + bar.set_stroke(BLUE_E, width=10 * dx, opacity=1) + bar.set_fill(BLUE, 0.5) + bars.add(bar) + return bars + + def get_lebesgue_bars(dy, axes=l_axes, func=dist.pdf, mx=0.7, y_max=dist.pdf(0.7)): + bars = VGroup() + for y in np.arange(dy, y_max + dy, dy): + x0 = binary_search(func, y, 0, mx) or mx + x1 = binary_search(func, y, mx, 1) or mx + line = Line(axes.c2p(x0, y - dy), axes.c2p(x1, y)) + bar = Rectangle() + bar.set_stroke(RED_E, 0) + bar.set_fill(RED_E, 0.5) + bar.replace(line, stretch=True) + bars.add(bar) + return bars + + r_bar_groups = [] + l_bar_groups = [] + Ns = [10, 20, 40, 80, 160] + Ms = [2, 4, 8, 16, 32] + for N, M in zip(Ns, Ms): + r_bar_groups.append(get_riemann_rects(dx=1 / N)) + l_bar_groups.append(get_lebesgue_bars(dy=1 / M)) self.play( - Write(eq), - FadeInFrom(rhs, UP) + FadeIn(r_bar_groups[0], lag_ratio=0.1), + FadeIn(l_bar_groups[0], lag_ratio=0.1), + FadeIn(footnote), ) self.wait() + for rbg0, rbg1, lbg0, lbg1 in zip(r_bar_groups, r_bar_groups[1:], l_bar_groups, l_bar_groups[1:]): + self.play( + ReplacementTransform( + rbg0, rbg1, + lag_ratio=1 / len(rbg0), + run_time=2, + ), + ReplacementTransform( + lbg0, lbg1, + lag_ratio=1 / len(lbg0), + run_time=2, + ), + ) + self.wait() + self.play( + FadeOut(r_bar_groups[-1]), + FadeOut(l_bar_groups[-1]), + r_graph.set_fill, BLUE_E, 1, + l_graph.set_fill, RED_E, 1, + ) + + +class MeasureTheoryLeadsTo(Scene): + def construct(self): + words = TextMobject("Measure Theory") + words.set_color(RED) + arrow = Vector(DOWN) + arrow.next_to(words, DOWN, buff=SMALL_BUFF) + arrow.set_stroke(width=7) + arrow.rotate(45 * DEGREES, about_point=arrow.get_start()) + self.play( + FadeInFrom(words, DOWN), + GrowArrow(arrow), + UpdateFromAlphaFunc(arrow, lambda m, a: m.set_opacity(a)), + ) + self.wait() + + +class WhenIWasFirstLearning(TeacherStudentsScene): + def construct(self): + self.teacher.change_mode("raise_right_hand") + self.play( + self.get_student_changes("pondering", "thinking", "tease"), + self.teacher.change, "thinking", + ) + + younger = BabyPiCreature(color=GREY_BROWN) + younger.set_height(2) + younger.move_to(self.students, DL) + + self.look_at(self.screen) + self.wait() + self.play( + ReplacementTransform(self.teacher, younger), + LaggedStartMap( + FadeOutAndShift, self.students, + lambda m: (m, DOWN), + ) + ) + + # Bubble + bubble = ThoughtBubble() + bubble[-1].set_fill(GREEN_SCREEN, 1) + bubble.move_to(younger.get_corner(UR), DL) + + self.play( + Write(bubble), + younger.change, "maybe", bubble.get_bubble_center(), + ) + self.play(Blink(younger)) + for mode in ["confused", "angry", "pondering", "maybe"]: + self.play(younger.change, mode) + for x in range(2): + self.wait() + if random.random() < 0.5: + self.play(Blink(younger)) + + +class PossibleYetProbabilityZero(Scene): + def construct(self): + poss = TextMobject("Possible") + prob = TextMobject("Probability = 0") + total = TextMobject("P(dart hits somewhere) = 1") + # total[1].next_to(total[0][0], RIGHT) + words = VGroup(poss, prob, total) + words.scale(1.5) + words.arrange(DOWN, aligned_edge=LEFT, buff=MED_LARGE_BUFF) + + self.play(Write(poss, run_time=0.5)) + self.wait() + self.play(FadeInFrom(prob, UP)) + self.wait() + self.play(FadeInFrom(total, UP)) + self.wait() + + +class TiePossibleToDensity(Scene): + def construct(self): + poss = TextMobject("Possibility") + prob = TextMobject("Probability", " $>$ 0") + dens = TextMobject("Probability \\emph{density}", " $>$ 0") + dens[0].set_color(BLUE) + implies = TexMobject("\\Rightarrow") + implies2 = implies.copy() + + poss.next_to(implies, LEFT) + prob.next_to(implies, RIGHT) + dens.next_to(implies, RIGHT) + cross = Cross(implies) + + self.camera.frame.scale(0.7, about_point=dens.get_center()) + + self.add(poss) + self.play( + FadeInFrom(prob, LEFT), + Write(implies, run_time=1) + ) + self.wait() + self.play(ShowCreation(cross)) + self.wait() + + self.play( + VGroup(implies, cross, prob).shift, UP, + FadeIn(implies2), + FadeIn(dens), + ) + self.wait() + + self.embed() + + +class DrawBigRect(Scene): + def construct(self): + rect = Rectangle(width=7, height=2.5) + rect.set_stroke(RED, 5) + rect.to_edge(RIGHT) + + words = TextMobject("Not how to\\\\think about it") + words.set_color(RED) + words.align_to(rect, LEFT) + words.to_edge(UP) + + arrow = Arrow( + words.get_bottom(), + rect.get_top(), + buff=0.25, + color=RED, + ) + + self.play(ShowCreation(rect)) + self.play( + FadeInFromDown(words), + GrowArrow(arrow), + ) + self.wait() + + +class Thumbnail(Scene): + def construct(self): + dartboard = Dartboard() + axes = NumberPlane( + x_min=-1.25, + x_max=1.25, + y_min=-1.25, + y_max=1.25, + axis_config={ + "unit_size": 0.5 * dartboard.get_width(), + "tick_frequency": 0.25, + }, + x_line_frequency=1.0, + y_line_frequency=1.0, + ) + group = VGroup(dartboard, axes) + group.to_edge(LEFT, buff=0) + + # Arrow + arrow = Vector(DR, max_stroke_width_to_length_ratio=np.inf) + arrow.move_to(axes.c2p(PI / 10, np.exp(1) / 10), DR) + arrow.scale(1.5, about_edge=DR) + arrow.set_stroke(WHITE, 10) + + black_arrow = arrow.copy() + black_arrow.set_color(BLACK) + black_arrow.set_stroke(width=20) + + arrow.points[0] += 0.025 * DR + + # Coords + coords = TexMobject("(x, y) = (0.31415\\dots, 0.27182\\dots)") + coords.set_width(5.5) + coords.set_stroke(BLACK, 10, background=True) + coords.next_to(axes.get_bottom(), UP, buff=0) + + # Words + words = VGroup( + TextMobject("Probability = 0"), + TextMobject("$\\dots$but still possible"), + ) + for word in words: + word.set_width(6) + words.arrange(DOWN, buff=MED_LARGE_BUFF) + words.next_to(axes, RIGHT) + words.to_edge(UP, buff=LARGE_BUFF) + + # Pi + morty = Mortimer() + morty.to_corner(DR) + morty.change("confused", words) + + self.add(group) + self.add(black_arrow) + self.add(arrow) + self.add(coords) + self.add(words) + self.add(morty) + + self.embed() + + +class Part2EndScreen(PatreonEndScreen): + CONFIG = { + "scroll_time": 30, + "specific_patrons": [ + "1stViewMaths", + "Adam Dřínek", + "Aidan Shenkman", + "Alan Stein", + "Albin Egasse", + "Alex Mijalis", + "Alexander Mai", + "Alexis Olson", + "Ali Yahya", + "Andrew Busey", + "Andrew Cary", + "Andrew R. Whalley", + "Anthony Losego", + "Aravind C V", + "Arjun Chakroborty", + "Arthur Zey", + "Ashwin Siddarth", + "Augustine Lim", + "Austin Goodman", + "Avi Finkel", + "Awoo", + "Axel Ericsson", + "Ayan Doss", + "AZsorcerer", + "Barry Fam", + "Ben Delo", + "Bernd Sing", + "Bill Gatliff", + "Bob Sanderson", + "Boris Veselinovich", + "Bradley Pirtle", + "Brandon Huang", + "Brian Staroselsky", + "Britt Selvitelle", + "Britton Finley", + "Burt Humburg", + "Calvin Lin", + "Charles Southerland", + "Charlie N", + "Chenna Kautilya", + "Chris Connett", + "Chris Druta", + "Christian Kaiser", + "cinterloper", + "Clark Gaebel", + "Colwyn Fritze-Moor", + "Cooper Jones", + "Corey Ogburn", + "D. Sivakumar", + "Dan Herbatschek", + "Daniel Brown", + "Daniel Herrera C", + "Darrell Thomas", + "Dave B", + "Dave Kester", + "dave nicponski", + "David B. Hill", + "David Clark", + "David Gow", + "Delton Ding", + "Dominik Wagner", + "Eddie Landesberg", + "Eduardo Rodriguez", + "emptymachine", + "Eric Younge", + "Eryq Ouithaqueue", + "Federico Lebron", + "Fernando Via Canel", + "Frank R. Brown, Jr.", + "Gavin", + "Giovanni Filippi", + "Goodwine", + "Hal Hildebrand", + "Hitoshi Yamauchi", + "Ivan Sorokin", + "Jacob Baxter", + "Jacob Harmon", + "Jacob Hartmann", + "Jacob Magnuson", + "Jalex Stark", + "Jameel Syed", + "James Beall", + "Jason Hise", + "Jayne Gabriele", + "Jean-Manuel Izaret", + "Jeff Dodds", + "Jeff Linse", + "Jeff Straathof", + "Jimmy Yang", + "John C. Vesey", + "John Camp", + "John Haley", + "John Le", + "John Luttig", + "John Rizzo", + "John V Wertheim", + "Jonathan Heckerman", + "Jonathan Wilson", + "Joseph John Cox", + "Joseph Kelly", + "Josh Kinnear", + "Joshua Claeys", + "Joshua Ouellette", + "Juan Benet", + "Kai-Siang Ang", + "Kanan Gill", + "Karl Niu", + "Kartik Cating-Subramanian", + "Kaustuv DeBiswas", + "Killian McGuinness", + "Klaas Moerman", + "Kros Dai", + "L0j1k", + "Lael S Costa", + "LAI Oscar", + "Lambda GPU Workstations", + "Laura Gast", + "Lee Redden", + "Linh Tran", + "Luc Ritchie", + "Ludwig Schubert", + "Lukas Biewald", + "Lukas Zenick", + "Magister Mugit", + "Magnus Dahlström", + "Magnus Hiie", + "Manoj Rewatkar - RITEK SOLUTIONS", + "Mark B Bahu", + "Mark Heising", + "Mark Mann", + "Martin Price", + "Mathias Jansson", + "Matt Godbolt", + "Matt Langford", + "Matt Roveto", + "Matt Russell", + "Matteo Delabre", + "Matthew Bouchard", + "Matthew Cocke", + "Maxim Nitsche", + "Michael Bos", + "Michael Day", + "Michael Hardel", + "Michael W White", + "Mihran Vardanyan", + "Mirik Gogri", + "Molly Mackinlay", + "Mustafa Mahdi", + "Márton Vaitkus", + "Nate Heckmann", + "Nicholas Cahill", + "Nikita Lesnikov", + "Oleg Leonov", + "Omar Zrien", + "Owen Campbell-Moore", + "Patrick Lucas", + "Pavel Dubov", + "Pesho Ivanov", + "Petar Veličković", + "Peter Ehrnstrom", + "Peter Francis", + "Peter Mcinerney", + "Pierre Lancien", + "Pradeep Gollakota", + "Rafael Bove Barrios", + "Randy C. Will", + "rehmi post", + "Rex Godby", + "Ripta Pasay", + "Rish Kundalia", + "Roman Sergeychik", + "Roobie", + "Ryan Atallah", + "Ryan Prayogo", + "Samuel Judge", + "SansWord Huang", + "Scott Gray", + "Scott Walter, Ph.D.", + "soekul", + "Solara570", + "Steve Huynh", + "Steve Muench", + "Steve Sperandeo", + "Steven Siddals", + "Stevie Metke", + "Sunil Nagaraj", + "supershabam", + "Susanne Fenja Mehr-Koks", + "Suteerth Vishnu", + "Suthen Thomas", + "Tal Einav", + "Taras Bobrovytsky", + "Tauba Auerbach", + "Ted Suzman", + "THIS IS THE point OF NO RE tUUurRrhghgGHhhnnn", + "Thomas J Sargent", + "Thomas Tarler", + "Tianyu Ge", + "Tihan Seale", + "Tyler Herrmann", + "Tyler McAtee", + "Tyler VanValkenburg", + "Tyler Veness", + "Vassili Philippov", + "Vasu Dubey", + "Veritasium", + "Vignesh Ganapathi Subramanian", + "Vinicius Reis", + "Vladimir Solomatin", + "Wooyong Ee", + "Xuanji Li", + "Yana Chernobilsky", + "YinYangBalance.Asia", + "Yorick Lesecque", + "Yu Jun", + "Yurii Monastyrshyn", + ], + } diff --git a/from_3b1b/active/bayes/beta3.py b/from_3b1b/active/bayes/beta3.py new file mode 100644 index 00000000..86ccbe13 --- /dev/null +++ b/from_3b1b/active/bayes/beta3.py @@ -0,0 +1,2150 @@ +from manimlib.imports import * +from from_3b1b.active.bayes.beta_helpers import * +from from_3b1b.active.bayes.beta1 import * +from from_3b1b.active.bayes.beta2 import ShowLimitToPdf + +import scipy.stats + +OUTPUT_DIRECTORY = "bayes/beta3" + + +class RemindOfWeightedCoin(Scene): + def construct(self): + # Largely copied from beta2 + + # Prob label + p_label = get_prob_coin_label() + p_label.set_height(0.7) + p_label.to_edge(UP) + + rhs = p_label[-1] + q_box = get_q_box(rhs) + p_label.add(q_box) + + self.add(p_label) + + # Coin grid + def get_random_coin_grid(p): + bools = np.random.random(100) < p + grid = get_coin_grid(bools) + return grid + + grid = get_random_coin_grid(0.5) + grid.next_to(p_label, DOWN, MED_LARGE_BUFF) + + self.play(LaggedStartMap( + FadeIn, grid, + lag_ratio=2 / len(grid), + run_time=3, + )) + self.wait() + + # Label as h + brace = Brace(q_box, DOWN, buff=SMALL_BUFF) + h_label = TexMobject("h") + h_label.next_to(brace, DOWN) + eq = TexMobject("=") + eq.next_to(h_label, RIGHT) + h_decimal = DecimalNumber(0.5) + h_decimal.next_to(eq, RIGHT) + + self.play( + GrowFromCenter(brace), + FadeInFrom(h_label, UP), + grid.scale, 0.8, {"about_edge": DOWN}, + ) + self.wait() + + # Alternate weightings + tail_grid = get_random_coin_grid(0) + head_grid = get_random_coin_grid(1) + grid70 = get_random_coin_grid(0.7) + alt_grids = [tail_grid, head_grid, grid70] + for ag in alt_grids: + ag.replace(grid) + + for coins in [grid, *alt_grids]: + for coin in coins: + coin.generate_target() + coin.target.rotate(90 * DEGREES, axis=UP) + coin.target.set_opacity(0) + + def get_grid_swap_anims(g1, g2): + return [ + LaggedStartMap(MoveToTarget, g1, lag_ratio=0.02, run_time=1.5, remover=True), + LaggedStartMap(MoveToTarget, g2, lag_ratio=0.02, run_time=1.5, rate_func=reverse_smooth), + ] + + self.play( + FadeIn(eq), + UpdateFromAlphaFunc(h_decimal, lambda m, a: m.set_opacity(a)), + ChangeDecimalToValue(h_decimal, 0, run_time=2), + *get_grid_swap_anims(grid, tail_grid) + ) + self.wait() + self.play( + ChangeDecimalToValue(h_decimal, 1, run_time=1.5), + *get_grid_swap_anims(tail_grid, head_grid) + ) + self.wait() + self.play( + ChangeDecimalToValue(h_decimal, 0.7, run_time=1.5), + *get_grid_swap_anims(head_grid, grid70) + ) + self.wait() + + # Graph + axes = scaled_pdf_axes() + axes.to_edge(DOWN, buff=MED_SMALL_BUFF) + axes.y_axis.numbers.set_opacity(0) + axes.y_axis_label.set_opacity(0) + + h_lines = VGroup() + for y in range(15): + h_line = Line(axes.c2p(0, y), axes.c2p(1, y)) + h_lines.add(h_line) + h_lines.set_stroke(WHITE, 0.5, opacity=0.5) + axes.add(h_lines) + + x_axis_label = p_label[:4].copy() + x_axis_label.set_height(0.4) + x_axis_label.next_to(axes.c2p(1, 0), UR, buff=SMALL_BUFF) + axes.x_axis.add(x_axis_label) + + n_heads_tracker = ValueTracker(3) + n_tails_tracker = ValueTracker(3) + + def get_graph(axes=axes, nht=n_heads_tracker, ntt=n_tails_tracker): + dist = scipy.stats.beta(nht.get_value() + 1, ntt.get_value() + 1) + graph = axes.get_graph(dist.pdf, step_size=0.05) + graph.set_stroke(BLUE, 3) + graph.set_fill(BLUE_E, 1) + return graph + + graph = always_redraw(get_graph) + + area_label = TextMobject("Area = 1") + area_label.set_height(0.5) + area_label.move_to(axes.c2p(0.5, 1)) + + # pdf label + pdf_label = TextMobject("probability ", "density ", "function") + pdf_label.next_to(axes.input_to_graph_point(0.5, graph), UP) + pdf_target_template = TextMobject("p", "d", "f") + pdf_target_template.next_to(axes.input_to_graph_point(0.7, graph), UR) + pdf_label.generate_target() + for part, letter2 in zip(pdf_label.target, pdf_target_template): + for letter1 in part: + letter1.move_to(letter2) + part[1:].set_opacity(0) + + # Add plot + self.add(axes, *self.mobjects) + self.play( + FadeOut(eq), + FadeOut(h_decimal), + LaggedStartMap(MoveToTarget, grid70, run_time=1, remover=True), + FadeIn(axes), + ) + self.play( + DrawBorderThenFill(graph), + FadeIn(area_label, rate_func=squish_rate_func(smooth, 0.5, 1), run_time=2), + Write(pdf_label, run_time=1), + ) + self.wait() + + # Region + lh_tracker = ValueTracker(0.7) + rh_tracker = ValueTracker(0.7) + + def get_region(axes=axes, graph=graph, lh_tracker=lh_tracker, rh_tracker=rh_tracker): + lh = lh_tracker.get_value() + rh = rh_tracker.get_value() + region = get_region_under_curve(axes, graph, lh, rh) + region.set_fill(GREY, 0.85) + region.set_stroke(YELLOW, 1) + return region + + region = always_redraw(get_region) + + region_area_label = DecimalNumber(num_decimal_places=3) + region_area_label.next_to(axes.c2p(0.7, 0), UP, MED_LARGE_BUFF) + + def update_ra_label(label, nht=n_heads_tracker, ntt=n_tails_tracker, lht=lh_tracker, rht=rh_tracker): + dist = scipy.stats.beta(nht.get_value() + 1, ntt.get_value() + 1) + area = dist.cdf(rht.get_value()) - dist.cdf(lht.get_value()) + label.set_value(area) + + region_area_label.add_updater(update_ra_label) + + range_label = VGroup( + TexMobject("0.6 \\le"), + p_label[:4].copy(), + TexMobject("\\le 0.8"), + ) + for mob in range_label: + mob.set_height(0.4) + range_label.arrange(RIGHT, buff=SMALL_BUFF) + pp_label = VGroup( + TexMobject("P("), + range_label, + TexMobject(")"), + ) + for mob in pp_label[::2]: + mob.set_height(0.7) + mob.set_color(YELLOW) + pp_label.arrange(RIGHT, buff=SMALL_BUFF) + pp_label.move_to(axes.c2p(0.3, 3)) + + self.play( + FadeIn(pp_label[::2]), + MoveToTarget(pdf_label), + FadeOut(area_label), + ) + self.wait() + self.play(TransformFromCopy(p_label[:4], range_label[1])) + self.wait() + self.play(TransformFromCopy(axes.x_axis.numbers[2], range_label[0])) + self.play(TransformFromCopy(axes.x_axis.numbers[3], range_label[2])) + self.wait() + + self.add(region) + self.play( + lh_tracker.set_value, 0.6, + rh_tracker.set_value, 0.8, + UpdateFromAlphaFunc( + region_area_label, + lambda m, a: m.set_opacity(a), + rate_func=squish_rate_func(smooth, 0.25, 1) + ), + run_time=3, + ) + self.wait() + + # 7/10 heads + bools = [True] * 7 + [False] * 3 + random.shuffle(bools) + coins = VGroup(*[ + get_coin("H" if heads else "T") + for heads in bools + ]) + coins.arrange(RIGHT) + coins.set_height(0.7) + coins.next_to(h_label, DOWN, buff=MED_LARGE_BUFF) + + heads = [c for c in coins if c.symbol == "H"] + numbers = VGroup(*[ + Integer(i + 1).set_height(0.2).next_to(coin, DOWN, SMALL_BUFF) + for i, coin in enumerate(heads) + ]) + + for coin in coins: + coin.save_state() + coin.rotate(90 * DEGREES, UP) + coin.set_opacity(0) + + pp_label.generate_target() + pp_label.target.set_height(0.5) + pp_label.target.next_to(axes.c2p(0, 2), RIGHT, MED_LARGE_BUFF) + + self.play( + LaggedStartMap(Restore, coins), + MoveToTarget(pp_label), + run_time=1, + ) + self.play(ShowIncreasingSubsets(numbers)) + self.wait() + + # Move plot + self.play( + n_heads_tracker.set_value, 7, + n_tails_tracker.set_value, 3, + FadeOut(pdf_label, rate_func=squish_rate_func(smooth, 0, 0.5)), + run_time=2 + ) + self.wait() + + # How does the answer change with more data + new_bools = [True] * 63 + [False] * 27 + random.shuffle(new_bools) + bools = [c.symbol == "H" for c in coins] + new_bools + grid = get_coin_grid(bools) + grid.set_height(3.5) + grid.next_to(axes.c2p(0, 3), RIGHT, MED_LARGE_BUFF) + + self.play( + FadeOut(numbers), + ReplacementTransform(coins, grid[:10]), + ) + self.play( + FadeIn(grid[10:], lag_ratio=0.1, rate_func=linear), + pp_label.next_to, grid, DOWN, + ) + self.wait() + self.add(graph, region, region_area_label, p_label, q_box, brace, h_label) + self.play( + n_heads_tracker.set_value, 70, + n_tails_tracker.set_value, 30, + ) + self.wait() + origin = axes.c2p(0, 0) + self.play( + axes.y_axis.stretch, 0.5, 1, {"about_point": origin}, + h_lines.stretch, 0.5, 1, {"about_point": origin}, + ) + self.wait() + + # Shift the shape around + pairs = [ + (70 * 3, 30 * 3), + (35, 15), + (35 + 20, 15 + 20), + (7, 3), + (70, 30), + ] + for nh, nt in pairs: + self.play( + n_heads_tracker.set_value, nh, + n_tails_tracker.set_value, nt, + run_time=2, + ) + self.wait() + + # End + self.embed() + + +class LastTimeWrapper(Scene): + def construct(self): + fs_rect = FullScreenFadeRectangle(fill_opacity=1, fill_color=GREY_E) + self.add(fs_rect) + + title = TextMobject("Last Time") + title.scale(1.5) + title.to_edge(UP) + + rect = ScreenRectangle() + rect.set_height(6) + rect.set_fill(BLACK, 1) + rect.next_to(title, DOWN) + + self.play( + DrawBorderThenFill(rect), + FadeInFromDown(title), + ) + self.wait() + + +class ComplainAboutSimplisticModel(ExternallyAnimatedScene): + pass + + +class BayesianFrequentistDivide(Scene): + def construct(self): + # Setup Bayesian vs. Frequentist divide + b_label = TextMobject("Bayesian") + f_label = TextMobject("Frequentist") + labels = VGroup(b_label, f_label) + for label, vect in zip(labels, [LEFT, RIGHT]): + label.set_height(0.7) + label.move_to(vect * FRAME_WIDTH / 4) + label.to_edge(UP, buff=0.35) + + h_line = Line(LEFT, RIGHT) + h_line.set_width(FRAME_WIDTH) + h_line.next_to(labels, DOWN) + v_line = Line(UP, DOWN) + v_line.set_height(FRAME_HEIGHT) + v_line.center() + + for label in labels: + label.save_state() + label.set_y(0) + self.play( + FadeInFrom(label, -normalize(label.get_center())), + ) + self.wait() + self.play( + ShowCreation(VGroup(v_line, h_line)), + *map(Restore, labels), + ) + self.wait() + + # Overlay ShowBayesianUpdating in editing + # Frequentist list (ignore?) + kw = { + "tex_to_color_map": { + "$p$-value": YELLOW, + "$H_0$": PINK, + "$\\alpha$": BLUE, + }, + "alignment": "", + } + freq_list = VGroup( + TextMobject("1. State a null hypothesis $H_0$", **kw), + TextMobject("2. Choose a test statistic,\\\\", "$\\qquad$ compute its value", **kw), + TextMobject("3. Calculate a $p$-value", **kw), + TextMobject("4. Choose a significance value $\\alpha$", **kw), + TextMobject("5. Reject $H_0$ if $p$-value\\\\", "$\\qquad$ is less than $\\alpha$", **kw), + ) + + freq_list.set_width(0.5 * FRAME_WIDTH - 1) + freq_list.arrange(DOWN, buff=MED_LARGE_BUFF, aligned_edge=LEFT) + freq_list.move_to(FRAME_WIDTH * RIGHT / 4) + freq_list.to_edge(DOWN, buff=LARGE_BUFF) + + # Frequentist icon + axes = get_beta_dist_axes(y_max=5, y_unit=1) + axes.set_width(0.5 * FRAME_WIDTH - 1) + axes.move_to(FRAME_WIDTH * RIGHT / 4 + DOWN) + + dist = scipy.stats.norm(0.5, 0.1) + graph = axes.get_graph(dist.pdf) + graphs = VGroup() + for x_min, x_max in [(0, 0.3), (0.3, 0.7), (0.7, 1.0)]: + graph = axes.get_graph(dist.pdf, x_min=x_min, x_max=x_max) + graph.add_line_to(axes.c2p(x_max, 0)) + graph.add_line_to(axes.c2p(x_min, 0)) + graph.add_line_to(graph.get_start()) + graphs.add(graph) + + graphs.set_stroke(width=0) + graphs.set_fill(RED, 1) + graphs[1].set_fill(GREY_D, 1) + + H_words = VGroup(*[TextMobject("Reject\\\\$H_0$") for x in range(2)]) + for H_word, graph, vect in zip(H_words, graphs[::2], [RIGHT, LEFT]): + H_word.next_to(graph, UP, MED_LARGE_BUFF) + arrow = Arrow( + H_word.get_bottom(), + graph.get_center() + 0.75 * vect, + buff=SMALL_BUFF + ) + H_word.add(arrow) + + H_words.set_color(RED) + self.add(H_words) + + self.add(axes) + self.add(graphs) + + self.embed() + + # Transition to 2x2 + # Go back to prior + # Label uniform prior + # Talk about real coin prior + # Update ad infinitum + + +class ArgumentBetweenBayesianAndFrequentist(Scene): + def construct(self): + pass + + +# From version 1 +class ShowBayesianUpdating(Scene): + CONFIG = { + "true_p": 0.72, + "random_seed": 4, + "initial_axis_scale_factor": 3.5 + } + + def construct(self): + # Axes + axes = scaled_pdf_axes(self.initial_axis_scale_factor) + self.add(axes) + + # Graph + n_heads = 0 + n_tails = 0 + graph = get_beta_graph(axes, n_heads, n_tails) + self.add(graph) + + # Get coins + true_p = self.true_p + bool_values = np.random.random(100) < true_p + bool_values[1] = True + coins = self.get_coins(bool_values) + coins.next_to(axes.y_axis, RIGHT, MED_LARGE_BUFF) + coins.to_edge(UP, LARGE_BUFF) + + # Probability label + p_label, prob, prob_box = self.get_probability_label() + self.add(p_label) + self.add(prob_box) + + # Slow animations + def head_likelihood(x): + return x + + def tail_likelihood(x): + return 1 - x + + n_previews = 10 + n_slow_previews = 5 + for x in range(n_previews): + coin = coins[x] + is_heads = bool_values[x] + + new_data_label = TextMobject("New data") + new_data_label.set_height(0.3) + arrow = Vector(0.5 * UP) + arrow.next_to(coin, DOWN, SMALL_BUFF) + new_data_label.next_to(arrow, DOWN, SMALL_BUFF) + new_data_label.shift(MED_SMALL_BUFF * RIGHT) + + if is_heads: + line = axes.get_graph(lambda x: x) + label = TexMobject("\\text{Scale by } x") + likelihood = head_likelihood + n_heads += 1 + else: + line = axes.get_graph(lambda x: 1 - x) + label = TexMobject("\\text{Scale by } (1 - x)") + likelihood = tail_likelihood + n_tails += 1 + label.next_to(graph, UP) + label.set_stroke(BLACK, 3, background=True) + line.set_stroke(YELLOW, 3) + + graph_copy = graph.copy() + graph_copy.unlock_triangulation() + scaled_graph = graph.copy() + scaled_graph.apply_function( + lambda p: axes.c2p( + axes.x_axis.p2n(p), + axes.y_axis.p2n(p) * likelihood(axes.x_axis.p2n(p)) + ) + ) + scaled_graph.set_color(GREEN) + + renorm_label = TextMobject("Renormalize") + renorm_label.move_to(label) + + new_graph = get_beta_graph(axes, n_heads, n_tails) + + renormalized_graph = scaled_graph.copy() + renormalized_graph.match_style(graph) + renormalized_graph.match_height(new_graph, stretch=True, about_edge=DOWN) + + if x < n_slow_previews: + self.play( + FadeInFromDown(coin), + FadeIn(new_data_label), + GrowArrow(arrow), + ) + self.play( + FadeOut(new_data_label), + FadeOut(arrow), + ShowCreation(line), + FadeIn(label), + ) + self.add(graph_copy, line, label) + self.play(Transform(graph_copy, scaled_graph)) + self.play( + FadeOut(line), + FadeOut(label), + FadeIn(renorm_label), + ) + self.play( + Transform(graph_copy, renormalized_graph), + FadeOut(graph), + ) + self.play(FadeOut(renorm_label)) + else: + self.add(coin) + graph_copy.become(scaled_graph) + self.add(graph_copy) + self.play( + Transform(graph_copy, renormalized_graph), + FadeOut(graph), + ) + graph = new_graph + self.remove(graph_copy) + self.add(new_graph) + + # Rescale y axis + axes.save_state() + sf = self.initial_axis_scale_factor + axes.y_axis.stretch(1 / sf, 1, about_point=axes.c2p(0, 0)) + for number in axes.y_axis.numbers: + number.stretch(sf, 1) + axes.y_axis.numbers[:4].set_opacity(0) + + self.play( + Restore(axes, rate_func=lambda t: smooth(1 - t)), + graph.stretch, 1 / sf, 1, {"about_edge": DOWN}, + run_time=2, + ) + + # Fast animations + for x in range(n_previews, len(coins)): + coin = coins[x] + is_heads = bool_values[x] + + if is_heads: + n_heads += 1 + else: + n_tails += 1 + new_graph = get_beta_graph(axes, n_heads, n_tails) + + self.add(coins[:x + 1]) + self.add(new_graph) + self.remove(graph) + self.wait(0.25) + # self.play( + # FadeIn(new_graph), + # run_time=0.25, + # ) + # self.play( + # FadeOut(graph), + # run_time=0.25, + # ) + graph = new_graph + + # Show confidence interval + dist = scipy.stats.beta(n_heads + 1, n_tails + 1) + v_lines = VGroup() + labels = VGroup() + x_bounds = dist.interval(0.95) + for x in x_bounds: + line = DashedLine( + axes.c2p(x, 0), + axes.c2p(x, 12), + ) + line.set_color(YELLOW) + v_lines.add(line) + label = DecimalNumber(x) + label.set_height(0.25) + label.next_to(line, UP) + label.match_color(line) + labels.add(label) + + true_graph = axes.get_graph(dist.pdf) + region = get_region_under_curve(axes, true_graph, *x_bounds) + region.set_fill(GREY_BROWN, 0.85) + region.set_stroke(YELLOW, 1) + + label95 = TexMobject("95\\%") + fix_percent(label95.family_members_with_points()[-1]) + label95.move_to(region, DOWN) + label95.shift(0.5 * UP) + + self.play(*map(ShowCreation, v_lines)) + self.play( + FadeIn(region), + Write(label95) + ) + self.wait() + for label in labels: + self.play(FadeInFromDown(label)) + self.wait() + + # Show true value + self.wait() + self.play(FadeOut(prob_box)) + self.play(ShowCreationThenFadeAround(prob)) + self.wait() + + # Much more data + many_bools = np.hstack([ + bool_values, + (np.random.random(1000) < true_p) + ]) + N_tracker = ValueTracker(100) + graph.N_tracker = N_tracker + graph.bools = many_bools + graph.axes = axes + graph.v_lines = v_lines + graph.labels = labels + graph.region = region + graph.label95 = label95 + + label95.width_ratio = label95.get_width() / region.get_width() + + def update_graph(graph): + N = int(graph.N_tracker.get_value()) + nh = sum(graph.bools[:N]) + nt = len(graph.bools[:N]) - nh + new_graph = get_beta_graph(graph.axes, nh, nt, step_size=0.05) + graph.become(new_graph) + + dist = scipy.stats.beta(nh + 1, nt + 1) + x_bounds = dist.interval(0.95) + for x, line, label in zip(x_bounds, graph.v_lines, graph.labels): + line.set_x(graph.axes.c2p(x, 0)[0]) + label.set_x(graph.axes.c2p(x, 0)[0]) + label.set_value(x) + + graph.labels[0].shift(MED_SMALL_BUFF * LEFT) + graph.labels[1].shift(MED_SMALL_BUFF * RIGHT) + + new_simple_graph = graph.axes.get_graph(dist.pdf) + new_region = get_region_under_curve(graph.axes, new_simple_graph, *x_bounds) + new_region.match_style(graph.region) + graph.region.become(new_region) + + graph.label95.set_width(graph.label95.width_ratio * graph.region.get_width()) + graph.label95.match_x(graph.region) + + self.add(graph, region, label95, p_label) + self.play( + N_tracker.set_value, 1000, + UpdateFromFunc(graph, update_graph), + Animation(v_lines), + Animation(labels), + Animation(graph.region), + Animation(graph.label95), + run_time=5, + ) + self.wait() + + # + + def get_coins(self, bool_values): + coins = VGroup(*[ + get_coin("H" if heads else "T") + for heads in bool_values + ]) + coins.arrange_in_grid(n_rows=10, buff=MED_LARGE_BUFF) + coins.set_height(5) + return coins + + def get_probability_label(self): + head = get_coin("H") + p_label = TexMobject( + "P(00) = ", + tex_to_color_map={"00": WHITE} + ) + template = p_label.get_part_by_tex("00") + head.replace(template) + p_label.replace_submobject( + p_label.index_of_part(template), + head, + ) + prob = DecimalNumber(self.true_p) + prob.next_to(p_label, RIGHT) + p_label.add(prob) + p_label.set_height(0.75) + p_label.to_corner(UR) + + prob_box = SurroundingRectangle(prob, buff=SMALL_BUFF) + prob_box.set_fill(GREY_D, 1) + prob_box.set_stroke(WHITE, 2) + + q_marks = TexMobject("???") + q_marks.move_to(prob_box) + prob_box.add(q_marks) + + return p_label, prob, prob_box + + +class HighlightReviewPartsReversed(HighlightReviewParts): + CONFIG = { + "reverse_order": True, + } + + +class Grey(Scene): + def construct(self): + self.add(FullScreenFadeRectangle(fill_color=GREY_D, fill_opacity=1)) + + +class ShowBayesRule(Scene): + def construct(self): + hyp = "\\text{Hypothesis}" + data = "\\text{Data}" + bayes = TexMobject( + f"P({hyp} \\,|\\, {data})", "=", "{", + f"P({data} \\,|\\, {hyp})", f"P({hyp})", + "\\over", f"P({data})", + tex_to_color_map={ + hyp: YELLOW, + data: GREEN, + } + ) + + title = TextMobject("Bayes' rule") + title.scale(2) + title.to_edge(UP) + + self.add(title) + self.add(*bayes[:5]) + self.wait() + self.play( + *[ + TransformFromCopy(bayes[i], bayes[j], path_arc=30 * DEGREES) + for i, j in [ + (0, 7), + (1, 10), + (2, 9), + (3, 8), + (4, 11), + ] + ], + FadeIn(bayes[5]), + run_time=1.5 + ) + self.wait() + self.play( + *[ + TransformFromCopy(bayes[i], bayes[j], path_arc=30 * DEGREES) + for i, j in [ + (0, 12), + (1, 13), + (4, 14), + (0, 16), + (3, 17), + (4, 18), + ] + ], + FadeIn(bayes[15]), + run_time=1.5 + ) + self.add(bayes) + self.wait() + + hyp_word = bayes.get_part_by_tex(hyp) + example_hyp = TextMobject( + "For example,\\\\", + "$0.9 < s < 0.99$", + ) + example_hyp[1].set_color(YELLOW) + example_hyp.next_to(hyp_word, DOWN, buff=1.5) + + data_word = bayes.get_part_by_tex(data) + example_data = TexMobject( + "48\\,", CMARK_TEX, + "\\,2\\,", XMARK_TEX, + ) + example_data.set_color_by_tex(CMARK_TEX, GREEN) + example_data.set_color_by_tex(XMARK_TEX, RED) + example_data.scale(1.5) + example_data.next_to(example_hyp, RIGHT, buff=1.5) + + hyp_arrow = Arrow( + hyp_word.get_bottom(), + example_hyp.get_top(), + ) + data_arrow = Arrow( + data_word.get_bottom(), + example_data.get_top(), + ) + + self.play( + GrowArrow(hyp_arrow), + FadeInFromPoint(example_hyp, hyp_word.get_center()), + ) + self.wait() + self.play( + GrowArrow(data_arrow), + FadeInFromPoint(example_data, data_word.get_center()), + ) + self.wait() + + +class VisualizeBayesRule(Scene): + def construct(self): + self.show_continuum() + self.show_arrows() + self.show_discrete_probabilities() + self.show_bayes_formula() + self.parallel_universes() + self.update_from_data() + + def show_continuum(self): + axes = get_beta_dist_axes(y_max=1, y_unit=0.1) + axes.y_axis.add_numbers( + *np.arange(0.2, 1.2, 0.2), + number_config={ + "num_decimal_places": 1, + } + ) + + p_label = TexMobject( + "P(s \\,|\\, \\text{data})", + tex_to_color_map={ + "s": YELLOW, + "\\text{data}": GREEN, + } + ) + p_label.scale(1.5) + p_label.to_edge(UP, LARGE_BUFF) + + s_part = p_label.get_part_by_tex("s").copy() + x_line = Line(axes.c2p(0, 0), axes.c2p(1, 0)) + x_line.set_stroke(YELLOW, 3) + + arrow = Vector(DOWN) + arrow.next_to(s_part, DOWN, SMALL_BUFF) + value = DecimalNumber(0, num_decimal_places=4) + value.set_color(YELLOW) + value.next_to(arrow, DOWN) + + self.add(axes) + self.add(p_label) + self.play( + s_part.next_to, x_line.get_start(), UR, SMALL_BUFF, + GrowArrow(arrow), + FadeInFromPoint(value, s_part.get_center()), + ) + + s_part.tracked = x_line + value.tracked = x_line + value.x_axis = axes.x_axis + self.play( + ShowCreation(x_line), + UpdateFromFunc( + s_part, + lambda m: m.next_to(m.tracked.get_end(), UR, SMALL_BUFF) + ), + UpdateFromFunc( + value, + lambda m: m.set_value( + m.x_axis.p2n(m.tracked.get_end()) + ) + ), + run_time=3, + ) + self.wait() + self.play( + FadeOut(arrow), + FadeOut(value), + ) + + self.p_label = p_label + self.s_part = s_part + self.value = value + self.x_line = x_line + self.axes = axes + + def show_arrows(self): + axes = self.axes + + arrows = VGroup() + arrow_template = Vector(DOWN) + arrow_template.lock_triangulation() + + def get_arrow(s, denom): + arrow = arrow_template.copy() + arrow.set_height(4 / denom) + arrow.move_to(axes.c2p(s, 0), DOWN) + arrow.set_color(interpolate_color( + GREY_A, GREY_C, random.random() + )) + return arrow + + for k in range(2, 50): + for n in range(1, k): + if np.gcd(n, k) != 1: + continue + s = n / k + arrows.add(get_arrow(s, k)) + for k in range(50, 1000): + arrows.add(get_arrow(1 / k, k)) + arrows.add(get_arrow(1 - 1 / k, k)) + + kw = { + "lag_ratio": 0.5, + "run_time": 5, + "rate_func": lambda t: t**4, + } + arrows.save_state() + for arrow in arrows: + arrow.stretch(0, 0) + arrow.set_stroke(width=0) + arrow.set_opacity(0) + self.play(Restore(arrows, **kw)) + self.play(LaggedStartMap( + ApplyMethod, arrows, + lambda m: (m.scale, 0, {"about_edge": DOWN}), + **kw + )) + self.remove(arrows) + self.wait() + + def show_discrete_probabilities(self): + axes = self.axes + + x_lines = VGroup() + dx = 0.01 + for x in np.arange(0, 1, dx): + line = Line( + axes.c2p(x, 0), + axes.c2p(x + dx, 0), + ) + line.set_stroke(BLUE, 3) + line.generate_target() + line.target.rotate( + 90 * DEGREES, + about_point=line.get_start() + ) + x_lines.add(line) + + self.add(x_lines) + self.play( + FadeOut(self.x_line), + LaggedStartMap( + MoveToTarget, x_lines, + ) + ) + + label = Integer(0) + label.set_height(0.5) + label.next_to(self.p_label[1], DOWN, LARGE_BUFF) + unit = TexMobject("\\%") + unit.match_height(label) + fix_percent(unit.family_members_with_points()[0]) + always(unit.next_to, label, RIGHT, SMALL_BUFF) + + arrow = Arrow() + arrow.max_stroke_width_to_length_ratio = 1 + arrow.axes = axes + arrow.label = label + arrow.add_updater(lambda m: m.put_start_and_end_on( + m.label.get_bottom() + MED_SMALL_BUFF * DOWN, + m.axes.c2p(0.01 * m.label.get_value(), 0.03), + )) + + self.add(label, unit, arrow) + self.play( + ChangeDecimalToValue(label, 99), + run_time=5, + ) + self.wait() + self.play(*map(FadeOut, [label, unit, arrow])) + + # Show prior label + p_label = self.p_label + given_data = p_label[2:4] + prior_label = TexMobject("P(s)", tex_to_color_map={"s": YELLOW}) + prior_label.match_height(p_label) + prior_label.move_to(p_label, DOWN, LARGE_BUFF) + + p_label.save_state() + self.play( + given_data.scale, 0.5, + given_data.set_opacity, 0.5, + given_data.to_corner, UR, + Transform(p_label[:2], prior_label[:2]), + Transform(p_label[-1], prior_label[-1]), + ) + self.wait() + + # Zoom in on the y-values + new_ticks = VGroup() + new_labels = VGroup() + dy = 0.01 + for y in np.arange(dy, 5 * dy, dy): + height = get_norm(axes.c2p(0, dy) - axes.c2p(0, 0)) + tick = axes.y_axis.get_tick(y, SMALL_BUFF) + label = DecimalNumber(y) + label.match_height(axes.y_axis.numbers[0]) + always(label.next_to, tick, LEFT, SMALL_BUFF) + + new_ticks.add(tick) + new_labels.add(label) + + for num in axes.y_axis.numbers: + height = num.get_height() + always(num.set_height, height, stretch=True) + + bars = VGroup() + dx = 0.01 + origin = axes.c2p(0, 0) + for x in np.arange(0, 1, dx): + rect = Rectangle( + width=get_norm(axes.c2p(dx, 0) - origin), + height=get_norm(axes.c2p(0, dy) - origin), + ) + rect.x = x + rect.set_stroke(BLUE, 1) + rect.set_fill(BLUE, 0.5) + rect.move_to(axes.c2p(x, 0), DL) + bars.add(rect) + + stretch_group = VGroup( + axes.y_axis, + bars, + new_ticks, + x_lines, + ) + x_lines.set_height( + bars.get_height(), + about_edge=DOWN, + stretch=True, + ) + + self.play( + stretch_group.stretch, 25, 1, {"about_point": axes.c2p(0, 0)}, + VFadeIn(bars), + VFadeIn(new_ticks), + VFadeIn(new_labels), + VFadeOut(x_lines), + run_time=4, + ) + + highlighted_bars = bars.copy() + highlighted_bars.set_color(YELLOW) + self.play( + LaggedStartMap( + FadeIn, highlighted_bars, + lag_ratio=0.5, + rate_func=there_and_back, + ), + ShowCreationThenFadeAround(new_labels[0]), + run_time=3, + ) + self.remove(highlighted_bars) + + # Nmae as prior + prior_name = TextMobject("Prior", " distribution") + prior_name.set_height(0.6) + prior_name.next_to(prior_label, DOWN, LARGE_BUFF) + + self.play(FadeInFromDown(prior_name)) + self.wait() + + # Show alternate distribution + bars.save_state() + for a, b in [(5, 2), (1, 6)]: + dist = scipy.stats.beta(a, b) + for bar, saved in zip(bars, bars.saved_state): + bar.target = saved.copy() + height = get_norm(axes.c2p(0.1 * dist.pdf(bar.x)) - axes.c2p(0, 0)) + bar.target.set_height(height, about_edge=DOWN, stretch=True) + + self.play(LaggedStartMap(MoveToTarget, bars, lag_ratio=0.00)) + self.wait() + self.play(Restore(bars)) + self.wait() + + uniform_name = TextMobject("Uniform") + uniform_name.match_height(prior_name) + uniform_name.move_to(prior_name, DL) + uniform_name.shift(RIGHT) + uniform_name.set_y(bars.get_top()[1] + MED_SMALL_BUFF, DOWN) + self.play( + prior_name[0].next_to, uniform_name, RIGHT, MED_SMALL_BUFF, DOWN, + FadeOutAndShift(prior_name[1], RIGHT), + FadeInFrom(uniform_name, LEFT) + ) + self.wait() + + self.bars = bars + self.uniform_label = VGroup(uniform_name, prior_name[0]) + + def show_bayes_formula(self): + uniform_label = self.uniform_label + p_label = self.p_label + bars = self.bars + + prior_label = VGroup( + p_label[0].deepcopy(), + p_label[1].deepcopy(), + p_label[4].deepcopy(), + ) + eq = TexMobject("=") + likelihood_label = TexMobject( + "P(", "\\text{data}", "|", "s", ")", + ) + likelihood_label.set_color_by_tex("data", GREEN) + likelihood_label.set_color_by_tex("s", YELLOW) + over = Line(LEFT, RIGHT) + p_data_label = TextMobject("P(", "\\text{data}", ")") + p_data_label.set_color_by_tex("data", GREEN) + + for mob in [eq, likelihood_label, over, p_data_label]: + mob.scale(1.5) + mob.set_opacity(0.1) + + eq.move_to(prior_label, LEFT) + over.set_width( + prior_label.get_width() + + likelihood_label.get_width() + + MED_SMALL_BUFF + ) + over.next_to(eq, RIGHT, MED_SMALL_BUFF) + p_data_label.next_to(over, DOWN, MED_SMALL_BUFF) + likelihood_label.next_to(over, UP, MED_SMALL_BUFF, RIGHT) + + self.play( + p_label.restore, + p_label.next_to, eq, LEFT, MED_SMALL_BUFF, + prior_label.next_to, over, UP, MED_SMALL_BUFF, LEFT, + FadeIn(eq), + FadeIn(likelihood_label), + FadeIn(over), + FadeIn(p_data_label), + FadeOut(uniform_label), + ) + + # Show new distribution + post_bars = bars.copy() + total_prob = 0 + for bar, p in zip(post_bars, np.arange(0, 1, 0.01)): + prob = scipy.stats.binom(50, p).pmf(48) + bar.stretch(prob, 1, about_edge=DOWN) + total_prob += 0.01 * prob + post_bars.stretch(1 / total_prob, 1, about_edge=DOWN) + post_bars.stretch(0.25, 1, about_edge=DOWN) # Lie to fit on screen... + post_bars.set_color(MAROON_D) + post_bars.set_fill(opacity=0.8) + + brace = Brace(p_label, DOWN) + post_word = brace.get_text("Posterior") + post_word.scale(1.25, about_edge=UP) + post_word.set_color(MAROON_D) + + self.play( + ReplacementTransform( + bars.copy().set_opacity(0), + post_bars, + ), + GrowFromCenter(brace), + FadeInFrom(post_word, 0.25 * UP) + ) + self.wait() + self.play( + eq.set_opacity, 1, + likelihood_label.set_opacity, 1, + ) + self.wait() + + data = get_check_count_label(48, 2) + data.scale(1.5) + data.next_to(likelihood_label, DOWN, buff=2, aligned_edge=LEFT) + data_arrow = Arrow( + likelihood_label[1].get_bottom(), + data.get_top() + ) + data_arrow.set_color(GREEN) + + self.play( + GrowArrow(data_arrow), + GrowFromPoint(data, data_arrow.get_start()), + ) + self.wait() + self.play(FadeOut(data_arrow)) + self.play( + over.set_opacity, 1, + p_data_label.set_opacity, 1, + ) + self.wait() + + self.play( + FadeOut(brace), + FadeOut(post_word), + FadeOut(post_bars), + FadeOut(data), + p_label.set_opacity, 0.1, + eq.set_opacity, 0.1, + likelihood_label.set_opacity, 0.1, + over.set_opacity, 0.1, + p_data_label.set_opacity, 0.1, + ) + + self.bayes = VGroup( + p_label, eq, + prior_label, likelihood_label, + over, p_data_label + ) + self.data = data + + def parallel_universes(self): + bars = self.bars + + cols = VGroup() + squares = VGroup() + sample_colors = color_gradient( + [GREEN_C, GREEN_D, GREEN_E], + 100 + ) + for bar in bars: + n_rows = 12 + col = VGroup() + for x in range(n_rows): + square = Rectangle( + width=bar.get_width(), + height=bar.get_height() / n_rows, + ) + square.set_stroke(width=0) + square.set_fill(opacity=1) + square.set_color(random.choice(sample_colors)) + col.add(square) + squares.add(square) + col.arrange(DOWN, buff=0) + col.move_to(bar) + cols.add(col) + squares.shuffle() + + self.play( + LaggedStartMap( + VFadeInThenOut, squares, + lag_ratio=0.005, + run_time=3 + ) + ) + self.remove(squares) + squares.set_opacity(1) + self.wait() + + example_col = cols[95] + + self.play( + bars.set_opacity, 0.25, + FadeIn(example_col, lag_ratio=0.1), + ) + self.wait() + + dist = scipy.stats.binom(50, 0.95) + for x in range(12): + square = random.choice(example_col).copy() + square.set_fill(opacity=0) + square.set_stroke(YELLOW, 2) + self.add(square) + nc = dist.ppf(random.random()) + data = get_check_count_label(nc, 50 - nc) + data.next_to(example_col, UP) + + self.add(square, data) + self.wait(0.5) + self.remove(square, data) + self.wait() + + self.data.set_opacity(1) + self.play( + FadeIn(self.data), + FadeOut(example_col), + self.bayes[3].set_opacity, 1, + ) + self.wait() + + def update_from_data(self): + bars = self.bars + data = self.data + bayes = self.bayes + + new_bars = bars.copy() + new_bars.set_stroke(opacity=1) + new_bars.set_fill(opacity=0.8) + for bar, p in zip(new_bars, np.arange(0, 1, 0.01)): + dist = scipy.stats.binom(50, p) + scalar = dist.pmf(48) + bar.stretch(scalar, 1, about_edge=DOWN) + + self.play( + ReplacementTransform( + bars.copy().set_opacity(0), + new_bars + ), + bars.set_fill, {"opacity": 0.1}, + bars.set_stroke, {"opacity": 0.1}, + run_time=2, + ) + + # Show example bar + bar95 = VGroup( + bars[95].copy(), + new_bars[95].copy() + ) + bar95.save_state() + bar95.generate_target() + bar95.target.scale(2) + bar95.target.next_to(bar95, UP, LARGE_BUFF) + bar95.target.set_stroke(BLUE, 3) + + ex_label = TexMobject("s", "=", "0.95") + ex_label.set_color(YELLOW) + ex_label.next_to(bar95.target, DOWN, submobject_to_align=ex_label[-1]) + + highlight = SurroundingRectangle(bar95, buff=0) + highlight.set_stroke(YELLOW, 2) + + self.play(FadeIn(highlight)) + self.play( + MoveToTarget(bar95), + FadeInFromDown(ex_label), + data.shift, LEFT, + ) + self.wait() + + side_brace = Brace(bar95[1], RIGHT, buff=SMALL_BUFF) + side_label = side_brace.get_text("0.26", buff=SMALL_BUFF) + self.play( + GrowFromCenter(side_brace), + FadeIn(side_label) + ) + self.wait() + self.play( + FadeOut(side_brace), + FadeOut(side_label), + FadeOut(ex_label), + ) + self.play( + bar95.restore, + bar95.set_opacity, 0, + ) + + for bar in bars[94:80:-1]: + highlight.move_to(bar) + self.wait(0.5) + self.play(FadeOut(highlight)) + self.wait() + + # Emphasize formula terms + tops = VGroup() + for bar, new_bar in zip(bars, new_bars): + top = Line(bar.get_corner(UL), bar.get_corner(UR)) + top.set_stroke(YELLOW, 2) + top.generate_target() + top.target.move_to(new_bar, UP) + tops.add(top) + + rect = SurroundingRectangle(bayes[2]) + rect.set_stroke(YELLOW, 1) + rect.target = SurroundingRectangle(bayes[3]) + rect.target.match_style(rect) + self.play( + ShowCreation(rect), + ShowCreation(tops), + ) + self.wait() + self.play( + LaggedStartMap( + MoveToTarget, tops, + run_time=2, + lag_ratio=0.02, + ), + MoveToTarget(rect), + ) + self.play(FadeOut(tops)) + self.wait() + + # Show alternate priors + axes = self.axes + bar_groups = VGroup() + for bar, new_bar in zip(bars, new_bars): + bar_groups.add(VGroup(bar, new_bar)) + + bar_groups.save_state() + for a, b in [(5, 2), (7, 1)]: + dist = scipy.stats.beta(a, b) + for bar, saved in zip(bar_groups, bar_groups.saved_state): + bar.target = saved.copy() + height = get_norm(axes.c2p(0.1 * dist.pdf(bar[0].x)) - axes.c2p(0, 0)) + height = max(height, 1e-6) + bar.target.set_height(height, about_edge=DOWN, stretch=True) + + self.play(LaggedStartMap(MoveToTarget, bar_groups, lag_ratio=0)) + self.wait() + self.play(Restore(bar_groups)) + self.wait() + + # Rescale + ex_p_label = TexMobject( + "P(s = 0.95 | 00000000) = ", + tex_to_color_map={ + "s = 0.95": YELLOW, + "00000000": WHITE, + } + ) + ex_p_label.scale(1.5) + ex_p_label.next_to(bars, UP, LARGE_BUFF) + ex_p_label.align_to(bayes, LEFT) + template = ex_p_label.get_part_by_tex("00000000") + template.set_opacity(0) + + highlight = SurroundingRectangle(new_bars[95], buff=0) + highlight.set_stroke(YELLOW, 1) + + self.remove(data) + self.play( + FadeIn(ex_p_label), + VFadeOut(data[0]), + data[1:].move_to, template, + FadeIn(highlight) + ) + self.wait() + + numer = new_bars[95].copy() + numer.set_stroke(YELLOW, 1) + denom = new_bars[80:].copy() + h_line = Line(LEFT, RIGHT) + h_line.set_width(3) + h_line.set_stroke(width=2) + h_line.next_to(ex_p_label, RIGHT) + + self.play( + numer.next_to, h_line, UP, + denom.next_to, h_line, DOWN, + ShowCreation(h_line), + ) + self.wait() + self.play( + denom.space_out_submobjects, + rate_func=there_and_back + ) + self.play( + bayes[4].set_opacity, 1, + bayes[5].set_opacity, 1, + FadeOut(rect), + ) + self.wait() + + # Rescale + self.play( + FadeOut(highlight), + FadeOut(ex_p_label), + FadeOut(data), + FadeOut(h_line), + FadeOut(numer), + FadeOut(denom), + bayes.set_opacity, 1, + ) + + new_bars.unlock_shader_data() + self.remove(new_bars, *new_bars) + self.play( + new_bars.set_height, 5, {"about_edge": DOWN, "stretch": True}, + new_bars.set_color, MAROON_D, + ) + self.wait() + + +class UniverseOf95Percent(WhatsTheModel): + CONFIG = {"s": 0.95} + + def construct(self): + self.introduce_buyer_and_seller() + for m, v in [(self.seller, RIGHT), (self.buyer, LEFT)]: + m.shift(v) + m.label.shift(v) + + pis = VGroup(self.seller, self.buyer) + label = get_prob_positive_experience_label(True, True) + label[-1].set_value(self.s) + label.set_height(1) + label.next_to(pis, UP, LARGE_BUFF) + self.add(label) + + for x in range(4): + self.play(*self.experience_animations( + self.seller, self.buyer, arc=30 * DEGREES, p=self.s + )) + + self.embed() + + +class UniverseOf50Percent(UniverseOf95Percent): + CONFIG = {"s": 0.5} + + +class OpenAndCloseAsideOnPdfs(Scene): + def construct(self): + labels = VGroup( + TextMobject("$\\langle$", "Aside on", " pdfs", "$\\rangle$"), + TextMobject("$\\langle$/", "Aside on", " pdfs", "$\\rangle$"), + ) + labels.set_width(FRAME_WIDTH / 2) + for label in labels: + label.set_color_by_tex("pdfs", YELLOW) + + self.play(FadeInFromDown(labels[0])) + self.wait() + self.play(Transform(*labels)) + self.wait() + + +class BayesRuleWithPdf(ShowLimitToPdf): + def construct(self): + # Axes + axes = self.get_axes() + sf = 1.5 + axes.y_axis.stretch(sf, 1, about_point=axes.c2p(0, 0)) + for number in axes.y_axis.numbers: + number.stretch(1 / sf, 1) + self.add(axes) + + # Formula + bayes = self.get_formula() + + post = bayes[:5] + eq = bayes[5] + prior = bayes[6:9] + likelihood = bayes[9:14] + over = bayes[14] + p_data = bayes[15:] + + self.play(FadeInFromDown(bayes)) + self.wait() + + # Prior + prior_graph = get_beta_graph(axes, 0, 0) + prior_graph_top = Line( + prior_graph.get_corner(UL), + prior_graph.get_corner(UR), + ) + prior_graph_top.set_stroke(YELLOW, 3) + + bayes.save_state() + bayes.set_opacity(0.2) + prior.set_opacity(1) + + self.play( + Restore(bayes, rate_func=reverse_smooth), + FadeIn(prior_graph), + ShowCreation(prior_graph_top), + ) + self.play(FadeOut(prior_graph_top)) + self.wait() + + # Scale Down + nh = 1 + nt = 2 + + scaled_graph = axes.get_graph( + lambda x: scipy.stats.binom(3, x).pmf(1) + 1e-6 + ) + scaled_graph.set_stroke(GREEN) + scaled_region = get_region_under_curve(axes, scaled_graph, 0, 1) + + def to_uniform(p, axes=axes): + return axes.c2p( + axes.x_axis.p2n(p), + int(axes.y_axis.p2n(p) != 0), + ) + + scaled_region.set_fill(opacity=0.75) + scaled_region.save_state() + scaled_region.apply_function(to_uniform) + + self.play( + Restore(scaled_region), + UpdateFromAlphaFunc( + scaled_region, + lambda m, a: m.set_opacity(a * 0.75), + ), + likelihood.set_opacity, 1, + ) + self.wait() + + # Rescale + new_graph = get_beta_graph(axes, nh, nt) + self.play( + ApplyMethod( + scaled_region.set_height, new_graph.get_height(), + {"about_edge": DOWN, "stretch": True}, + run_time=2, + ), + over.set_opacity, 1, + p_data.set_opacity, 1, + ) + self.wait() + self.play( + post.set_opacity, 1, + eq.set_opacity, 1, + ) + self.wait() + + # Use lower case + new_bayes = self.get_formula(lowercase=True) + new_bayes.replace(bayes, dim_to_match=0) + rects = VGroup( + SurroundingRectangle(new_bayes[0][0]), + SurroundingRectangle(new_bayes[6][0]), + ) + rects.set_stroke(YELLOW, 3) + + self.remove(bayes) + bayes = self.get_formula() + bayes.unlock_triangulation() + self.add(bayes) + self.play(Transform(bayes, new_bayes)) + self.play(ShowCreationThenFadeOut(rects)) + + def get_formula(self, lowercase=False): + p_sym = "p" if lowercase else "P" + bayes = TexMobject( + p_sym + "({s} \\,|\\, \\text{data})", "=", + "{" + p_sym + "({s})", + "P(\\text{data} \\,|\\, {s})", + "\\over", + "P(\\text{data})", + tex_to_color_map={ + "{s}": YELLOW, + "\\text{data}": GREEN, + } + ) + bayes.set_height(1.5) + bayes.to_edge(UP) + return bayes + + +class TalkThroughCoinExample(ShowBayesianUpdating): + def construct(self): + # Setup + axes = self.get_axes() + x_label = TexMobject("x") + x_label.next_to(axes.x_axis.get_end(), UR, MED_SMALL_BUFF) + axes.add(x_label) + + p_label, prob, prob_box = self.get_probability_label() + prob_box_x = x_label.copy().move_to(prob_box) + + self.add(axes) + self.add(p_label) + self.add(prob_box) + + self.wait() + q_marks = prob_box[1] + prob_box.remove(q_marks) + self.play( + FadeOut(q_marks), + TransformFromCopy(x_label, prob_box_x) + ) + prob_box.add(prob_box_x) + + # Setup coins + bool_values = (np.random.random(100) < self.true_p) + bool_values[:5] = [True, False, True, True, False] + coins = self.get_coins(bool_values) + coins.next_to(axes.y_axis, RIGHT, MED_LARGE_BUFF) + coins.to_edge(UP) + + # Random coin + rows = VGroup() + for x in range(5): + row = self.get_coins(np.random.random(10) < self.true_p) + row.arrange(RIGHT, buff=MED_LARGE_BUFF) + row.set_width(6) + row.move_to(UP) + rows.add(row) + + last_row = VMobject() + for row in rows: + self.play( + FadeOutAndShift(last_row, DOWN), + FadeIn(row, lag_ratio=0.1) + ) + last_row = row + self.play(FadeOutAndShift(last_row, DOWN)) + + # Uniform pdf + region = get_beta_graph(axes, 0, 0) + graph = Line( + region.get_corner(UL), + region.get_corner(UR), + ) + func_label = TexMobject("f(x) =", "1") + func_label.next_to(graph, UP) + + self.play( + FadeIn(func_label, lag_ratio=0.1), + ShowCreation(graph), + ) + self.add(region, graph) + self.play(FadeIn(region)) + self.wait() + + # First flip + coin = coins[0] + arrow = Vector(0.5 * UP) + arrow.next_to(coin, DOWN, SMALL_BUFF) + data_label = TextMobject("New data") + data_label.set_height(0.25) + data_label.next_to(arrow, DOWN) + data_label.shift(0.5 * RIGHT) + + self.play( + FadeInFrom(coin, DOWN), + GrowArrow(arrow), + Write(data_label, run_time=1) + ) + self.wait() + + # Show Bayes rule + bayes = TexMobject( + "p({x} | \\text{data})", "=", + "p({x})", + "{P(\\text{data} | {x})", + "\\over", + "P(\\text{data})", + tex_to_color_map={ + "{x}": WHITE, + "\\text{data}": GREEN, + } + ) + bayes.next_to(func_label, UP, LARGE_BUFF, LEFT) + + likelihood = bayes[9:14] + p_data = bayes[15:] + likelihood_rect = SurroundingRectangle(likelihood, buff=0.05) + likelihood_rect.save_state() + p_data_rect = SurroundingRectangle(p_data, buff=0.05) + + likelihood_x_label = TexMobject("x") + likelihood_x_label.next_to(likelihood_rect, UP) + + self.play(FadeInFromDown(bayes)) + self.wait() + self.play(ShowCreation(likelihood_rect)) + self.wait() + + self.play(TransformFromCopy(likelihood[-2], likelihood_x_label)) + self.wait() + + # Scale by x + times_x = TexMobject("\\cdot \\, x") + times_x.next_to(func_label, RIGHT, buff=0.2) + + new_graph = axes.get_graph(lambda x: x) + sub_region = get_region_under_curve(axes, new_graph, 0, 1) + + self.play( + Write(times_x), + Transform(graph, new_graph), + ) + self.play( + region.set_opacity, 0.5, + FadeIn(sub_region), + ) + self.wait() + + # Show example scalings + low_x = 0.1 + high_x = 0.9 + lines = VGroup() + for x in [low_x, high_x]: + lines.add(Line(axes.c2p(x, 0), axes.c2p(x, 1))) + + lines.set_stroke(YELLOW, 3) + + for x, line in zip([low_x, high_x], lines): + self.play(FadeIn(line)) + self.play(line.scale, x, {"about_edge": DOWN}) + self.wait() + self.play(FadeOut(lines)) + + # Renormalize + self.play( + FadeOut(likelihood_x_label), + ReplacementTransform(likelihood_rect, p_data_rect), + ) + self.wait() + + one = func_label[1] + two = TexMobject("2") + two.move_to(one, LEFT) + + self.play( + FadeOut(region), + sub_region.stretch, 2, 1, {"about_edge": DOWN}, + sub_region.set_color, BLUE, + graph.stretch, 2, 1, {"about_edge": DOWN}, + FadeInFromDown(two), + FadeOutAndShift(one, UP), + ) + region = sub_region + func_label = VGroup(func_label[0], two, times_x) + self.add(func_label) + + self.play(func_label.shift, 0.5 * UP) + self.wait() + + const = TexMobject("C") + const.scale(0.9) + const.move_to(two, DR) + const.shift(0.07 * RIGHT) + self.play( + FadeOutAndShift(two, UP), + FadeInFrom(const, DOWN) + ) + self.remove(func_label) + func_label = VGroup(func_label[0], const, times_x) + self.add(func_label) + self.play(FadeOut(p_data_rect)) + self.wait() + + # Show tails + coin = coins[1] + self.play( + arrow.next_to, coin, DOWN, SMALL_BUFF, + MaintainPositionRelativeTo(data_label, arrow), + FadeInFromDown(coin), + ) + self.wait() + + to_prior_arrow = Arrow( + func_label[0][3], + bayes[6], + max_tip_length_to_length_ratio=0.15, + stroke_width=3, + ) + to_prior_arrow.set_color(RED) + + self.play(Indicate(func_label, scale_factor=1.2, color=RED)) + self.play(ShowCreation(to_prior_arrow)) + self.wait() + self.play(FadeOut(to_prior_arrow)) + + # Scale by (1 - x) + eq_1mx = TexMobject("(1 - x)") + dot = TexMobject("\\cdot") + rhs_part = VGroup(dot, eq_1mx) + rhs_part.arrange(RIGHT, buff=0.2) + rhs_part.move_to(func_label, RIGHT) + + l_1mx = eq_1mx.copy() + likelihood_rect.restore() + l_1mx.next_to(likelihood_rect, UP, SMALL_BUFF) + + self.play( + ShowCreation(likelihood_rect), + FadeInFrom(l_1mx, 0.5 * DOWN), + ) + self.wait() + self.play(ShowCreationThenFadeOut(Underline(p_label))) + self.play(Indicate(coins[1])) + self.wait() + self.play( + TransformFromCopy(l_1mx, eq_1mx), + FadeInFrom(dot, RIGHT), + func_label.next_to, dot, LEFT, 0.2, + ) + + scaled_graph = axes.get_graph(lambda x: 2 * x * (1 - x)) + scaled_region = get_region_under_curve(axes, scaled_graph, 0, 1) + + self.play(Transform(graph, scaled_graph)) + self.play(FadeIn(scaled_region)) + self.wait() + + # Renormalize + self.remove(likelihood_rect) + self.play( + TransformFromCopy(likelihood_rect, p_data_rect), + FadeOut(l_1mx) + ) + new_graph = get_beta_graph(axes, 1, 1) + group = VGroup(graph, scaled_region) + self.play( + group.set_height, + new_graph.get_height(), {"about_edge": DOWN, "stretch": True}, + group.set_color, BLUE, + FadeOut(region), + ) + region = scaled_region + self.play(FadeOut(p_data_rect)) + self.wait() + self.play(ShowCreationThenFadeAround(const)) + + # Repeat + exp1 = Integer(1) + exp1.set_height(0.2) + exp1.move_to(func_label[2].get_corner(UR), DL) + exp1.shift(0.02 * DOWN + 0.07 * RIGHT) + + exp2 = exp1.copy() + exp2.move_to(eq_1mx.get_corner(UR), DL) + exp2.shift(0.1 * RIGHT) + exp2.align_to(exp1, DOWN) + + shift_vect = UP + 0.5 * LEFT + VGroup(exp1, exp2).shift(shift_vect) + + self.play( + FadeInFrom(exp1, DOWN), + FadeInFrom(exp2, DOWN), + VGroup(func_label, dot, eq_1mx).shift, shift_vect, + bayes.scale, 0.5, + bayes.next_to, p_label, DOWN, LARGE_BUFF, {"aligned_edge": RIGHT}, + ) + nh = 1 + nt = 1 + for coin, is_heads in zip(coins[2:10], bool_values[2:10]): + self.play( + arrow.next_to, coin, DOWN, SMALL_BUFF, + MaintainPositionRelativeTo(data_label, arrow), + FadeInFrom(coin, DOWN), + ) + if is_heads: + nh += 1 + old_exp = exp1 + else: + nt += 1 + old_exp = exp2 + + new_exp = old_exp.copy() + new_exp.increment_value(1) + + dist = scipy.stats.beta(nh + 1, nt + 1) + new_graph = axes.get_graph(dist.pdf) + new_region = get_region_under_curve(axes, new_graph, 0, 1) + new_region.match_style(region) + + self.play( + FadeOut(graph), + FadeOut(region), + FadeIn(new_graph), + FadeIn(new_region), + FadeOutAndShift(old_exp, MED_SMALL_BUFF * UP), + FadeInFrom(new_exp, MED_SMALL_BUFF * DOWN), + ) + graph = new_graph + region = new_region + self.remove(new_exp) + self.add(old_exp) + old_exp.increment_value() + self.wait() + + if coin is coins[4]: + area_label = TextMobject("Area = 1") + area_label.move_to(axes.c2p(0.6, 0.8)) + self.play(GrowFromPoint( + area_label, const.get_center() + )) + + +class PDefectEqualsQmark(Scene): + def construct(self): + label = TexMobject( + "P(\\text{Defect}) = ???", + tex_to_color_map={ + "\\text{Defect}": RED, + } + ) + self.play(FadeInFrom(label, DOWN)) + self.wait() + + +class UpdateOnceWithBinomial(TalkThroughCoinExample): + def construct(self): + # Fair bit of copy-pasting from above. If there's + # time, refactor this properly + # Setup + axes = self.get_axes() + x_label = TexMobject("x") + x_label.next_to(axes.x_axis.get_end(), UR, MED_SMALL_BUFF) + axes.add(x_label) + + p_label, prob, prob_box = self.get_probability_label() + prob_box_x = x_label.copy().move_to(prob_box) + + q_marks = prob_box[1] + prob_box.remove(q_marks) + prob_box.add(prob_box_x) + + self.add(axes) + self.add(p_label) + self.add(prob_box) + + # Coins + bool_values = (np.random.random(100) < self.true_p) + bool_values[:5] = [True, False, True, True, False] + coins = self.get_coins(bool_values) + coins.next_to(axes.y_axis, RIGHT, MED_LARGE_BUFF) + coins.to_edge(UP) + self.add(coins[:10]) + + # Uniform pdf + region = get_beta_graph(axes, 0, 0) + graph = axes.get_graph( + lambda x: 1, + min_samples=30, + ) + self.add(region, graph) + + # Show Bayes rule + bayes = TexMobject( + "p({x} | \\text{data})", "=", + "p({x})", + "{P(\\text{data} | {x})", + "\\over", + "P(\\text{data})", + tex_to_color_map={ + "{x}": WHITE, + "\\text{data}": GREEN, + } + ) + bayes.move_to(axes.c2p(0, 2.5)) + bayes.align_to(coins, LEFT) + + likelihood = bayes[9:14] + # likelihood_rect = SurroundingRectangle(likelihood, buff=0.05) + + self.add(bayes) + + # All data at once + brace = Brace(coins[:10], DOWN) + all_data_label = brace.get_text("One update from all data") + + self.wait() + self.play( + GrowFromCenter(brace), + FadeInFrom(all_data_label, 0.2 * UP), + ) + self.wait() + + # Binomial formula + nh = sum(bool_values[:10]) + nt = sum(~bool_values[:10]) + + likelihood_brace = Brace(likelihood, UP) + t2c = { + str(nh): BLUE, + str(nt): RED, + } + binom_formula = TexMobject( + "{10 \\choose ", str(nh), "}", + "x^{", str(nh), "}", + "(1-x)^{" + str(nt) + "}", + tex_to_color_map=t2c, + ) + binom_formula[0][-1].set_color(BLUE) + binom_formula[1].set_color(WHITE) + binom_formula.set_width(likelihood_brace.get_width() + 0.5) + binom_formula.next_to(likelihood_brace, UP) + + self.play( + TransformFromCopy(brace, likelihood_brace), + FadeOut(all_data_label), + FadeIn(binom_formula) + ) + self.wait() + + # New plot + rhs = TexMobject( + "C \\cdot", + "x^{", str(nh), "}", + "(1-x)^{", str(nt), "}", + tex_to_color_map=t2c + ) + rhs.next_to(bayes[:5], DOWN, LARGE_BUFF, aligned_edge=LEFT) + eq = TexMobject("=") + eq.rotate(90 * DEGREES) + eq.next_to(bayes[:5], DOWN, buff=0.35) + + dist = scipy.stats.beta(nh + 1, nt + 1) + new_graph = axes.get_graph(dist.pdf) + new_graph.shift(1e-6 * UP) + new_graph.set_stroke(WHITE, 1, opacity=0.5) + new_region = get_region_under_curve(axes, new_graph, 0, 1) + new_region.match_style(region) + new_region.set_opacity(0.75) + + self.add(new_region, new_graph, bayes) + region.unlock_triangulation() + self.play( + FadeOut(graph), + FadeOut(region), + FadeIn(new_graph), + FadeIn(new_region), + run_time=1, + ) + self.play( + Write(eq), + FadeInFrom(rhs, UP) + ) + self.wait() diff --git a/from_3b1b/active/bayes/beta_helpers.py b/from_3b1b/active/bayes/beta_helpers.py index 9981c56c..26f9ecee 100644 --- a/from_3b1b/active/bayes/beta_helpers.py +++ b/from_3b1b/active/bayes/beta_helpers.py @@ -5,6 +5,11 @@ import scipy.stats CMARK_TEX = "\\text{\\ding{51}}" XMARK_TEX = "\\text{\\ding{55}}" +COIN_COLOR_MAP = { + "H": BLUE_E, + "T": RED_E, +} + class Histogram(Group): CONFIG = { @@ -171,29 +176,6 @@ def get_random_process(choices, shuffle_time=2, total_time=3, change_rate=0.05, return container -def get_coin(color, symbol): - coin = VGroup() - circ = Circle() - circ.set_fill(color, 1) - circ.set_stroke(WHITE, 1) - circ.set_height(1) - label = TextMobject(symbol) - label.set_height(0.5 * circ.get_height()) - label.move_to(circ) - coin.add(circ, label) - coin.symbol = symbol - coin.lock_triangulation() - return coin - - -def get_random_coin(**kwargs): - coins = VGroup( - get_coin(BLUE_E, "H"), - get_coin(RED_E, "T"), - ) - return get_random_process(coins, **kwargs) - - def get_die_faces(): dot = Dot() dot.set_width(0.15) @@ -242,6 +224,69 @@ def get_random_card(height=1, **kwargs): return get_random_process(cards, **kwargs) +# Coins +def get_coin(symbol, color=None): + if color is None: + color = COIN_COLOR_MAP.get(symbol, GREY_E) + coin = VGroup() + circ = Circle() + circ.set_fill(color, 1) + circ.set_stroke(WHITE, 1) + circ.set_height(1) + label = TextMobject(symbol) + label.set_height(0.5 * circ.get_height()) + label.move_to(circ) + coin.add(circ, label) + coin.symbol = symbol + coin.lock_triangulation() + return coin + + +def get_random_coin(**kwargs): + return get_random_process([get_coin("H"), get_coin("T")], **kwargs) + + +def get_prob_coin_label(symbol="H", color=None, p=0.5, num_decimal_places=2): + label = TexMobject("P", "(", "00", ")", "=",) + coin = get_coin(symbol, color) + template = label.get_part_by_tex("00") + coin.replace(template) + label.replace_submobject(label.index_of_part(template), coin) + rhs = DecimalNumber(p, num_decimal_places=num_decimal_places) + rhs.next_to(label, RIGHT, buff=MED_SMALL_BUFF) + label.add(rhs) + return label + + +def get_q_box(mob): + box = SurroundingRectangle(mob) + box.set_stroke(WHITE, 1) + box.set_fill(GREY_E, 1) + q_marks = TexMobject("???") + max_width = 0.8 * box.get_width() + max_height = 0.8 * box.get_height() + + if q_marks.get_width() > max_width: + q_marks.set_width(max_width) + + if q_marks.get_height() > max_height: + q_marks.set_height(max_height) + + q_marks.move_to(box) + box.add(q_marks) + return box + + +def get_coin_grid(bools, height=6): + coins = VGroup(*[ + get_coin("H" if heads else "T") + for heads in bools + ]) + coins.arrange_in_grid() + coins.set_height(height) + return coins + + def get_prob_positive_experience_label(include_equals=False, include_decimal=False, include_q_mark=False): @@ -325,11 +370,37 @@ def get_beta_dist_axes(y_max=20, y_unit=2, label_y=False, **kwargs): return result +def scaled_pdf_axes(scale_factor=3.5): + axes = get_beta_dist_axes( + label_y=True, + y_unit=1, + ) + axes.y_axis.numbers.set_submobjects([ + *axes.y_axis.numbers[:5], + *axes.y_axis.numbers[4::5] + ]) + sf = scale_factor + axes.y_axis.stretch(sf, 1, about_point=axes.c2p(0, 0)) + for number in axes.y_axis.numbers: + number.stretch(1 / sf, 1) + axes.y_axis_label.to_edge(LEFT) + axes.y_axis_label.add_background_rectangle(opacity=1) + axes.set_stroke(background=True) + return axes + + +def close_off_graph(axes, graph): + x_max = axes.x_axis.p2n(graph.get_end()) + graph.add_line_to(axes.c2p(x_max, 0)) + graph.add_line_to(axes.c2p(0, 0)) + graph.lock_triangulation() + return graph + + def get_beta_graph(axes, n_plus, n_minus, **kwargs): dist = scipy.stats.beta(n_plus + 1, n_minus + 1) graph = axes.get_graph(dist.pdf, **kwargs) - graph.add_line_to(axes.c2p(1, 0)) - graph.add_line_to(axes.c2p(0, 0)) + close_off_graph(axes, graph) graph.set_stroke(BLUE, 2) graph.set_fill(BLUE_E, 1) graph.lock_triangulation() diff --git a/from_3b1b/active/ctracing.py b/from_3b1b/active/ctracing.py new file mode 100644 index 00000000..416975f3 --- /dev/null +++ b/from_3b1b/active/ctracing.py @@ -0,0 +1,754 @@ +from manimlib.imports import * +from from_3b1b.active.sir import * + + +class LastFewMonths(Scene): + def construct(self): + words = TextMobject("Last ", "few\\\\", "months:") + words.set_height(4) + underlines = VGroup() + for word in words: + underline = Line(LEFT, RIGHT) + underline.match_width(word) + underline.next_to(word, DOWN, SMALL_BUFF) + underlines.add(underline) + underlines[0].stretch(1.4, 0, about_edge=LEFT) + underlines.set_color(BLUE) + + # self.play(ShowCreation(underlines)) + self.play(ShowIncreasingSubsets(words, run_time=0.75, rate_func=linear)) + self.wait() + + +class UnemploymentTitle(Scene): + def construct(self): + words = TextMobject("Unemployment claims\\\\per week in the US")[0] + words.set_width(FRAME_WIDTH - 1) + words.to_edge(UP) + arrow = Arrow( + words.get_bottom(), + words.get_bottom() + 3 * RIGHT + 3 * DOWN, + stroke_width=10, + tip_length=0.5, + ) + arrow.set_color(BLUE_E) + words.set_color(BLACK) + self.play( + ShowIncreasingSubsets(words), + ShowCreation(arrow), + ) + self.wait() + + +class ExplainTracing(Scene): + def construct(self): + # Words + words = VGroup( + TextMobject("Testing, ", "Testing, ", "Testing!"), + TextMobject("Contact Tracing"), + ) + words[0].set_color(GREEN) + words[1].set_color(BLUE_B) + words.set_width(FRAME_WIDTH - 2) + words.arrange(DOWN, buff=1) + + self.play(ShowIncreasingSubsets(words[0], rate_func=linear)) + self.wait() + self.play(Write(words[1], run_time=1)) + self.wait() + + self.play( + words[1].to_edge, UP, + FadeOutAndShift(words[0], 6 * UP) + ) + + ct_word = words[1][0] + + # Groups + clusters = VGroup() + for x in range(4): + cluster = VGroup() + for y in range(4): + cluster.add(Randolph()) + cluster.arrange_in_grid(buff=LARGE_BUFF) + clusters.add(cluster) + clusters.scale(0.5) + clusters.arrange_in_grid(buff=2) + clusters.set_height(4) + + self.play(FadeIn(clusters)) + + pis = VGroup() + boxes = VGroup() + for cluster in clusters: + for pi in cluster: + pis.add(pi) + box = SurroundingRectangle(pi, buff=0.05) + boxes.add(box) + pi.box = box + + boxes.set_stroke(WHITE, 1) + + sicky = clusters[0][2] + covid_words = TextMobject("COVID-19\\\\Positive!") + covid_words.set_color(RED) + arrow = Vector(RIGHT, color=RED) + arrow.next_to(sicky, LEFT) + covid_words.next_to(arrow, LEFT, SMALL_BUFF) + + self.play( + sicky.change, "sick", + sicky.set_color, "#9BBD37", + FadeInFrom(covid_words, RIGHT), + GrowArrow(arrow), + ) + self.play(ShowCreation(sicky.box)) + self.wait(2) + anims = [] + for pi in clusters[0]: + if pi is not sicky: + anims.append(ApplyMethod(pi.change, "tired")) + anims.append(ShowCreation(pi.box)) + self.play(*anims) + self.wait() + + self.play(VFadeIn( + boxes[4:], + run_time=2, + rate_func=there_and_back_with_pause, + )) + self.wait() + + self.play(FadeOut( + VGroup( + covid_words, + arrow, + *boxes[:4], + *pis, + ), + lag_ratio=0.1, + run_time=3, + )) + self.play(ct_word.move_to, 2 * UP) + + # Underlines + implies = TexMobject("\\Downarrow") + implies.scale(2) + implies.next_to(ct_word, DOWN, MED_LARGE_BUFF) + loc_tracking = TextMobject("Location Tracking") + loc_tracking.set_color(GREY_BROWN) + loc_tracking.match_height(ct_word) + loc_tracking.next_to(implies, DOWN, MED_LARGE_BUFF) + + q_marks = TexMobject("???") + q_marks.scale(2) + q_marks.next_to(implies, RIGHT) + + cross = Cross(implies) + cross.set_stroke(RED, 7) + + self.play( + Write(implies), + FadeInFrom(loc_tracking, UP) + ) + self.play(FadeIn(q_marks, lag_ratio=0.1)) + self.wait() + + parts = VGroup(ct_word[:7], ct_word[7:]) + lines = VGroup() + for part in parts: + line = Line(part.get_left(), part.get_right()) + line.align_to(part[0], DOWN) + line.shift(0.1 * DOWN) + lines.add(line) + + ct_word.set_stroke(BLACK, 2, background=True) + self.add(lines[1], ct_word) + self.play(ShowCreation(lines[1])) + self.wait() + self.play(ShowCreation(lines[0])) + self.wait() + + self.play( + ShowCreation(cross), + FadeOutAndShift(q_marks, RIGHT), + FadeOut(lines), + ) + self.wait() + + dp_3t = TextMobject("DP-3T") + dp_3t.match_height(ct_word) + dp_3t.move_to(loc_tracking) + dp_3t_long = TextMobject("Decentralized Privacy-Preserving Proximity Tracing") + dp_3t_long.next_to(dp_3t, DOWN, LARGE_BUFF) + + arrow = Vector(UP) + arrow.set_stroke(width=8) + arrow.move_to(implies) + + self.play( + FadeInFromDown(dp_3t), + FadeOut(loc_tracking), + FadeOut(implies), + FadeOut(cross), + ShowCreation(arrow) + ) + self.play(Write(dp_3t_long)) + self.wait() + + +class ContactTracingMisnomer(Scene): + def construct(self): + # Word play + words = TextMobject("Contact ", "Tracing") + words.scale(2) + rects = VGroup(*[ + SurroundingRectangle(word, buff=0.2) + for word in words + ]) + expl1 = TextMobject("Doesn't ``trace'' you...") + expl2 = TextMobject("...or your contacts") + expls = VGroup(expl1, expl2) + colors = [RED, BLUE] + + self.add(words) + for vect, rect, expl, color in zip([UP, DOWN], reversed(rects), expls, colors): + arrow = Vector(-vect) + arrow.next_to(rect, vect, SMALL_BUFF) + expl.next_to(arrow, vect, SMALL_BUFF) + rect.set_color(color) + arrow.set_color(color) + expl.set_color(color) + + self.play( + FadeInFrom(expl, -vect), + GrowArrow(arrow), + ShowCreation(rect), + ) + self.wait() + + self.play(Write( + VGroup(*self.mobjects), + rate_func=lambda t: smooth(1 - t), + run_time=3, + )) + + +class ContactTracingWords(Scene): + def construct(self): + words = TextMobject("Contact\\\\", "Tracing") + words.set_height(4) + for word in words: + self.add(word) + self.wait() + self.wait() + return + self.play(ShowIncreasingSubsets(words)) + self.wait() + self.play( + words.set_height, 1, + words.to_corner, UL, + ) + self.wait() + + +class WanderingDotsWithLines(Scene): + def construct(self): + sim = SIRSimulation( + city_population=20, + person_type=DotPerson, + person_config={ + "color_map": { + "S": GREY, + "I": GREY, + "R": GREY, + }, + "infection_ring_style": { + "stroke_color": YELLOW, + }, + "max_speed": 0.5, + }, + infection_time=100, + ) + + for person in sim.people: + person.set_status("S") + person.infection_start_time += random.random() + + lines = VGroup() + + max_dist = 1.25 + + def update_lines(lines): + lines.remove(*lines.submobjects) + for p1 in sim.people: + for p2 in sim.people: + if p1 is p2: + continue + dist = get_norm(p1.get_center() - p2.get_center()) + if dist < max_dist: + line = Line(p1.get_center(), p2.get_center()) + alpha = (max_dist - dist) / max_dist + line.set_stroke( + interpolate_color(WHITE, RED, alpha), + width=4 * alpha + ) + lines.add(line) + + lines.add_updater(update_lines) + + self.add(lines) + self.add(sim) + self.wait(10) + for person in sim.people: + person.set_status("I") + person.infection_start_time += random.random() + self.wait(50) + + +class WhatAboutPeopleWithoutPhones(TeacherStudentsScene): + def construct(self): + self.student_says( + "What about people\\\\without phones?", + target_mode="sassy", + added_anims=[self.teacher.change, "guilty"] + ) + self.change_student_modes("angry", "angry", "sassy") + self.wait() + self.play(self.teacher.change, "tease") + self.wait() + + words = VectorizedPoint() + words.scale(1.5) + words.to_corner(UL) + + self.play( + FadeInFromDown(words), + RemovePiCreatureBubble(self.students[2]), + *[ + ApplyMethod(pi.change, "pondering", words) + for pi in self.pi_creatures + ] + ) + self.wait(5) + + +class PiGesture1(Scene): + def construct(self): + randy = Randolph(mode="raise_right_hand", height=2) + bubble = randy.get_bubble( + bubble_class=SpeechBubble, + height=2, width=3, + ) + bubble.write("This one's\\\\great") + bubble.content.scale(0.8) + bubble.content.set_color(BLACK) + bubble.set_color(BLACK) + bubble.set_fill(opacity=0) + randy.set_stroke(BLACK, 5, background=True) + self.add(randy, bubble, bubble.content) + + +class PiGesture2(Scene): + def construct(self): + randy = Randolph(mode="raise_left_hand", height=2) + randy.look(UL) + # randy.flip() + randy.set_color(GREY_BROWN) + bubble = randy.get_bubble( + bubble_class=SpeechBubble, + height=2, width=3, + direction=LEFT, + ) + bubble.write("So is\\\\this one") + bubble.content.scale(0.8) + bubble.content.set_color(BLACK) + bubble.set_color(BLACK) + bubble.set_fill(opacity=0) + randy.set_stroke(BLACK, 5, background=True) + self.add(randy, bubble, bubble.content) + + +class PiGesture3(Scene): + def construct(self): + randy = Randolph(mode="hooray", height=2) + randy.flip() + bubble = randy.get_bubble( + bubble_class=SpeechBubble, + height=2, width=3, + direction=LEFT, + ) + bubble.write("And this\\\\one") + bubble.content.scale(0.8) + bubble.content.set_color(BLACK) + bubble.set_color(BLACK) + bubble.set_fill(opacity=0) + randy.set_stroke(BLACK, 5, background=True) + self.add(randy, bubble, bubble.content) + + +class AppleGoogleCoop(Scene): + def construct(self): + logos = Group( + self.get_apple_logo(), + self.get_google_logo(), + ) + for logo in logos: + logo.set_height(2) + apple, google = logos + + logos.arrange(RIGHT, buff=3) + + arrows = VGroup() + for vect, u in zip([UP, DOWN], [0, 1]): + m1, m2 = logos[u], logos[1 - u] + arrows.add(Arrow( + m1.get_edge_center(vect), + m2.get_edge_center(vect), + path_arc=-90 * DEGREES, + buff=MED_LARGE_BUFF, + stroke_width=10, + )) + + self.play(LaggedStart( + Write(apple), + FadeIn(google), + lag_ratio=0.7, + )) + self.wait() + self.play(ShowCreation(arrows, run_time=2)) + self.wait() + + def get_apple_logo(self): + result = SVGMobject("apple_logo") + result.set_color("#b3b3b3") + return result + + def get_google_logo(self): + result = ImageMobject("google_logo_black") + return result + + +class LocationTracking(Scene): + def construct(self): + question = TextMobject( + "Would you like this company to track\\\\", + "and occasionally sell your location?" + ) + question.to_edge(UP, buff=LARGE_BUFF) + + slider = Rectangle(width=1.25, height=0.5) + slider.round_corners(radius=0.25) + slider.set_fill(GREEN, 1) + slider.next_to(question, DOWN, buff=MED_LARGE_BUFF) + + dot = Dot(radius=0.25) + dot.set_fill(GREY_C, 1) + dot.set_stroke(WHITE, 3) + dot.move_to(slider, RIGHT) + + morty = Mortimer() + morty.next_to(slider, RIGHT) + morty.to_edge(DOWN) + + bubble = morty.get_bubble( + height=2, + width=3, + direction=LEFT, + ) + + answer = TextMobject("Um...", "no.") + answer.set_height(0.4) + answer.set_color(YELLOW) + bubble.add_content(answer) + + self.add(morty) + + self.play( + FadeInFromDown(question), + Write(slider), + FadeIn(dot), + ) + self.play(morty.change, "confused", slider) + self.play(Blink(morty)) + self.play( + FadeIn(bubble), + Write(answer[0]), + ) + self.wait() + self.play( + dot.move_to, slider, LEFT, + slider.set_fill, {"opacity": 0}, + FadeIn(answer[1]), + morty.change, "sassy" + ) + self.play(Blink(morty)) + self.wait(2) + self.play(Blink(morty)) + self.wait(2) + + +class MoreLinks(Scene): + def construct(self): + words = TextMobject("See more links\\\\in the description.") + words.scale(2) + words.to_edge(UP, buff=2) + arrows = VGroup(*[ + Vector(1.5 * DOWN, stroke_width=10) + for x in range(4) + ]) + arrows.arrange(RIGHT, buff=0.75) + arrows.next_to(words, DOWN, buff=0.5) + for arrow, color in zip(arrows, [BLUE_D, BLUE_C, BLUE_E, GREY_BROWN]): + arrow.set_color(color) + self.play(Write(words)) + self.play(LaggedStartMap(ShowCreation, arrows)) + self.wait() + + +class LDMEndScreen(PatreonEndScreen): + CONFIG = { + "scroll_time": 20, + "specific_patrons": [ + "1stViewMaths", + "Aaron", + "Adam Dřínek", + "Adam Margulies", + "Aidan Shenkman", + "Alan Stein", + "Albin Egasse", + "Alex Mijalis", + "Alexander Mai", + "Alexis Olson", + "Ali Yahya", + "Andreas Snekloth Kongsgaard", + "Andrew Busey", + "Andrew Cary", + "Andrew R. Whalley", + "Aravind C V", + "Arjun Chakroborty", + "Arthur Zey", + "Ashwin Siddarth", + "Augustine Lim", + "Austin Goodman", + "Avi Finkel", + "Awoo", + "Axel Ericsson", + "Ayan Doss", + "AZsorcerer", + "Barry Fam", + "Bartosz Burclaf", + "Ben Delo", + "Benjamin Bailey", + "Bernd Sing", + "Bill Gatliff", + "Boris Veselinovich", + "Bradley Pirtle", + "Brandon Huang", + "Brendan Shah", + "Brian Cloutier", + "Brian Staroselsky", + "Britt Selvitelle", + "Britton Finley", + "Burt Humburg", + "Calvin Lin", + "Carl-Johan R. Nordangård", + "Charles Southerland", + "Charlie N", + "Chris Connett", + "Chris Druta", + "Christian Kaiser", + "cinterloper", + "Clark Gaebel", + "Colwyn Fritze-Moor", + "Corey Ogburn", + "D. Sivakumar", + "Dan Herbatschek", + "Daniel Brown", + "Daniel Herrera C", + "Darrell Thomas", + "Dave B", + "Dave Cole", + "Dave Kester", + "dave nicponski", + "David B. Hill", + "David Clark", + "David Gow", + "Delton Ding", + "Dominik Wagner", + "Eduardo Rodriguez", + "Emilio Mendoza", + "emptymachine", + "Eric Younge", + "Eryq Ouithaqueue", + "Federico Lebron", + "Fernando Via Canel", + "Frank R. Brown, Jr.", + "gary", + "Giovanni Filippi", + "Goodwine", + "Hal Hildebrand", + "Heptonion", + "Hitoshi Yamauchi", + "Isaac Gubernick", + "Ivan Sorokin", + "Jacob Baxter", + "Jacob Harmon", + "Jacob Hartmann", + "Jacob Magnuson", + "Jalex Stark", + "Jameel Syed", + "James Beall", + "Jason Hise", + "Jayne Gabriele", + "Jean-Manuel Izaret", + "Jeff Dodds", + "Jeff Linse", + "Jeff Straathof", + "Jeffrey Wolberg", + "Jimmy Yang", + "Joe Pregracke", + "Johan Auster", + "John C. Vesey", + "John Camp", + "John Haley", + "John Le", + "John Luttig", + "John Rizzo", + "John V Wertheim", + "jonas.app", + "Jonathan Heckerman", + "Jonathan Wilson", + "Joseph John Cox", + "Joseph Kelly", + "Josh Kinnear", + "Joshua Claeys", + "Joshua Ouellette", + "Juan Benet", + "Julien Dubois", + "Kai-Siang Ang", + "Kanan Gill", + "Karl Niu", + "Kartik Cating-Subramanian", + "Kaustuv DeBiswas", + "Killian McGuinness", + "kkm", + "Klaas Moerman", + "Kristoffer Börebäck", + "Kros Dai", + "L0j1k", + "Lael S Costa", + "LAI Oscar", + "Lambda GPU Workstations", + "Laura Gast", + "Lee Redden", + "Linh Tran", + "Luc Ritchie", + "Ludwig Schubert", + "Lukas Biewald", + "Lukas Zenick", + "Magister Mugit", + "Magnus Dahlström", + "Magnus Hiie", + "Manoj Rewatkar - RITEK SOLUTIONS", + "Mark B Bahu", + "Mark Heising", + "Mark Hopkins", + "Mark Mann", + "Martin Price", + "Mathias Jansson", + "Matt Godbolt", + "Matt Langford", + "Matt Roveto", + "Matt Russell", + "Matteo Delabre", + "Matthew Bouchard", + "Matthew Cocke", + "Maxim Nitsche", + "Michael Bos", + "Michael Hardel", + "Michael W White", + "Mirik Gogri", + "Molly Mackinlay", + "Mustafa Mahdi", + "Márton Vaitkus", + "Nero Li", + "Nicholas Cahill", + "Nikita Lesnikov", + "Nitu Kitchloo", + "Oleg Leonov", + "Oliver Steele", + "Omar Zrien", + "Omer Tuchfeld", + "Patrick Gibson", + "Patrick Lucas", + "Pavel Dubov", + "Pesho Ivanov", + "Petar Veličković", + "Peter Ehrnstrom", + "Peter Francis", + "Peter Mcinerney", + "Pierre Lancien", + "Pradeep Gollakota", + "Rafael Bove Barrios", + "Raghavendra Kotikalapudi", + "Randy C. Will", + "rehmi post", + "Rex Godby", + "Ripta Pasay", + "Rish Kundalia", + "Roman Sergeychik", + "Roobie", + "Ryan Atallah", + "Samuel Judge", + "SansWord Huang", + "Scott Gray", + "Scott Walter, Ph.D.", + "soekul", + "Solara570", + "Spyridon Michalakis", + "Stephen Shanahan", + "Steve Huynh", + "Steve Muench", + "Steve Sperandeo", + "Steven Siddals", + "Stevie Metke", + "Sundar Subbarayan", + "supershabam", + "Suteerth Vishnu", + "Suthen Thomas", + "Tal Einav", + "Taras Bobrovytsky", + "Tauba Auerbach", + "Ted Suzman", + "Terry Hayes", + "THIS IS THE point OF NO RE tUUurRrhghgGHhhnnn", + "Thomas J Sargent", + "Thomas Tarler", + "Tianyu Ge", + "Tihan Seale", + "Tim Erbes", + "Tim Kazik", + "Tomasz Legutko", + "Tyler Herrmann", + "Tyler Parcell", + "Tyler VanValkenburg", + "Tyler Veness", + "Ubiquity Ventures", + "Vassili Philippov", + "Vasu Dubey", + "Veritasium", + "Vignesh Ganapathi Subramanian", + "Vinicius Reis", + "Vladimir Solomatin", + "Wooyong Ee", + "Xuanji Li", + "Yana Chernobilsky", + "Yavor Ivanov", + "Yetinother", + "YinYangBalance.Asia", + "Yu Jun", + "Yurii Monastyrshyn", + "Zachariah Rosenberg", + ], + } diff --git a/from_3b1b/active/diffyq/part2/fourier_series.py b/from_3b1b/active/diffyq/part2/fourier_series.py index 8ca76b28..ea58f9c0 100644 --- a/from_3b1b/active/diffyq/part2/fourier_series.py +++ b/from_3b1b/active/diffyq/part2/fourier_series.py @@ -331,6 +331,8 @@ class FourierOfPiSymbol(FourierCirclesScene): def add_vectors_circles_path(self): path = self.get_path() coefs = self.get_coefficients_of_path(path) + for coef in coefs: + print(coef) vectors = self.get_rotating_vectors(coefficients=coefs) circles = self.get_circles(vectors) self.set_decreasing_stroke_widths(circles) diff --git a/from_3b1b/active/diffyq/part3/temperature_graphs.py b/from_3b1b/active/diffyq/part3/temperature_graphs.py index 13d2c196..e1c2f1c1 100644 --- a/from_3b1b/active/diffyq/part3/temperature_graphs.py +++ b/from_3b1b/active/diffyq/part3/temperature_graphs.py @@ -1369,9 +1369,7 @@ class SineWaveScaledByExp(TemperatureGraphScene): theta=-80 * DEGREES, distance=50, ) - self.camera.set_frame_center( - 2 * RIGHT, - ) + self.camera.frame.move_to(2 * RIGHT) def show_sine_wave(self): time_tracker = ValueTracker(0) diff --git a/from_3b1b/active/ldm.py b/from_3b1b/active/ldm.py new file mode 100644 index 00000000..d6880eb4 --- /dev/null +++ b/from_3b1b/active/ldm.py @@ -0,0 +1,1445 @@ +from manimlib.imports import * +from from_3b1b.active.bayes.beta_helpers import * +import math + + +class StreamIntro(Scene): + def construct(self): + # Add logo + logo = Logo() + spikes = VGroup(*[ + spike + for layer in logo.spike_layers + for spike in layer + ]) + self.add(*logo.family_members_with_points()) + + # Add label + label = TextMobject("The lesson will\\\\begin shortly") + label.scale(2) + label.next_to(logo, DOWN) + self.add(label) + + self.camera.frame.move_to(DOWN) + + for spike in spikes: + point = spike.get_start() + spike.angle = angle_of_vector(point) + + anims = [] + for spike in spikes: + anims.append(Rotate( + spike, spike.angle * 28 * 2, + about_point=ORIGIN, + rate_func=linear, + )) + self.play(*anims, run_time=60 * 5) + self.wait(20) + + +class OldStreamIntro(Scene): + def construct(self): + morty = Mortimer() + morty.flip() + morty.set_height(2) + morty.to_corner(DL) + self.play(PiCreatureSays( + morty, "The lesson will\\\\begin soon.", + bubble_kwargs={ + "height": 2, + "width": 3, + }, + target_mode="hooray", + )) + bound = AnimatedBoundary(morty.bubble.content, max_stroke_width=1) + self.add(bound, morty.bubble, morty.bubble.content) + self.remove(morty.bubble.content) + morty.bubble.set_fill(opacity=0) + + self.camera.frame.scale(0.6, about_edge=DL) + + self.play(Blink(morty)) + self.wait(5) + self.play(Blink(morty)) + self.wait(3) + return + + text = TextMobject("The lesson will\\\\begin soon.") + text.set_height(1.5) + text.to_corner(DL, buff=LARGE_BUFF) + self.add(text) + + +class QuadraticFormula(TeacherStudentsScene): + def construct(self): + formula = TexMobject( + "\\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}", + ) + formula.next_to(self.students, UP, buff=MED_LARGE_BUFF, aligned_edge=LEFT) + self.add(formula) + + self.change_student_modes( + "angry", "tired", "sad", + look_at_arg=formula, + ) + self.teacher_says( + "It doesn't have\\\\to be this way.", + bubble_kwargs={ + "width": 4, + "height": 3, + } + ) + self.wait(5) + self.change_student_modes( + "pondering", "thinking", "erm", + look_at_arg=formula + ) + self.wait(12) + + +class SimplerQuadratic(Scene): + def construct(self): + tex = TexMobject("m \\pm \\sqrt{m^2 - p}") + tex.set_stroke(BLACK, 12, background=True) + tex.scale(1.5) + self.add(tex) + + +class CosGraphs(Scene): + def construct(self): + axes = Axes( + x_min=-0.75 * TAU, + x_max=0.75 * TAU, + y_min=-1.5, + y_max=1.5, + x_axis_config={ + "tick_frequency": PI / 4, + "include_tip": False, + }, + y_axis_config={ + "tick_frequency": 0.5, + "include_tip": False, + "unit_size": 1.5, + } + ) + + graph1 = axes.get_graph(np.cos) + graph2 = axes.get_graph(lambda x: np.cos(x)**2) + + graph1.set_stroke(YELLOW, 5) + graph2.set_stroke(BLUE, 5) + + label1 = TexMobject("\\cos(x)") + label2 = TexMobject("\\cos^2(x)") + + label1.match_color(graph1) + label1.set_height(0.75) + label1.next_to(axes.input_to_graph_point(-PI, graph1), DOWN) + + label2.match_color(graph2) + label2.set_height(0.75) + label2.next_to(axes.input_to_graph_point(PI, graph2), UP) + + for mob in [graph1, graph2, label1, label2]: + mc = mob.copy() + mc.set_stroke(BLACK, 10, background=True) + self.add(mc) + + self.add(axes) + self.add(graph1) + self.add(graph2) + self.add(label1) + self.add(label2) + + self.embed() + + +class SineWave(Scene): + def construct(self): + w_axes = self.get_wave_axes() + square, circle, c_axes = self.get_edge_group() + + self.add(w_axes) + self.add(square, circle, c_axes) + + theta_tracker = ValueTracker(0) + c_dot = Dot(color=YELLOW) + c_line = Line(DOWN, UP, color=GREEN) + w_dot = Dot(color=YELLOW) + w_line = Line(DOWN, UP, color=GREEN) + + def update_c_dot(dot, axes=c_axes, tracker=theta_tracker): + theta = tracker.get_value() + dot.move_to(axes.c2p( + np.cos(theta), + np.sin(theta), + )) + + def update_c_line(line, axes=c_axes, tracker=theta_tracker): + theta = tracker.get_value() + x = np.cos(theta) + y = np.sin(theta) + if y == 0: + y = 1e-6 + line.put_start_and_end_on( + axes.c2p(x, 0), + axes.c2p(x, y), + ) + + def update_w_dot(dot, axes=w_axes, tracker=theta_tracker): + theta = tracker.get_value() + dot.move_to(axes.c2p(theta, np.sin(theta))) + + def update_w_line(line, axes=w_axes, tracker=theta_tracker): + theta = tracker.get_value() + x = theta + y = np.sin(theta) + if y == 0: + y = 1e-6 + line.put_start_and_end_on( + axes.c2p(x, 0), + axes.c2p(x, y), + ) + + def get_partial_circle(circle=circle, tracker=theta_tracker): + result = circle.copy() + theta = tracker.get_value() + result.pointwise_become_partial( + circle, 0, clip(theta / TAU, 0, 1), + ) + result.set_stroke(RED, width=3) + return result + + def get_partial_wave(axes=w_axes, tracker=theta_tracker): + theta = tracker.get_value() + graph = axes.get_graph(np.sin, x_min=0, x_max=theta, step_size=0.025) + graph.set_stroke(BLUE, 3) + return graph + + def get_h_line(axes=w_axes, tracker=theta_tracker): + theta = tracker.get_value() + return Line( + axes.c2p(0, 0), + axes.c2p(theta, 0), + stroke_color=RED + ) + + c_dot.add_updater(update_c_dot) + c_line.add_updater(update_c_line) + w_dot.add_updater(update_w_dot) + w_line.add_updater(update_w_line) + partial_circle = always_redraw(get_partial_circle) + partial_wave = always_redraw(get_partial_wave) + h_line = always_redraw(get_h_line) + + self.add(partial_circle) + self.add(partial_wave) + self.add(h_line) + self.add(c_line, c_dot) + self.add(w_line, w_dot) + + sin_label = TexMobject( + "\\sin\\left(\\theta\\right)", + tex_to_color_map={"\\theta": RED} + ) + sin_label.next_to(w_axes.get_top(), UR) + self.add(sin_label) + + self.play( + theta_tracker.set_value, 1.25 * TAU, + run_time=15, + rate_func=linear, + ) + + def get_wave_axes(self): + wave_axes = Axes( + x_min=0, + x_max=1.25 * TAU, + y_min=-1.0, + y_max=1.0, + x_axis_config={ + "tick_frequency": TAU / 8, + "unit_size": 1.0, + }, + y_axis_config={ + "tick_frequency": 0.5, + "include_tip": False, + "unit_size": 1.5, + } + ) + wave_axes.y_axis.add_numbers( + -1, 1, number_config={"num_decimal_places": 1} + ) + wave_axes.to_edge(RIGHT, buff=MED_SMALL_BUFF) + + pairs = [ + (PI / 2, "\\frac{\\pi}{2}"), + (PI, "\\pi"), + (3 * PI / 2, "\\frac{3\\pi}{2}"), + (2 * PI, "2\\pi"), + ] + syms = VGroup() + for val, tex in pairs: + sym = TexMobject(tex) + sym.scale(0.5) + sym.next_to(wave_axes.c2p(val, 0), DOWN, MED_SMALL_BUFF) + syms.add(sym) + wave_axes.add(syms) + + theta = TexMobject("\\theta") + theta.set_color(RED) + theta.next_to(wave_axes.x_axis.get_end(), UP) + wave_axes.add(theta) + + return wave_axes + + def get_edge_group(self): + axes_max = 1.25 + radius = 1.5 + axes = Axes( + x_min=-axes_max, + x_max=axes_max, + y_min=-axes_max, + y_max=axes_max, + axis_config={ + "tick_frequency": 0.5, + "include_tip": False, + "numbers_with_elongated_ticks": [-1, 1], + "tick_size": 0.05, + "unit_size": radius, + }, + ) + axes.to_edge(LEFT, buff=MED_LARGE_BUFF) + + background = SurroundingRectangle(axes, buff=MED_SMALL_BUFF) + background.set_stroke(WHITE, 1) + background.set_fill(GREY_E, 1) + + circle = Circle(radius=radius) + circle.move_to(axes) + circle.set_stroke(WHITE, 1) + + nums = VGroup() + for u in 1, -1: + num = Integer(u) + num.set_height(0.2) + num.set_stroke(BLACK, 3, background=True) + num.next_to(axes.c2p(u, 0), DOWN + u * RIGHT, SMALL_BUFF) + nums.add(num) + + axes.add(nums) + + return background, circle, axes + + +class CosWave(SineWave): + CONFIG = { + "include_square": False, + } + + def construct(self): + w_axes = self.get_wave_axes() + square, circle, c_axes = self.get_edge_group() + + self.add(w_axes) + self.add(square, circle, c_axes) + + theta_tracker = ValueTracker(0) + c_dot = Dot(color=YELLOW) + c_line = Line(DOWN, UP, color=GREEN) + w_dot = Dot(color=YELLOW) + w_line = Line(DOWN, UP, color=GREEN) + + def update_c_dot(dot, axes=c_axes, tracker=theta_tracker): + theta = tracker.get_value() + dot.move_to(axes.c2p( + np.cos(theta), + np.sin(theta), + )) + + def update_c_line(line, axes=c_axes, tracker=theta_tracker): + theta = tracker.get_value() + x = np.cos(theta) + y = np.sin(theta) + line.set_points_as_corners([ + axes.c2p(0, y), + axes.c2p(x, y), + ]) + + def update_w_dot(dot, axes=w_axes, tracker=theta_tracker): + theta = tracker.get_value() + dot.move_to(axes.c2p(theta, np.cos(theta))) + + def update_w_line(line, axes=w_axes, tracker=theta_tracker): + theta = tracker.get_value() + x = theta + y = np.cos(theta) + if y == 0: + y = 1e-6 + line.set_points_as_corners([ + axes.c2p(x, 0), + axes.c2p(x, y), + ]) + + def get_partial_circle(circle=circle, tracker=theta_tracker): + result = circle.copy() + theta = tracker.get_value() + result.pointwise_become_partial( + circle, 0, clip(theta / TAU, 0, 1), + ) + result.set_stroke(RED, width=3) + return result + + def get_partial_wave(axes=w_axes, tracker=theta_tracker): + theta = tracker.get_value() + graph = axes.get_graph(np.cos, x_min=0, x_max=theta, step_size=0.025) + graph.set_stroke(PINK, 3) + return graph + + def get_h_line(axes=w_axes, tracker=theta_tracker): + theta = tracker.get_value() + return Line( + axes.c2p(0, 0), + axes.c2p(theta, 0), + stroke_color=RED + ) + + def get_square(line=c_line): + square = Square() + square.set_stroke(WHITE, 1) + square.set_fill(MAROON_B, opacity=0.5) + square.match_width(line) + square.move_to(line, DOWN) + return square + + def get_square_graph(axes=w_axes, tracker=theta_tracker): + theta = tracker.get_value() + graph = axes.get_graph( + lambda x: np.cos(x)**2, x_min=0, x_max=theta, step_size=0.025 + ) + graph.set_stroke(MAROON_B, 3) + return graph + + c_dot.add_updater(update_c_dot) + c_line.add_updater(update_c_line) + w_dot.add_updater(update_w_dot) + w_line.add_updater(update_w_line) + h_line = always_redraw(get_h_line) + partial_circle = always_redraw(get_partial_circle) + partial_wave = always_redraw(get_partial_wave) + + self.add(partial_circle) + self.add(partial_wave) + self.add(h_line) + self.add(c_line, c_dot) + self.add(w_line, w_dot) + + if self.include_square: + self.add(always_redraw(get_square)) + self.add(always_redraw(get_square_graph)) + + cos_label = TexMobject( + "\\cos\\left(\\theta\\right)", + tex_to_color_map={"\\theta": RED} + ) + cos_label.next_to(w_axes.get_top(), UR) + self.add(cos_label) + + self.play( + theta_tracker.set_value, 1.25 * TAU, + run_time=15, + rate_func=linear, + ) + + +class CosSquare(CosWave): + CONFIG = { + "include_square": True + } + + +class ComplexNumberPreview(Scene): + def construct(self): + plane = ComplexPlane(axis_config={"stroke_width": 4}) + plane.add_coordinates() + + z = complex(2, 1) + dot = Dot() + dot.move_to(plane.n2p(z)) + label = TexMobject("2+i") + label.set_color(YELLOW) + dot.set_color(YELLOW) + label.next_to(dot, UR, SMALL_BUFF) + label.set_stroke(BLACK, 5, background=True) + + line = Line(plane.n2p(0), plane.n2p(z)) + arc = Arc(start_angle=0, angle=np.log(z).imag, radius=0.5) + + self.add(plane) + self.add(line, arc) + self.add(dot) + self.add(label) + + self.embed() + + +class ComplexMultiplication(Scene): + def construct(self): + # Add plane + plane = ComplexPlane() + plane.add_coordinates() + + z = complex(2, 1) + z_dot = Dot(color=PINK) + z_dot.move_to(plane.n2p(z)) + z_label = TexMobject("z") + z_label.next_to(z_dot, UR, buff=0.5 * SMALL_BUFF) + z_label.match_color(z_dot) + + self.add(plane) + self.add(z_dot) + self.add(z_label) + + # Show 1 + one_vect = Vector(RIGHT) + one_vect.set_color(YELLOW) + one_vect.target = Vector(plane.n2p(z)) + one_vect.target.match_style(one_vect) + + z_rhs = TexMobject("=", "z \\cdot 1") + z_rhs[1].match_color(one_vect) + z_rhs.next_to(z_label, RIGHT, 1.5 * SMALL_BUFF, aligned_edge=DOWN) + z_rhs.set_stroke(BLACK, 3, background=True) + + one_label, i_label = [l for l in plane.coordinate_labels if l.get_value() == 1] + + self.play(GrowArrow(one_vect)) + self.wait() + self.add(one_vect, z_dot) + self.play( + MoveToTarget(one_vect), + TransformFromCopy(one_label, z_rhs), + ) + self.wait() + + # Show i + i_vect = Vector(UP, color=GREEN) + zi_point = plane.n2p(z * complex(0, 1)) + i_vect.target = Vector(zi_point) + i_vect.target.match_style(i_vect) + i_vect_label = TexMobject("z \\cdot i") + i_vect_label.match_color(i_vect) + i_vect_label.set_stroke(BLACK, 3, background=True) + i_vect_label.next_to(zi_point, UL, SMALL_BUFF) + + self.play(GrowArrow(i_vect)) + self.wait() + self.play( + MoveToTarget(i_vect), + TransformFromCopy(i_label, i_vect_label), + run_time=1, + ) + self.wait() + + self.play( + TransformFromCopy(one_vect, i_vect.target, path_arc=-90 * DEGREES), + ) + self.wait() + + # Transform plane + plane.generate_target() + for mob in plane.target.family_members_with_points(): + if isinstance(mob, Line): + mob.set_stroke(GREY, opacity=0.5) + new_plane = ComplexPlane(faded_line_ratio=0) + + self.remove(plane) + self.add(plane, new_plane, *self.mobjects) + + new_plane.generate_target() + new_plane.target.apply_complex_function(lambda w, z=z: w * z) + + self.play( + MoveToTarget(plane), + MoveToTarget(new_plane), + run_time=6, + rate_func=there_and_back_with_pause + ) + self.wait() + + # Show Example Point + w = complex(2, -1) + w_dot = Dot(plane.n2p(w), color=WHITE) + one_vects = VGroup(*[Vector(RIGHT) for x in range(2)]) + one_vects.arrange(RIGHT, buff=0) + one_vects.move_to(plane.n2p(0), LEFT) + one_vects.set_color(YELLOW) + new_i_vect = Vector(DOWN) + new_i_vect.move_to(plane.n2p(2), UP) + new_i_vect.set_color(GREEN) + vects = VGroup(*one_vects, new_i_vect) + vects.set_opacity(0.8) + + w_group = VGroup(*vects, w_dot) + w_group.target = VGroup( + one_vect.copy().set_opacity(0.8), + one_vect.copy().shift(plane.n2p(z)).set_opacity(0.8), + i_vect.copy().rotate(PI, about_point=ORIGIN).shift(2 * plane.n2p(z)).set_opacity(0.8), + Dot(plane.n2p(w * z), color=WHITE) + ) + + self.play(FadeInFromLarge(w_dot)) + self.wait() + self.play(ShowCreation(vects)) + self.wait() + + self.play( + MoveToTarget(plane), + MoveToTarget(new_plane), + MoveToTarget(w_group), + run_time=2, + path_arc=np.log(z).imag, + ) + self.wait() + + +class RotatePiCreature(Scene): + def construct(self): + randy = Randolph(mode="thinking") + randy.set_height(6) + + plane = ComplexPlane(x_min=-12, x_max=12) + plane.add_coordinates() + + self.camera.frame.move_to(3 * RIGHT) + + self.add(randy) + self.wait() + self.play(Rotate(randy, 30 * DEGREES, run_time=3)) + self.wait() + self.play(Rotate(randy, -30 * DEGREES)) + + self.add(plane, randy) + self.play( + ShowCreation(plane), + randy.set_opacity, 0.75, + ) + self.wait() + + dots = VGroup() + for mob in randy.family_members_with_points(): + for point in mob.get_anchors(): + dot = Dot(point) + dot.set_height(0.05) + dots.add(dot) + + self.play(ShowIncreasingSubsets(dots)) + self.wait() + + label = VGroup( + TexMobject("(x + iy)"), + Vector(DOWN), + TexMobject("(\\cos(30^\\circ) + i\\sin(30^\\circ))", "(x + iy)"), + ) + label[2][0].set_color(YELLOW) + label.arrange(DOWN) + label.to_corner(DR) + label.shift(3 * RIGHT) + + for mob in label: + mob.add_background_rectangle() + + self.play(FadeIn(label)) + self.wait() + + randy.add(dots) + self.play(Rotate(randy, 30 * DEGREES), run_time=3) + self.wait() + + +class ExpMeaning(Scene): + CONFIG = { + "include_circle": True + } + + def construct(self): + # Plane + plane = ComplexPlane(y_min=-6, y_max=6) + plane.shift(1.5 * DOWN) + plane.add_coordinates() + if self.include_circle: + circle = Circle(radius=1) + circle.set_stroke(RED, 1) + circle.move_to(plane.n2p(0)) + plane.add(circle) + + # Equation + equation = TexMobject( + "\\text{exp}(i\\theta) = ", + "1 + ", + "i\\theta + ", + "{(i\\theta)^2 \\over 2} + ", + "{(i\\theta)^3 \\over 6} + ", + "{(i\\theta)^4 \\over 24} + ", + "\\cdots", + tex_to_color_map={ + "\\theta": YELLOW, + "i": GREEN, + }, + ) + equation.add_background_rectangle(buff=MED_SMALL_BUFF, opacity=1) + equation.to_edge(UL, buff=0) + + # Label + theta_tracker = ValueTracker(0) + theta_label = VGroup( + TexMobject("\\theta = "), + DecimalNumber(0, num_decimal_places=4) + ) + theta_decimal = theta_label[1] + theta_decimal.add_updater( + lambda m, tt=theta_tracker: m.set_value(tt.get_value()) + ) + theta_label.arrange(RIGHT, buff=SMALL_BUFF) + theta_label.set_color(YELLOW) + theta_label.add_to_back(BackgroundRectangle( + theta_label, + buff=MED_SMALL_BUFF, + fill_opacity=1, + )) + theta_label.next_to(equation, DOWN, aligned_edge=LEFT, buff=0) + + # Vectors + def get_vectors(n_vectors=20, plane=plane, tracker=theta_tracker): + last_tip = plane.n2p(0) + z = complex(0, tracker.get_value()) + vects = VGroup() + colors = color_gradient([GREEN, YELLOW, RED], 6) + for i, color in zip(range(n_vectors), it.cycle(colors)): + vect = Vector(complex_to_R3(z**i / math.factorial(i))) + vect.set_color(color) + vect.shift(last_tip) + last_tip = vect.get_end() + vects.add(vect) + return vects + + vectors = always_redraw(get_vectors) + dot = Dot() + dot.set_height(0.03) + dot.add_updater(lambda m, vs=vectors: m.move_to(vs[-1].get_end())) + + self.add(plane) + self.add(vectors) + self.add(dot) + self.add(equation) + self.add(theta_label) + + self.play( + theta_tracker.set_value, 1, + run_time=3, + rate_func=smooth, + ) + self.wait() + for target in PI, TAU: + self.play( + theta_tracker.set_value, target, + run_time=10, + ) + self.wait() + + self.embed() + + +class ExpMeaningWithoutCircle(ExpMeaning): + CONFIG = { + "include_circle": False, + } + + +class PositionAndVelocityExample(Scene): + def construct(self): + plane = NumberPlane() + + self.add(plane) + + self.embed() + + +class EulersFormula(Scene): + def construct(self): + kw = {"tex_to_color_map": {"\\theta": YELLOW}} + formula = TexMobject( + "&e^{i\\theta} = \\\\ &\\cos\\left(\\theta\\right) + i\\cdot\\sin\\left(\\theta\\right)", + )[0] + formula[:4].scale(2, about_edge=UL) + formula[:4].shift(SMALL_BUFF * RIGHT + MED_LARGE_BUFF * UP) + VGroup(formula[2], formula[8], formula[17]).set_color(YELLOW) + formula.scale(1.5) + formula.set_stroke(BLACK, 5, background=True) + self.add(formula) + + +class EtoILimit(Scene): + def construct(self): + tex = TexMobject( + "\\lim_{n \\to \\infty} \\left(1 + \\frac{it}{n}\\right)^n", + )[0] + VGroup(tex[3], tex[12], tex[14]).set_color(YELLOW) + tex[9].set_color(BLUE) + tex.scale(1.5) + tex.set_stroke(BLACK, 5, background=True) + # self.add(tex) + + text = TextMobject("Interest rate\\\\of ", "$\\sqrt{-1}$") + text[1].set_color(BLUE) + text.scale(1.5) + text.set_stroke(BLACK, 5, background=True) + self.add(text) + + +class ImaginaryInterestRates(Scene): + def construct(self): + plane = ComplexPlane(x_min=-20, x_max=20, y_min=-20, y_max=20) + plane.add_coordinates() + circle = Circle(radius=1) + circle.set_stroke(YELLOW, 1) + self.add(plane, circle) + + frame = self.camera.frame + frame.save_state() + frame.generate_target() + frame.target.set_width(25) + frame.target.move_to(8 * RIGHT + 2 * DOWN) + + dt_tracker = ValueTracker(1) + + def get_vectors(tracker=dt_tracker, plane=plane, T=8): + dt = tracker.get_value() + last_z = 1 + vects = VGroup() + for t in np.arange(0, T, dt): + next_z = last_z + complex(0, 1) * last_z * dt + vects.add(Arrow( + plane.n2p(last_z), + plane.n2p(next_z), + buff=0, + )) + last_z = next_z + vects.set_submobject_colors_by_gradient(YELLOW, GREEN, BLUE) + return vects + + vects = get_vectors() + + line = Line() + line.add_updater(lambda m, v=vects: m.put_start_and_end_on( + ORIGIN, v[-1].get_start() if len(v) > 0 else RIGHT, + )) + + self.add(line) + self.play( + ShowIncreasingSubsets( + vects, + rate_func=linear, + int_func=np.ceil, + ), + MoveToTarget( + frame, + rate_func=squish_rate_func(smooth, 0.5, 1), + ), + run_time=8, + ) + self.wait() + self.play(FadeOut(line)) + + self.remove(vects) + vects = always_redraw(get_vectors) + self.add(vects) + self.play( + Restore(frame), + dt_tracker.set_value, 0.2, + run_time=5, + ) + self.wait() + self.play(dt_tracker.set_value, 0.01, run_time=3) + vects.clear_updaters() + self.wait() + + theta_tracker = ValueTracker(0) + + def get_arc(tracker=theta_tracker): + theta = tracker.get_value() + arc = Arc( + radius=1, + stroke_width=3, + stroke_color=RED, + start_angle=0, + angle=theta + ) + return arc + + arc = always_redraw(get_arc) + dot = Dot() + dot.add_updater(lambda m, arc=arc: m.move_to(arc.get_end())) + + label = VGroup( + DecimalNumber(0, num_decimal_places=3), + TextMobject("Years") + ) + label.arrange(RIGHT, aligned_edge=DOWN) + label.move_to(3 * LEFT + 1.5 * UP) + + label[0].set_color(RED) + label[0].add_updater(lambda m, tt=theta_tracker: m.set_value(tt.get_value())) + + self.add(BackgroundRectangle(label), label, arc, dot) + for n in range(1, 5): + target = n * PI / 2 + self.play( + theta_tracker.set_value, target, + run_time=3 + ) + self.wait(2) + + +class Logs(Scene): + def construct(self): + log = TexMobject( + "&\\text{log}(ab) = \\\\ &\\text{log}(a) + \\text{log}(b)", + tex_to_color_map={"a": BLUE, "b": YELLOW}, + alignment="", + ) + + log.scale(1.5) + log.set_stroke(BLACK, 5, background=True) + + self.add(log) + + +class LnX(Scene): + def construct(self): + sym = TexMobject("\\ln(x)") + sym.scale(3) + sym.shift(UP) + sym.set_stroke(BLACK, 5, background=True) + + word = TextMobject("Natural?") + word.scale(1.5) + word.set_color(YELLOW) + word.set_stroke(BLACK, 5, background=True) + word.next_to(sym, DOWN, buff=0.5) + arrow = Arrow(word.get_top(), sym[0][1].get_bottom()) + + self.add(sym, word, arrow) + + +class HarmonicSum(Scene): + def construct(self): + axes = Axes( + x_min=0, + x_max=13, + y_min=0, + y_max=1.25, + y_axis_config={ + "unit_size": 4, + "tick_frequency": 0.25, + } + ) + axes.to_corner(DL, buff=1) + axes.x_axis.add_numbers() + axes.y_axis.add_numbers( + *np.arange(0.25, 1.25, 0.25), + number_config={"num_decimal_places": 2}, + ) + self.add(axes) + + graph = axes.get_graph(lambda x: 1 / x, x_min=0.1, x_max=15) + graph_fill = graph.copy() + graph_fill.add_line_to(axes.c2p(15, 0)) + graph_fill.add_line_to(axes.c2p(1, 0)) + graph_fill.add_line_to(axes.c2p(1, 1)) + graph.set_stroke(WHITE, 3) + graph_fill.set_fill(BLUE_E, 0.5) + graph_fill.set_stroke(width=0) + self.add(graph, graph_fill) + + bars = VGroup() + bar_labels = VGroup() + for x in range(1, 15): + line = Line(axes.c2p(x, 0), axes.c2p(x + 1, 1 / x)) + bar = Rectangle() + bar.set_fill(GREEN_E, 1) + bar.replace(line, stretch=True) + bars.add(bar) + + label = TexMobject(f"1 \\over {x}") + label.set_height(0.7) + label.next_to(bar, UP, SMALL_BUFF) + bar_labels.add(label) + + bars.set_submobject_colors_by_gradient(GREEN_C, GREEN_E) + bars.set_stroke(WHITE, 1) + bars.set_fill(opacity=0.25) + + self.add(bars) + self.add(bar_labels) + + + self.embed() + + +class PowerTower(Scene): + def construct(self): + mob = TexMobject("4 = x^{x^{{x^{x^{x^{\cdot^{\cdot^{\cdot}}}}}}}}") + mob[0][-1].shift(0.1 * DL) + mob[0][-2].shift(0.05 * DL) + + mob.set_height(4) + mob.set_stroke(BLACK, 5, background=True) + + self.add(mob) + + +class ItoTheI(Scene): + def construct(self): + tex = TexMobject("i^i") + # tex = TexMobject("\\sqrt{-1}^{\\sqrt{-1}}") + tex.set_height(3) + tex.set_stroke(BLACK, 8, background=True) + self.add(tex) + + +class ComplexExponentialPlay(Scene): + def setup(self): + self.transform_alpha = 0 + + def construct(self): + # Plane + plane = ComplexPlane( + x_min=-2 * FRAME_WIDTH, + x_max=2 * FRAME_WIDTH, + y_min=-2 * FRAME_HEIGHT, + y_max=2 * FRAME_HEIGHT, + ) + plane.add_coordinates() + self.add(plane) + + # R Dot + r_dot = Dot(color=YELLOW) + + def update_r_dot(dot, point_tracker=self.mouse_drag_point): + point = point_tracker.get_location() + if abs(point[0]) < 0.1: + point[0] = 0 + if abs(point[1]) < 0.1: + point[1] = 0 + dot.move_to(point) + + r_dot.add_updater(update_r_dot) + self.mouse_drag_point.move_to(plane.n2p(1)) + + # Transformed sample dots + def func(z, dot=r_dot, plane=plane): + r = plane.p2n(dot.get_center()) + result = np.exp(r * z) + if abs(result) > 20: + result *= 20 / abs(result) + return result + + sample_dots = VGroup() + dot_template = Dot(radius=0.05) + dot_template.set_opacity(0.8) + spacing = 0.05 + for x in np.arange(-7, 7, spacing): + dot = dot_template.copy() + dot.set_color(TEAL) + dot.z = x + dot.move_to(plane.n2p(dot.z)) + sample_dots.add(dot) + for y in np.arange(-6, 6, spacing): + dot = dot_template.copy() + dot.set_color(MAROON) + dot.z = complex(0, y) + dot.move_to(plane.n2p(dot.z)) + sample_dots.add(dot) + + special_values = [1, complex(0, 1), -1, complex(0, -1)] + special_dots = VGroup(*[ + list(filter(lambda d: abs(d.z - x) < 0.01, sample_dots))[0] + for x in special_values + ]) + for dot in special_dots: + dot.set_fill(opacity=1) + dot.scale(1.2) + dot.set_stroke(WHITE, 2) + + sample_dots.save_state() + + def update_sample(sample, f=func, plane=plane, scene=self): + sample.restore() + sample.apply_function_to_submobject_positions( + lambda p: interpolate( + p, + plane.n2p(f(plane.p2n(p))), + scene.transform_alpha, + ) + ) + return sample + + sample_dots.add_updater(update_sample) + + # Sample lines + x_line = Line(plane.n2p(plane.x_min), plane.n2p(plane.x_max)) + y_line = Line(plane.n2p(plane.y_min), plane.n2p(plane.y_max)) + y_line.rotate(90 * DEGREES) + x_line.set_color(GREEN) + y_line.set_color(PINK) + axis_lines = VGroup(x_line, y_line) + for line in axis_lines: + line.insert_n_curves(50) + axis_lines.save_state() + + def update_axis_liens(lines=axis_lines, f=func, plane=plane, scene=self): + lines.restore() + lines.apply_function( + lambda p: interpolate( + p, + plane.n2p(f(plane.p2n(p))), + scene.transform_alpha, + ) + ) + lines.make_smooth() + + axis_lines.add_updater(update_axis_liens) + + # Labels + labels = VGroup( + TexMobject("f(1)"), + TexMobject("f(i)"), + TexMobject("f(-1)"), + TexMobject("f(-i)"), + ) + for label, dot in zip(labels, special_dots): + label.set_height(0.3) + label.match_color(dot) + label.set_stroke(BLACK, 3, background=True) + label.add_background_rectangle(opacity=0.5) + + def update_labels(labels, dots=special_dots, scene=self): + for label, dot in zip(labels, dots): + label.next_to(dot, UR, 0.5 * SMALL_BUFF) + label.set_opacity(self.transform_alpha) + + labels.add_updater(update_labels) + + # Titles + title = TexMobject( + "f(x) =", "\\text{exp}(r\\cdot x)", + tex_to_color_map={"r": YELLOW} + ) + title.to_corner(UL) + title.set_stroke(BLACK, 5, background=True) + brace = Brace(title[1:], UP, buff=SMALL_BUFF) + e_pow = TexMobject("e^{rx}", tex_to_color_map={"r": YELLOW}) + e_pow.add_background_rectangle() + e_pow.next_to(brace, UP, buff=SMALL_BUFF) + title.add(brace, e_pow) + + r_eq = VGroup( + TexMobject("r=", tex_to_color_map={"r": YELLOW}), + DecimalNumber(1) + ) + r_eq.arrange(RIGHT, aligned_edge=DOWN) + r_eq.next_to(title, DOWN, aligned_edge=LEFT) + r_eq[0].set_stroke(BLACK, 5, background=True) + r_eq[1].set_color(YELLOW) + r_eq[1].add_updater(lambda m: m.set_value(plane.p2n(r_dot.get_center()))) + + self.add(title) + self.add(r_eq) + + # self.add(axis_lines) + self.add(sample_dots) + self.add(r_dot) + self.add(labels) + + # Animations + def update_transform_alpha(mob, alpha, scene=self): + scene.transform_alpha = alpha + + frame = self.camera.frame + frame.set_height(10) + r_dot.clear_updaters() + r_dot.move_to(plane.n2p(1)) + + self.play( + UpdateFromAlphaFunc( + VectorizedPoint(), + update_transform_alpha, + ) + ) + self.play(r_dot.move_to, plane.n2p(2)) + self.wait() + self.play(r_dot.move_to, plane.n2p(PI)) + self.wait() + self.play(r_dot.move_to, plane.n2p(np.log(2))) + self.wait() + self.play(r_dot.move_to, plane.n2p(complex(0, np.log(2))), path_arc=90 * DEGREES, run_time=2) + self.wait() + self.play(r_dot.move_to, plane.n2p(complex(0, PI / 2))) + self.wait() + self.play(r_dot.move_to, plane.n2p(np.log(2)), run_time=2) + self.wait() + self.play(frame.set_height, 14) + self.play(r_dot.move_to, plane.n2p(complex(np.log(2), PI)), run_time=3) + self.wait() + self.play(r_dot.move_to, plane.n2p(complex(np.log(2), TAU)), run_time=3) + self.wait() + + self.embed() + + def on_mouse_scroll(self, point, offset): + frame = self.camera.frame + if self.zoom_on_scroll: + factor = 1 + np.arctan(10 * offset[1]) + frame.scale(factor, about_point=ORIGIN) + else: + self.transform_alpha = clip(self.transform_alpha + 5 * offset[1], 0, 1) + + +class LDMEndScreen(PatreonEndScreen): + CONFIG = { + "scroll_time": 20, + "specific_patrons": [ + "1stViewMaths", + "Adam Dřínek", + "Adam Margulies", + "Aidan Shenkman", + "Alan Stein", + "Alex Mijalis", + "Alexander Mai", + "Alexis Olson", + "Ali Yahya", + "Andreas Snekloth Kongsgaard", + "Andrew Busey", + "Andrew Cary", + "Andrew R. Whalley", + "Anthony Losego", + "Aravind C V", + "Arjun Chakroborty", + "Arthur Zey", + "Ashwin Siddarth", + "Augustine Lim", + "Austin Goodman", + "Avi Finkel", + "Awoo", + "Axel Ericsson", + "Ayan Doss", + "AZsorcerer", + "Barry Fam", + "Bartosz Burclaf", + "Ben Delo", + "Bernd Sing", + "Bill Gatliff", + "Bob Sanderson", + "Boris Veselinovich", + "Bradley Pirtle", + "Brandon Huang", + "Brendan Shah", + "Brian Cloutier", + "Brian Staroselsky", + "Britt Selvitelle", + "Britton Finley", + "Burt Humburg", + "Calvin Lin", + "Charles Southerland", + "Charlie N", + "Chenna Kautilya", + "Chris Connett", + "Chris Druta", + "Christian Kaiser", + "cinterloper", + "Clark Gaebel", + "Colwyn Fritze-Moor", + "Cooper Jones", + "Corey Ogburn", + "D. Sivakumar", + "Dan Herbatschek", + "Daniel Herrera C", + "Darrell Thomas", + "Dave B", + "Dave Cole", + "Dave Kester", + "dave nicponski", + "David B. Hill", + "David Clark", + "David Gow", + "Delton Ding", + "Eduardo Rodriguez", + "Emilio Mendoza Palafox", + "emptymachine", + "Eric Younge", + "Eryq Ouithaqueue", + "Federico Lebron", + "Fernando Via Canel", + "Frank R. Brown, Jr.", + "Giovanni Filippi", + "Goodwine", + "Hal Hildebrand", + "Heptonion", + "Hitoshi Yamauchi", + "Ivan Sorokin", + "Jacob Baxter", + "Jacob Harmon", + "Jacob Hartmann", + "Jacob Magnuson", + "Jalex Stark", + "Jameel Syed", + "James Beall", + "Jason Hise", + "Jayne Gabriele", + "Jean-Manuel Izaret", + "Jeff Dodds", + "Jeff Linse", + "Jeff Straathof", + "Jeffrey Wolberg", + "Jimmy Yang", + "Joe Pregracke", + "Johan Auster", + "John C. Vesey", + "John Camp", + "John Haley", + "John Le", + "John Luttig", + "John Rizzo", + "John V Wertheim", + "Jonathan Heckerman", + "Jonathan Wilson", + "Joseph John Cox", + "Joseph Kelly", + "Josh Kinnear", + "Joshua Claeys", + "Joshua Ouellette", + "Juan Benet", + "Julien Dubois", + "Kai-Siang Ang", + "Kanan Gill", + "Karl Niu", + "Kartik Cating-Subramanian", + "Kaustuv DeBiswas", + "Killian McGuinness", + "kkm", + "Klaas Moerman", + "Kristoffer Börebäck", + "Kros Dai", + "L0j1k", + "Lael S Costa", + "LAI Oscar", + "Lambda GPU Workstations", + "Laura Gast", + "Lee Redden", + "Linh Tran", + "Luc Ritchie", + "Ludwig Schubert", + "Lukas Biewald", + "Lukas Zenick", + "Magister Mugit", + "Magnus Dahlström", + "Magnus Hiie", + "Manoj Rewatkar", + "Mark B Bahu", + "Mark Heising", + "Mark Hopkins", + "Mark Mann", + "Martin Price", + "Mathias Jansson", + "Matt Godbolt", + "Matt Langford", + "Matt Roveto", + "Matt Russell", + "Matteo Delabre", + "Matthew Bouchard", + "Matthew Cocke", + "Maxim Nitsche", + "Michael Bos", + "Michael Day", + "Michael Hardel", + "Michael W White", + "Mihran Vardanyan", + "Mirik Gogri", + "Molly Mackinlay", + "Mustafa Mahdi", + "Márton Vaitkus", + "Nate Heckmann", + "Nero Li", + "Nicholas Cahill", + "Nikita Lesnikov", + "Oleg Leonov", + "Oliver Steele", + "Omar Zrien", + "Omer Tuchfeld", + "Patrick Lucas", + "Pavel Dubov", + "Pesho Ivanov", + "Petar Veličković", + "Peter Ehrnstrom", + "Peter Francis", + "Peter Mcinerney", + "Pierre Lancien", + "Pradeep Gollakota", + "Rafael Bove Barrios", + "Raghavendra Kotikalapudi", + "Randy C. Will", + "rehmi post", + "Rex Godby", + "Ripta Pasay", + "Rish Kundalia", + "Roman Sergeychik", + "Roobie", + "Ryan Atallah", + "Samuel Judge", + "SansWord Huang", + "Scott Gray", + "Scott Walter, Ph.D.", + "soekul", + "Solara570", + "Stephen Shanahan", + "Steve Huynh", + "Steve Muench", + "Steve Sperandeo", + "Steven Siddals", + "Stevie Metke", + "Sundar Subbarayan", + "Sunil Nagaraj", + "supershabam", + "Suteerth Vishnu", + "Suthen Thomas", + "Tal Einav", + "Taras Bobrovytsky", + "Tauba Auerbach", + "Ted Suzman", + "Thomas J Sargent", + "Thomas Tarler", + "Tianyu Ge", + "Tihan Seale", + "Tim Erbes", + "Tim Kazik", + "Tomasz Legutko", + "Tyler Herrmann", + "Tyler Parcell", + "Tyler VanValkenburg", + "Tyler Veness", + "Vassili Philippov", + "Vasu Dubey", + "Veritasium", + "Vignesh Ganapathi Subramanian", + "Vinicius Reis", + "Vladimir Solomatin", + "Wooyong Ee", + "Xuanji Li", + "Yana Chernobilsky", + "Yavor Ivanov", + "YinYangBalance.Asia", + "Yu Jun", + "Yurii Monastyrshyn", + ], + } diff --git a/manimlib/camera/camera.py b/manimlib/camera/camera.py index 8603f78b..9676b67f 100644 --- a/manimlib/camera/camera.py +++ b/manimlib/camera/camera.py @@ -1,5 +1,3 @@ -from functools import reduce -import operator as op import moderngl from colour import Color @@ -9,39 +7,127 @@ import itertools as it from manimlib.constants import * from manimlib.mobject.mobject import Mobject +from manimlib.mobject.mobject import Point from manimlib.utils.config_ops import digest_config from manimlib.utils.iterables import batch_by_property from manimlib.utils.simple_functions import fdiv from manimlib.utils.shaders import shader_info_to_id from manimlib.utils.shaders import shader_id_to_info from manimlib.utils.shaders import get_shader_code_from_file +from manimlib.utils.simple_functions import clip +from manimlib.utils.space_ops import angle_of_vector +from manimlib.utils.space_ops import rotation_matrix_transpose_from_quaternion +from manimlib.utils.space_ops import rotation_matrix_transpose +from manimlib.utils.space_ops import quaternion_from_angle_axis +from manimlib.utils.space_ops import quaternion_mult -# TODO, think about how to incorporate perspective, -# and change get_height, etc. to take orientation into account class CameraFrame(Mobject): CONFIG = { "width": FRAME_WIDTH, "height": FRAME_HEIGHT, - "center": ORIGIN, + "center_point": ORIGIN, + # Theta, phi, gamma + "euler_angles": [0, 0, 0], + "focal_distance": 5, } def init_points(self): - self.points = np.array([UL, UR, DR, DL]) - self.set_width(self.width, stretch=True) - self.set_height(self.height, stretch=True) - self.move_to(self.center) - self.save_state() + self.points = np.array([self.center_point]) + self.euler_angles = np.array(self.euler_angles, dtype='float64') + + def to_default_state(self): + self.center() + self.set_height(FRAME_HEIGHT) + self.set_width(FRAME_WIDTH) + self.set_rotation(0, 0, 0) + return self + + def get_inverse_camera_position_matrix(self): + result = np.identity(4) + # First shift so that origin of real space coincides with camera origin + result[:3, 3] = -self.get_center().T + # Rotate based on camera orientation + result[:3, :3] = np.dot(self.get_inverse_camera_rotation_matrix(), result[:3, :3]) + # Scale to have height 2 (matching the height of the box [-1, 1]^2) + result *= 2 / self.height + return result + + def get_inverse_camera_rotation_matrix(self): + theta, phi, gamma = self.euler_angles + quat = quaternion_mult( + quaternion_from_angle_axis(theta, OUT), + quaternion_from_angle_axis(phi, RIGHT), + quaternion_from_angle_axis(gamma, OUT), + ) + return rotation_matrix_transpose_from_quaternion(quat) + + def rotate(self, angle, axis=OUT, **kwargs): + curr_rot_T = self.get_inverse_camera_rotation_matrix() + added_rot_T = rotation_matrix_transpose(angle, axis) + new_rot_T = np.dot(curr_rot_T, added_rot_T) + Fz = new_rot_T[2] + phi = np.arccos(Fz[2]) + theta = angle_of_vector(Fz[:2]) + PI / 2 + partial_rot_T = np.dot( + rotation_matrix_transpose(phi, RIGHT), + rotation_matrix_transpose(theta, OUT), + ) + gamma = angle_of_vector(np.dot(partial_rot_T, new_rot_T.T)[:, 0]) + # TODO, write a function that converts quaternions to euler angles + self.euler_angles[:] = theta, phi, gamma + return self + + def set_rotation(self, theta=0, phi=0, gamma=0): + self.euler_angles[:] = theta, phi, gamma + return self + + def increment_theta(self, dtheta): + self.euler_angles[0] += dtheta + return self + + def increment_phi(self, dphi): + self.euler_angles[1] = clip(self.euler_angles[1] + dphi, 0, PI) + return self + + def increment_gamma(self, dgamma): + self.euler_angles[2] += dgamma + return self + + def scale(self, scale_factor, **kwargs): + # TODO, handle about_point and about_edge? + self.height *= scale_factor + self.width *= scale_factor + return self + + def set_height(self, height): + self.height = height + return self + + def set_width(self, width): + self.width = width + return self + + def get_height(self): + return self.height + + def get_width(self): + return self.width + + def get_center(self): + return self.points[0] + + def get_focal_distance(self): + return self.focal_distance + + def interpolate(self, mobject1, mobject2, alpha, **kwargs): + pass class Camera(object): CONFIG = { "background_image": None, - "frame_config": { - "width": FRAME_WIDTH, - "height": FRAME_HEIGHT, - "center": ORIGIN, - }, + "frame_config": {}, "pixel_height": DEFAULT_PIXEL_HEIGHT, "pixel_width": DEFAULT_PIXEL_WIDTH, "frame_rate": DEFAULT_FRAME_RATE, # TODO, move this elsewhere @@ -55,7 +141,8 @@ class Camera(object): "image_mode": "RGBA", "n_channels": 4, "pixel_array_dtype": 'uint8', - "line_width_multiple": 0.01, + "light_source_position": [-10, 10, 10], # TODO, add multiple light sources + "apply_depth_test": False, } def __init__(self, ctx=None, **kwargs): @@ -65,6 +152,7 @@ class Camera(object): self.init_context(ctx) self.init_shaders() self.init_textures() + self.init_light_source() def init_frame(self): self.frame = CameraFrame(**self.frame_config) @@ -78,13 +166,19 @@ class Camera(object): self.fbo = self.get_fbo() self.fbo.use() - self.ctx.enable(moderngl.BLEND) - self.ctx.blend_func = ( + flag = moderngl.BLEND + if self.apply_depth_test: + flag |= moderngl.DEPTH_TEST + self.ctx.enable(flag) + self.ctx.blend_func = ( # Needed? moderngl.SRC_ALPHA, moderngl.ONE_MINUS_SRC_ALPHA, moderngl.ONE, moderngl.ONE ) self.background_fbo = None + def init_light_source(self): + self.light_source = Point(self.light_source_position) + # Methods associated with the frame buffer def get_fbo(self): return self.ctx.simple_framebuffer( @@ -107,8 +201,8 @@ class Camera(object): frame_height = frame_width / aspect_ratio else: frame_width = aspect_ratio * frame_height - self.set_frame_height(frame_height) - self.set_frame_width(frame_width) + self.frame.set_height(frame_height) + self.frame.set_width(frame_width) def clear(self): rgba = (*Color(self.background_color).get_rgb(), self.background_opacity) @@ -163,7 +257,6 @@ class Camera(object): def get_pixel_height(self): return self.get_pixel_shape()[1] - # TODO, make these work for a rotated frame def get_frame_height(self): return self.frame.get_height() @@ -176,17 +269,12 @@ class Camera(object): def get_frame_center(self): return self.frame.get_center() - def set_frame_height(self, height): - self.frame.set_height(height, stretch=True) - - def set_frame_width(self, width): - self.frame.set_width(width, stretch=True) - - def set_frame_center(self, center): - self.frame.move_to(center) - def pixel_coords_to_space_coords(self, px, py, relative=False): - pw, ph = self.fbo.size + # pw, ph = self.fbo.size + # Bad hack, not sure why this is needed. + pw, ph = self.get_pixel_shape() + pw //= 2 + ph //= 2 fw, fh = self.get_frame_shape() fc = self.get_frame_center() if relative: @@ -196,19 +284,6 @@ class Camera(object): scale = fh / ph return fc + scale * np.array([(px - pw / 2), (py - ph / 2), 0]) - # TODO, account for 3d - # Also, move this to CameraFrame? - def is_in_frame(self, mobject): - fc = self.get_frame_center() - fh = self.get_frame_height() - fw = self.get_frame_width() - return not reduce(op.or_, [ - mobject.get_right()[0] < fc[0] - fw, - mobject.get_bottom()[1] > fc[1] + fh, - mobject.get_left()[0] > fc[0] + fw, - mobject.get_top()[1] < fc[1] - fh, - ]) - # Rendering def capture(self, *mobjects, **kwargs): self.refresh_shader_uniforms() @@ -266,15 +341,17 @@ class Camera(object): if shader is None: return # TODO, think about how uniforms come from mobjects as well. - fh = self.get_frame_height() - fc = self.get_frame_center() pw, ph = self.get_pixel_shape() + transform = self.frame.get_inverse_camera_position_matrix() + light = self.light_source.get_location() + transformed_light = np.dot(transform, [*light, 1])[:3] mapping = { - 'scale': fh / 2, # Scale based on frame size + 'to_screen_space': tuple(transform.T.flatten()), 'aspect_ratio': (pw / ph), # AR based on pixel shape - 'anti_alias_width': ANTI_ALIAS_WIDTH_OVER_FRAME_HEIGHT * fh, - 'frame_center': tuple(fc), + 'focal_distance': self.frame.get_focal_distance(), + 'anti_alias_width': 3 / ph, # 1.5 Pixel widths + 'light_source_position': tuple(transformed_light), } for key, value in mapping.items(): try: @@ -302,3 +379,8 @@ class Camera(object): texture.use(location=tid) self.path_to_texture_id[path] = tid return self.path_to_texture_id[path] + + +class ThreeDCamera(Camera): + # Purely here to keep old scenes happy + pass diff --git a/manimlib/camera/three_d_camera.py b/manimlib/camera/three_d_camera.py deleted file mode 100644 index 7cb3b954..00000000 --- a/manimlib/camera/three_d_camera.py +++ /dev/null @@ -1,232 +0,0 @@ -import numpy as np - -from manimlib.camera.camera import Camera -from manimlib.constants import * -from manimlib.mobject.three_d_utils import get_3d_vmob_end_corner -from manimlib.mobject.three_d_utils import get_3d_vmob_end_corner_unit_normal -from manimlib.mobject.three_d_utils import get_3d_vmob_start_corner -from manimlib.mobject.three_d_utils import get_3d_vmob_start_corner_unit_normal -from manimlib.mobject.types.point_cloud_mobject import Point -from manimlib.mobject.value_tracker import ValueTracker -from manimlib.utils.color import get_shaded_rgb -from manimlib.utils.simple_functions import clip_in_place -from manimlib.utils.space_ops import rotation_about_z -from manimlib.utils.space_ops import rotation_matrix - - -class ThreeDCamera(Camera): - CONFIG = { - "shading_factor": 0.2, - "distance": 20.0, - "default_distance": 5.0, - "phi": 0, # Angle off z axis - "theta": -90 * DEGREES, # Rotation about z axis - "gamma": 0, # Rotation about normal vector to camera - "light_source_start_point": 9 * DOWN + 7 * LEFT + 10 * OUT, - "frame_center": ORIGIN, - "should_apply_shading": True, - "exponential_projection": False, - "max_allowable_norm": 3 * FRAME_WIDTH, - } - - def __init__(self, *args, **kwargs): - Camera.__init__(self, *args, **kwargs) - self.phi_tracker = ValueTracker(self.phi) - self.theta_tracker = ValueTracker(self.theta) - self.distance_tracker = ValueTracker(self.distance) - self.gamma_tracker = ValueTracker(self.gamma) - self.light_source = Point(self.light_source_start_point) - self.frame_center = Point(self.frame_center) - self.fixed_orientation_mobjects = dict() - self.fixed_in_frame_mobjects = set() - self.reset_rotation_matrix() - - def capture(self, *mobjects, **kwargs): - self.reset_rotation_matrix() - Camera.capture(self, *mobjects, **kwargs) - - def get_value_trackers(self): - return [ - self.phi_tracker, - self.theta_tracker, - self.distance_tracker, - self.gamma_tracker, - ] - - def modified_rgbas(self, vmobject, rgbas): - if not self.should_apply_shading: - return rgbas - if vmobject.shade_in_3d and (vmobject.get_num_points() > 0): - light_source_point = self.light_source.points[0] - if len(rgbas) < 2: - shaded_rgbas = rgbas.repeat(2, axis=0) - else: - shaded_rgbas = np.array(rgbas[:2]) - shaded_rgbas[0, :3] = get_shaded_rgb( - shaded_rgbas[0, :3], - get_3d_vmob_start_corner(vmobject), - get_3d_vmob_start_corner_unit_normal(vmobject), - light_source_point, - ) - shaded_rgbas[1, :3] = get_shaded_rgb( - shaded_rgbas[1, :3], - get_3d_vmob_end_corner(vmobject), - get_3d_vmob_end_corner_unit_normal(vmobject), - light_source_point, - ) - return shaded_rgbas - return rgbas - - def get_stroke_rgbas(self, vmobject, background=False): - return self.modified_rgbas( - vmobject, vmobject.get_stroke_rgbas(background) - ) - - def get_fill_rgbas(self, vmobject): - return self.modified_rgbas( - vmobject, vmobject.get_fill_rgbas() - ) - - def get_mobjects_to_display(self, *args, **kwargs): - mobjects = Camera.get_mobjects_to_display( - self, *args, **kwargs - ) - rot_matrix = self.get_rotation_matrix() - - def z_key(mob): - if not (hasattr(mob, "shade_in_3d") and mob.shade_in_3d): - return np.inf - # Assign a number to a three dimensional mobjects - # based on how close it is to the camera - return np.dot( - mob.get_z_index_reference_point(), - rot_matrix.T - )[2] - return sorted(mobjects, key=z_key) - - def get_phi(self): - return self.phi_tracker.get_value() - - def get_theta(self): - return self.theta_tracker.get_value() - - def get_distance(self): - return self.distance_tracker.get_value() - - def get_gamma(self): - return self.gamma_tracker.get_value() - - def get_frame_center(self): - return self.frame_center.points[0] - - def set_phi(self, value): - self.phi_tracker.set_value(value) - - def set_theta(self, value): - self.theta_tracker.set_value(value) - - def set_distance(self, value): - self.distance_tracker.set_value(value) - - def set_gamma(self, value): - self.gamma_tracker.set_value(value) - - def set_frame_center(self, point): - self.frame_center.move_to(point) - - def reset_rotation_matrix(self): - self.rotation_matrix = self.generate_rotation_matrix() - - def get_rotation_matrix(self): - return self.rotation_matrix - - def generate_rotation_matrix(self): - phi = self.get_phi() - theta = self.get_theta() - gamma = self.get_gamma() - matrices = [ - rotation_about_z(-theta - 90 * DEGREES), - rotation_matrix(-phi, RIGHT), - rotation_about_z(gamma), - ] - result = np.identity(3) - for matrix in matrices: - result = np.dot(matrix, result) - return result - - def project_points(self, points): - frame_center = self.get_frame_center() - distance = self.get_distance() - rot_matrix = self.get_rotation_matrix() - - points = points - frame_center - points = np.dot(points, rot_matrix.T) - zs = points[:, 2] - for i in 0, 1: - if self.exponential_projection: - # Proper projedtion would involve multiplying - # x and y by d / (d-z). But for points with high - # z value that causes weird artifacts, and applying - # the exponential helps smooth it out. - factor = np.exp(zs / distance) - lt0 = zs < 0 - factor[lt0] = (distance / (distance - zs[lt0])) - else: - factor = (distance / (distance - zs)) - factor[(distance - zs) < 0] = 10**6 - # clip_in_place(factor, 0, 10**6) - points[:, i] *= factor - points = points + frame_center - return points - - def project_point(self, point): - return self.project_points(point.reshape((1, 3)))[0, :] - - def transform_points_pre_display(self, mobject, points): - points = super().transform_points_pre_display(mobject, points) - fixed_orientation = mobject in self.fixed_orientation_mobjects - fixed_in_frame = mobject in self.fixed_in_frame_mobjects - - if fixed_in_frame: - return points - if fixed_orientation: - center_func = self.fixed_orientation_mobjects[mobject] - center = center_func() - new_center = self.project_point(center) - return points + (new_center - center) - else: - return self.project_points(points) - - def add_fixed_orientation_mobjects( - self, *mobjects, - use_static_center_func=False, - center_func=None): - # This prevents the computation of mobject.get_center - # every single time a projetion happens - def get_static_center_func(mobject): - point = mobject.get_center() - return (lambda: point) - - for mobject in mobjects: - if center_func: - func = center_func - elif use_static_center_func: - func = get_static_center_func(mobject) - else: - func = mobject.get_center - for submob in mobject.get_family(): - self.fixed_orientation_mobjects[submob] = func - - def add_fixed_in_frame_mobjects(self, *mobjects): - for mobject in self.extract_mobject_family_members(mobjects): - self.fixed_in_frame_mobjects.add(mobject) - - def remove_fixed_orientation_mobjects(self, *mobjects): - for mobject in self.extract_mobject_family_members(mobjects): - if mobject in self.fixed_orientation_mobjects: - self.fixed_orientation_mobjects.remove(mobject) - - def remove_fixed_in_frame_mobjects(self, *mobjects): - for mobject in self.extract_mobject_family_members(mobjects): - if mobject in self.fixed_in_frame_mobjects: - self.fixed_in_frame_mobjects.remove(mobject) diff --git a/manimlib/constants.py b/manimlib/constants.py index 7422b13d..d6dd1770 100644 --- a/manimlib/constants.py +++ b/manimlib/constants.py @@ -154,7 +154,6 @@ DEFAULT_PIXEL_WIDTH = PRODUCTION_QUALITY_CAMERA_CONFIG["pixel_width"] DEFAULT_FRAME_RATE = 60 DEFAULT_STROKE_WIDTH = 4 -ANTI_ALIAS_WIDTH_OVER_FRAME_HEIGHT = 1e-3 FRAME_HEIGHT = 8.0 FRAME_WIDTH = FRAME_HEIGHT * DEFAULT_PIXEL_WIDTH / DEFAULT_PIXEL_HEIGHT diff --git a/manimlib/for_3b1b_videos/common_scenes.py b/manimlib/for_3b1b_videos/common_scenes.py index 87e85e30..12bc160c 100644 --- a/manimlib/for_3b1b_videos/common_scenes.py +++ b/manimlib/for_3b1b_videos/common_scenes.py @@ -155,6 +155,7 @@ class PatreonEndScreen(PatreonThanks, PiCreatureScene): "capitalize": True, "name_y_spacing": 0.6, "thanks_words": "Many thanks to this channel's supporters", + "scroll_time": 20, } def construct(self): @@ -250,7 +251,7 @@ class PatreonEndScreen(PatreonThanks, PiCreatureScene): columns.target.to_edge(DOWN, buff=4) vect = columns.target.get_center() - columns.get_center() distance = get_norm(vect) - wait_time = 20 + wait_time = self.scroll_time always_shift( columns, direction=normalize(vect), @@ -267,7 +268,8 @@ class PatreonEndScreen(PatreonThanks, PiCreatureScene): "akostrikov": "Aleksandr Kostrikov", "Jacob Baxter": "Will Fleshman", "Sansword Huang": "SansWord@TW", - "Still working on an upcoming skeptical humanist SciFi novels- Elux Luc": "Uber Miguel", + "Sunil Nagaraj": "Ubiquity Ventures", + "Nitu Kitchloo": "Ish Kitchloo", } for n1, n2 in modification_map.items(): if name.lower() == n1.lower(): @@ -345,6 +347,7 @@ class Banner(Scene): pis.set_height(self.pi_height) pis.arrange(RIGHT, aligned_edge=DOWN) pis.move_to(self.pi_bottom, DOWN) + self.pis = pis self.add(pis) if self.use_date: diff --git a/manimlib/imports.py b/manimlib/imports.py index 23c1498e..ae5a3e7a 100644 --- a/manimlib/imports.py +++ b/manimlib/imports.py @@ -30,9 +30,6 @@ from manimlib.animation.transform import * from manimlib.animation.update import * from manimlib.camera.camera import * -from manimlib.camera.mapping_camera import * -from manimlib.camera.moving_camera import * -from manimlib.camera.three_d_camera import * from manimlib.mobject.coordinate_systems import * from manimlib.mobject.changing import * @@ -50,10 +47,10 @@ from manimlib.mobject.svg.drawings import * from manimlib.mobject.svg.svg_mobject import * from manimlib.mobject.svg.tex_mobject import * from manimlib.mobject.svg.text_mobject import * -from manimlib.mobject.three_d_utils import * from manimlib.mobject.three_dimensions import * from manimlib.mobject.types.image_mobject import * from manimlib.mobject.types.point_cloud_mobject import * +from manimlib.mobject.types.surface_mobject import * from manimlib.mobject.types.vectorized_mobject import * from manimlib.mobject.mobject_update_utils import * from manimlib.mobject.value_tracker import * diff --git a/manimlib/mobject/coordinate_systems.py b/manimlib/mobject/coordinate_systems.py index e0a605b1..288b6354 100644 --- a/manimlib/mobject/coordinate_systems.py +++ b/manimlib/mobject/coordinate_systems.py @@ -84,9 +84,12 @@ class CoordinateSystem(): ) return self.axis_labels - def get_graph(self, function, **kwargs): - x_min = kwargs.pop("x_min", self.x_min) - x_max = kwargs.pop("x_max", self.x_max) + def get_graph(self, function, x_min=None, x_max=None, **kwargs): + if x_min is None: + x_min = self.x_min + if x_max is None: + x_max = self.x_max + graph = ParametricFunction( lambda t: self.coords_to_point(t, function(t)), t_min=x_min, diff --git a/manimlib/mobject/frame.py b/manimlib/mobject/frame.py index 312a3ce0..4ea2260d 100644 --- a/manimlib/mobject/frame.py +++ b/manimlib/mobject/frame.py @@ -3,9 +3,6 @@ from manimlib.mobject.geometry import Rectangle from manimlib.utils.config_ops import digest_config -# TODO, put CameraFrame in here? - - class ScreenRectangle(Rectangle): CONFIG = { "aspect_ratio": 16.0 / 9.0, diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 77957834..4a1e83f0 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -974,7 +974,7 @@ class Mobject(Container): def arrange_in_grid(self, n_rows=None, n_cols=None, **kwargs): submobs = self.submobjects if n_rows is None and n_cols is None: - n_cols = int(np.sqrt(len(submobs))) + n_rows = int(np.sqrt(len(submobs))) if n_rows is not None: v1 = RIGHT diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 2be9b758..8b6ff7d1 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -9,6 +9,9 @@ from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.utils.config_ops import digest_config +TEXT_MOB_SCALE_FACTOR = 0.05 + + class TextSetting(object): def __init__(self, start, end, font, slant, weight, line_num=-1): self.start = start @@ -45,8 +48,23 @@ class Text(SVGMobject): self.lsh = self.size if self.lsh == -1 else self.lsh file_name = self.text2svg() + self.remove_last_M(file_name) SVGMobject.__init__(self, file_name, **config) + nppc = self.n_points_per_curve + for each in self: + if len(each.points) == 0: + continue + points = each.points + last = points[0] + each.clear_points() + for index, point in enumerate(points): + each.append_points([point]) + if index != len(points) - 1 and (index + 1) % nppc == 0 and any(point != points[index+1]): + each.add_line_to(last) + last = points[index + 1] + each.add_line_to(last) + if self.t2c: self.set_color_by_t2c() if self.gradient: @@ -55,7 +73,15 @@ class Text(SVGMobject): self.set_color_by_t2g() # anti-aliasing - self.scale(0.1) + if self.height is None: + self.scale(TEXT_MOB_SCALE_FACTOR) + + def remove_last_M(self, file_name): + with open(file_name, 'r') as fpr: + content = fpr.read() + content = re.sub(r'Z M [^A-Za-z]*? "\/>', 'Z "/>', content) + with open(file_name, 'w') as fpw: + fpw.write(content) def find_indexes(self, word): m = re.match(r'\[([0-9\-]{0,}):([0-9\-]{0,})\]', word) diff --git a/manimlib/mobject/three_d_shading_utils.py b/manimlib/mobject/three_d_shading_utils.py deleted file mode 100644 index 4d8d7190..00000000 --- a/manimlib/mobject/three_d_shading_utils.py +++ /dev/null @@ -1,58 +0,0 @@ -import numpy as np - -from manimlib.constants import ORIGIN -from manimlib.utils.space_ops import get_unit_normal - - -# TODO, these ideas should be deprecated - -def get_3d_vmob_gradient_start_and_end_points(vmob): - return ( - get_3d_vmob_start_corner(vmob), - get_3d_vmob_end_corner(vmob), - ) - - -def get_3d_vmob_start_corner_index(vmob): - return 0 - - -def get_3d_vmob_end_corner_index(vmob): - return ((len(vmob.points) - 1) // 6) * 3 - - -def get_3d_vmob_start_corner(vmob): - if vmob.get_num_points() == 0: - return np.array(ORIGIN) - return vmob.points[get_3d_vmob_start_corner_index(vmob)] - - -def get_3d_vmob_end_corner(vmob): - if vmob.get_num_points() == 0: - return np.array(ORIGIN) - return vmob.points[get_3d_vmob_end_corner_index(vmob)] - - -def get_3d_vmob_unit_normal(vmob, point_index): - n_points = vmob.get_num_points() - if vmob.get_num_points() == 0: - return np.array(ORIGIN) - i = point_index - im1 = i - 1 if i > 0 else (n_points - 2) - ip1 = i + 1 if i < (n_points - 1) else 1 - return get_unit_normal( - vmob.points[ip1] - vmob.points[i], - vmob.points[im1] - vmob.points[i], - ) - - -def get_3d_vmob_start_corner_unit_normal(vmob): - return get_3d_vmob_unit_normal( - vmob, get_3d_vmob_start_corner_index(vmob) - ) - - -def get_3d_vmob_end_corner_unit_normal(vmob): - return get_3d_vmob_unit_normal( - vmob, get_3d_vmob_end_corner_index(vmob) - ) diff --git a/manimlib/mobject/three_d_utils.py b/manimlib/mobject/three_d_utils.py deleted file mode 100644 index 46e83165..00000000 --- a/manimlib/mobject/three_d_utils.py +++ /dev/null @@ -1,64 +0,0 @@ -import numpy as np - -from manimlib.constants import ORIGIN -from manimlib.constants import UP -from manimlib.utils.space_ops import get_norm -from manimlib.utils.space_ops import get_unit_normal - - -# TODO, these ideas should be deprecated - - -def get_3d_vmob_gradient_start_and_end_points(vmob): - return ( - get_3d_vmob_start_corner(vmob), - get_3d_vmob_end_corner(vmob), - ) - - -def get_3d_vmob_start_corner_index(vmob): - return 0 - - -def get_3d_vmob_end_corner_index(vmob): - return ((len(vmob.points) - 1) // 6) * 3 - - -def get_3d_vmob_start_corner(vmob): - if vmob.get_num_points() == 0: - return np.array(ORIGIN) - return vmob.points[get_3d_vmob_start_corner_index(vmob)] - - -def get_3d_vmob_end_corner(vmob): - if vmob.get_num_points() == 0: - return np.array(ORIGIN) - return vmob.points[get_3d_vmob_end_corner_index(vmob)] - - -def get_3d_vmob_unit_normal(vmob, point_index): - n_points = vmob.get_num_points() - if len(vmob.get_anchors()) <= 2: - return np.array(UP) - i = point_index - im3 = i - 3 if i > 2 else (n_points - 4) - ip3 = i + 3 if i < (n_points - 3) else 3 - unit_normal = get_unit_normal( - vmob.points[ip3] - vmob.points[i], - vmob.points[im3] - vmob.points[i], - ) - if get_norm(unit_normal) == 0: - return np.array(UP) - return unit_normal - - -def get_3d_vmob_start_corner_unit_normal(vmob): - return get_3d_vmob_unit_normal( - vmob, get_3d_vmob_start_corner_index(vmob) - ) - - -def get_3d_vmob_end_corner_unit_normal(vmob): - return get_3d_vmob_unit_normal( - vmob, get_3d_vmob_end_corner_index(vmob) - ) diff --git a/manimlib/mobject/three_dimensions.py b/manimlib/mobject/three_dimensions.py index 7cf11163..8e44d7c9 100644 --- a/manimlib/mobject/three_dimensions.py +++ b/manimlib/mobject/three_dimensions.py @@ -1,29 +1,14 @@ from manimlib.constants import * from manimlib.mobject.geometry import Square +from manimlib.mobject.types.surface_mobject import SurfaceMobject from manimlib.mobject.types.vectorized_mobject import VGroup -from manimlib.mobject.types.vectorized_mobject import VMobject -from manimlib.utils.iterables import listify -from manimlib.utils.space_ops import z_to_vector - -############## -# TODO, replace these with a special 3d type, not VMobject - - -class ThreeDVMobject(VMobject): +class ParametricSurface(SurfaceMobject): CONFIG = { - "shade_in_3d": True, - } - - -class ParametricSurface(VGroup): - CONFIG = { - "u_min": 0, - "u_max": 1, - "v_min": 0, - "v_max": 1, - "resolution": 32, + "u_range": (0, 1), + "v_range": (0, 1), + "resolution": (32, 32), "surface_piece_config": {}, "fill_color": BLUE_D, "fill_opacity": 1.0, @@ -34,93 +19,72 @@ class ParametricSurface(VGroup): "pre_function_handle_to_anchor_scale_factor": 0.00001, } - def __init__(self, func, **kwargs): - VGroup.__init__(self, **kwargs) - self.func = func - self.setup_in_uv_space() - self.apply_function(lambda p: func(p[0], p[1])) - if self.should_make_jagged: - self.make_jagged() - - def get_u_values_and_v_values(self): - res = listify(self.resolution) - if len(res) == 1: - u_res = v_res = res[0] + def __init__(self, function=None, **kwargs): + if function is None: + self.uv_func = self.func else: - u_res, v_res = res - u_min = self.u_min - u_max = self.u_max - v_min = self.v_min - v_max = self.v_max + self.uv_func = function + super().__init__(**kwargs) - u_values = np.linspace(u_min, u_max, u_res + 1) - v_values = np.linspace(v_min, v_max, v_res + 1) - - return u_values, v_values - - def setup_in_uv_space(self): - u_values, v_values = self.get_u_values_and_v_values() - faces = VGroup() - for i in range(len(u_values) - 1): - for j in range(len(v_values) - 1): - u1, u2 = u_values[i:i + 2] - v1, v2 = v_values[j:j + 2] - face = ThreeDVMobject() - face.set_points_as_corners([ - [u1, v1, 0], - [u2, v1, 0], - [u2, v2, 0], - [u1, v2, 0], - [u1, v1, 0], - ]) - faces.add(face) - face.u_index = i - face.v_index = j - face.u1 = u1 - face.u2 = u2 - face.v1 = v1 - face.v2 = v2 - faces.set_fill( - color=self.fill_color, - opacity=self.fill_opacity + def init_points(self): + epsilon = 1e-6 # For differentials + nu, nv = self.resolution + u_range = np.linspace(*self.u_range, nu + 1) + v_range = np.linspace(*self.v_range, nv + 1) + # List of three grids, [Pure uv values, those nudged by du, those nudged by dv] + uv_grids = [ + np.array([[[u, v] for v in v_range] for u in u_range]) + for (du, dv) in [(0, 0), (epsilon, 0), (0, epsilon)] + ] + point_grid, points_nudged_du, points_nudged_dv = [ + np.apply_along_axis(lambda p: self.uv_func(*p), 2, uv_grid) + for uv_grid in uv_grids + ] + normal_grid = np.cross( + (points_nudged_du - point_grid) / epsilon, + (points_nudged_dv - point_grid) / epsilon, ) - faces.set_stroke( - color=self.stroke_color, - width=self.stroke_width, - opacity=self.stroke_opacity, + + self.set_points( + self.get_triangle_ready_array_from_grid(point_grid), + self.get_triangle_ready_array_from_grid(normal_grid), ) - self.add(*faces) - if self.checkerboard_colors: - self.set_fill_by_checkerboard(*self.checkerboard_colors) - def set_fill_by_checkerboard(self, *colors, opacity=None): - n_colors = len(colors) - for face in self: - c_index = (face.u_index + face.v_index) % n_colors - face.set_fill(colors[c_index], opacity=opacity) + # self.points = point_grid[indices] + + def get_triangle_ready_array_from_grid(self, grid): + # Given a grid, say of points or normals, this returns an Nx3 array + # whose rows are elements from this grid in such such a way that successive + # triplets of points form triangles covering the grid. + nu = grid.shape[0] - 1 + nv = grid.shape[1] - 1 + dim = grid.shape[2] + arr = np.zeros((nu * nv * 6, dim)) + # To match the triangles covering this surface + arr[0::6] = grid[:-1, :-1].reshape((nu * nv, 3)) # Top left + arr[1::6] = grid[+1:, :-1].reshape((nu * nv, 3)) # Bottom left + arr[2::6] = grid[:-1, +1:].reshape((nu * nv, 3)) # Top right + arr[3::6] = grid[:-1, +1:].reshape((nu * nv, 3)) # Top right + arr[4::6] = grid[+1:, :-1].reshape((nu * nv, 3)) # Bottom left + arr[5::6] = grid[+1:, +1:].reshape((nu * nv, 3)) # Bottom right + return arr + + def func(self, u, v): + pass -# Specific shapes - +# Sphere, cylinder, cube, prism class Sphere(ParametricSurface): CONFIG = { "resolution": (12, 24), "radius": 1, - "u_min": 0.001, - "u_max": PI - 0.001, - "v_min": 0, - "v_max": TAU, + "u_range": (0, PI), + "v_range": (0, TAU), } - def __init__(self, **kwargs): - ParametricSurface.__init__( - self, self.func, **kwargs - ) - self.scale(self.radius) - def func(self, u, v): - return np.array([ + return self.radius * np.array([ np.cos(v) * np.sin(u), np.sin(v) * np.sin(u), np.cos(u) diff --git a/manimlib/mobject/types/surface_mobject.py b/manimlib/mobject/types/surface_mobject.py new file mode 100644 index 00000000..b1774fef --- /dev/null +++ b/manimlib/mobject/types/surface_mobject.py @@ -0,0 +1,76 @@ +import numpy as np +import moderngl + +# from PIL import Image + +from manimlib.constants import * +from manimlib.mobject.mobject import Mobject +from manimlib.utils.color import color_to_rgba + + +class SurfaceMobject(Mobject): + CONFIG = { + "color": GREY, + "opacity": 1, + "gloss": 1.0, + "render_primative": moderngl.TRIANGLES, + # "render_primative": moderngl.TRIANGLE_STRIP, + "vert_shader_file": "surface_vert.glsl", + "frag_shader_file": "surface_frag.glsl", + "shader_dtype": [ + ('point', np.float32, (3,)), + ('normal', np.float32, (3,)), + ('color', np.float32, (4,)), + ('gloss', np.float32, (1,)), + # ('im_coords', np.float32, (2,)), + ] + } + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def init_points(self): + self.points = np.zeros((0, self.dim)) + self.normals = np.zeros((0, self.dim)) + + def init_colors(self): + self.set_color(self.color, self.opacity) + + def set_points(self, points, normals=None): + self.points = np.array(points) + if normals is None: + v01 = points[1:-1] - points[:-2] + v02 = points[2:] - points[:-2] + crosses = np.cross(v01, v02) + crosses[1::2] *= -1 # Because of reversed orientation of every other triangle in the strip + self.normals = np.vstack([ + crosses, + crosses[-1:].repeat(2, 0) # Repeat last entry twice + ]) + else: + self.normals = np.array(normals) + + def set_color(self, color, opacity): + # TODO, allow for multiple colors + rgba = color_to_rgba(color, opacity) + self.rgbas = np.array([rgba]) + + def apply_function(self, function, **kwargs): + # Apply it to infinitesimal neighbors to preserve normals + pass + + def rotate(self, axis, angle, **kwargs): + # Account for normals + pass + + def stretch(self, factor, dim, **kwargs): + # Account for normals + pass + + def get_shader_data(self): + data = self.get_blank_shader_data_array(len(self.points)) + data["point"] = self.points + data["normal"] = self.normals + data["color"] = self.rgbas + data["gloss"] = self.gloss + return data diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 67070995..d4631deb 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -1,12 +1,13 @@ import itertools as it +import operator as op import moderngl from colour import Color +from functools import reduce from manimlib.constants import * from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Point -from manimlib.mobject.three_d_utils import get_3d_vmob_gradient_start_and_end_points from manimlib.utils.bezier import bezier from manimlib.utils.bezier import get_smooth_handle_points from manimlib.utils.bezier import get_quadratic_approximation_of_cubic @@ -19,10 +20,12 @@ from manimlib.utils.iterables import make_even from manimlib.utils.iterables import stretch_array_to_length from manimlib.utils.iterables import stretch_array_to_length_with_interpolation from manimlib.utils.iterables import listify -from manimlib.utils.space_ops import cross2d -from manimlib.utils.space_ops import get_norm from manimlib.utils.space_ops import angle_between_vectors +from manimlib.utils.space_ops import cross2d from manimlib.utils.space_ops import earclip_triangulation +from manimlib.utils.space_ops import get_norm +from manimlib.utils.space_ops import get_unit_normal +from manimlib.utils.space_ops import z_to_vector from manimlib.utils.shaders import get_shader_info @@ -58,21 +61,26 @@ class VMobject(Mobject): "fill_frag_shader_file": "quadratic_bezier_fill_frag.glsl", # Could also be Bevel, Miter, Round "joint_type": "auto", + # Positive gloss up to 1 makes it reflect the light. + "gloss": 0.2, "render_primative": moderngl.TRIANGLES, "triangulation_locked": False, "fill_dtype": [ ('point', np.float32, (3,)), + ('unit_normal', np.float32, (3,)), ('color', np.float32, (4,)), ('fill_all', np.float32, (1,)), - ('orientation', np.float32, (1,)), + ('gloss', np.float32, (1,)), ], "stroke_dtype": [ ("point", np.float32, (3,)), ("prev_point", np.float32, (3,)), ("next_point", np.float32, (3,)), + ('unit_normal', np.float32, (3,)), ("stroke_width", np.float32, (1,)), ("color", np.float32, (4,)), ("joint_type", np.float32, (1,)), + ("gloss", np.float32, (1,)), ] } @@ -228,6 +236,14 @@ class VMobject(Mobject): super().fade(darkness, family) return self + def set_gloss(self, gloss, family=True): + if family: + for sm in self.get_family(): + sm.gloss = gloss + else: + self.gloss = gloss + return self + def get_fill_rgbas(self): try: return self.fill_rgbas @@ -654,6 +670,14 @@ class VMobject(Mobject): self.get_end_anchors(), )))) + def get_points_without_null_curves(self, atol=1e-9): + nppc = self.n_points_per_curve + distinct_curves = reduce(op.or_, [ + (abs(self.points[i::nppc] - self.points[0::nppc]) > atol).any(1) + for i in range(1, nppc) + ]) + return self.points[distinct_curves.repeat(nppc)] + def get_arc_length(self, n_sample_points=None): if n_sample_points is None: n_sample_points = 4 * self.get_num_curves() + 1 @@ -665,6 +689,38 @@ class VMobject(Mobject): norms = np.array([get_norm(d) for d in diffs]) return norms.sum() + def get_area_vector(self): + # Returns a vector whose length is the area bound by + # the polygon formed by the anchor points, pointing + # in a direction perpendicular to the polygon according + # to the right hand rule. + if self.has_no_points(): + return np.zeros(3) + + nppc = self.n_points_per_curve + p0 = self.points[0::nppc] + p1 = self.points[nppc - 1::nppc] + + # Each term goes through all edges [(x1, y1, z1), (x2, y2, z2)] + return 0.5 * np.array([ + sum((p0[:, 1] + p1[:, 1]) * (p1[:, 2] - p0[:, 2])), # Add up (y1 + y2)*(z2 - z1) + sum((p0[:, 2] + p1[:, 2]) * (p1[:, 0] - p0[:, 0])), # Add up (z1 + z2)*(x2 - x1) + sum((p0[:, 0] + p1[:, 0]) * (p1[:, 1] - p0[:, 1])), # Add up (x1 + x2)*(y2 - y1) + ]) + + def get_unit_normal_vector(self): + if len(self.points) < 3: + return OUT + area_vect = self.get_area_vector() + area = get_norm(area_vect) + if area > 0: + return area_vect / area + else: + return get_unit_normal( + self.points[1] - self.points[0], + self.points[2] - self.points[1], + ) + # Alignment def align_points(self, vmobject): self.align_rgbas(vmobject) @@ -896,15 +952,20 @@ class VMobject(Mobject): if len(stroke_width) > 1: stroke_width = self.stretched_style_array_matching_points(stroke_width) - data = self.get_blank_shader_data_array(len(self.points), "stroke_data") - data['point'] = self.points - data['prev_point'][:3] = self.points[-3:] - data['prev_point'][3:] = self.points[:-3] - data['next_point'][:-3] = self.points[3:] - data['next_point'][-3:] = self.points[:3] - data['stroke_width'][:, 0] = stroke_width - data['color'] = rgbas - data['joint_type'] = joint_type_to_code[self.joint_type] + points = self.get_points_without_null_curves() + nppc = self.n_points_per_curve + + data = self.get_blank_shader_data_array(len(points), "stroke_data") + data["point"] = points + data["prev_point"][:nppc] = points[-nppc:] + data["prev_point"][nppc:] = points[:-nppc] + data["next_point"][:-nppc] = points[nppc:] + data["next_point"][-nppc:] = points[:nppc] + data["unit_normal"] = self.get_unit_normal_vector() + data["stroke_width"][:, 0] = stroke_width + data["color"] = rgbas + data["joint_type"] = joint_type_to_code[self.joint_type] + data["gloss"] = self.gloss return data def lock_triangulation(self, family=True): @@ -912,7 +973,6 @@ class VMobject(Mobject): for mob in mobs: mob.triangulation_locked = False mob.saved_triangulation = mob.get_triangulation() - mob.saved_orientation = mob.get_orientation() mob.triangulation_locked = True return self @@ -925,26 +985,12 @@ class VMobject(Mobject): if sm.triangulation_locked: sm.lock_triangulation(family=False) - def get_signed_polygonal_area(self): - nppc = self.n_points_per_curve - p0 = self.points[0::nppc] - p1 = self.points[nppc - 1::nppc] - # Add up (x1 + x2)*(y2 - y1) for all edges (x1, y1), (x2, y2) - return sum((p0[:, 0] + p1[:, 0]) * (p1[:, 1] - p0[:, 1])) - - def get_orientation(self): - if self.triangulation_locked: - return self.saved_orientation - if self.has_no_points(): - return 0 - return np.sign(self.get_signed_polygonal_area()) - - def get_triangulation(self, orientation=None): + def get_triangulation(self, normal_vector=None): # Figure out how to triangulate the interior to know # how to send the points as to the vertex shader. # First triangles come directly from the points - if orientation is None: - orientation = self.get_orientation() + if normal_vector is None: + normal_vector = self.get_unit_normal_vector() if self.triangulation_locked: return self.saved_triangulation @@ -952,7 +998,9 @@ class VMobject(Mobject): if len(self.points) <= 1: return [] - points = self.points + # Rotate points such that unit normal vector is OUT + # TODO, 99% of the time this does nothing. Do a check for that? + points = np.dot(self.points, z_to_vector(normal_vector)) indices = np.arange(len(points), dtype=int) b0s = points[0::3] @@ -961,9 +1009,8 @@ class VMobject(Mobject): v01s = b1s - b0s v12s = b2s - b1s - # TODO, account for 3d crosses = cross2d(v01s, v12s) - convexities = orientation * np.sign(crosses) + convexities = np.sign(crosses) atol = self.tolerance_for_point_equality end_of_loop = np.zeros(len(b0s), dtype=bool) @@ -983,31 +1030,28 @@ class VMobject(Mobject): # Triangulate inner_verts = points[inner_vert_indices] - inner_tri_indices = inner_vert_indices[ - earclip_triangulation(inner_verts, rings) - ] + inner_tri_indices = inner_vert_indices[earclip_triangulation(inner_verts, rings)] tri_indices = np.hstack([indices, inner_tri_indices]) return tri_indices def get_fill_shader_data(self): points = self.points - - orientation = self.get_orientation() - tri_indices = self.get_triangulation(orientation) + unit_normal = self.get_unit_normal_vector() + tri_indices = self.get_triangulation(unit_normal) # TODO, best way to enable multiple colors? rgbas = self.get_fill_rgbas()[:1] data = self.get_blank_shader_data_array(len(tri_indices), "fill_data") data["point"] = points[tri_indices] + data["unit_normal"] = unit_normal data["color"] = rgbas # Assume the triangulation is such that the first n_points points # are on the boundary, and the rest are in the interior data["fill_all"][:len(points)] = 0 data["fill_all"][len(points):] = 1 - data["orientation"] = orientation - + data["gloss"] = self.gloss return data diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index 7bc34bb3..412c64a3 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -120,6 +120,8 @@ class Scene(Container): # Stack depth of 2 means the shell will use # the namespace of the caller, not this method shell(stack_depth=2) + # End scene when exiting an embed. + raise EndSceneEarlyException() def __str__(self): return self.__class__.__name__ @@ -527,6 +529,9 @@ class Scene(Container): def on_mouse_drag(self, point, d_point, buttons, modifiers): self.mouse_drag_point.move_to(point) + # Only if 3d rotation is enabled? + self.camera.frame.increment_theta(-d_point[0]) + self.camera.frame.increment_phi(d_point[1]) def on_mouse_press(self, point, button, mods): pass @@ -548,7 +553,7 @@ class Scene(Container): def on_key_press(self, symbol, modifiers): if chr(symbol) == "r": - self.camera.frame.restore() + self.camera.frame.to_default_state() elif chr(symbol) == "z": self.zoom_on_scroll = True elif chr(symbol) == "q": diff --git a/manimlib/scene/scene_file_writer.py b/manimlib/scene/scene_file_writer.py index e110ad4b..fd234bc0 100644 --- a/manimlib/scene/scene_file_writer.py +++ b/manimlib/scene/scene_file_writer.py @@ -226,6 +226,7 @@ class SceneFileWriter(object): command += [ '-vcodec', 'libx264', '-pix_fmt', 'yuv420p', + # '-pix_fmt', 'yuv444p14le', ] command += [temp_file_path] self.writing_process = sp.Popen(command, stdin=sp.PIPE) diff --git a/manimlib/scene/three_d_scene.py b/manimlib/scene/three_d_scene.py index 1d4979ea..5cb847b8 100644 --- a/manimlib/scene/three_d_scene.py +++ b/manimlib/scene/three_d_scene.py @@ -1,5 +1,4 @@ from manimlib.animation.transform import ApplyMethod -from manimlib.camera.three_d_camera import ThreeDCamera from manimlib.constants import DEGREES from manimlib.constants import PRODUCTION_QUALITY_CAMERA_CONFIG from manimlib.mobject.coordinate_systems import ThreeDAxes @@ -12,9 +11,10 @@ from manimlib.utils.config_ops import digest_config from manimlib.utils.config_ops import merge_dicts_recursively +# TODO, these seem deprecated. + class ThreeDScene(Scene): CONFIG = { - "camera_class": ThreeDCamera, "ambient_camera_rotation": None, "default_angled_camera_orientation_kwargs": { "phi": 70 * DEGREES, diff --git a/manimlib/shaders/add_light.glsl b/manimlib/shaders/add_light.glsl new file mode 100644 index 00000000..1c1b61cb --- /dev/null +++ b/manimlib/shaders/add_light.glsl @@ -0,0 +1,23 @@ +vec4 add_light(vec4 raw_color, vec3 point, vec3 unit_normal, vec3 light_coords, float gloss){ + if(gloss == 0.0) return raw_color; + + // TODO, do we actually want this? For VMobjects its nice to just choose whichever unit normal + // is pointing towards the camera. + if(unit_normal.z < 0){ + unit_normal *= -1; + } + + float camera_distance = 6; // TODO, read this in as a uniform? + // Assume everything has already been rotated such that camera is in the z-direction + vec3 to_camera = vec3(0, 0, camera_distance) - point; + vec3 to_light = light_coords - point; + vec3 light_reflection = -to_light + 2 * unit_normal * dot(to_light, unit_normal); + float dot_prod = dot(normalize(light_reflection), normalize(to_camera)); + // float shine = gloss * exp(-3 * pow(1 - dot_prod, 2)); + float shine = 2 * gloss * exp(-1 * pow(1 - dot_prod, 2)); + float dp2 = dot(normalize(to_light), unit_normal); + return vec4( + mix(0.5, 1.0, max(dp2, 0)) * mix(raw_color.rgb, vec3(1.0), shine), + raw_color.a + ); +} \ No newline at end of file diff --git a/manimlib/shaders/get_gl_Position.glsl b/manimlib/shaders/get_gl_Position.glsl new file mode 100644 index 00000000..071ebb52 --- /dev/null +++ b/manimlib/shaders/get_gl_Position.glsl @@ -0,0 +1,14 @@ +// Assumes the following uniforms exist in the surrounding context: +// uniform float aspect_ratio; +// uniform float focal_distance; + +vec4 get_gl_Position(vec3 point){ + point.x /= aspect_ratio; + point.z /= focal_distance; + point.xy /= max(1 - point.z, 0); + // Todo, does this discontinuity add weirdness? Theoretically, by this point, + // the z-coordiante of gl_Position only matter for z-indexing. The reason + // for thie line is to avoid agressive clipping of distant points. + if(point.z < 0) point.z *= 0.1; + return vec4(point.xy, -point.z, 1); +} \ No newline at end of file diff --git a/manimlib/shaders/get_unit_normal.glsl b/manimlib/shaders/get_unit_normal.glsl new file mode 100644 index 00000000..ed1b975d --- /dev/null +++ b/manimlib/shaders/get_unit_normal.glsl @@ -0,0 +1,22 @@ +vec3 get_unit_normal(in vec3[3] points){ + float tol = 1e-6; + vec3 v1 = normalize(points[1] - points[0]); + vec3 v2 = normalize(points[2] - points[0]); + vec3 cp = cross(v1, v2); + float cp_norm = length(cp); + if(cp_norm < tol){ + // Three points form a line, so find a normal vector + // to that line in the plane shared with the z-axis + vec3 k_hat = vec3(0.0, 0.0, 1.0); + vec3 new_cp = cross(cross(v2, k_hat), v2); + float new_cp_norm = length(new_cp); + if(new_cp_norm < tol){ + // We only come here if all three points line up + // on the z-axis. + return vec3(0.0, 1.0, 0.0); + // return k_hat; + } + return new_cp / new_cp_norm; + } + return cp / cp_norm; +} \ No newline at end of file diff --git a/manimlib/shaders/image_vert.glsl b/manimlib/shaders/image_vert.glsl index 1e81a472..f7e48932 100644 --- a/manimlib/shaders/image_vert.glsl +++ b/manimlib/shaders/image_vert.glsl @@ -1,8 +1,9 @@ #version 330 -uniform float scale; uniform float aspect_ratio; -uniform vec3 frame_center; +uniform float anti_alias_width; +uniform mat4 to_screen_space; +uniform float focal_distance; uniform sampler2D Texture; @@ -14,16 +15,11 @@ out vec2 v_im_coords; out float v_opacity; // Analog of import for manim only -#INSERT rotate_point_for_frame.glsl -#INSERT scale_and_shift_point_for_frame.glsl +#INSERT get_gl_Position.glsl +#INSERT position_point_into_frame.glsl void main(){ v_im_coords = im_coords; v_opacity = opacity; - gl_Position = vec4( - rotate_point_for_frame( - scale_and_shift_point_for_frame(point) - ), - 1.0 - ); + gl_Position = get_gl_Position(position_point_into_frame(point)); } \ No newline at end of file diff --git a/manimlib/shaders/position_point_into_frame.glsl b/manimlib/shaders/position_point_into_frame.glsl new file mode 100644 index 00000000..91818735 --- /dev/null +++ b/manimlib/shaders/position_point_into_frame.glsl @@ -0,0 +1,9 @@ +// Must be used in an environment with the following uniforms: +// uniform mat4 to_screen_space; +// uniform float focal_distance; + +vec3 position_point_into_frame(vec3 point){ + // Apply the pre-computed to_screen_space matrix. + vec4 new_point = to_screen_space * vec4(point, 1); + return new_point.xyz; +} diff --git a/manimlib/shaders/quadratic_bezier_distance.glsl b/manimlib/shaders/quadratic_bezier_distance.glsl index c928b254..d7fd8ccd 100644 --- a/manimlib/shaders/quadratic_bezier_distance.glsl +++ b/manimlib/shaders/quadratic_bezier_distance.glsl @@ -1,6 +1,8 @@ // This file is not a shader, it's just a set of // functions meant to be inserted into other shaders. +// Must be inserted in a context with a definition for modify_distance_for_endpoints + // All of this is with respect to a curve that's been rotated/scaled // so that b0 = (0, 0) and b1 = (1, 0). That is, b2 entirely // determines the shape of the curve @@ -16,36 +18,6 @@ vec2 bezier(float t, vec2 b2){ ); } -void compute_C_and_grad_C(float a, float b, vec2 p, out float Cxy, out vec2 grad_Cxy){ - // Curve has the implicit form x = a*y + b*sqrt(y), which is also - // 0 = -x^2 + 2axy + b^2 y - a^2 y^2. - Cxy = -p.x*p.x + 2 * a * p.x*p.y + b*b * p.y - a*a * p.y*p.y; - - // Approximate distance to curve using the gradient of -x^2 + 2axy + b^2 y - a^2 y^2 - grad_Cxy = vec2( - -2 * p.x + 2 * a * p.y, // del C / del x - 2 * a * p.x + b*b - 2 * a*a * p.y // del C / del y - ); -} - -// This function is flawed. -float cheap_dist_to_curve(vec2 p, vec2 b2){ - float a = (b2.x - 2.0) / b2.y; - float b = sign(b2.y) * 2.0 / sqrt(abs(b2.y)); - float x = p.x; - float y = p.y; - - // Curve has the implicit form x = a*y + b*sqrt(y), which is also - // 0 = -x^2 + 2axy + b^2 y - a^2 y^2. - float Cxy = -x * x + 2 * a * x * y + sign(b2.y) * b * b * y - a * a * y * y; - - // Approximate distance to curve using the gradient of -x^2 + 2axy + b^2 y - a^2 y^2 - vec2 grad_Cxy = 2 * vec2( - -x + a * y, // del C / del x - a * x + b * b / 2 - a * a * y // del C / del y - ); - return abs(Cxy / length(grad_Cxy)); -} float cube_root(float x){ return sign(x) * pow(abs(x), 1.0 / 3.0); @@ -98,9 +70,9 @@ int cubic_solve(float a, float b, float c, float d, out float roots[3]){ float dist_to_line(vec2 p, vec2 b2){ float t = clamp(p.x / b2.x, 0, 1); float dist; - if(t == 0) dist = length(p); + if(t == 0) dist = length(p); else if(t == 1) dist = distance(p, b2); - else dist = abs(p.y); + else dist = abs(p.y); return modify_distance_for_endpoints(p, dist, t); } @@ -114,10 +86,9 @@ float dist_to_point_on_curve(vec2 p, float t, vec2 b2){ } -float min_dist_to_curve(vec2 p, vec2 b2, float degree, bool quick_approx){ +float min_dist_to_curve(vec2 p, vec2 b2, float degree){ // Check if curve is really a a line if(degree == 1) return dist_to_line(p, b2); - if(quick_approx) return cheap_dist_to_curve(p, b2); // Try finding the exact sdf by solving the equation // (d/dt) dist^2(t) = 0, which amount to the following diff --git a/manimlib/shaders/quadratic_bezier_fill_frag.glsl b/manimlib/shaders/quadratic_bezier_fill_frag.glsl index 785d9a90..76b934db 100644 --- a/manimlib/shaders/quadratic_bezier_fill_frag.glsl +++ b/manimlib/shaders/quadratic_bezier_fill_frag.glsl @@ -1,22 +1,23 @@ #version 330 +uniform vec3 light_source_position; +uniform mat4 to_screen_space; + in vec4 color; -in float fill_type; +in float fill_all; // Either 0 or 1e in float uv_anti_alias_width; +in vec3 xyz_coords; +in vec3 global_unit_normal; +in float orientation; in vec2 uv_coords; -in vec2 wz_coords; in vec2 uv_b2; in float bezier_degree; +in float gloss; out vec4 frag_color; -const float FILL_INSIDE = 0; -const float FILL_OUTSIDE = 1; -const float FILL_ALL = 2; - - -// Needed for quadratic_bezier_distance +// Needed for quadratic_bezier_distance insertion below float modify_distance_for_endpoints(vec2 p, float dist, float t){ return dist; } @@ -25,28 +26,48 @@ float modify_distance_for_endpoints(vec2 p, float dist, float t){ // so to share functionality between this and others, the caller // replaces this line with the contents of quadratic_bezier_sdf.glsl #INSERT quadratic_bezier_distance.glsl - - -bool is_inside_curve(){ - if(bezier_degree < 2) return false; - - float value = wz_coords.x * wz_coords.x - wz_coords.y; - if(fill_type == FILL_INSIDE) return value < 0; - if(fill_type == FILL_OUTSIDE) return value > 0; - return false; -} +#INSERT add_light.glsl float sdf(){ - if(is_inside_curve()) return -1; - return min_dist_to_curve(uv_coords, uv_b2, bezier_degree, false); + if(bezier_degree < 2){ + return abs(uv_coords[1]); + } + float u2 = uv_b2.x; + float v2 = uv_b2.y; + // For really flat curves, just take the distance to x-axis + if(abs(v2 / u2) < 0.1 * uv_anti_alias_width){ + return abs(uv_coords[1]); + } + // For flat-ish curves, take the curve + else if(abs(v2 / u2) < 0.5 * uv_anti_alias_width){ + return min_dist_to_curve(uv_coords, uv_b2, bezier_degree); + } + // I know, I don't love this amount of arbitrary-seeming branching either, + // but a number of strange dimples and bugs pop up otherwise. + + // This converts uv_coords to yet another space where the bezier points sit on + // (0, 0), (1/2, 0) and (1, 1), so that the curve can be expressed implicityly + // as y = x^2. + mat2 to_simple_space = mat2( + v2, 0, + 2 - u2, 4 * v2 + ); + vec2 p = to_simple_space * uv_coords; + // Sign takes care of whether we should be filling the inside or outside of curve. + float sn = orientation * sign(v2); + float Fp = sn * (p.x * p.x - p.y); + vec2 grad = vec2( + -2 * p.x * v2, // del C / del u + 4 * v2 - 4 * p.x * (2 - u2) // del C / del v + ); + return Fp / length(grad); } void main() { if (color.a == 0) discard; - frag_color = color; - if (fill_type == FILL_ALL) return; + frag_color = add_light(color, xyz_coords, global_unit_normal, light_source_position, gloss); + if (fill_all == 1.0) return; frag_color.a *= smoothstep(1, 0, sdf() / uv_anti_alias_width); - // frag_color.a += 0.2; -} \ No newline at end of file +} diff --git a/manimlib/shaders/quadratic_bezier_fill_geom.glsl b/manimlib/shaders/quadratic_bezier_fill_geom.glsl index 5d850350..8d0da6f1 100644 --- a/manimlib/shaders/quadratic_bezier_fill_geom.glsl +++ b/manimlib/shaders/quadratic_bezier_fill_geom.glsl @@ -3,176 +3,119 @@ layout (triangles) in; layout (triangle_strip, max_vertices = 5) out; -uniform float scale; -uniform float aspect_ratio; uniform float anti_alias_width; -uniform vec3 frame_center; +// Needed for get_gl_Position +uniform float aspect_ratio; +uniform float focal_distance; in vec3 bp[3]; +in vec3 v_global_unit_normal[3]; in vec4 v_color[3]; in float v_fill_all[3]; -in float v_orientation[3]; +in float v_gloss[3]; out vec4 color; -out float fill_type; +out float gloss; +out float fill_all; out float uv_anti_alias_width; +out vec3 xyz_coords; +out vec3 global_unit_normal; +out float orientation; // uv space is where b0 = (0, 0), b1 = (1, 0), and transform is orthogonal out vec2 uv_coords; out vec2 uv_b2; -// wz space is where b0 = (0, 0), b1 = (0.5, 0), b2 = (1, 1) -out vec2 wz_coords; - out float bezier_degree; -const float FILL_INSIDE = 0; -const float FILL_OUTSIDE = 1; -const float FILL_ALL = 2; - -const float SQRT5 = 2.236068; - - // To my knowledge, there is no notion of #include for shaders, // so to share functionality between this and others, the caller -// replaces this line with the contents of named file +// in manim replaces this line with the contents of named file #INSERT quadratic_bezier_geometry_functions.glsl -#INSERT scale_and_shift_point_for_frame.glsl +#INSERT get_gl_Position.glsl +#INSERT get_unit_normal.glsl -mat3 get_xy_to_wz(vec2 b0, vec2 b1, vec2 b2){ - // If linear or null, this matrix is not needed - if(bezier_degree < 2) return mat3(1.0); - - vec2 inv_col1 = 2 * (b1 - b0); - vec2 inv_col2 = b2 - 2 * b1 + b0; - float inv_det = cross(inv_col1, inv_col2); - - mat3 transform = mat3( - inv_col2.y, -inv_col1.y, 0, - -inv_col2.x, inv_col1.x, 0, - 0, 0, inv_det - ) / inv_det; - - mat3 shift = mat3( - 1, 0, 0, - 0, 1, 0, - -b0.x, -b0.y, 1 - ); - return transform * shift; +void emit_vertex_wrapper(vec3 point, int index){ + color = v_color[index]; + gloss = v_gloss[index]; + global_unit_normal = v_global_unit_normal[index]; + xyz_coords = point; + gl_Position = get_gl_Position(xyz_coords); + EmitVertex(); } void emit_simple_triangle(){ for(int i = 0; i < 3; i++){ - color = v_color[i]; - gl_Position = vec4( - scale_and_shift_point_for_frame(bp[i]), - 1.0 - ); - EmitVertex(); + emit_vertex_wrapper(bp[i], i); } EndPrimitive(); } -void emit_pentagon(vec2 bp0, vec2 bp1, vec2 bp2, float orientation){ +void emit_pentagon(vec3[3] points, vec3 normal){ + vec3 p0 = points[0]; + vec3 p1 = points[1]; + vec3 p2 = points[2]; // Tangent vectors - vec2 t01 = normalize(bp1 - bp0); - vec2 t12 = normalize(bp2 - bp1); - - // Inside and left turn -> rot right -> -1 - // Outside and left turn -> rot left -> +1 - // Inside and right turn -> rot left -> +1 - // Outside and right turn -> rot right -> -1 - float c_orient = (cross(t01, t12) > 0) ? 1 : -1; - c_orient *= orientation; - - bool fill_in = (c_orient > 0); - fill_type = fill_in ? FILL_INSIDE : FILL_OUTSIDE; - - // float orient = in_or_out * c_orient; - - // Normal vectors - // Rotate tangent vector 90-degrees clockwise - // if the curve is positively oriented, otherwise - // rotate it 90-degrees counterclockwise - vec2 n01 = orientation * vec2(t01.y, -t01.x); - vec2 n12 = orientation * vec2(t12.y, -t12.x); + vec3 t01 = normalize(p1 - p0); + vec3 t12 = normalize(p2 - p1); + // Vectors perpendicular to the curve in the plane of the curve pointing outside the curve + vec3 p0_perp = cross(t01, normal); + vec3 p2_perp = cross(t12, normal); + bool fill_in = orientation > 0; float aaw = anti_alias_width; - vec2 nudge1 = fill_in ? 0.5 * aaw * (n01 + n12) : vec2(0); - vec2 corners[5] = vec2[5]( - bp0 + aaw * n01, - bp0, - bp1 + nudge1, - bp2, - bp2 + aaw * n12 - ); + vec3 corners[5]; + if(fill_in){ + // Note, straight lines will also fall into this case, and since p0_perp and p2_perp + // will point to the right of the curve, it's just what we want + corners = vec3[5]( + p0 + aaw * p0_perp, + p0, + p1 + 0.5 * aaw * (p0_perp + p2_perp), + p2, + p2 + aaw * p2_perp + ); + }else{ + corners = vec3[5]( + p0, + p0 - aaw * p0_perp, + p1, + p2 - aaw * p2_perp, + p2 + ); + } - int coords_index_map[5] = int[5](0, 1, 2, 3, 4); - if(!fill_in) coords_index_map = int[5](1, 0, 2, 4, 3); - - mat3 xy_to_uv = get_xy_to_uv(bp0, bp1); - mat3 xy_to_wz = get_xy_to_wz(bp0, bp1, bp2); - uv_b2 = (xy_to_uv * vec3(bp2, 1)).xy; - uv_anti_alias_width = anti_alias_width / length(bp1 - bp0); + mat4 xyz_to_uv = get_xyz_to_uv(p0, p1, normal); + uv_b2 = (xyz_to_uv * vec4(p2, 1)).xy; + uv_anti_alias_width = anti_alias_width / length(p1 - p0); for(int i = 0; i < 5; i++){ - vec2 corner = corners[coords_index_map[i]]; - uv_coords = (xy_to_uv * vec3(corner, 1)).xy; - wz_coords = (xy_to_wz * vec3(corner, 1)).xy; - float z; - // I haven't a clue why an index map doesn't work just - // as well here, but for some reason it doesn't. - if(i < 2){ - color = v_color[0]; - z = bp[0].z; - } - else if(i == 2){ - color = v_color[1]; - z = bp[1].z; - } - else{ - color = v_color[2]; - z = bp[2].z; - } - gl_Position = vec4( - scale_and_shift_point_for_frame(vec3(corner, z)), - 1.0 - ); - EmitVertex(); + vec3 corner = corners[i]; + uv_coords = (xyz_to_uv * vec4(corner, 1)).xy; + int j = int(sign(i - 1) + 1); // Maps i = [0, 1, 2, 3, 4] onto j = [0, 0, 1, 2, 2] + emit_vertex_wrapper(corner, j); } EndPrimitive(); } void main(){ - float fill_all = v_fill_all[0]; + fill_all = v_fill_all[0]; + vec3 local_unit_normal = get_unit_normal(vec3[3](bp[0], bp[1], bp[2])); + orientation = sign(dot(v_global_unit_normal[0], local_unit_normal)); if(fill_all == 1){ - fill_type = FILL_ALL; emit_simple_triangle(); - }else{ - vec2 new_bp[3]; - int n = get_reduced_control_points(bp[0].xy, bp[1].xy, bp[2].xy, new_bp); - bezier_degree = float(n); - float orientation = v_orientation[0]; - - vec2 bp0, bp1, bp2; - if(n == 0){ - return; // Don't emit any vertices - } - else if(n == 1){ - bp0 = new_bp[0]; - bp2 = new_bp[1]; - bp1 = 0.5 * (bp0 + bp2); - }else{ - bp0 = new_bp[0]; - bp1 = new_bp[1]; - bp2 = new_bp[2]; - } - - emit_pentagon(bp0, bp1, bp2, orientation); + return; } + + vec3 new_bp[3]; + bezier_degree = get_reduced_control_points(vec3[3](bp[0], bp[1], bp[2]), new_bp); + if(bezier_degree >= 1){ + emit_pentagon(new_bp, local_unit_normal); + } + // Don't emit any vertices for bezier_degree 0 } diff --git a/manimlib/shaders/quadratic_bezier_fill_vert.glsl b/manimlib/shaders/quadratic_bezier_fill_vert.glsl index 48486841..9f7ab752 100644 --- a/manimlib/shaders/quadratic_bezier_fill_vert.glsl +++ b/manimlib/shaders/quadratic_bezier_fill_vert.glsl @@ -1,24 +1,28 @@ #version 330 +uniform mat4 to_screen_space; + in vec3 point; +in vec3 unit_normal; in vec4 color; -// fill_all is 0 or 1 -in float fill_all; -// orientation is +1 for counterclockwise curves, -1 otherwise -in float orientation; +in float fill_all; // Either 0 or 1 +in float gloss; out vec3 bp; // Bezier control point +out vec3 v_global_unit_normal; out vec4 v_color; out float v_fill_all; -out float v_orientation; - - -#INSERT rotate_point_for_frame.glsl +out float v_gloss; +// To my knowledge, there is no notion of #include for shaders, +// so to share functionality between this and others, the caller +// replaces this line with the contents of named file +#INSERT position_point_into_frame.glsl void main(){ - bp = rotate_point_for_frame(point); + bp = position_point_into_frame(point); + v_global_unit_normal = normalize(position_point_into_frame(unit_normal)); v_color = color; v_fill_all = fill_all; - v_orientation = orientation; + v_gloss = gloss; } \ No newline at end of file diff --git a/manimlib/shaders/quadratic_bezier_geometry_functions.glsl b/manimlib/shaders/quadratic_bezier_geometry_functions.glsl index d5cd01d0..25ab9bd4 100644 --- a/manimlib/shaders/quadratic_bezier_geometry_functions.glsl +++ b/manimlib/shaders/quadratic_bezier_geometry_functions.glsl @@ -1,26 +1,33 @@ // This file is not a shader, it's just a set of // functions meant to be inserted into other shaders. -float cross(vec2 v, vec2 w){ +float cross2d(vec2 v, vec2 w){ return v.x * w.y - w.x * v.y; } -// Matrix to convert to a uv space defined so that +// Orthogonal matrix to convert to a uv space defined so that // b0 goes to [0, 0] and b1 goes to [1, 0] -mat3 get_xy_to_uv(vec2 b0, vec2 b1){ - vec2 T = b1 - b0; - - mat3 shift = mat3( - 1, 0, 0, - 0, 1, 0, - -b0.x, -b0.y, 1 +mat4 get_xyz_to_uv(vec3 b0, vec3 b1, vec3 unit_normal){ + mat4 shift = mat4( + 1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + -b0.x, -b0.y, -b0.z, 1 ); - mat3 rotate_and_scale = mat3( - T.x, -T.y, 0, - T.y, T.x, 0, - 0, 0, 1 - ) / dot(T, T); - return rotate_and_scale * shift; + + float scale_factor = length(b1 - b0); + vec3 I = (b1 - b0) / scale_factor; + vec3 K = unit_normal; + vec3 J = cross(K, I); + // Transpose (hence inverse) of matrix taking + // i-hat to I, k-hat to unit_normal, and j-hat to their cross + mat4 rotate = mat4( + I.x, J.x, K.x, 0, + I.y, J.y, K.y, 0, + I.z, J.z, K.z, 0, + 0, 0, 0 , 1 + ); + return (1 / scale_factor) * rotate * shift; } @@ -29,31 +36,40 @@ mat3 get_xy_to_uv(vec2 b0, vec2 b1){ // which for quadratics will be the same, but for linear and null // might change. The idea is to inform the caller of the degree, // while also passing tangency information in the linear case. -int get_reduced_control_points(vec2 b0, vec2 b1, vec2 b2, out vec2 new_points[3]){ - float epsilon = 1e-6; - vec2 v01 = (b1 - b0); - vec2 v12 = (b2 - b1); - bool distinct_01 = length(v01) > epsilon; // v01 is considered nonzero - bool distinct_12 = length(v12) > epsilon; // v12 is considered nonzero +// float get_reduced_control_points(vec3 b0, vec3 b1, vec3 b2, out vec3 new_points[3]){ +float get_reduced_control_points(in vec3 points[3], out vec3 new_points[3]){ + float length_threshold = 1e-6; + float angle_threshold = 1e-3; + + vec3 p0 = points[0]; + vec3 p1 = points[1]; + vec3 p2 = points[2]; + vec3 v01 = (p1 - p0); + vec3 v12 = (p2 - p1); + + float dot_prod = clamp(dot(normalize(v01), normalize(v12)), -1, 1); + bool aligned = acos(dot_prod) < angle_threshold; + bool distinct_01 = length(v01) > length_threshold; // v01 is considered nonzero + bool distinct_12 = length(v12) > length_threshold; // v12 is considered nonzero int n_uniques = int(distinct_01) + int(distinct_12); - if(n_uniques == 2){ - bool linear = dot(normalize(v01), normalize(v12)) > 1 - epsilon; - if(linear){ - new_points[0] = b0; - new_points[1] = b2; - return 1; - }else{ - new_points[0] = b0; - new_points[1] = b1; - new_points[2] = b2; - return 2; - } - }else if(n_uniques == 1){ - new_points[0] = b0; - new_points[1] = b2; - return 1; + + bool quadratic = (n_uniques == 2) && !aligned; + bool linear = (n_uniques == 1) || ((n_uniques == 2) && aligned); + bool constant = (n_uniques == 0); + if(quadratic){ + new_points[0] = p0; + new_points[1] = p1; + new_points[2] = p2; + return 2.0; + }else if(linear){ + new_points[0] = p0; + new_points[1] = (p0 + p2) / 2.0; + new_points[2] = p2; + return 1.0; }else{ - new_points[0] = b0; - return 0; + new_points[0] = p0; + new_points[1] = p0; + new_points[2] = p0; + return 0.0; } } \ No newline at end of file diff --git a/manimlib/shaders/quadratic_bezier_stroke_frag.glsl b/manimlib/shaders/quadratic_bezier_stroke_frag.glsl index 8055bfc5..67c5c08c 100644 --- a/manimlib/shaders/quadratic_bezier_stroke_frag.glsl +++ b/manimlib/shaders/quadratic_bezier_stroke_frag.glsl @@ -1,10 +1,16 @@ #version 330 +uniform mat4 to_screen_space; +uniform vec3 light_source_position; + +in vec3 xyz_coords; +in vec3 global_unit_normal; in vec2 uv_coords; in vec2 uv_b2; in float uv_stroke_width; in vec4 color; +in float gloss; in float uv_anti_alias_width; in float has_prev; @@ -19,7 +25,7 @@ in float bezier_degree; out vec4 frag_color; -float cross(vec2 v, vec2 w){ +float cross2d(vec2 v, vec2 w){ return v.x * w.y - w.x * v.y; } @@ -64,8 +70,8 @@ float modify_distance_for_endpoints(vec2 p, float dist, float t){ ); vec2 v21_unit = v21 / length(v21); float bevel_d = max( - abs(cross(p - uv_b2, v21_unit)), - abs(cross((rot * (p - uv_b2)), v21_unit)) + abs(cross2d(p - uv_b2, v21_unit)), + abs(cross2d((rot * (p - uv_b2)), v21_unit)) ); return min(dist, bevel_d); } @@ -78,14 +84,19 @@ float modify_distance_for_endpoints(vec2 p, float dist, float t){ // so to share functionality between this and others, the caller // replaces this line with the contents of named file #INSERT quadratic_bezier_distance.glsl +#INSERT add_light.glsl void main() { if (uv_stroke_width == 0) discard; - frag_color = color; - float dist_to_curve = min_dist_to_curve(uv_coords, uv_b2, bezier_degree, false); + // Add lighting if needed + frag_color = add_light(color, xyz_coords, global_unit_normal, light_source_position, gloss); + + float dist_to_curve = min_dist_to_curve(uv_coords, uv_b2, bezier_degree); // An sdf for the region around the curve we wish to color. float signed_dist = abs(dist_to_curve) - 0.5 * uv_stroke_width; frag_color.a *= smoothstep(0.5, -0.5, signed_dist / uv_anti_alias_width); + + // frag_color.a += 0.3; } \ No newline at end of file diff --git a/manimlib/shaders/quadratic_bezier_stroke_geom.glsl b/manimlib/shaders/quadratic_bezier_stroke_geom.glsl index 9434fb74..2cf39a29 100644 --- a/manimlib/shaders/quadratic_bezier_stroke_geom.glsl +++ b/manimlib/shaders/quadratic_bezier_stroke_geom.glsl @@ -3,21 +3,24 @@ layout (triangles) in; layout (triangle_strip, max_vertices = 5) out; -uniform float scale; +// Needed for get_gl_Position uniform float aspect_ratio; +uniform float focal_distance; uniform float anti_alias_width; -uniform vec3 frame_center; in vec3 bp[3]; in vec3 prev_bp[3]; in vec3 next_bp[3]; +in vec3 v_global_unit_normal[3]; in vec4 v_color[3]; in float v_stroke_width[3]; in float v_joint_type[3]; +in float v_gloss[3]; out vec4 color; out float uv_stroke_width; +out float gloss; out float uv_anti_alias_width; out float has_prev; @@ -29,6 +32,8 @@ out float angle_to_next; out float bezier_degree; +out vec3 xyz_coords; +out vec3 global_unit_normal; out vec2 uv_coords; out vec2 uv_b2; @@ -43,36 +48,47 @@ const float MITER_JOINT = 3; // so to share functionality between this and others, the caller // replaces this line with the contents of named file #INSERT quadratic_bezier_geometry_functions.glsl -#INSERT scale_and_shift_point_for_frame.glsl +#INSERT get_gl_Position.glsl +#INSERT get_unit_normal.glsl -float angle_between_vectors(vec2 v1, vec2 v2){ - vec2 nv1 = normalize(v1); - vec2 nv2 = normalize(v2); +float get_aaw_scalar(vec3 normal){ + return min(abs(normal.z), 5); +} + + +float angle_between_vectors(vec3 v1, vec3 v2, vec3 normal){ + float v1_norm = length(v1); + float v2_norm = length(v2); + if(v1_norm == 0 || v2_norm == 0) return 0; + vec3 nv1 = v1 / v1_norm; + vec3 nv2 = v2 / v2_norm; + // float signed_area = clamp(dot(cross(nv1, nv2), normal), -1, 1); + // return asin(signed_area); float unsigned_angle = acos(clamp(dot(nv1, nv2), -1, 1)); - float sn = sign(cross(nv1, nv2)); + float sn = sign(dot(cross(nv1, nv2), normal)); return sn * unsigned_angle; } -bool find_intersection(vec2 p0, vec2 v0, vec2 p1, vec2 v1, out vec2 intersection){ +bool find_intersection(vec3 p0, vec3 v0, vec3 p1, vec3 v1, vec3 normal, out vec3 intersection){ // Find the intersection of a line passing through // p0 in the direction v0 and one passing through p1 in // the direction p1. // That is, find a solutoin to p0 + v0 * t = p1 + v1 * s // float det = -v0.x * v1.y + v1.x * v0.y; - float det = cross(v1, v0); + float det = dot(cross(v1, v0), normal); if(det == 0){ // intersection = p0; return false; } - float t = cross(p0 - p1, v1) / det; + float t = dot(cross(p0 - p1, v1), normal) / det; intersection = p0 + v0 * t; return true; } -bool is_between(vec2 p, vec2 a, vec2 b){ +bool is_between(vec3 p, vec3 a, vec3 b){ // Assumes three points fall on a line, returns whether // or not p sits between a and b. float d_pa = distance(p, a); @@ -84,18 +100,18 @@ bool is_between(vec2 p, vec2 a, vec2 b){ // Tries to detect if one of the corners defined by the buffer around // b0 and b2 should be modified to form a better convex hull -bool should_motify_corner(vec2 c, vec2 from_c, vec2 o1, vec2 o2, vec2 from_o, float buff){ - vec2 int1; - vec2 int2; - find_intersection(c, from_c, o1, from_o, int1); - find_intersection(c, from_c, o2, from_o, int2); +bool should_motify_corner(vec3 c, vec3 from_c, vec3 o1, vec3 o2, vec3 from_o, vec3 normal, float buff){ + vec3 int1; + vec3 int2; + find_intersection(c, from_c, o1, from_o, normal, int1); + find_intersection(c, from_c, o2, from_o, normal, int2); return !is_between(int2, c + 1 * from_c * buff, int1); } -void create_joint(float angle, vec2 unit_tan, float buff, float should_bevel, - vec2 static_c0, out vec2 changing_c0, - vec2 static_c1, out vec2 changing_c1){ +void create_joint(float angle, vec3 unit_tan, float buff, float should_bevel, + vec3 static_c0, out vec3 changing_c0, + vec3 static_c1, out vec3 changing_c1){ float shift; float joint_type = v_joint_type[0]; bool miter = ( @@ -119,122 +135,70 @@ void create_joint(float angle, vec2 unit_tan, float buff, float should_bevel, // This function is responsible for finding the corners of // a bounding region around the bezier curve, which can be // emitted as a triangle fan -int get_corners(vec2 controls[3], int degree, out vec2 corners[5]){ - // Unit vectors for directions between - // Various control points - vec2 v02, v20, v10, v01, v12, v21; +int get_corners(vec3 controls[3], vec3 normal, int degree, out vec3 corners[5]){ + vec3 p0 = controls[0]; + vec3 p1 = controls[1]; + vec3 p2 = controls[2]; - vec2 p0 = controls[0]; - vec2 p2 = controls[degree]; - v02 = normalize(p2 - p0); - v20 = -v02; - if(degree == 2){ - v10 = normalize(p0 - controls[1]); - v12 = normalize(p2 - controls[1]); - }else{ - v10 = v20; - v12 = v02; - } - v01 = -v10; - v21 = -v12; + // Unit vectors for directions between control points + vec3 v10 = normalize(p0 - p1); + vec3 v12 = normalize(p2 - p1); + vec3 v01 = -v10; + vec3 v21 = -v12; - // Find bounding points around ends - vec2 p0_perp = vec2(-v01.y, v01.x); - vec2 p2_perp = vec2(-v21.y, v21.x); + // + vec3 p0_perp = cross(normal, v01); // Pointing to the left of the curve from p0 + vec3 p2_perp = cross(normal, v12); // Pointing to the left of the curve from p2 - float buff0 = 0.5 * v_stroke_width[0] + anti_alias_width; - float buff2 = 0.5 * v_stroke_width[2] + anti_alias_width; - float aaw0 = (1 - has_prev) * anti_alias_width; - float aaw2 = (1 - has_next) * anti_alias_width; + // aaw is the added width given around the polygon for antialiasing. + // In case the normal is faced away from (0, 0, 1), the vector to the + // camera, this is scaled up. + float aaw = anti_alias_width / get_aaw_scalar(normal); + float buff0 = 0.5 * v_stroke_width[0] + aaw; + float buff2 = 0.5 * v_stroke_width[2] + aaw; + float aaw0 = (1 - has_prev) * aaw; + float aaw2 = (1 - has_next) * aaw; - vec2 c0 = p0 - buff0 * p0_perp + aaw0 * v10; - vec2 c1 = p0 + buff0 * p0_perp + aaw0 * v10; - vec2 c2 = p2 - p2_perp * buff2 + aaw2 * v12; - vec2 c3 = p2 + p2_perp * buff2 + aaw2 * v12; + vec3 c0 = p0 - buff0 * p0_perp + aaw0 * v10; + vec3 c1 = p0 + buff0 * p0_perp + aaw0 * v10; + vec3 c2 = p2 + buff2 * p2_perp + aaw2 * v12; + vec3 c3 = p2 - buff2 * p2_perp + aaw2 * v12; // Account for previous and next control points - if(has_prev == 1){ - create_joint(angle_from_prev, v01, buff0, bevel_start, c0, c0, c1, c1); - } - if(has_next == 1){ - create_joint(-angle_to_next, v21, buff2, bevel_end, c2, c2, c3, c3); - } + if(has_prev > 0) create_joint(angle_from_prev, v01, buff0, bevel_start, c0, c0, c1, c1); + if(has_next > 0) create_joint(angle_to_next, v21, buff2, bevel_end, c3, c3, c2, c2); - // Linear case is the simplets + // Linear case is the simplest if(degree == 1){ // Swap between 2 and 3 is deliberate, the order of corners // should be for a triangle_strip. Last entry is a dummy - corners = vec2[5](c0, c1, c3, c2, vec2(0.0)); + corners = vec3[5](c0, c1, c3, c2, vec3(0.0)); return 4; } - - // Some admitedly complicated logic to (hopefully efficiently) - // make sure corners forms a convex hull around the curve. - if(cross(v10, v12) > 0){ - bool change_c0 = ( - // has_prev == 0 && - dot(v21, v20) > 0 && - should_motify_corner(c0, v01, c2, c3, v21, buff0) - ); - if(change_c0) c0 = p0 + p2_perp * buff0; - - bool change_c3 = ( - // has_next == 0 && - dot(v01, v02) > 0 && - should_motify_corner(c3, v21, c1, c0, v01, buff2) - ); - if(change_c3) c3 = p2 - p0_perp * buff2; - - vec2 i12; - find_intersection(c1, v01, c2, v21, i12); - corners = vec2[5](c1, c0, i12, c3, c2); - }else{ - bool change_c1 = ( - // has_prev == 0 && - dot(v21, v20) > 0 && - should_motify_corner(c1, v01, c3, c2, v21, buff0) - ); - if(change_c1) c1 = p0 - p2_perp * buff0; - - bool change_c2 = ( - // has_next == 0 && - dot(v01, v02) > 0 && - should_motify_corner(c2, v21, c0, c1, v01, buff2) - ); - if(change_c2) c2 = p2 + p0_perp * buff2; - - vec2 i03; - find_intersection(c0, v01, c3, v21, i03); - corners = vec2[5](c0, c1, i03, c2, c3); - } + // Otherwise, form a pentagon around the curve + float orientation = sign(dot(cross(v01, v12), normal)); // Positive for ccw curves + if(orientation > 0) corners = vec3[5](c0, c1, p1, c2, c3); + else corners = vec3[5](c1, c0, p1, c3, c2); + // Replace corner[2] with convex hull point accounting for stroke width + find_intersection(corners[0], v01, corners[4], v21, normal, corners[2]); return 5; } -void set_adjascent_info(vec2 c0, vec2 tangent, - int degree, int mult, int flip, - vec2 adj[3], - out float has, +void set_adjascent_info(vec3 c0, vec3 tangent, + int degree, + vec3 normal, + vec3 adj[3], out float bevel, out float angle ){ float joint_type = v_joint_type[0]; - - has = 0; - bevel = 0; - angle = 0; - - vec2 new_adj[3]; - int adj_degree = get_reduced_control_points( - adj[0], adj[1], adj[2], new_adj - ); - has = float(adj_degree > 0); - if(has == 1){ - vec2 adj = new_adj[mult * adj_degree - flip]; - angle = flip * angle_between_vectors(c0 - adj, tangent); - } + vec3 new_adj[3]; + float adj_degree = get_reduced_control_points(adj, new_adj); + // Check if adj_degree is zero? + angle = angle_between_vectors(c0 - new_adj[1], tangent, normal); // Decide on joint type - bool one_linear = (degree == 1 || adj_degree == 1); + bool one_linear = (degree == 1 || adj_degree == 1.0); bool should_bevel = ( (joint_type == AUTO_JOINT && one_linear) || joint_type == BEVEL_JOINT @@ -243,73 +207,68 @@ void set_adjascent_info(vec2 c0, vec2 tangent, } -void set_previous_and_next(vec2 controls[3], int degree){ - float a_tol = 1e-10; +void set_previous_and_next(vec3 controls[3], int degree, vec3 normal){ + float a_tol = 1e-8; - if(distance(prev_bp[2], bp[0]) < a_tol){ - vec2 tangent = controls[1] - controls[0]; + // Made as floats not bools so they can be passed to the frag shader + has_prev = float(distance(prev_bp[2], bp[0]) < a_tol); + has_next = float(distance(next_bp[0], bp[2]) < a_tol); + + if(has_prev > 0){ + vec3 tangent = controls[1] - controls[0]; set_adjascent_info( - controls[0], tangent, degree, 1, 1, - vec2[3](prev_bp[0].xy, prev_bp[1].xy, prev_bp[2].xy), - has_prev, bevel_start, angle_from_prev + controls[0], tangent, degree, normal, + vec3[3](prev_bp[0], prev_bp[1], prev_bp[2]), + bevel_start, angle_from_prev ); } - if(distance(next_bp[0], bp[2]) < a_tol){ - vec2 tangent = controls[degree - 1] - controls[degree]; + if(has_next > 0){ + vec3 tangent = controls[1] - controls[2]; set_adjascent_info( - controls[degree], tangent, degree, 0, -1, - vec2[3](next_bp[0].xy, next_bp[1].xy, next_bp[2].xy), - has_next, bevel_end, angle_to_next + controls[2], tangent, degree, normal, + vec3[3](next_bp[0], next_bp[1], next_bp[2]), + bevel_end, angle_to_next ); + angle_to_next *= -1; } } void main() { - vec2 controls[3]; - int degree = get_reduced_control_points(bp[0].xy, bp[1].xy, bp[2].xy, controls); - bezier_degree = float(degree); + vec3 unit_normal = v_global_unit_normal[0]; + // anti_alias_width /= cos(0.5 * acos(abs(unit_normal.z))); - // Null curve or linear with higher index than needed + vec3 controls[3]; + bezier_degree = get_reduced_control_points(vec3[3](bp[0], bp[1], bp[2]), controls); + int degree = int(bezier_degree); + + // Null curve if(degree == 0) return; - set_previous_and_next(controls, degree); + set_previous_and_next(controls, degree, unit_normal); // Find uv conversion matrix - mat3 xy_to_uv = get_xy_to_uv(controls[0], controls[1]); + mat4 xyz_to_uv = get_xyz_to_uv(controls[0], controls[1], unit_normal); float scale_factor = length(controls[1] - controls[0]); - uv_anti_alias_width = anti_alias_width / scale_factor; - uv_b2 = (xy_to_uv * vec3(controls[degree], 1.0)).xy; + uv_anti_alias_width = anti_alias_width / scale_factor / get_aaw_scalar(unit_normal); + uv_b2 = (xyz_to_uv * vec4(controls[2], 1.0)).xy; // Corners of a bounding region around curve - vec2 corners[5]; - int n_corners = get_corners(controls, degree, corners); + vec3 corners[5]; + int n_corners = get_corners(controls, unit_normal, degree, corners); - // Get style info aligned to the corners - float stroke_widths[5]; - vec4 stroke_colors[5]; - float z_values[5]; - int index_map[5]; - if(n_corners == 4) index_map = int[5](0, 0, 2, 2, 2); - else index_map = int[5](0, 0, 1, 2, 2); - for(int i = 0; i < 5; i++){ - stroke_widths[i] = v_stroke_width[index_map[i]]; - stroke_colors[i] = v_color[index_map[i]]; - z_values[i] = bp[index_map[i]].z; // TODO, seems clunky - } + int index_map[5] = int[5](0, 0, 1, 2, 2); + if(n_corners == 4) index_map[2] = 2; // Emit each corner for(int i = 0; i < n_corners; i++){ - vec2 corner = corners[i]; - uv_coords = (xy_to_uv * vec3(corner, 1.0)).xy; - - uv_stroke_width = stroke_widths[i] / scale_factor; - color = stroke_colors[i]; - - gl_Position = vec4( - scale_and_shift_point_for_frame(vec3(corner, z_values[i])), - 1.0 - ); + xyz_coords = corners[i]; + uv_coords = (xyz_to_uv * vec4(xyz_coords, 1.0)).xy; + uv_stroke_width = v_stroke_width[index_map[i]] / scale_factor; + color = v_color[index_map[i]]; + gloss = v_gloss[index_map[i]]; + global_unit_normal = v_global_unit_normal[index_map[i]]; + gl_Position = get_gl_Position(xyz_coords); EmitVertex(); } EndPrimitive(); diff --git a/manimlib/shaders/quadratic_bezier_stroke_vert.glsl b/manimlib/shaders/quadratic_bezier_stroke_vert.glsl index 60b70580..74e4242c 100644 --- a/manimlib/shaders/quadratic_bezier_stroke_vert.glsl +++ b/manimlib/shaders/quadratic_bezier_stroke_vert.glsl @@ -1,34 +1,44 @@ #version 330 +uniform mat4 to_screen_space; +uniform float focal_distance; + in vec3 point; in vec3 prev_point; in vec3 next_point; +in vec3 unit_normal; in float stroke_width; in vec4 color; in float joint_type; +in float gloss; -out vec3 bp; // Bezier control point +// Bezier control point +out vec3 bp; out vec3 prev_bp; out vec3 next_bp; +out vec3 v_global_unit_normal; out float v_stroke_width; out vec4 v_color; out float v_joint_type; +out float v_gloss; -// TODO, this should maybe depend on scale -const float STROKE_WIDTH_CONVERSION = 0.01; - - -#INSERT rotate_point_for_frame.glsl +const float STROKE_WIDTH_CONVERSION = 0.0025; +// To my knowledge, there is no notion of #include for shaders, +// so to share functionality between this and others, the caller +// replaces this line with the contents of named file +#INSERT position_point_into_frame.glsl void main(){ + bp = position_point_into_frame(point); + prev_bp = position_point_into_frame(prev_point); + next_bp = position_point_into_frame(next_point); + v_global_unit_normal = normalize(position_point_into_frame(unit_normal)); + v_stroke_width = STROKE_WIDTH_CONVERSION * stroke_width; v_color = color; v_joint_type = joint_type; - - bp = rotate_point_for_frame(point); - prev_bp = rotate_point_for_frame(prev_point); - next_bp = rotate_point_for_frame(next_point); + v_gloss = gloss; } \ No newline at end of file diff --git a/manimlib/shaders/rotate_point_for_frame.glsl b/manimlib/shaders/rotate_point_for_frame.glsl deleted file mode 100644 index 5e36077e..00000000 --- a/manimlib/shaders/rotate_point_for_frame.glsl +++ /dev/null @@ -1,4 +0,0 @@ -vec3 rotate_point_for_frame(vec3 point){ - // TODO, orient in 3d based on certain rotation matrices - return point; -} diff --git a/manimlib/shaders/scale_and_shift_point_for_frame.glsl b/manimlib/shaders/scale_and_shift_point_for_frame.glsl index fabeca1d..4c13eced 100644 --- a/manimlib/shaders/scale_and_shift_point_for_frame.glsl +++ b/manimlib/shaders/scale_and_shift_point_for_frame.glsl @@ -1,11 +1,8 @@ -// Assumes theese uniforms exist in the surrounding context -// uniform float scale; +// Assumes the following uniforms exist in the surrounding context: // uniform float aspect_ratio; -// uniform float frame_center; +// TODO, rename -vec3 scale_and_shift_point_for_frame(vec3 point){ - point -= frame_center; - point /= scale; +vec3 get_gl_Position(vec3 point){ point.x /= aspect_ratio; return point; } \ No newline at end of file diff --git a/manimlib/shaders/simple_vert.glsl b/manimlib/shaders/simple_vert.glsl index 829ab02e..c5d8c546 100644 --- a/manimlib/shaders/simple_vert.glsl +++ b/manimlib/shaders/simple_vert.glsl @@ -1,15 +1,16 @@ #version 330 -uniform float scale; uniform float aspect_ratio; uniform float anti_alias_width; -uniform vec3 frame_center; +uniform mat4 to_screen_space; +uniform float focal_distance; in vec3 point; // Analog of import for manim only -#INSERT set_gl_Position.glsl +#INSERT get_gl_Position.glsl +#INSERT position_point_into_frame.glsl void main(){ - set_gl_Position(point); + gl_Position = get_gl_Position(position_point_into_frame(point)); } \ No newline at end of file diff --git a/manimlib/shaders/surface_frag.glsl b/manimlib/shaders/surface_frag.glsl new file mode 100644 index 00000000..3f0eb69c --- /dev/null +++ b/manimlib/shaders/surface_frag.glsl @@ -0,0 +1,12 @@ +#version 330 + +// uniform sampler2D Texture; + +// in vec2 v_im_coords; +in vec4 v_color; + +out vec4 frag_color; + +void main() { + frag_color = v_color; +} \ No newline at end of file diff --git a/manimlib/shaders/surface_vert.glsl b/manimlib/shaders/surface_vert.glsl new file mode 100644 index 00000000..2a8284de --- /dev/null +++ b/manimlib/shaders/surface_vert.glsl @@ -0,0 +1,31 @@ +#version 330 + +uniform float aspect_ratio; +uniform float anti_alias_width; +uniform mat4 to_screen_space; +uniform float focal_distance; +uniform vec3 light_source_position; + +// uniform sampler2D Texture; + +in vec3 point; +in vec3 normal; +// in vec2 im_coords; +in vec4 color; +in float gloss; + +// out vec2 v_im_coords; +out vec4 v_color; + +// Analog of import for manim only +#INSERT position_point_into_frame.glsl +#INSERT get_gl_Position.glsl +#INSERT add_light.glsl + +void main(){ + vec3 xyz_coords = position_point_into_frame(point); + vec3 unit_normal = normalize(position_point_into_frame(normal)); + // v_im_coords = im_coords; + v_color = add_light(color, xyz_coords, unit_normal, light_source_position, gloss); + gl_Position = get_gl_Position(xyz_coords); +} \ No newline at end of file diff --git a/manimlib/utils/space_ops.py b/manimlib/utils/space_ops.py index 6bc974fb..19bf005b 100644 --- a/manimlib/utils/space_ops.py +++ b/manimlib/utils/space_ops.py @@ -3,9 +3,10 @@ import math import itertools as it from mapbox_earcut import triangulate_float32 as earcut +from manimlib.constants import RIGHT +from manimlib.constants import UP from manimlib.constants import OUT from manimlib.constants import PI -from manimlib.constants import RIGHT from manimlib.constants import TAU from manimlib.utils.iterables import adjacent_pairs @@ -84,6 +85,22 @@ def thick_diagonal(dim, thickness=2): return (np.abs(row_indices - col_indices) < thickness).astype('uint8') +def rotation_matrix_transpose_from_quaternion(quat): + quat_inv = quaternion_conjugate(quat) + return [ + quaternion_mult(quat, [0, *basis], quat_inv)[1:] + for basis in [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + ] + ] + + +def rotation_matrix_from_quaternion(quat): + return np.transpose(rotation_matrix_transpose_from_quaternion(quat)) + + def rotation_matrix_transpose(angle, axis): if axis[0] == 0 and axis[1] == 0: # axis = [0, 0, z] case is common enough it's worth @@ -97,15 +114,7 @@ def rotation_matrix_transpose(angle, axis): [0, 0, 1], ] quat = quaternion_from_angle_axis(angle, axis) - quat_inv = quaternion_conjugate(quat) - return [ - quaternion_mult(quat, [0, *basis], quat_inv)[1:] - for basis in [ - [1, 0, 0], - [0, 1, 0], - [0, 0, 1], - ] - ] + return rotation_matrix_transpose_from_quaternion(quat) def rotation_matrix(angle, axis): @@ -128,25 +137,11 @@ def z_to_vector(vector): Returns some matrix in SO(3) which takes the z-axis to the (normalized) vector provided as an argument """ - norm = get_norm(vector) - if norm == 0: + cp = cross(OUT, vector) + if get_norm(cp) == 0: return np.identity(3) - v = np.array(vector) / norm - phi = np.arccos(v[2]) - if any(v[:2]): - # projection of vector to unit circle - axis_proj = v[:2] / get_norm(v[:2]) - theta = np.arccos(axis_proj[0]) - if axis_proj[1] < 0: - theta = -theta - else: - theta = 0 - phi_down = np.array([ - [math.cos(phi), 0, math.sin(phi)], - [0, 1, 0], - [-math.sin(phi), 0, math.cos(phi)] - ]) - return np.dot(rotation_about_z(theta), phi_down) + angle = np.arccos(np.dot(OUT, normalize(vector))) + return rotation_matrix(angle, axis=cp) def angle_of_vector(vector): @@ -188,8 +183,19 @@ def cross(v1, v2): ]) -def get_unit_normal(v1, v2): - return normalize(cross(v1, v2)) +def get_unit_normal(v1, v2, tol=1e-6): + v1 = normalize(v1) + v2 = normalize(v2) + cp = cross(v1, v2) + cp_norm = get_norm(cp) + if cp_norm < tol: + # Vectors align, so find a normal to them in the plane shared with the z-axis + new_cp = cross(cross(v1, OUT), v1) + new_cp_norm = get_norm(new_cp) + if new_cp_norm < tol: + return UP + return new_cp / new_cp_norm + return cp / cp_norm ###