mirror of
https://github.com/3b1b/manim.git
synced 2025-04-13 09:47:07 +00:00
NN Part 1 published
This commit is contained in:
parent
6fac1a578c
commit
5c1a8f9a32
4 changed files with 544 additions and 42 deletions
|
@ -117,9 +117,12 @@ class Mobject(object):
|
|||
def deepcopy(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def generate_target(self):
|
||||
def generate_target(self, use_deepcopy = False):
|
||||
self.target = None #Prevent exponential explosion
|
||||
self.target = self.copy()
|
||||
if use_deepcopy:
|
||||
self.target = self.deepcopy()
|
||||
else:
|
||||
self.target = self.copy()
|
||||
return self.target
|
||||
|
||||
#### Transforming operations ######
|
||||
|
|
386
nn/part1.py
386
nn/part1.py
|
@ -632,6 +632,8 @@ class LayOutPlan(TeacherStudentsScene, NetworkScene):
|
|||
self.remove(self.network_mob)
|
||||
|
||||
def construct(self):
|
||||
self.force_skipping()
|
||||
|
||||
self.show_words()
|
||||
self.show_network()
|
||||
self.show_math()
|
||||
|
@ -757,6 +759,8 @@ class LayOutPlan(TeacherStudentsScene, NetworkScene):
|
|||
def show_videos(self):
|
||||
network_mob = self.network_mob
|
||||
learning = self.learning_word
|
||||
structure = TextMobject("Structure")
|
||||
structure.highlight(YELLOW)
|
||||
videos = VGroup(*[
|
||||
VideoIcon().set_fill(RED)
|
||||
for x in range(2)
|
||||
|
@ -770,13 +774,17 @@ class LayOutPlan(TeacherStudentsScene, NetworkScene):
|
|||
network_mob.target.move_to(videos[0])
|
||||
learning.generate_target()
|
||||
learning.target.next_to(videos[1], UP)
|
||||
structure.next_to(videos[0], UP)
|
||||
structure.shift(0.5*SMALL_BUFF*UP)
|
||||
|
||||
self.revert_to_original_skipping_status()
|
||||
self.play(
|
||||
MoveToTarget(network_mob),
|
||||
MoveToTarget(learning)
|
||||
)
|
||||
self.play(
|
||||
DrawBorderThenFill(videos[0]),
|
||||
FadeIn(structure),
|
||||
self.get_student_changes(*["pondering"]*3)
|
||||
)
|
||||
self.dither()
|
||||
|
@ -1192,13 +1200,29 @@ class IntroduceEachLayer(PreviewMNistNetwork):
|
|||
network_mob = self.network_mob
|
||||
neurons = self.neurons
|
||||
layer = network_mob.layers[0]
|
||||
layer.save_state()
|
||||
layer.rotate(np.pi/2)
|
||||
layer.center()
|
||||
layer.brace_label.rotate_in_place(-np.pi/2)
|
||||
n = network_mob.max_shown_neurons/2
|
||||
|
||||
rows = VGroup(*[
|
||||
VGroup(*neurons[28*i:28*(i+1)])
|
||||
for i in range(28)
|
||||
])
|
||||
|
||||
self.play(
|
||||
FadeOut(self.braces),
|
||||
FadeOut(self.brace_labels),
|
||||
FadeOut(VGroup(*self.num_pixels_equation[:-1]))
|
||||
)
|
||||
|
||||
self.play(rows.space_out_submobjects, 1.2)
|
||||
self.play(
|
||||
rows.arrange_submobjects, RIGHT, buff = SMALL_BUFF,
|
||||
path_arc = np.pi/2,
|
||||
run_time = 2
|
||||
)
|
||||
self.play(
|
||||
ReplacementTransform(
|
||||
VGroup(*neurons[:n]),
|
||||
|
@ -1212,15 +1236,15 @@ class IntroduceEachLayer(PreviewMNistNetwork):
|
|||
VGroup(*neurons[-n:]),
|
||||
VGroup(*layer.neurons[-n:]),
|
||||
),
|
||||
FadeIn(self.corner_image)
|
||||
)
|
||||
self.play(
|
||||
ReplacementTransform(
|
||||
self.num_pixels_equation[-1],
|
||||
layer.brace_label
|
||||
),
|
||||
FadeIn(layer.brace)
|
||||
FadeIn(layer.brace),
|
||||
)
|
||||
self.play(layer.restore, FadeIn(self.corner_image))
|
||||
self.dither()
|
||||
for edge_group, layer in zip(network_mob.edge_groups, network_mob.layers[1:]):
|
||||
self.play(
|
||||
|
@ -1320,6 +1344,69 @@ class IntroduceEachLayer(PreviewMNistNetwork):
|
|||
self.remove_random_edges(0.7)
|
||||
self.feed_forward(self.image_vect)
|
||||
|
||||
class DiscussChoiceForHiddenLayers(TeacherStudentsScene):
|
||||
def construct(self):
|
||||
network_mob = MNistNetworkMobject(
|
||||
layer_to_layer_buff = 2.5,
|
||||
neuron_stroke_color = WHITE,
|
||||
)
|
||||
network_mob.scale_to_fit_height(4)
|
||||
network_mob.to_edge(UP, buff = LARGE_BUFF)
|
||||
layers = VGroup(*network_mob.layers[1:3])
|
||||
rects = VGroup(*map(SurroundingRectangle, layers))
|
||||
self.add(network_mob)
|
||||
|
||||
two_words = TextMobject("2 hidden layers")
|
||||
two_words.highlight(YELLOW)
|
||||
sixteen_words = TextMobject("16 neurons each")
|
||||
sixteen_words.highlight(MAROON_B)
|
||||
for words in two_words, sixteen_words:
|
||||
words.next_to(rects, UP)
|
||||
|
||||
neurons_anim = LaggedStart(
|
||||
Indicate,
|
||||
VGroup(*it.chain(*[layer.neurons for layer in layers])),
|
||||
rate_func = there_and_back,
|
||||
scale_factor = 2,
|
||||
color = MAROON_B,
|
||||
)
|
||||
|
||||
self.play(
|
||||
ShowCreation(rects),
|
||||
Write(two_words, run_time = 1),
|
||||
self.teacher.change, "raise_right_hand",
|
||||
)
|
||||
self.dither()
|
||||
self.play(
|
||||
FadeOut(rects),
|
||||
ReplacementTransform(two_words, sixteen_words),
|
||||
neurons_anim
|
||||
)
|
||||
self.dither()
|
||||
self.play(self.teacher.change, "shruggie")
|
||||
self.change_student_modes("erm", "confused", "sassy")
|
||||
self.dither()
|
||||
self.student_says(
|
||||
"Why 2 \\\\ layers?",
|
||||
student_index = 1,
|
||||
bubble_kwargs = {"direction" : RIGHT},
|
||||
run_time = 1,
|
||||
target_mode = "raise_left_hand",
|
||||
)
|
||||
self.play(self.teacher.change, "happy")
|
||||
self.dither()
|
||||
self.student_says(
|
||||
"Why 16?",
|
||||
student_index = 0,
|
||||
run_time = 1,
|
||||
)
|
||||
self.play(neurons_anim, run_time = 3)
|
||||
self.play(
|
||||
self.teacher.change, "shruggie",
|
||||
RemovePiCreatureBubble(self.students[0]),
|
||||
)
|
||||
self.dither()
|
||||
|
||||
class MoreHonestMNistNetworkPreview(IntroduceEachLayer):
|
||||
CONFIG = {
|
||||
"network_mob_config" : {
|
||||
|
@ -1597,7 +1684,7 @@ class BreakUpMacroPatterns(IntroduceEachLayer):
|
|||
|
||||
def show_upper_loop_activation(self):
|
||||
neuron = self.network_mob.layers[-2].neurons[0]
|
||||
words = TextMobject("Upper loop neuron...mabye...")
|
||||
words = TextMobject("Upper loop neuron...maybe...")
|
||||
words.scale(0.8)
|
||||
words.next_to(neuron, UP)
|
||||
words.shift(RIGHT)
|
||||
|
@ -1619,11 +1706,11 @@ class BreakUpMacroPatterns(IntroduceEachLayer):
|
|||
]
|
||||
|
||||
self.play(FadeIn(nine))
|
||||
self.add_foreground_mobject(self.patterns)
|
||||
self.play(
|
||||
ShowCreation(rect),
|
||||
Write(words)
|
||||
)
|
||||
self.add_foreground_mobject(self.patterns)
|
||||
self.feed_forward(np.random.random(784))
|
||||
self.dither(2)
|
||||
|
||||
|
@ -2204,7 +2291,7 @@ class IntroduceWeights(IntroduceEachLayer):
|
|||
pixels.next_to(neuron, RIGHT, LARGE_BUFF)
|
||||
rect = SurroundingRectangle(pixels, color = BLUE)
|
||||
|
||||
pixels_to_detect = self.get_pixels_to_detect()
|
||||
pixels_to_detect = self.get_pixels_to_detect(pixels)
|
||||
|
||||
self.play(
|
||||
FadeIn(rect),
|
||||
|
@ -2249,7 +2336,8 @@ class IntroduceWeights(IntroduceEachLayer):
|
|||
p_labels[-1].shift(SMALL_BUFF*RIGHT)
|
||||
|
||||
def get_alpha_func(i, start = 0):
|
||||
m = int(5*np.sin(2*np.pi*i/128.))
|
||||
# m = int(5*np.sin(2*np.pi*i/128.))
|
||||
m = random.randint(1, 10)
|
||||
return lambda a : start + (1-2*start)*np.sin(np.pi*a*m)**2
|
||||
|
||||
decimals = VGroup()
|
||||
|
@ -2660,12 +2748,16 @@ class IntroduceSigmoid(GraphScene):
|
|||
name = TextMobject("Sigmoid")
|
||||
name.next_to(ORIGIN, RIGHT, LARGE_BUFF)
|
||||
name.to_edge(UP)
|
||||
char = self.x_axis_label.replace("$", "")
|
||||
equation = TexMobject(
|
||||
"\\sigma(x) = \\frac{1}{1+e^{-x}}"
|
||||
"\\sigma(%s) = \\frac{1}{1+e^{-%s}}"%(char, char)
|
||||
)
|
||||
equation.next_to(name, DOWN)
|
||||
self.add(equation, name)
|
||||
|
||||
self.equation = equation
|
||||
self.sigmoid_name = name
|
||||
|
||||
def add_graph(self):
|
||||
graph = self.get_graph(
|
||||
lambda x : 1./(1+np.exp(-x)),
|
||||
|
@ -2675,6 +2767,8 @@ class IntroduceSigmoid(GraphScene):
|
|||
self.play(ShowCreation(graph))
|
||||
self.dither()
|
||||
|
||||
self.sigmoid_graph = graph
|
||||
|
||||
###
|
||||
|
||||
def show_part(self, x_min, x_max, color):
|
||||
|
@ -3494,9 +3588,9 @@ class IntroduceWeightMatrix(NetworkScene):
|
|||
"w_{%s, 0}"%i,
|
||||
"w_{%s, 1}"%i,
|
||||
"\\cdots",
|
||||
"w_{%s, k}"%i,
|
||||
"w_{%s, n}"%i,
|
||||
]))
|
||||
for i in "1", "n"
|
||||
for i in "1", "k"
|
||||
]
|
||||
dots_row = VGroup(*map(TexMobject, [
|
||||
"\\vdots", "\\vdots", "\\ddots", "\\vdots"
|
||||
|
@ -3591,12 +3685,65 @@ class IntroduceWeightMatrix(NetworkScene):
|
|||
FadeIn, VGroup(*result_terms[1:])
|
||||
))
|
||||
self.dither(2)
|
||||
self.show_meaning_of_lower_rows(
|
||||
arrow, brace, top_row_rect, result_terms
|
||||
)
|
||||
self.play(*map(FadeOut, [
|
||||
result_terms, result_brackets, equals,
|
||||
arrow, brace,
|
||||
top_row_rect, column_rect
|
||||
result_terms, result_brackets, equals, column_rect
|
||||
]))
|
||||
|
||||
def show_meaning_of_lower_rows(self, arrow, brace, row_rect, result_terms):
|
||||
n1, n2, nk = neurons = VGroup(*[
|
||||
self.network_mob.layers[1].neurons[i]
|
||||
for i in 0, 1, -1
|
||||
])
|
||||
for n in neurons:
|
||||
n.save_state()
|
||||
n.edges_in.save_state()
|
||||
|
||||
rect2 = SurroundingRectangle(result_terms[1])
|
||||
rectk = SurroundingRectangle(result_terms[-1])
|
||||
VGroup(rect2, rectk).highlight(WHITE)
|
||||
row2 = self.lower_matrix_rows[0]
|
||||
rowk = self.lower_matrix_rows[-1]
|
||||
|
||||
def show_edges(neuron):
|
||||
self.play(LaggedStart(
|
||||
ShowCreationThenDestruction,
|
||||
neuron.edges_in.copy().set_stroke(GREEN, 5),
|
||||
lag_ratio = 0.7,
|
||||
run_time = 1,
|
||||
))
|
||||
|
||||
self.play(
|
||||
row_rect.move_to, row2,
|
||||
n1.fade,
|
||||
n1.set_fill, None, 0,
|
||||
n1.edges_in.set_stroke, None, 1,
|
||||
n2.set_stroke, WHITE, 3,
|
||||
n2.edges_in.set_stroke, None, 3,
|
||||
ReplacementTransform(arrow, rect2),
|
||||
FadeOut(brace),
|
||||
)
|
||||
show_edges(n2)
|
||||
self.play(
|
||||
row_rect.move_to, rowk,
|
||||
n2.restore,
|
||||
n2.edges_in.restore,
|
||||
nk.set_stroke, WHITE, 3,
|
||||
nk.edges_in.set_stroke, None, 3,
|
||||
ReplacementTransform(rect2, rectk),
|
||||
)
|
||||
show_edges(nk)
|
||||
self.play(
|
||||
n1.restore,
|
||||
n1.edges_in.restore,
|
||||
nk.restore,
|
||||
nk.edges_in.restore,
|
||||
FadeOut(rectk),
|
||||
FadeOut(row_rect),
|
||||
)
|
||||
|
||||
def add_bias_vector(self):
|
||||
bias = self.bias
|
||||
bias_name = self.bias_name
|
||||
|
@ -4054,7 +4201,7 @@ class NextVideo(MoreHonestMNistNetworkPreview, PiCreatureScene):
|
|||
content = self.content
|
||||
|
||||
video = VideoIcon()
|
||||
video.scale_to_fit_height(2)
|
||||
video.scale_to_fit_height(3)
|
||||
video.set_fill(RED, 0.8)
|
||||
video.next_to(morty, UP+LEFT)
|
||||
|
||||
|
@ -4098,12 +4245,12 @@ class NextVideo(MoreHonestMNistNetworkPreview, PiCreatureScene):
|
|||
)
|
||||
bang = subscribe_word[1]
|
||||
subscribe_word.to_corner(DOWN+RIGHT)
|
||||
subscribe_word.shift(2*UP)
|
||||
subscribe_word.shift(3*UP)
|
||||
q_mark = TextMobject("?")
|
||||
q_mark.move_to(bang, LEFT)
|
||||
arrow = Arrow(ORIGIN, DOWN, color = RED, buff = 0)
|
||||
arrow.next_to(subscribe_word, DOWN)
|
||||
arrow.shift(RIGHT)
|
||||
arrow.shift(MED_LARGE_BUFF * RIGHT)
|
||||
|
||||
self.play(
|
||||
Write(subscribe_word),
|
||||
|
@ -4120,7 +4267,7 @@ class NextVideo(MoreHonestMNistNetworkPreview, PiCreatureScene):
|
|||
morty = self.pi_creature
|
||||
|
||||
network_mob, rect, video, words = self.video
|
||||
network_mob.generate_target()
|
||||
network_mob.generate_target(use_deepcopy = True)
|
||||
network_mob.target.scale_to_fit_height(5)
|
||||
network_mob.target.to_corner(UP+LEFT)
|
||||
neurons = VGroup(*network_mob.target.layers[-1].neurons[:2])
|
||||
|
@ -4216,6 +4363,209 @@ class NNPatreonThanks(PatreonThanks):
|
|||
]
|
||||
}
|
||||
|
||||
class PiCreatureGesture(PiCreatureScene):
|
||||
def construct(self):
|
||||
self.play(self.pi_creature.change, "raise_right_hand")
|
||||
self.dither(5)
|
||||
self.play(self.pi_creature.change, "happy")
|
||||
self.dither(4)
|
||||
|
||||
class IntroduceReLU(IntroduceSigmoid):
|
||||
CONFIG = {
|
||||
"x_axis_label" : "$a$"
|
||||
}
|
||||
def construct(self):
|
||||
self.setup_axes()
|
||||
self.add_title()
|
||||
self.add_graph()
|
||||
self.old_school()
|
||||
self.show_ReLU()
|
||||
self.label_input_regions()
|
||||
|
||||
def old_school(self):
|
||||
sigmoid_graph = self.sigmoid_graph
|
||||
sigmoid_title = VGroup(
|
||||
self.sigmoid_name,
|
||||
self.equation
|
||||
)
|
||||
cross = Cross(sigmoid_title)
|
||||
old_school = TextMobject("Old school")
|
||||
old_school.to_corner(UP+RIGHT)
|
||||
old_school.highlight(RED)
|
||||
arrow = Arrow(
|
||||
old_school.get_bottom(),
|
||||
self.equation.get_right(),
|
||||
color = RED
|
||||
)
|
||||
|
||||
self.play(ShowCreation(cross))
|
||||
self.play(
|
||||
Write(old_school, run_time = 1),
|
||||
GrowArrow(arrow)
|
||||
)
|
||||
self.dither(2)
|
||||
self.play(
|
||||
ApplyMethod(
|
||||
VGroup(cross, sigmoid_title).shift,
|
||||
SPACE_WIDTH*RIGHT,
|
||||
rate_func = running_start
|
||||
),
|
||||
FadeOut(old_school),
|
||||
FadeOut(arrow),
|
||||
)
|
||||
self.play(ShowCreation(
|
||||
self.sigmoid_graph,
|
||||
rate_func = lambda t : smooth(1-t),
|
||||
remover = True
|
||||
))
|
||||
|
||||
def show_ReLU(self):
|
||||
graph = VGroup(
|
||||
Line(
|
||||
self.coords_to_point(-7, 0),
|
||||
self.coords_to_point(0, 0),
|
||||
),
|
||||
Line(
|
||||
self.coords_to_point(0, 0),
|
||||
self.coords_to_point(4, 4),
|
||||
),
|
||||
)
|
||||
graph.highlight(YELLOW)
|
||||
char = self.x_axis_label.replace("$", "")
|
||||
equation = TextMobject("ReLU($%s$) = max$(0, %s)$"%(char, char))
|
||||
equation.shift(SPACE_WIDTH*LEFT/2)
|
||||
equation.to_edge(UP)
|
||||
equation.add_background_rectangle()
|
||||
name = TextMobject("Rectified linear unit")
|
||||
name.move_to(equation)
|
||||
name.add_background_rectangle()
|
||||
|
||||
self.play(Write(equation))
|
||||
self.play(ShowCreation(graph), Animation(equation))
|
||||
self.dither(2)
|
||||
self.play(
|
||||
Write(name),
|
||||
equation.shift, DOWN
|
||||
)
|
||||
self.dither(2)
|
||||
|
||||
self.ReLU_graph = graph
|
||||
|
||||
def label_input_regions(self):
|
||||
l1, l2 = self.ReLU_graph
|
||||
neg_words = TextMobject("Inactive")
|
||||
neg_words.highlight(RED)
|
||||
neg_words.next_to(self.coords_to_point(-2, 0), UP)
|
||||
|
||||
pos_words = TextMobject("Same as $f(a) = a$")
|
||||
pos_words.highlight(GREEN)
|
||||
pos_words.next_to(
|
||||
self.coords_to_point(1, 1),
|
||||
DOWN+RIGHT
|
||||
)
|
||||
|
||||
self.revert_to_original_skipping_status()
|
||||
self.play(ShowCreation(l1.copy().highlight(RED)))
|
||||
self.play(Write(neg_words))
|
||||
self.dither()
|
||||
self.play(ShowCreation(l2.copy().highlight(GREEN)))
|
||||
self.play(Write(pos_words))
|
||||
self.dither(2)
|
||||
|
||||
class CompareSigmoidReLUOnDeepNetworks(PiCreatureScene):
|
||||
def construct(self):
|
||||
morty, lisha = self.morty, self.lisha
|
||||
sigmoid_graph = FunctionGraph(
|
||||
sigmoid,
|
||||
x_min = -5,
|
||||
x_max = 5,
|
||||
)
|
||||
sigmoid_graph.stretch_to_fit_width(3)
|
||||
sigmoid_graph.highlight(YELLOW)
|
||||
sigmoid_graph.next_to(lisha, UP+LEFT)
|
||||
sigmoid_graph.shift_onto_screen()
|
||||
sigmoid_name = TextMobject("Sigmoid")
|
||||
sigmoid_name.next_to(sigmoid_graph, UP)
|
||||
sigmoid_graph.add(sigmoid_name)
|
||||
|
||||
slow_learner = TextMobject("Slow learner")
|
||||
slow_learner.highlight(YELLOW)
|
||||
slow_learner.to_corner(UP+LEFT)
|
||||
slow_arrow = Arrow(
|
||||
slow_learner.get_bottom(),
|
||||
sigmoid_graph.get_top(),
|
||||
)
|
||||
|
||||
relu_graph = VGroup(
|
||||
Line(2*LEFT, ORIGIN),
|
||||
Line(ORIGIN, np.sqrt(2)*(RIGHT+UP)),
|
||||
)
|
||||
relu_graph.highlight(BLUE)
|
||||
relu_graph.next_to(lisha, UP+RIGHT)
|
||||
relu_name = TextMobject("ReLU")
|
||||
relu_name.move_to(relu_graph, UP)
|
||||
relu_graph.add(relu_name)
|
||||
|
||||
network_mob = NetworkMobject(Network(
|
||||
sizes = [6, 4, 5, 4, 3, 5, 2]
|
||||
))
|
||||
network_mob.scale(0.8)
|
||||
network_mob.to_edge(UP, buff = MED_SMALL_BUFF)
|
||||
network_mob.shift(RIGHT)
|
||||
edge_update = ContinualEdgeUpdate(
|
||||
network_mob, stroke_width_exp = 1,
|
||||
)
|
||||
|
||||
self.play(
|
||||
FadeIn(sigmoid_name),
|
||||
ShowCreation(sigmoid_graph),
|
||||
lisha.change, "raise_left_hand",
|
||||
morty.change, "pondering"
|
||||
)
|
||||
self.play(
|
||||
Write(slow_learner, run_time = 1),
|
||||
GrowArrow(slow_arrow)
|
||||
)
|
||||
self.dither()
|
||||
self.play(
|
||||
FadeIn(relu_name),
|
||||
ShowCreation(relu_graph),
|
||||
lisha.change, "raise_right_hand",
|
||||
morty.change, "thinking"
|
||||
)
|
||||
self.play(FadeIn(network_mob))
|
||||
self.add(edge_update)
|
||||
self.dither(10)
|
||||
|
||||
|
||||
|
||||
###
|
||||
def create_pi_creatures(self):
|
||||
morty = Mortimer()
|
||||
morty.shift(SPACE_WIDTH*RIGHT/2).to_edge(DOWN)
|
||||
lisha = PiCreature(color = BLUE_C)
|
||||
lisha.shift(SPACE_WIDTH*LEFT/2).to_edge(DOWN)
|
||||
self.morty, self.lisha = morty, lisha
|
||||
return morty, lisha
|
||||
|
||||
class ShowAmplify(PiCreatureScene):
|
||||
def construct(self):
|
||||
morty = self.pi_creature
|
||||
rect = ScreenRectangle(height = 5)
|
||||
rect.to_corner(UP+LEFT)
|
||||
rect.shift(DOWN)
|
||||
email = TextMobject("3blue1brown@amplifypartners.com")
|
||||
email.next_to(rect, UP)
|
||||
|
||||
self.play(
|
||||
ShowCreation(rect),
|
||||
morty.change, "raise_right_hand"
|
||||
)
|
||||
self.dither(2)
|
||||
self.play(Write(email))
|
||||
self.play(morty.change, "happy", rect)
|
||||
self.dither(10)
|
||||
|
||||
class Thumbnail(NetworkScene):
|
||||
CONFIG = {
|
||||
"network_mob_config" : {
|
||||
|
@ -4225,11 +4575,13 @@ class Thumbnail(NetworkScene):
|
|||
def construct(self):
|
||||
network_mob = self.network_mob
|
||||
network_mob.scale_to_fit_height(2*SPACE_HEIGHT - 1)
|
||||
for layer in network_mob.layers:
|
||||
layer.neurons.set_stroke(width = 5)
|
||||
|
||||
edge_update = ContinualEdgeUpdate(
|
||||
network_mob,
|
||||
max_stroke_width = 10,
|
||||
stroke_width_exp = 5,
|
||||
stroke_width_exp = 4,
|
||||
)
|
||||
edge_update.internal_time = 3
|
||||
edge_update.update(0)
|
||||
|
|
192
nn/part2.py
192
nn/part2.py
|
@ -102,9 +102,9 @@ class PreviewLearning(NetworkScene):
|
|||
"n_examples" : 15,
|
||||
"max_stroke_width" : 3,
|
||||
"stroke_width_exp" : 3,
|
||||
"eta" : 5.0,
|
||||
"positive_change_color" : GREEN_B,
|
||||
"negative_change_color" : RED_B,
|
||||
"eta" : 3.0,
|
||||
"positive_change_color" : average_color(*2*[GREEN] + [YELLOW]),
|
||||
"negative_change_color" : average_color(*2*[RED] + [YELLOW]),
|
||||
}
|
||||
def construct(self):
|
||||
self.initialize_network()
|
||||
|
@ -197,31 +197,20 @@ class PreviewLearning(NetworkScene):
|
|||
edge.rotate_in_place(np.pi)
|
||||
if i == 2:
|
||||
delta_edges.submobjects = [
|
||||
delta_edges[-(j+1)]
|
||||
delta_edges[j]
|
||||
for j in np.argsort(shown_nw.T.flatten())
|
||||
]
|
||||
network = self.network
|
||||
network.weights[i] -= self.eta*nw
|
||||
network.biases[i] -= self.eta*nb
|
||||
|
||||
reversed_delta_edges = VGroup(*it.chain(*reversed(delta_edge_groups)))
|
||||
reversed_delta_neurons = VGroup(*reversed(delta_neuron_groups))
|
||||
|
||||
self.play(
|
||||
LaggedStart(
|
||||
ShowCreation,
|
||||
reversed_delta_edges,
|
||||
run_time = 1.5,
|
||||
lag_ratio = 0.15,
|
||||
),
|
||||
FadeIn(
|
||||
reversed_delta_neurons,
|
||||
run_time = 2,
|
||||
submobject_mode = "lagged_start",
|
||||
lag_factor = 4,
|
||||
rate_func = None,
|
||||
self.play(
|
||||
ShowCreation(
|
||||
delta_edges, submobject_mode = "all_at_once"
|
||||
),
|
||||
FadeIn(delta_neurons),
|
||||
run_time = 0.5
|
||||
)
|
||||
)
|
||||
edge_groups.save_state()
|
||||
self.color_network_edges()
|
||||
self.remove(edge_groups)
|
||||
|
@ -229,7 +218,7 @@ class PreviewLearning(NetworkScene):
|
|||
[ReplacementTransform(
|
||||
edge_groups.saved_state, edge_groups,
|
||||
)],
|
||||
map(FadeOut, [reversed_delta_edges, reversed_delta_neurons]),
|
||||
map(FadeOut, [delta_edge_groups, delta_neuron_groups]),
|
||||
added_outro_anims,
|
||||
))
|
||||
|
||||
|
@ -424,7 +413,164 @@ class FunctionMinmization(GraphScene):
|
|||
])
|
||||
self.dither(10)
|
||||
|
||||
|
||||
class IntroduceCostFunction(PreviewLearning):
|
||||
def construct(self):
|
||||
self.force_skipping()
|
||||
|
||||
self.isolate_one_neuron()
|
||||
self.reminder_of_weights_and_bias()
|
||||
self.initialize_randomly()
|
||||
self.feed_in_example()
|
||||
self.make_fun_of_output()
|
||||
self.need_a_cost_function()
|
||||
self.show_cost_function()
|
||||
|
||||
def isolate_one_neuron(self):
|
||||
network_mob = self.network_mob
|
||||
network_mob.shift(LEFT)
|
||||
neuron_groups = VGroup(*[
|
||||
layer.neurons
|
||||
for layer in network_mob.layers[1:]
|
||||
])
|
||||
edge_groups = network_mob.edge_groups
|
||||
neuron = neuron_groups[0][7].deepcopy()
|
||||
output_labels = network_mob.output_labels
|
||||
kwargs = {
|
||||
"submobject_mode" : "lagged_start",
|
||||
"run_time" : 2,
|
||||
}
|
||||
self.play(
|
||||
FadeOut(edge_groups, **kwargs),
|
||||
FadeOut(neuron_groups, **kwargs),
|
||||
FadeOut(output_labels, **kwargs),
|
||||
Animation(neuron),
|
||||
neuron.edges_in.set_stroke, None, 2,
|
||||
)
|
||||
self.dither()
|
||||
|
||||
self.neuron = neuron
|
||||
|
||||
def reminder_of_weights_and_bias(self):
|
||||
neuron = self.neuron
|
||||
layer0 = self.network_mob.layers[0]
|
||||
active_layer0 = self.network_mob.get_active_layer(
|
||||
0, np.random.random(len(layer0.neurons))
|
||||
)
|
||||
prev_neurons = layer0.neurons
|
||||
|
||||
weights = 4*(np.random.random(len(neuron.edges_in))-0.5)
|
||||
weighted_edges = VGroup(*[
|
||||
edge.copy().set_stroke(
|
||||
color = GREEN if w > 0 else RED,
|
||||
width = abs(w)
|
||||
)
|
||||
for w, edge in zip(weights, neuron.edges_in)
|
||||
])
|
||||
|
||||
formula = TexMobject(
|
||||
"=", "\\sigma(",
|
||||
"w_1", "a_1", "+",
|
||||
"w_2", "a_2", "+",
|
||||
"\\cdots", "+",
|
||||
"w_n", "a_n", "+", "b", ")"
|
||||
)
|
||||
w_labels = formula.get_parts_by_tex("w_")
|
||||
a_labels = formula.get_parts_by_tex("a_")
|
||||
b = formula.get_part_by_tex("b")
|
||||
sigma = VGroup(
|
||||
formula.get_part_by_tex("\\sigma"),
|
||||
formula.get_part_by_tex(")"),
|
||||
)
|
||||
symbols = VGroup(*[
|
||||
formula.get_parts_by_tex(tex)
|
||||
for tex in "=", "+", "dots"
|
||||
])
|
||||
|
||||
w_labels.highlight(GREEN)
|
||||
b.highlight(BLUE)
|
||||
sigma.highlight(YELLOW)
|
||||
# formula.to_edge(UP)
|
||||
formula.next_to(neuron, RIGHT)
|
||||
|
||||
weights_word = TextMobject("Weights")
|
||||
weights_word.next_to(neuron.edges_in, RIGHT, aligned_edge = UP)
|
||||
weights_word.highlight(GREEN)
|
||||
weights_arrow = Arrow(
|
||||
weights_word.get_bottom(),
|
||||
neuron.edges_in[0].get_center(),
|
||||
color = GREEN
|
||||
)
|
||||
|
||||
alt_weights_arrows = VGroup(*[
|
||||
Arrow(
|
||||
weights_word.get_bottom(),
|
||||
w_label.get_top(),
|
||||
color = GREEN
|
||||
)
|
||||
for w_label in w_labels
|
||||
])
|
||||
|
||||
bias_word = TextMobject("Bias")
|
||||
bias_arrow = Vector(DOWN, color = BLUE)
|
||||
bias_arrow.next_to(b, UP, SMALL_BUFF)
|
||||
bias_word.next_to(bias_arrow, UP, SMALL_BUFF)
|
||||
bias_word.highlight(BLUE)
|
||||
|
||||
self.revert_to_original_skipping_status()
|
||||
self.play(
|
||||
Transform(layer0, active_layer0),
|
||||
FadeIn(a_labels),
|
||||
FadeIn(symbols),
|
||||
run_time = 2,
|
||||
submobject_mode = "lagged_start"
|
||||
)
|
||||
self.play(
|
||||
Write(weights_word),
|
||||
GrowArrow(weights_arrow),
|
||||
Transform(neuron.edges_in, weighted_edges),
|
||||
run_time = 1,
|
||||
)
|
||||
self.dither()
|
||||
self.play(
|
||||
ReplacementTransform(
|
||||
weighted_edges.copy(), w_labels,
|
||||
),
|
||||
ReplacementTransform(
|
||||
VGroup(weights_arrow),
|
||||
alt_weights_arrows
|
||||
)
|
||||
)
|
||||
self.dither()
|
||||
self.play(
|
||||
Write(b),
|
||||
Write(bias_word),
|
||||
GrowArrow(bias_arrow),
|
||||
run_time = 1
|
||||
)
|
||||
self.play(Write(sigma))
|
||||
self.dither(2)
|
||||
|
||||
def initialize_randomly(self):
|
||||
pass
|
||||
|
||||
def feed_in_example(self):
|
||||
pass
|
||||
|
||||
def make_fun_of_output(self):
|
||||
pass
|
||||
|
||||
def need_a_cost_function(self):
|
||||
pass
|
||||
|
||||
def show_cost_function(self):
|
||||
pass
|
||||
|
||||
|
||||
####
|
||||
|
||||
def activate_network(self, train_in, *added_anims):
|
||||
##TODO
|
||||
PreviewLearning.activate_network(self, train_in, *added_anims)
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -328,6 +328,7 @@ class PiCreatureBubbleIntroduction(AnimationGroup):
|
|||
bubble_class = self.bubble_class,
|
||||
**self.bubble_kwargs
|
||||
)
|
||||
Group(bubble, bubble.content).shift_onto_screen()
|
||||
|
||||
pi_creature.generate_target()
|
||||
pi_creature.target.change_mode(self.target_mode)
|
||||
|
|
Loading…
Add table
Reference in a new issue