Preview back propogation scene tweaked to my liking

This commit is contained in:
Grant Sanderson 2017-10-02 19:20:53 -07:00
parent 5048fe80a5
commit 151f733f09

View file

@ -103,6 +103,8 @@ class PreviewLearning(NetworkScene):
"max_stroke_width" : 3, "max_stroke_width" : 3,
"stroke_width_exp" : 3, "stroke_width_exp" : 3,
"eta" : 5.0, "eta" : 5.0,
"positive_change_color" : GREEN_B,
"negative_change_color" : RED_B,
} }
def construct(self): def construct(self):
self.initialize_network() self.initialize_network()
@ -115,7 +117,7 @@ class PreviewLearning(NetworkScene):
self.color_network_edges() self.color_network_edges()
def add_training_words(self): def add_training_words(self):
words = TextMobject("Training in \\\\ progress $\\dots$") words = TextMobject("Training in \\\\ progress$\\dots$")
words.scale(1.5) words.scale(1.5)
words.to_corner(UP+LEFT) words.to_corner(UP+LEFT)
@ -127,12 +129,14 @@ class PreviewLearning(NetworkScene):
image = get_training_image_group(train_in, train_out) image = get_training_image_group(train_in, train_out)
self.activate_network(train_in, FadeIn(image)) self.activate_network(train_in, FadeIn(image))
self.backprop_one_example( self.backprop_one_example(
train_in, train_out, FadeOut(image) train_in, train_out,
FadeOut(image), self.network_mob.layers.restore
) )
def activate_network(self, train_in, *added_anims): def activate_network(self, train_in, *added_anims):
network_mob = self.network_mob network_mob = self.network_mob
layers = network_mob.layers layers = network_mob.layers
layers.save_state()
activations = self.network.get_activation_of_all_layers(train_in) activations = self.network.get_activation_of_all_layers(train_in)
active_layers = [ active_layers = [
self.network_mob.get_active_layer(i, vect) self.network_mob.get_active_layer(i, vect)
@ -173,6 +177,8 @@ class PreviewLearning(NetworkScene):
delta_neuron_groups, neuron_groups, delta_neuron_groups, neuron_groups,
delta_edge_groups, edge_groups delta_edge_groups, edge_groups
) )
pc_color = self.positive_change_color
nc_color = self.negative_change_color
for i, nb, nw, delta_neurons, neurons, delta_edges, edges in reversed(tups): for i, nb, nw, delta_neurons, neurons, delta_edges, edges in reversed(tups):
shown_nw = self.get_adjusted_first_matrix(nw) shown_nw = self.get_adjusted_first_matrix(nw)
if np.max(shown_nw) == 0: if np.max(shown_nw) == 0:
@ -180,42 +186,45 @@ class PreviewLearning(NetworkScene):
max_b = np.max(np.abs(nb)) max_b = np.max(np.abs(nb))
max_w = np.max(np.abs(shown_nw)) max_w = np.max(np.abs(shown_nw))
for neuron, b in zip(delta_neurons, nb): for neuron, b in zip(delta_neurons, nb):
color = RED_E if b > 0 else GREEN_E color = nc_color if b > 0 else pc_color
# neuron.set_fill(color, abs(b)/max_b) # neuron.set_fill(color, abs(b)/max_b)
neuron.set_stroke(color, 3) neuron.set_stroke(color, 3)
for edge, w in zip(delta_edges.split(), shown_nw.T.flatten()): for edge, w in zip(delta_edges.split(), shown_nw.T.flatten()):
edge.set_stroke( edge.set_stroke(
RED_E if w > 0 else GREEN_E, nc_color if w > 0 else pc_color,
3*abs(w)/max_w 3*abs(w)/max_w
) )
edge.rotate_in_place(np.pi) edge.rotate_in_place(np.pi)
if i == 0: if i == 2:
delta_edges.submobjects = [ delta_edges.submobjects = [
delta_edges[j] delta_edges[-(j+1)]
for j in np.argsort(shown_nw.T.flatten()) for j in np.argsort(shown_nw.T.flatten())
] ]
network = self.network network = self.network
network.weights[i] -= self.eta*nw network.weights[i] -= self.eta*nw
network.biases[i] -= self.eta*nb network.biases[i] -= self.eta*nb
reversed_delta_edges = VGroup(*reversed(delta_edge_groups)) reversed_delta_edges = VGroup(*it.chain(*reversed(delta_edge_groups)))
reversed_delta_neurons = VGroup(*reversed(delta_neuron_groups)) reversed_delta_neurons = VGroup(*reversed(delta_neuron_groups))
edge_groups.save_state() edge_groups.save_state()
self.play( self.play(
ShowCreation( LaggedStart(
ShowCreation,
reversed_delta_edges, reversed_delta_edges,
run_time = 2, run_time = 1.5,
submobject_mode = "lagged_start", lag_ratio = 0.15,
lag_factor = 6,
), ),
FadeIn( FadeIn(
reversed_delta_neurons, reversed_delta_neurons,
run_time = 2, run_time = 2,
submobject_mode = "lagged_start", submobject_mode = "lagged_start",
lag_factor = 4, lag_factor = 4,
rate_func = None,
) )
) )
self.color_network_edges() self.color_network_edges()
self.remove(edge_groups)
self.play(*it.chain( self.play(*it.chain(
[ReplacementTransform( [ReplacementTransform(
edge_groups.saved_state, edge_groups, edge_groups.saved_state, edge_groups,
@ -281,6 +290,8 @@ class TrainingVsTestData(Scene):
self.get_examples() for x in range(2) self.get_examples() for x in range(2)
] ]
training_examples.next_to(ORIGIN, LEFT)
test_examples.next_to(ORIGIN, RIGHT)
self.play( self.play(
LaggedStart(FadeIn, training_examples), LaggedStart(FadeIn, training_examples),
LaggedStart(FadeIn, test_examples), LaggedStart(FadeIn, test_examples),
@ -339,7 +350,7 @@ class TrainingVsTestData(Scene):
self.remove(train_ex) self.remove(train_ex)
self.add(new_ex) self.add(new_ex)
new_ex[0][0].highlight(color) new_ex[0][0].highlight(color)
self.dither(1./10) self.dither(1./30)
training_examples = new_examples training_examples = new_examples
class NotSciFi(TeacherStudentsScene): class NotSciFi(TeacherStudentsScene):
@ -348,20 +359,51 @@ class NotSciFi(TeacherStudentsScene):
self.student_says( self.student_says(
"Machines learning?!?", "Machines learning?!?",
student_index = 0, student_index = 0,
target_mode = "confused", target_mode = "pleading",
run_time = 1,
) )
bubble = students[0].bubble bubble = students[0].bubble
students[0].bubble = None students[0].bubble = None
self.student_says( self.student_says(
"Run!", student_index = 2, "Should we \\\\ be worried?", student_index = 2,
target_mode = "pleading", target_mode = "confused",
bubble_kwargs = {"direction" : LEFT} bubble_kwargs = {"direction" : LEFT},
run_time = 1,
) )
self.dither() self.dither()
students[0].bubble = bubble students[0].bubble = bubble
self.teacher_says(
"It's actually \\\\ just calculus.",
run_time = 1
)
self.teacher.bubble = None
self.dither()
self.student_says(
"Even worse!",
target_mode = "horrified",
bubble_kwargs = {
"direction" : LEFT,
"width" : 3,
"height" : 2,
},
)
self.dither(2)
class FunctionMinmization(GraphScene):
CONFIG = {
"x_labeled_nums" : range(-1, 10),
}
def construct(self):
self.setup_axes()
def func(x):
x -= 5
return 0.1*(x**3 - 9*x) + 4
graph = self.get_graph(func)
graph_label = self.get_graph_label(graph, "C(x)")
self.add(graph, graph_label)
dot = Dot(color = YELLOW)
x =