mirror of
https://github.com/3b1b/manim.git
synced 2025-08-05 16:49:03 +00:00
635 lines
17 KiB
Python
635 lines
17 KiB
Python
from manimlib.imports import *
|
|
import scipy.stats
|
|
|
|
|
|
CMARK_TEX = "\\text{\\ding{51}}"
|
|
XMARK_TEX = "\\text{\\ding{55}}"
|
|
|
|
COIN_COLOR_MAP = {
|
|
"H": BLUE_E,
|
|
"T": RED_E,
|
|
}
|
|
|
|
|
|
class Histogram(Group):
|
|
CONFIG = {
|
|
"height": 5,
|
|
"width": 10,
|
|
"y_max": 1,
|
|
"y_axis_numbers_to_show": range(20, 120, 20),
|
|
"y_axis_label_height": 0.25,
|
|
"y_tick_freq": 0.2,
|
|
"x_label_freq": 1,
|
|
"include_h_lines": True,
|
|
"h_line_style": {
|
|
"stroke_width": 1,
|
|
"stroke_color": LIGHT_GREY,
|
|
# "draw_stroke_behind_fill": True,
|
|
},
|
|
"bar_style": {
|
|
"stroke_width": 1,
|
|
"stroke_color": WHITE,
|
|
"fill_opacity": 1,
|
|
},
|
|
"bar_colors": [BLUE, GREEN]
|
|
}
|
|
|
|
def __init__(self, data, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.data = data
|
|
|
|
self.add_axes()
|
|
if self.include_h_lines:
|
|
self.add_h_lines()
|
|
self.add_bars(data)
|
|
self.add_x_axis_labels()
|
|
self.add_y_axis_labels()
|
|
|
|
def add_axes(self):
|
|
n_bars = len(self.data)
|
|
axes_config = {
|
|
"x_min": 0,
|
|
"x_max": n_bars,
|
|
"x_axis_config": {
|
|
"unit_size": self.width / n_bars,
|
|
"include_tip": False,
|
|
},
|
|
"y_min": 0,
|
|
"y_max": self.y_max,
|
|
"y_axis_config": {
|
|
"unit_size": self.height / self.y_max,
|
|
"include_tip": False,
|
|
"tick_frequency": self.y_tick_freq,
|
|
},
|
|
}
|
|
axes = Axes(**axes_config)
|
|
axes.center()
|
|
self.axes = axes
|
|
self.add(axes)
|
|
|
|
def add_h_lines(self):
|
|
axes = self.axes
|
|
axes.h_lines = VGroup()
|
|
for tick in axes.y_axis.tick_marks:
|
|
line = Line(**self.h_line_style)
|
|
line.match_width(axes.x_axis)
|
|
line.move_to(tick.get_center(), LEFT)
|
|
axes.h_lines.add(line)
|
|
axes.add(axes.h_lines)
|
|
|
|
def add_bars(self, data):
|
|
self.bars = self.get_bars(data)
|
|
self.add(self.bars)
|
|
|
|
def add_x_axis_labels(self):
|
|
axes = self.axes
|
|
axes.x_labels = VGroup()
|
|
for x, bar in list(enumerate(self.bars))[::self.x_label_freq]:
|
|
label = Integer(x)
|
|
label.set_height(0.25)
|
|
label.next_to(bar, DOWN)
|
|
axes.x_labels.add(label)
|
|
axes.add(axes.x_labels)
|
|
|
|
def add_y_axis_labels(self):
|
|
axes = self.axes
|
|
labels = VGroup()
|
|
for value in self.y_axis_numbers_to_show:
|
|
label = Integer(value, unit="\\%")
|
|
fix_percent(label[-1][0])
|
|
label.set_height(self.y_axis_label_height)
|
|
label.next_to(axes.y_axis.n2p(0.01 * value), LEFT)
|
|
labels.add(label)
|
|
axes.y_labels = labels
|
|
axes.y_axis.add(labels)
|
|
|
|
# Bar manipulations
|
|
def get_bars(self, data):
|
|
portions = np.array(data).astype(float)
|
|
total = portions.sum()
|
|
if total == 0:
|
|
portions[:] = 0
|
|
else:
|
|
portions /= total
|
|
bars = VGroup()
|
|
for x, prop in enumerate(portions):
|
|
bar = Rectangle()
|
|
width = get_norm(self.axes.c2p(1, 0) - self.axes.c2p(0, 0))
|
|
height = get_norm(self.axes.c2p(0, 1) - self.axes.c2p(0, 0))
|
|
bar.set_width(width)
|
|
bar.set_height(height * prop, stretch=True)
|
|
bar.move_to(self.axes.c2p(x, 0), DL)
|
|
bars.add(bar)
|
|
|
|
bars.set_submobject_colors_by_gradient(*self.bar_colors)
|
|
bars.set_style(**self.bar_style)
|
|
return bars
|
|
|
|
|
|
# Images of randomness
|
|
|
|
def fix_percent(sym):
|
|
# Really need to make this unneeded...
|
|
new_sym = sym.copy()
|
|
path_lengths = [len(path) for path in sym.get_subpaths()]
|
|
n = sum(path_lengths[:2])
|
|
p1 = sym.points[:n]
|
|
p2 = sym.points[n:]
|
|
sym.points = p1
|
|
new_sym.points = p2
|
|
sym.add(new_sym)
|
|
sym.lock_triangulation()
|
|
|
|
|
|
def get_random_process(choices, shuffle_time=2, total_time=3, change_rate=0.05,
|
|
h_buff=0.1, v_buff=0.1):
|
|
content = choices[0]
|
|
|
|
container = Square()
|
|
container.set_opacity(0)
|
|
container.set_width(content.get_width() + 2 * h_buff, stretch=True)
|
|
container.set_height(content.get_height() + 2 * v_buff, stretch=True)
|
|
container.move_to(content)
|
|
container.add(content)
|
|
container.time = 0
|
|
container.last_change_time = 0
|
|
|
|
def update(container, dt):
|
|
container.time += dt
|
|
|
|
t = container.time
|
|
change = all([
|
|
(t % total_time) < shuffle_time,
|
|
container.time - container.last_change_time > change_rate
|
|
])
|
|
if change:
|
|
mob = container.submobjects[0]
|
|
new_mob = random.choice(choices)
|
|
new_mob.match_height(mob)
|
|
new_mob.move_to(container, DL)
|
|
new_mob.shift(2 * np.random.random() * h_buff * RIGHT)
|
|
new_mob.shift(2 * np.random.random() * v_buff * UP)
|
|
container.set_submobjects([new_mob])
|
|
container.last_change_time = container.time
|
|
|
|
container.add_updater(update)
|
|
return container
|
|
|
|
|
|
def get_die_faces():
|
|
dot = Dot()
|
|
dot.set_width(0.15)
|
|
dot.set_color(BLUE_B)
|
|
|
|
square = Square()
|
|
square.round_corners(0.25)
|
|
square.set_stroke(WHITE, 2)
|
|
square.set_fill(DARKER_GREY, 1)
|
|
square.set_width(0.6)
|
|
|
|
edge_groups = [
|
|
(ORIGIN,),
|
|
(UL, DR),
|
|
(UL, ORIGIN, DR),
|
|
(UL, UR, DL, DR),
|
|
(UL, UR, ORIGIN, DL, DR),
|
|
(UL, UR, LEFT, RIGHT, DL, DR),
|
|
]
|
|
|
|
arrangements = VGroup(*[
|
|
VGroup(*[
|
|
dot.copy().move_to(square.get_bounding_box_point(ec))
|
|
for ec in edge_group
|
|
])
|
|
for edge_group in edge_groups
|
|
])
|
|
square.set_width(1)
|
|
|
|
faces = VGroup(*[
|
|
VGroup(square.copy(), arrangement)
|
|
for arrangement in arrangements
|
|
])
|
|
faces.arrange(RIGHT)
|
|
|
|
return faces
|
|
|
|
|
|
def get_random_die(**kwargs):
|
|
return get_random_process(get_die_faces(), **kwargs)
|
|
|
|
|
|
def get_random_card(height=1, **kwargs):
|
|
cards = DeckOfCards()
|
|
cards.set_height(height)
|
|
return get_random_process(cards, **kwargs)
|
|
|
|
|
|
# Coins
|
|
def get_coin(symbol, color=None):
|
|
if color is None:
|
|
color = COIN_COLOR_MAP.get(symbol, GREY_E)
|
|
coin = VGroup()
|
|
circ = Circle()
|
|
circ.set_fill(color, 1)
|
|
circ.set_stroke(WHITE, 1)
|
|
circ.set_height(1)
|
|
label = TextMobject(symbol)
|
|
label.set_height(0.5 * circ.get_height())
|
|
label.move_to(circ)
|
|
coin.add(circ, label)
|
|
coin.symbol = symbol
|
|
coin.lock_triangulation()
|
|
return coin
|
|
|
|
|
|
def get_random_coin(**kwargs):
|
|
return get_random_process([get_coin("H"), get_coin("T")], **kwargs)
|
|
|
|
|
|
def get_prob_coin_label(symbol="H", color=None, p=0.5, num_decimal_places=2):
|
|
label = TexMobject("P", "(", "00", ")", "=",)
|
|
coin = get_coin(symbol, color)
|
|
template = label.get_part_by_tex("00")
|
|
coin.replace(template)
|
|
label.replace_submobject(label.index_of_part(template), coin)
|
|
rhs = DecimalNumber(p, num_decimal_places=num_decimal_places)
|
|
rhs.next_to(label, RIGHT, buff=MED_SMALL_BUFF)
|
|
label.add(rhs)
|
|
return label
|
|
|
|
|
|
def get_q_box(mob):
|
|
box = SurroundingRectangle(mob)
|
|
box.set_stroke(WHITE, 1)
|
|
box.set_fill(GREY_E, 1)
|
|
q_marks = TexMobject("???")
|
|
max_width = 0.8 * box.get_width()
|
|
max_height = 0.8 * box.get_height()
|
|
|
|
if q_marks.get_width() > max_width:
|
|
q_marks.set_width(max_width)
|
|
|
|
if q_marks.get_height() > max_height:
|
|
q_marks.set_height(max_height)
|
|
|
|
q_marks.move_to(box)
|
|
box.add(q_marks)
|
|
return box
|
|
|
|
|
|
def get_coin_grid(bools, height=6):
|
|
coins = VGroup(*[
|
|
get_coin("H" if heads else "T")
|
|
for heads in bools
|
|
])
|
|
coins.arrange_in_grid()
|
|
coins.set_height(height)
|
|
return coins
|
|
|
|
|
|
def get_prob_positive_experience_label(include_equals=False,
|
|
include_decimal=False,
|
|
include_q_mark=False):
|
|
label = TexMobject(
|
|
"P", "(", "00000", ")",
|
|
)
|
|
|
|
pe = TextMobject("Positive\\\\experience")
|
|
pe.set_color(GREEN)
|
|
pe.replace(label[2], dim_to_match=0)
|
|
label.replace_submobject(2, pe)
|
|
VGroup(label[1], label[3]).match_height(
|
|
pe, stretch=True, about_edge=DOWN,
|
|
)
|
|
if include_equals:
|
|
eq = TexMobject("=").next_to(label, RIGHT)
|
|
label.add(eq)
|
|
if include_decimal:
|
|
decimal = DecimalNumber(0.95)
|
|
decimal.next_to(label, RIGHT)
|
|
decimal.set_color(YELLOW)
|
|
label.decimal = decimal
|
|
label.add(decimal)
|
|
if include_q_mark:
|
|
q_mark = TexMobject("?")
|
|
q_mark.relative_mob = label[-1]
|
|
q_mark.add_updater(
|
|
lambda m: m.next_to(m.relative_mob, RIGHT, SMALL_BUFF)
|
|
)
|
|
label.add(q_mark)
|
|
|
|
return label
|
|
|
|
|
|
def get_beta_dist_axes(y_max=20, y_unit=2, label_y=False, **kwargs):
|
|
config = {
|
|
"x_min": 0,
|
|
"x_max": 1,
|
|
"x_axis_config": {
|
|
"unit_size": 0.1,
|
|
"tick_frequency": 0.1,
|
|
"include_tip": False,
|
|
},
|
|
"y_min": 0,
|
|
"y_max": y_max,
|
|
"y_axis_config": {
|
|
"unit_size": 1,
|
|
"tick_frequency": y_unit,
|
|
"include_tip": False,
|
|
},
|
|
}
|
|
result = Axes(**config)
|
|
origin = result.c2p(0, 0)
|
|
kw = {
|
|
"about_point": origin,
|
|
"stretch": True,
|
|
}
|
|
result.x_axis.set_width(11, **kw)
|
|
result.y_axis.set_height(6, **kw)
|
|
|
|
x_vals = np.arange(0, 1, 0.2) + 0.2
|
|
result.x_axis.add_numbers(
|
|
*x_vals,
|
|
number_config={"num_decimal_places": 1}
|
|
)
|
|
|
|
if label_y:
|
|
result.y_axis.add_numbers(
|
|
*np.arange(y_unit, y_max, y_unit)
|
|
)
|
|
label = TextMobject("Probability density")
|
|
label.scale(0.5)
|
|
label.next_to(result.y_axis.get_top(), UR, SMALL_BUFF)
|
|
label.next_to(result.y_axis, UP, SMALL_BUFF)
|
|
label.align_to(result.y_axis.numbers, LEFT)
|
|
result.add(label)
|
|
result.y_axis_label = label
|
|
|
|
result.to_corner(DR, LARGE_BUFF)
|
|
|
|
return result
|
|
|
|
|
|
def scaled_pdf_axes(scale_factor=3.5):
|
|
axes = get_beta_dist_axes(
|
|
label_y=True,
|
|
y_unit=1,
|
|
)
|
|
axes.y_axis.numbers.set_submobjects([
|
|
*axes.y_axis.numbers[:5],
|
|
*axes.y_axis.numbers[4::5]
|
|
])
|
|
sf = scale_factor
|
|
axes.y_axis.stretch(sf, 1, about_point=axes.c2p(0, 0))
|
|
for number in axes.y_axis.numbers:
|
|
number.stretch(1 / sf, 1)
|
|
axes.y_axis_label.to_edge(LEFT)
|
|
axes.y_axis_label.add_background_rectangle(opacity=1)
|
|
axes.set_stroke(background=True)
|
|
return axes
|
|
|
|
|
|
def close_off_graph(axes, graph):
|
|
x_max = axes.x_axis.p2n(graph.get_end())
|
|
graph.add_line_to(axes.c2p(x_max, 0))
|
|
graph.add_line_to(axes.c2p(0, 0))
|
|
graph.lock_triangulation()
|
|
return graph
|
|
|
|
|
|
def get_beta_graph(axes, n_plus, n_minus, **kwargs):
|
|
dist = scipy.stats.beta(n_plus + 1, n_minus + 1)
|
|
graph = axes.get_graph(dist.pdf, **kwargs)
|
|
close_off_graph(axes, graph)
|
|
graph.set_stroke(BLUE, 2)
|
|
graph.set_fill(BLUE_E, 1)
|
|
graph.lock_triangulation()
|
|
return graph
|
|
|
|
|
|
def get_beta_label(n_plus, n_minus, point=ORIGIN):
|
|
template = TextMobject("Beta(", "00", ",", "00", ")")
|
|
template.scale(1.5)
|
|
a_label = Integer(n_plus + 1)
|
|
a_label.set_color(GREEN)
|
|
b_label = Integer(n_minus + 1)
|
|
b_label.set_color(RED)
|
|
|
|
for i, label in (1, a_label), (3, b_label):
|
|
label.match_height(template[i])
|
|
label.move_to(template[i], DOWN)
|
|
template.replace_submobject(i, label)
|
|
template.save_state()
|
|
template.arrange(RIGHT, buff=0.15)
|
|
for t1, t2 in zip(template, template.saved_state):
|
|
t1.align_to(t2, DOWN)
|
|
|
|
return template
|
|
|
|
|
|
def get_plusses_and_minuses(n_rows=15, n_cols=20, p=0.95):
|
|
result = VGroup()
|
|
for x in range(n_rows * n_cols):
|
|
if random.random() < p:
|
|
mob = TexMobject(CMARK_TEX)
|
|
mob.set_color(GREEN)
|
|
mob.is_plus = True
|
|
else:
|
|
mob = TexMobject(XMARK_TEX)
|
|
mob.set_color(RED)
|
|
mob.is_plus = False
|
|
mob.set_width(1)
|
|
result.add(mob)
|
|
|
|
result.arrange_in_grid(n_rows, n_cols)
|
|
result.set_width(5.5)
|
|
return result
|
|
|
|
|
|
def get_checks_and_crosses(bools, width=12):
|
|
result = VGroup()
|
|
for positive in bools:
|
|
if positive:
|
|
mob = TexMobject(CMARK_TEX)
|
|
mob.set_color(GREEN)
|
|
else:
|
|
mob = TexMobject(XMARK_TEX)
|
|
mob.set_color(RED)
|
|
mob.positive = positive
|
|
mob.set_width(0.5)
|
|
result.add(mob)
|
|
result.arrange(RIGHT, buff=MED_SMALL_BUFF)
|
|
result.set_width(width)
|
|
return result
|
|
|
|
|
|
def get_underlines(marks):
|
|
underlines = VGroup()
|
|
for mark in marks:
|
|
underlines.add(Underline(mark))
|
|
for line in underlines:
|
|
line.align_to(underlines[-1], DOWN)
|
|
return underlines
|
|
|
|
|
|
def get_random_checks_and_crosses(n=50, s=0.95, width=12):
|
|
return get_checks_and_crosses(
|
|
bools=(np.random.random(n) < s),
|
|
width=width
|
|
)
|
|
|
|
|
|
def get_random_num_row(s, n=10):
|
|
values = np.random.random(n)
|
|
nums = VGroup()
|
|
syms = VGroup()
|
|
for x, value in enumerate(values):
|
|
num = DecimalNumber(value)
|
|
num.set_height(0.25)
|
|
num.move_to(x * RIGHT)
|
|
num.positive = (num.get_value() < s)
|
|
if num.positive:
|
|
num.set_color(GREEN)
|
|
sym = TexMobject(CMARK_TEX)
|
|
else:
|
|
num.set_color(RED)
|
|
sym = TexMobject(XMARK_TEX)
|
|
sym.match_color(num)
|
|
sym.match_height(num)
|
|
sym.positive = num.positive
|
|
sym.next_to(num, UP)
|
|
|
|
nums.add(num)
|
|
syms.add(sym)
|
|
|
|
row = VGroup(nums, syms)
|
|
row.nums = nums
|
|
row.syms = syms
|
|
row.n_positive = sum([m.positive for m in nums])
|
|
|
|
row.set_width(10)
|
|
row.center().to_edge(UP)
|
|
return row
|
|
|
|
|
|
def get_prob_review_label(n_positive, n_negative, s=0.95):
|
|
label = TexMobject(
|
|
"P(",
|
|
f"{n_positive}\\,{CMARK_TEX}", ",\\,",
|
|
f"{n_negative}\\,{XMARK_TEX}",
|
|
"\\,|\\,",
|
|
"s = {:.2f}".format(s),
|
|
")",
|
|
)
|
|
label.set_color_by_tex_to_color_map({
|
|
CMARK_TEX: GREEN,
|
|
XMARK_TEX: RED,
|
|
"0.95": YELLOW,
|
|
})
|
|
return label
|
|
|
|
|
|
def get_binomial_formula(n, k, p):
|
|
n_mob = Integer(n, color=WHITE)
|
|
k_mob = Integer(k, color=GREEN)
|
|
nmk_mob = Integer(n - k, color=RED)
|
|
p_mob = DecimalNumber(p, color=YELLOW)
|
|
|
|
n_str = "N" * len(n_mob)
|
|
k_str = "K" * len(k_mob)
|
|
p_str = "P" * len(k_mob)
|
|
nmk_str = "M" * len(nmk_mob)
|
|
|
|
formula = TexMobject(
|
|
"\\left(",
|
|
"{" + n_str,
|
|
"\\over",
|
|
k_str + "}",
|
|
"\\right)",
|
|
"(", p_str, ")",
|
|
"^{" + k_str + "}",
|
|
"(1 - ", p_str, ")",
|
|
"^{" + nmk_str + "}",
|
|
)
|
|
parens = VGroup(formula[0], formula[4])
|
|
parens.space_out_submobjects(0.7)
|
|
formula.remove(formula.get_part_by_tex("\\over"))
|
|
pairs = (
|
|
(n_mob, n_str),
|
|
(k_mob, k_str),
|
|
(nmk_mob, nmk_str),
|
|
(p_mob, p_str),
|
|
)
|
|
for mob, tex in pairs:
|
|
parts = formula.get_parts_by_tex(tex)
|
|
for part in parts:
|
|
mob_copy = mob.copy()
|
|
i = formula.index_of_part_by_tex(tex)
|
|
mob_copy.match_height(part)
|
|
mob_copy.move_to(part, DOWN)
|
|
formula.replace_submobject(i, mob_copy)
|
|
|
|
terms = VGroup(
|
|
formula[:4],
|
|
formula[4:7],
|
|
formula[7],
|
|
formula[8:11],
|
|
formula[11],
|
|
)
|
|
ys = [term.get_y() for term in terms]
|
|
terms.arrange(RIGHT, buff=SMALL_BUFF)
|
|
terms[0].shift(SMALL_BUFF * LEFT)
|
|
for term, y in zip(terms, ys):
|
|
term.set_y(y)
|
|
|
|
return formula
|
|
|
|
|
|
def get_check_count_label(nc, nx, include_rect=True):
|
|
result = VGroup(
|
|
Integer(nc),
|
|
TexMobject(CMARK_TEX, color=GREEN),
|
|
Integer(nx),
|
|
TexMobject(XMARK_TEX, color=RED),
|
|
)
|
|
result.arrange(RIGHT, buff=SMALL_BUFF)
|
|
result[2:].shift(SMALL_BUFF * RIGHT)
|
|
|
|
if include_rect:
|
|
rect = SurroundingRectangle(result)
|
|
rect.set_stroke(WHITE, 1)
|
|
rect.set_fill(GREY_E, 1)
|
|
result.add_to_back(rect)
|
|
|
|
return result
|
|
|
|
|
|
def reverse_smooth(t):
|
|
return smooth(1 - t)
|
|
|
|
|
|
def get_region_under_curve(axes, graph, min_x, max_x):
|
|
props = [
|
|
binary_search(
|
|
function=lambda a: axes.x_axis.p2n(graph.pfp(a)),
|
|
target=x,
|
|
lower_bound=axes.x_min,
|
|
upper_bound=axes.x_max,
|
|
)
|
|
for x in [min_x, max_x]
|
|
]
|
|
region = graph.copy()
|
|
region.pointwise_become_partial(graph, *props)
|
|
region.add_line_to(axes.c2p(max_x, 0))
|
|
region.add_line_to(axes.c2p(min_x, 0))
|
|
region.add_line_to(region.get_start())
|
|
|
|
region.set_stroke(GREEN, 2)
|
|
region.set_fill(GREEN, 0.5)
|
|
|
|
region.axes = axes
|
|
region.graph = graph
|
|
region.min_x = min_x
|
|
region.max_x = max_x
|
|
|
|
return region
|