mirror of
https://github.com/3b1b/manim.git
synced 2025-04-13 09:47:07 +00:00
Up to weight matrix introduction of nn
This commit is contained in:
parent
2d20084573
commit
7f69b6aa93
4 changed files with 656 additions and 14 deletions
|
@ -36,6 +36,8 @@ python extract_scene.py -p example_scenes.py SquareToCircle
|
|||
|
||||
`-p` gives a preview of an animation, `-w` will write it to a file, and `-s` will show/save the final image in the animation.
|
||||
|
||||
You will probably want to change the MOVIE_DIR constant to be whatever direction you want video files to output to.
|
||||
|
||||
Look through the old_projects folder to see the code for previous 3b1b videos.
|
||||
|
||||
While developing a scene, the `-s` flag is helpful to just see what things look like at the end without having to generate the full animation. It can also be helpful to put `self.force_skipping()` at the top of the construct method, and `self.revert_to_original_skipping_status()` before the portion of the scene that you want to test, and run with the `-p` flag to just see a preview of one part of the scene.
|
||||
|
|
|
@ -330,6 +330,10 @@ class LaggedStart(Animation):
|
|||
anim.update(alpha)
|
||||
return self
|
||||
|
||||
def clean_up(self, *args, **kwargs):
|
||||
for anim in self.subanimations:
|
||||
anim.clean_up(*args, **kwargs)
|
||||
|
||||
class DelayByOrder(Animation):
|
||||
"""
|
||||
Modifier of animation.
|
||||
|
|
648
nn/scenes.py
648
nn/scenes.py
|
@ -153,6 +153,8 @@ class NetworkMobject(VGroup):
|
|||
radius = self.neuron_radius,
|
||||
stroke_color = self.neuron_stroke_color,
|
||||
stroke_width = self.neuron_stroke_width,
|
||||
fill_color = self.neuron_fill_color,
|
||||
fill_opacity = 0,
|
||||
)
|
||||
for x in range(n_neurons)
|
||||
])
|
||||
|
@ -277,13 +279,17 @@ class MNistNetworkMobject(NetworkMobject):
|
|||
class NetworkScene(Scene):
|
||||
CONFIG = {
|
||||
"layer_sizes" : [8, 6, 6, 4],
|
||||
"network_mob_config" : {},
|
||||
}
|
||||
def setup(self):
|
||||
self.add_network()
|
||||
|
||||
def add_network(self):
|
||||
self.network = Network(sizes = self.layer_sizes)
|
||||
self.network_mob = NetworkMobject(self.network)
|
||||
self.network_mob = NetworkMobject(
|
||||
self.network,
|
||||
**self.network_mob_config
|
||||
)
|
||||
self.add(self.network_mob)
|
||||
|
||||
def feed_forward(self, input_vector, false_confidence = False, added_anims = None):
|
||||
|
@ -731,12 +737,12 @@ class PreviewMNistNetwork(NetworkScene):
|
|||
def feed_in_image(self, in_vect):
|
||||
image = PixelsFromVect(in_vect)
|
||||
image.next_to(self.network_mob, LEFT, LARGE_BUFF, UP)
|
||||
big_rect = SurroundingRectangle(image, color = BLUE)
|
||||
image_rect = SurroundingRectangle(image, color = BLUE)
|
||||
start_neurons = self.network_mob.layers[0].neurons.copy()
|
||||
start_neurons.set_stroke(WHITE, width = 0)
|
||||
start_neurons.set_fill(WHITE, 0)
|
||||
|
||||
self.play(FadeIn(image), FadeIn(big_rect))
|
||||
self.play(FadeIn(image), FadeIn(image_rect))
|
||||
self.feed_forward(in_vect, added_anims = [
|
||||
self.get_image_to_layer_one_animation(image, start_neurons)
|
||||
])
|
||||
|
@ -749,10 +755,13 @@ class PreviewMNistNetwork(NetworkScene):
|
|||
self.network_mob.output_labels[n],
|
||||
))
|
||||
self.play(ShowCreation(rect))
|
||||
self.play(FadeOut(rect))
|
||||
self.reset_display(rect, image, image_rect)
|
||||
|
||||
def reset_display(self, answer_rect, image, image_rect):
|
||||
self.play(FadeOut(answer_rect))
|
||||
self.play(
|
||||
FadeOut(image),
|
||||
FadeOut(big_rect),
|
||||
FadeOut(image_rect),
|
||||
self.network_mob.deactivate_layers,
|
||||
)
|
||||
|
||||
|
@ -2787,23 +2796,33 @@ class ContinualEdgeUpdate(ContinualAnimation):
|
|||
def __init__(self, network_mob, **kwargs):
|
||||
n_cycles = 5
|
||||
edges = VGroup(*it.chain(*network_mob.edge_groups))
|
||||
self.move_to_targets = []
|
||||
for edge in edges:
|
||||
edge.colors = [edge.get_color()] + [
|
||||
edge.colors = [
|
||||
random.choice([GREEN, GREEN, GREEN, RED])
|
||||
for x in range(n_cycles)
|
||||
]
|
||||
edge.widths = [edge.get_stroke_width()] + [
|
||||
edge.widths = [
|
||||
3*random.random()**7
|
||||
for x in range(n_cycles)
|
||||
]
|
||||
edge.cycle_time = 1 + random.random()
|
||||
|
||||
edge.generate_target()
|
||||
edge.target.set_stroke(edge.colors[0], edge.widths[0])
|
||||
self.move_to_targets.append(MoveToTarget(edge))
|
||||
self.edges = edges
|
||||
ContinualAnimation.__init__(self, edges, **kwargs)
|
||||
|
||||
def update_mobject(self, dt):
|
||||
if self.internal_time < 1:
|
||||
alpha = smooth(self.internal_time)
|
||||
for move_to_target in self.move_to_targets:
|
||||
move_to_target.update(alpha)
|
||||
return
|
||||
for edge in self.edges:
|
||||
t = self.internal_time/edge.cycle_time
|
||||
alpha = (self.internal_time%edge.cycle_time)/edge.cycle_time
|
||||
t = (self.internal_time-1)/edge.cycle_time
|
||||
alpha = ((self.internal_time-1)%edge.cycle_time)/edge.cycle_time
|
||||
low_n = int(t)%len(edge.colors)
|
||||
high_n = int(t+1)%len(edge.colors)
|
||||
color = interpolate_color(edge.colors[low_n], edge.colors[high_n], alpha)
|
||||
|
@ -3029,6 +3048,617 @@ class ShowRemainingNetwork(IntroduceWeights):
|
|||
*added_anims
|
||||
)
|
||||
|
||||
class ImagineSettingByHand(Scene):
|
||||
def construct(self):
|
||||
randy = Randolph()
|
||||
randy.scale(0.7)
|
||||
randy.to_corner(DOWN+LEFT)
|
||||
|
||||
bubble = randy.get_bubble()
|
||||
network_mob = NetworkMobject(
|
||||
Network(sizes = [8, 6, 6, 4]),
|
||||
neuron_stroke_color = WHITE
|
||||
)
|
||||
network_mob.scale(0.7)
|
||||
network_mob.move_to(bubble.get_bubble_center())
|
||||
network_mob.shift(MED_SMALL_BUFF*RIGHT + SMALL_BUFF*(UP+RIGHT))
|
||||
|
||||
self.add(randy, bubble, network_mob)
|
||||
self.add(ContinualEdgeUpdate(network_mob))
|
||||
self.play(randy.change, "pondering")
|
||||
self.dither()
|
||||
self.play(Blink(randy))
|
||||
self.dither()
|
||||
self.play(randy.change, "horrified", network_mob)
|
||||
self.play(Blink(randy))
|
||||
self.dither(10)
|
||||
|
||||
class WhenTheNetworkFails(MoreHonestMNistNetworkPreview):
|
||||
CONFIG = {
|
||||
"network_mob_config" : {"layer_to_layer_buff" : 2}
|
||||
}
|
||||
def construct(self):
|
||||
self.setup_network_mob()
|
||||
self.black_box()
|
||||
self.incorrect_classification()
|
||||
self.ask_about_weights()
|
||||
|
||||
def setup_network_mob(self):
|
||||
self.network_mob.scale(0.8)
|
||||
self.network_mob.to_edge(DOWN)
|
||||
|
||||
def black_box(self):
|
||||
network_mob = self.network_mob
|
||||
layers = VGroup(*network_mob.layers[1:3])
|
||||
box = SurroundingRectangle(
|
||||
layers,
|
||||
stroke_color = WHITE,
|
||||
fill_color = BLACK,
|
||||
fill_opacity = 0.8,
|
||||
)
|
||||
words = TextMobject("...rather than treating this as a black box")
|
||||
words.next_to(box, UP, LARGE_BUFF)
|
||||
|
||||
self.play(
|
||||
Write(words, run_time = 2),
|
||||
DrawBorderThenFill(box)
|
||||
)
|
||||
self.dither()
|
||||
self.play(*map(FadeOut, [words, box]))
|
||||
|
||||
def incorrect_classification(self):
|
||||
network = self.network
|
||||
training_data, validation_data, test_data = load_data_wrapper()
|
||||
for in_vect, result in test_data[20:]:
|
||||
network_answer = np.argmax(network.feedforward(in_vect))
|
||||
if network_answer != result:
|
||||
break
|
||||
self.feed_in_image(in_vect)
|
||||
|
||||
wrong = TextMobject("Wrong!")
|
||||
wrong.highlight(RED)
|
||||
wrong.next_to(self.network_mob.layers[-1], UP+RIGHT)
|
||||
self.play(Write(wrong, run_time = 1))
|
||||
|
||||
def ask_about_weights(self):
|
||||
question = TextMobject(
|
||||
"What weights are used here?\\\\",
|
||||
"What are they doing?"
|
||||
)
|
||||
question.next_to(self.network_mob, UP)
|
||||
|
||||
self.add(ContinualEdgeUpdate(self.network_mob))
|
||||
self.play(Write(question))
|
||||
self.dither(10)
|
||||
|
||||
|
||||
###
|
||||
|
||||
def reset_display(self, *args):
|
||||
pass
|
||||
|
||||
class EvenWhenItWorks(TeacherStudentsScene):
|
||||
def construct(self):
|
||||
self.teacher_says(
|
||||
"Even when it works,\\\\",
|
||||
"dig into why."
|
||||
)
|
||||
self.change_student_modes(*["pondering"]*3)
|
||||
self.dither(7)
|
||||
|
||||
class IntroduceWeightMatrix(NetworkScene):
|
||||
CONFIG = {
|
||||
"network_mob_config" : {
|
||||
"neuron_stroke_color" : WHITE,
|
||||
"neuron_fill_color" : WHITE,
|
||||
"neuron_radius" : 0.35,
|
||||
"layer_to_layer_buff" : 2,
|
||||
},
|
||||
"layer_sizes" : [8, 6],
|
||||
}
|
||||
def construct(self):
|
||||
self.setup_network_mob()
|
||||
self.show_weighted_sum()
|
||||
self.organize_activations_into_column()
|
||||
self.organize_weights_as_matrix()
|
||||
self.show_meaning_of_matrix_row()
|
||||
self.connect_weighted_sum_to_matrix_multiplication()
|
||||
self.add_bias_vector()
|
||||
self.apply_sigmoid()
|
||||
self.write_clean_final_expression()
|
||||
|
||||
def setup_network_mob(self):
|
||||
self.network_mob.to_edge(LEFT, buff = LARGE_BUFF)
|
||||
self.network_mob.layers[1].neurons.shift(0.02*RIGHT)
|
||||
|
||||
def show_weighted_sum(self):
|
||||
self.fade_many_neurons()
|
||||
self.activate_first_layer()
|
||||
self.show_first_neuron_weighted_sum()
|
||||
self.add_bias()
|
||||
self.add_sigmoid()
|
||||
self.dither()
|
||||
##
|
||||
|
||||
def fade_many_neurons(self):
|
||||
anims = []
|
||||
neurons = self.network_mob.layers[1].neurons
|
||||
for neuron in neurons[1:]:
|
||||
neuron.save_state()
|
||||
neuron.edges_in.save_state()
|
||||
anims += [
|
||||
neuron.fade, 0.8,
|
||||
neuron.set_fill, None, 0,
|
||||
neuron.edges_in.fade, 0.8,
|
||||
]
|
||||
anims += [
|
||||
Animation(neurons[0]),
|
||||
Animation(neurons[0].edges_in),
|
||||
]
|
||||
self.play(*anims)
|
||||
|
||||
def activate_first_layer(self):
|
||||
layer = self.network_mob.layers[0]
|
||||
activations = 0.7*np.random.random(len(layer.neurons))
|
||||
active_layer = self.network_mob.get_active_layer(0, activations)
|
||||
a_labels = VGroup(*[
|
||||
TexMobject("a^{(0)}_%d"%d)
|
||||
for d in range(len(layer.neurons))
|
||||
])
|
||||
for label, neuron in zip(a_labels, layer.neurons):
|
||||
label.scale(0.75)
|
||||
label.move_to(neuron)
|
||||
|
||||
self.play(
|
||||
Transform(layer, active_layer),
|
||||
Write(a_labels, run_time = 2)
|
||||
)
|
||||
|
||||
self.a_labels = a_labels
|
||||
|
||||
def show_first_neuron_weighted_sum(self):
|
||||
neuron = self.network_mob.layers[1].neurons[0]
|
||||
a_labels = VGroup(*self.a_labels[:2]).copy()
|
||||
a_labels.generate_target()
|
||||
w_labels = VGroup(*[
|
||||
TexMobject("w_{0, %d}"%d)
|
||||
for d in range(len(a_labels))
|
||||
])
|
||||
weighted_sum = VGroup()
|
||||
symbols = VGroup()
|
||||
for a_label, w_label in zip(a_labels.target, w_labels):
|
||||
a_label.scale(1./0.75)
|
||||
plus = TexMobject("+")
|
||||
weighted_sum.add(w_label, a_label, plus)
|
||||
symbols.add(plus)
|
||||
weighted_sum.add(
|
||||
TexMobject("\\cdots"),
|
||||
TexMobject("+"),
|
||||
TexMobject("w_{0, n}"),
|
||||
TexMobject("a^{(0)}_n"),
|
||||
)
|
||||
|
||||
weighted_sum.arrange_submobjects(RIGHT)
|
||||
a1_label = TexMobject("a^{(1)}_0")
|
||||
a1_label.next_to(neuron, RIGHT)
|
||||
equals = TexMobject("=").next_to(a1_label, RIGHT)
|
||||
weighted_sum.next_to(equals, RIGHT)
|
||||
|
||||
symbols.add(*weighted_sum[-4:-2])
|
||||
w_labels.add(weighted_sum[-2])
|
||||
a_labels.add(self.a_labels[-1].copy())
|
||||
a_labels.target.add(weighted_sum[-1])
|
||||
a_labels.add(VGroup(*self.a_labels[2:-1]).copy())
|
||||
a_labels.target.add(VectorizedPoint(weighted_sum[-4].get_center()))
|
||||
|
||||
VGroup(a1_label, equals, weighted_sum).scale(
|
||||
0.75, about_point = a1_label.get_left()
|
||||
)
|
||||
|
||||
w_labels.highlight(GREEN)
|
||||
w_labels.shift(0.6*SMALL_BUFF*DOWN)
|
||||
a_labels.target.shift(0.5*SMALL_BUFF*UP)
|
||||
|
||||
self.play(
|
||||
Write(a1_label),
|
||||
Write(equals),
|
||||
neuron.set_fill, None, 0.3,
|
||||
run_time = 1
|
||||
)
|
||||
self.play(MoveToTarget(a_labels, run_time = 1.5))
|
||||
self.play(
|
||||
Write(w_labels),
|
||||
Write(symbols),
|
||||
)
|
||||
|
||||
self.a1_label = a1_label
|
||||
self.a1_equals = equals
|
||||
self.w_labels = w_labels
|
||||
self.a_labels_in_sum = a_labels
|
||||
self.symbols = symbols
|
||||
self.weighted_sum = VGroup(w_labels, a_labels, symbols)
|
||||
|
||||
def add_bias(self):
|
||||
weighted_sum = self.weighted_sum
|
||||
bias = TexMobject("+\\,", "b_0")
|
||||
bias.scale(0.75)
|
||||
bias.next_to(weighted_sum, RIGHT, SMALL_BUFF)
|
||||
bias.shift(0.5*SMALL_BUFF*DOWN)
|
||||
name = TextMobject("Bias")
|
||||
name.scale(0.75)
|
||||
name.next_to(bias, DOWN, MED_LARGE_BUFF)
|
||||
arrow = Arrow(name, bias, buff = SMALL_BUFF)
|
||||
VGroup(name, arrow, bias).highlight(BLUE)
|
||||
|
||||
self.play(
|
||||
FadeIn(name),
|
||||
FadeIn(bias),
|
||||
GrowArrow(arrow),
|
||||
)
|
||||
|
||||
self.weighted_sum.add(bias)
|
||||
|
||||
self.bias = bias
|
||||
self.bias_name = VGroup(name, arrow)
|
||||
|
||||
def add_sigmoid(self):
|
||||
weighted_sum = self.weighted_sum
|
||||
weighted_sum.generate_target()
|
||||
sigma, lp, rp = mob = TexMobject("\\sigma\\big(\\big)")
|
||||
# mob.scale(0.75)
|
||||
sigma.move_to(weighted_sum.get_left())
|
||||
sigma.shift(0.5*SMALL_BUFF*(DOWN+RIGHT))
|
||||
lp.next_to(sigma, RIGHT, SMALL_BUFF)
|
||||
weighted_sum.target.next_to(lp, RIGHT, SMALL_BUFF)
|
||||
rp.next_to(weighted_sum.target, RIGHT, SMALL_BUFF)
|
||||
|
||||
name = TextMobject("Sigmoid")
|
||||
name.next_to(sigma, UP, MED_LARGE_BUFF)
|
||||
arrow = Arrow(name, sigma, buff = SMALL_BUFF)
|
||||
sigmoid_name = VGroup(name, arrow)
|
||||
VGroup(sigmoid_name, mob).highlight(YELLOW)
|
||||
|
||||
self.play(
|
||||
FadeIn(mob),
|
||||
MoveToTarget(weighted_sum),
|
||||
MaintainPositionRelativeTo(self.bias_name, self.bias),
|
||||
)
|
||||
self.play(FadeIn(sigmoid_name))
|
||||
|
||||
self.sigma = sigma
|
||||
self.sigma_parens = VGroup(lp, rp)
|
||||
self.sigmoid_name = sigmoid_name
|
||||
|
||||
##
|
||||
|
||||
def organize_activations_into_column(self):
|
||||
a_labels = self.a_labels.copy()
|
||||
a_labels.generate_target()
|
||||
column = a_labels.target
|
||||
a_labels_in_sum = self.a_labels_in_sum
|
||||
|
||||
dots = TexMobject("\\vdots")
|
||||
mid_as = VGroup(*column[2:-1])
|
||||
Transform(mid_as, dots).update(1)
|
||||
last_a = column[-1]
|
||||
new_last_a = TexMobject(
|
||||
last_a.get_tex_string().replace("7", "n")
|
||||
)
|
||||
new_last_a.replace(last_a)
|
||||
Transform(last_a, new_last_a).update(1)
|
||||
|
||||
VGroup(
|
||||
*column[:2] + [mid_as] + [column[-1]]
|
||||
).arrange_submobjects(DOWN)
|
||||
column.shift(DOWN + 3.5*RIGHT)
|
||||
|
||||
pre_brackets = self.get_brackets(a_labels)
|
||||
post_bracketes = self.get_brackets(column)
|
||||
pre_brackets.set_fill(opacity = 0)
|
||||
|
||||
self.play(LaggedStart(
|
||||
Indicate, self.a_labels,
|
||||
rate_func = there_and_back
|
||||
))
|
||||
self.play(
|
||||
MoveToTarget(a_labels),
|
||||
Transform(pre_brackets, post_bracketes),
|
||||
run_time = 2
|
||||
)
|
||||
self.dither()
|
||||
self.play(*[
|
||||
LaggedStart(Indicate, mob, rate_func = there_and_back)
|
||||
for mob in a_labels, a_labels_in_sum
|
||||
])
|
||||
self.dither()
|
||||
|
||||
self.a_column = a_labels
|
||||
self.a_column_brackets = pre_brackets
|
||||
|
||||
def organize_weights_as_matrix(self):
|
||||
a_column = self.a_column
|
||||
a_column_brackets = self.a_column_brackets
|
||||
w_brackets = a_column_brackets.copy()
|
||||
w_brackets.next_to(a_column_brackets, LEFT, SMALL_BUFF)
|
||||
lwb, rwb = w_brackets
|
||||
|
||||
w_labels = self.w_labels.copy()
|
||||
w_labels.submobjects.insert(
|
||||
2, self.symbols[-2].copy()
|
||||
)
|
||||
w_labels.generate_target()
|
||||
w_labels.target.arrange_submobjects(RIGHT)
|
||||
w_labels.target.next_to(a_column[0], LEFT, buff = 0.8)
|
||||
lwb.next_to(w_labels.target, LEFT, SMALL_BUFF)
|
||||
lwb.align_to(rwb, UP)
|
||||
|
||||
row_1, row_k = [
|
||||
VGroup(*map(TexMobject, [
|
||||
"w_{%s, 0}"%i,
|
||||
"w_{%s, 1}"%i,
|
||||
"\\cdots",
|
||||
"w_{%s, k}"%i,
|
||||
]))
|
||||
for i in "1", "n"
|
||||
]
|
||||
dots_row = VGroup(*map(TexMobject, [
|
||||
"\\vdots", "\\vdots", "\\ddots", "\\vdots"
|
||||
]))
|
||||
|
||||
lower_rows = VGroup(row_1, dots_row, row_k)
|
||||
lower_rows.scale(0.75)
|
||||
last_row = w_labels.target
|
||||
for row in lower_rows:
|
||||
for target, mover in zip(last_row, row):
|
||||
mover.move_to(target)
|
||||
if "w" in mover.get_tex_string():
|
||||
mover.highlight(GREEN)
|
||||
row.next_to(last_row, DOWN, buff = 0.45)
|
||||
last_row = row
|
||||
|
||||
self.play(
|
||||
MoveToTarget(w_labels),
|
||||
Write(w_brackets, run_time = 1)
|
||||
)
|
||||
self.play(FadeIn(
|
||||
lower_rows,
|
||||
run_time = 3,
|
||||
submobject_mode = "lagged_start",
|
||||
))
|
||||
self.dither()
|
||||
|
||||
self.top_matrix_row = w_labels
|
||||
self.lower_matrix_rows = lower_rows
|
||||
self.matrix_brackets = w_brackets
|
||||
|
||||
def show_meaning_of_matrix_row(self):
|
||||
row = self.top_matrix_row
|
||||
edges = self.network_mob.layers[1].neurons[0].edges_in.copy()
|
||||
edges.set_stroke(GREEN, 5)
|
||||
rect = SurroundingRectangle(row, color = GREEN_B)
|
||||
|
||||
self.play(ShowCreation(rect))
|
||||
for x in range(2):
|
||||
self.play(LaggedStart(
|
||||
ShowCreationThenDestruction, edges,
|
||||
lag_ratio = 0.8
|
||||
))
|
||||
self.dither()
|
||||
|
||||
self.top_row_rect = rect
|
||||
|
||||
def connect_weighted_sum_to_matrix_multiplication(self):
|
||||
a_column = self.a_column
|
||||
a_brackets = self.a_column_brackets
|
||||
top_row_rect = self.top_row_rect
|
||||
|
||||
column_rect = SurroundingRectangle(a_column)
|
||||
|
||||
equals = TexMobject("=")
|
||||
equals.next_to(a_brackets, RIGHT)
|
||||
result_brackets = a_brackets.copy()
|
||||
result_terms = VGroup()
|
||||
for i in 0, 1, 4, -1:
|
||||
a = a_column[i]
|
||||
if i == 4:
|
||||
mob = TexMobject("\\vdots")
|
||||
else:
|
||||
# mob = Circle(radius = 0.2, color = YELLOW)
|
||||
mob = TexMobject("?").scale(1.3).highlight(YELLOW)
|
||||
result_terms.add(mob.move_to(a))
|
||||
VGroup(result_brackets, result_terms).next_to(equals, RIGHT)
|
||||
|
||||
brace = Brace(
|
||||
VGroup(self.w_labels, self.a_labels_in_sum), DOWN
|
||||
)
|
||||
arrow = Arrow(
|
||||
brace.get_bottom(),
|
||||
result_terms[0].get_top(),
|
||||
buff = SMALL_BUFF
|
||||
)
|
||||
|
||||
self.play(
|
||||
GrowArrow(arrow),
|
||||
GrowFromCenter(brace),
|
||||
)
|
||||
self.play(
|
||||
Write(equals),
|
||||
FadeIn(result_brackets),
|
||||
)
|
||||
self.play(ShowCreation(column_rect))
|
||||
self.play(ReplacementTransform(
|
||||
VGroup(top_row_rect, column_rect).copy(),
|
||||
result_terms[0]
|
||||
))
|
||||
self.play(LaggedStart(
|
||||
FadeIn, VGroup(*result_terms[1:])
|
||||
))
|
||||
self.dither(2)
|
||||
self.play(*map(FadeOut, [
|
||||
result_terms, result_brackets, equals,
|
||||
arrow, brace,
|
||||
top_row_rect, column_rect
|
||||
]))
|
||||
|
||||
def add_bias_vector(self):
|
||||
bias = self.bias
|
||||
bias_name = self.bias_name
|
||||
a_column_brackets = self.a_column_brackets
|
||||
a_column = self.a_column
|
||||
|
||||
plus = TexMobject("+")
|
||||
b_brackets = a_column_brackets.copy()
|
||||
b_column = VGroup(*map(TexMobject, [
|
||||
"b_0", "b_1", "\\vdots", "b_n",
|
||||
]))
|
||||
b_column.scale(0.85)
|
||||
b_column.arrange_submobjects(DOWN, buff = 0.35)
|
||||
b_column.move_to(a_column)
|
||||
b_column.highlight(BLUE)
|
||||
plus.next_to(a_column_brackets, RIGHT)
|
||||
VGroup(b_brackets, b_column).next_to(plus, RIGHT)
|
||||
|
||||
bias_rect = SurroundingRectangle(bias)
|
||||
|
||||
self.play(ShowCreation(bias_rect))
|
||||
self.play(FadeOut(bias_rect))
|
||||
self.play(
|
||||
Write(plus),
|
||||
Write(b_brackets),
|
||||
Transform(self.bias[1].copy(), b_column[0]),
|
||||
run_time = 1
|
||||
)
|
||||
self.play(LaggedStart(
|
||||
FadeIn, VGroup(*b_column[1:])
|
||||
))
|
||||
self.dither()
|
||||
|
||||
self.bias_plus = plus
|
||||
self.b_brackets = b_brackets
|
||||
self.b_column = b_column
|
||||
|
||||
def apply_sigmoid(self):
|
||||
expression_bounds = VGroup(
|
||||
self.matrix_brackets[0], self.b_brackets[1]
|
||||
)
|
||||
sigma = self.sigma.copy()
|
||||
slp, srp = self.sigma_parens.copy()
|
||||
|
||||
big_lp, big_rp = parens = TexMobject("()")
|
||||
parens.scale(3)
|
||||
parens.stretch_to_fit_height(expression_bounds.get_height())
|
||||
big_lp.next_to(expression_bounds, LEFT, SMALL_BUFF)
|
||||
big_rp.next_to(expression_bounds, RIGHT, SMALL_BUFF)
|
||||
parens.highlight(YELLOW)
|
||||
|
||||
self.play(
|
||||
sigma.scale, 2,
|
||||
sigma.next_to, big_lp, LEFT, SMALL_BUFF,
|
||||
Transform(slp, big_lp),
|
||||
Transform(srp, big_rp),
|
||||
)
|
||||
self.dither(2)
|
||||
|
||||
self.big_sigma_group = VGroup(VGroup(sigma), slp, srp)
|
||||
|
||||
def write_clean_final_expression(self):
|
||||
self.fade_weighted_sum()
|
||||
expression = TexMobject(
|
||||
"\\textbf{a}^{(1)}",
|
||||
"=",
|
||||
"\\sigma",
|
||||
"\\big(",
|
||||
"\\textbf{W}",
|
||||
"\\textbf{a}^{(0)}",
|
||||
"+",
|
||||
"\\textbf{b}",
|
||||
"\\big)",
|
||||
)
|
||||
expression.highlight_by_tex_to_color_map({
|
||||
"sigma" : YELLOW,
|
||||
"big" : YELLOW,
|
||||
"W" : GREEN,
|
||||
"\\textbf{b}" : BLUE
|
||||
})
|
||||
expression.next_to(self.big_sigma_group, UP, LARGE_BUFF)
|
||||
a1, equals, sigma, lp, W, a0, plus, b, rp = expression
|
||||
|
||||
neuron_anims = []
|
||||
neurons = VGroup(*self.network_mob.layers[1].neurons[1:])
|
||||
for neuron in neurons:
|
||||
neuron_anims += [
|
||||
neuron.restore,
|
||||
neuron.set_fill, None, random.random()
|
||||
]
|
||||
neuron_anims += [
|
||||
neuron.edges_in.restore
|
||||
]
|
||||
|
||||
self.play(ReplacementTransform(
|
||||
VGroup(
|
||||
self.top_matrix_row, self.lower_matrix_rows,
|
||||
self.matrix_brackets
|
||||
).copy(),
|
||||
VGroup(W),
|
||||
))
|
||||
self.play(ReplacementTransform(
|
||||
VGroup(self.a_column, self.a_column_brackets).copy(),
|
||||
VGroup(VGroup(a0)),
|
||||
))
|
||||
self.play(
|
||||
ReplacementTransform(
|
||||
VGroup(self.b_column, self.b_brackets).copy(),
|
||||
VGroup(VGroup(b))
|
||||
),
|
||||
ReplacementTransform(
|
||||
self.bias_plus.copy(), plus
|
||||
)
|
||||
)
|
||||
self.play(ReplacementTransform(
|
||||
self.big_sigma_group.copy(),
|
||||
VGroup(sigma, lp, rp)
|
||||
))
|
||||
self.dither()
|
||||
self.play(*neuron_anims, run_time = 2)
|
||||
self.play(
|
||||
ReplacementTransform(neurons.copy(), a1),
|
||||
FadeIn(equals)
|
||||
)
|
||||
self.dither(2)
|
||||
|
||||
def fade_weighted_sum(self):
|
||||
self.play(*map(FadeOut, [
|
||||
self.a1_label, self.a1_equals,
|
||||
self.sigma, self.sigma_parens,
|
||||
self.weighted_sum,
|
||||
self.bias_name,
|
||||
self.sigmoid_name,
|
||||
]))
|
||||
|
||||
|
||||
###
|
||||
|
||||
def get_brackets(self, mob):
|
||||
lb, rb = both = TexMobject("\\big[\\big]")
|
||||
both.scale_to_fit_width(mob.get_width())
|
||||
both.stretch_to_fit_height(1.2*mob.get_height())
|
||||
lb.next_to(mob, LEFT, SMALL_BUFF)
|
||||
rb.next_to(mob, RIGHT, SMALL_BUFF)
|
||||
return both
|
||||
|
||||
|
||||
|
||||
class EoLA3Wrapper(Scene):
|
||||
def construct(self):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -84,11 +84,7 @@ class Line(VMobject):
|
|||
self.account_for_buff()
|
||||
|
||||
def account_for_buff(self):
|
||||
anchors = self.get_anchors()
|
||||
length = sum([
|
||||
np.linalg.norm(a2-a1)
|
||||
for a1, a2 in zip(anchors, anchors[1:])
|
||||
])
|
||||
length = self.get_arc_length()
|
||||
if length < 2*self.buff or self.buff == 0:
|
||||
return
|
||||
buff_proportion = self.buff / length
|
||||
|
@ -117,6 +113,16 @@ class Line(VMobject):
|
|||
start, end = self.get_start_and_end()
|
||||
return np.linalg.norm(start - end)
|
||||
|
||||
def get_arc_length(self):
|
||||
if self.path_arc:
|
||||
anchors = self.get_anchors()
|
||||
return sum([
|
||||
np.linalg.norm(a2-a1)
|
||||
for a1, a2 in zip(anchors, anchors[1:])
|
||||
])
|
||||
else:
|
||||
return self.get_length()
|
||||
|
||||
def get_start_and_end(self):
|
||||
return self.get_start(), self.get_end()
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue