Poor initialization of model weights can be problematic in at least two different cases:
1. Final Layer Bias Initilization
It can be common to see a “hockey stick” plot when analyzing the loss plots of a neural network for the first few iterations of training. This happens when network is first learning to scale the logits into the same range as the labels.
To mitigate this, the bias of the final layer can be set based on the type of ML task, and the balance of the dataset used. For the example of Makemore, Andrej sets the bias to zero to correspond to a uniform distribution. He could have also investigated the distribution of letters in the dataset and set the bias according to that - this is kind of like Empirical Bayes!
init well. Initialize the final layer weights correctly. E.g. if you are regressing some values that have a mean of 50 then initialize the final bias to 50. If you have an imbalanced dataset of a ratio 1:10 of positives:negatives, set the bias on your logits such that your network predicts probability of 0.1 at initialization. Setting these correctly will speed up convergence and eliminate “hockey stick” loss curves where in the first few iteration your network is basically just learning the bias.
2. Weight Initialization
Large values in matrices in pre-activation states can lead to strange behaviour when passed through activation functions. For example, activation functions that have asymptotic areas (tanh, sigmoid) can “saturate”, leading to a high number of outputs which are either 0 or 1. This can cause more problems in deep networks. Ideally we want activations in the network to stay within the unit Gaussian range - mean of 0, std of 1
We can fix this be initializing weights intelligently. In theory, this can be done manually by scaling down randomly initialized weights, and inspecting the effects on the distribution of activations. However this approach is not scalable to larger networks.
Kaiming (AKA He) Initialization is a more principled way of doing this. It is based on a mathematical analysis of the effects on the statistical changes made by activation functions.
This is functionally the same as what Andrej shows in the lecture - dividing the unit normal weights by \(gain * \sqrt{n_l}\), where \(n_l\) is the “fan-in”, the number of inputs to the layer, and \(gain\) is an activation specific constant. These constants are defined in pytorch too:
Let’s imagine some neuron inputs x and neuron weights w. We can calculate the matrix product y before passing the resulting values through an activation function RELU to get some activations. When we compare the standard deviation in h to x we can see there has been an increase.
torch.randn draws values from a unit normal distribution, \(\mu=0, \sigma=1\) - so we can change the standard deviation by scaling it by a factor. For RELU, our Kaiming initialization should have \(\sigma = \sqrt{\frac{2}{n_l}}\), where \(n_l\) is the “fan-in” - the number of inputs to the neuron
w = torch.randn(20, 200) * (2/20)**0.5# init w with He Initializationy = x @ w # x is fixed from aboveh = torch.relu(y)print(x.std(), h.std())plt.figure(figsize=(20,5))plt.subplot(121)plt.hist(x.view(-1), density=True);plt.subplot(122)plt.hist(h.view(-1), density=True);
tensor(0.9962) tensor(0.8292)
We can see that this has reduced the standard deviation of the activations, closer to the range of unit normal (although obviously the output of RELU is not normally distributed)
NOTE: Default parameter initialization in torch.nn.Linear
When taking a look at the source code of nn.Linear, it might seem like they are using the Kaiming init:
def reset_parameters(self) ->None:# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see# https://github.com/pytorch/pytorch/issues/57109 init.kaiming_uniform_(self.weight, a=math.sqrt(5))ifself.bias isnotNone: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound =1/ math.sqrt(fan_in) if fan_in >0else0 init.uniform_(self.bias, -bound, bound)
However from the comment it becomes clear that actually they are using the nn.init.kaiming_uniform_ function as a hack to achieve values from the uniform distribution \(U(-1/\sqrt(in\_features), 1/\sqrt(in\_features))\)
To me it seems strange that they are not using an init defined in nn.init for this…
One answer in this older thread suggests that the initialisation resembles what is referred to “LeCun Initialisation”. This comment is probably long overdue, but pytorch does not implement LeCun or He/Kaiming initialisation for the Linear module.
If we go through the code (v1.5.0) of Linear.reset_parameters, the first line initialises the weight matrix: init.kaiming_uniform_(self.weight, a=math.sqrt(5)). If we take a look at how kaiming_uniform is implemented, we find that this line is equivalent to
fan = tensor.size(1) # fan-infor linear, as computed by _calculate_correct_fan
gain = math.sqrt(2.0 / (1 + a ** 2)) # gain, as computed by calculate_gain
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std
with torch.no_grad():
return tensor.uniform_(-bound, bound)
Since a = math.sqrt(5) the weights are initialised with std = 1 / math.sqrt(3.0 * fan_in). For reference, LeCun initialisation would be 1 / math.sqrt(fan_in) and He initialisation uses math.sqrt(2 / fan_in).
The bias initialisation in Linear.reset_parameters reveals another problem. Although biases are normally initialised with zeros (for the sake of simplicity), the idea is probably to initialise the biases with std = math.sqrt(1 / fan_in) (cf. LeCun init). By using this value for the boundaries of the uniform distribution, the resulting distribution has std math.sqrt(1 / 3.0 * fan_in), which happens to be the same as the standard deviation for the weights.
A more reasonable default for me would be to use LeCun initialisation (since this has been the go-to standard since 1998). I could also understand Kaiming initialisation as the default, because everyone is using ReLU activation functions everywhere anyway (although I have a feeling that this is not necessarily the case for people working with fully connected networks). Some time ago, I submitted a pull request to adopt LeCun initialisation as the default, but there seems to be little incentive to actually make changes due to backward compatibility.
This probably also explains why pytorch ended up with its own initialisation strategy for fully connected networks. Someone must have forgotten about the fact that a uniform distribution with bounds -b, b has a standard deviation of b / math.sqrt(3) instead of just b. Due to backwards compatibility this got stuck and no-one is willing to make the change to the more widely accepted and standard initialisation.
Bias Initialization
For biases not in the final layer, both Kaiming Initialization and other regimes set the bias to 0
Batch Normalization
This methodology was proposed in 2015 to address issues with the stability of training deeper models.
In effect the idea is to calculate the mean and standard deviation of all inputs across the batch, and use these to center and scale the inputs to unit normal.
An additional two parameters \(\gamma\) and \(\beta\) are learned as part of the optimisation process, and they restore the representation power of the network after the centering and scaling
\[ x_{scaled} = \frac{x_i - \mu}{\sigma + \epsilon} \] (a very small epsilon value is used to avoid divide by zero issues)
\[ y_{out} = \gamma x_{scaled} + \beta\]
One problem with this approach is that it couples together all samples within a batch when it comes to backpropagation. It can also be the source of painful bugs when used in the wild.
Since then, numerous other methods have been proposed to try and address these issues, eg. Layer Normalization
E01: I did not get around to seeing what happens when you initialize all weights and biases to zero. Try this and train the neural net. You might think either that 1) the network trains just fine or 2) the network doesn’t train at all, but actually it is 3) the network trains but only partially, and achieves a pretty bad final performance. Inspect the gradients and activations to figure out what is happening and why the network is only partially training, and what part is being trained exactly.
When all weights and biases are set to zero, we can empirically see that there is not much going on when we train the model. In fact, the only parameters that seem to be changing are the final layer bias. We can reason about this by thinking about the forward and backward passes when all parameters are initialized to 0.
Forward pass
\(y = Wx + b \rightarrow 0 * x + 0 = 0\)
\(tanh(y) = tanh(0) = 0\)
Based on this, we know that the activation produced by each neuron will be 0. This also means that all neurons inputs (except for the first layer) will also be zero.
Backward Pass
Thinking about the backward pass, we can think about what the local derivatives are for the final layer expression, to discover what is going on with the gradients.
\(y = Wx + b\)
\(\frac{\partial y}{\partial W} = x\)
\(\frac{\partial y}{\partial b} = 1\)
Using the chain rule, we will be multiplying these local gradients by the global gradients to get the gradient to update with. The kicker here is that \(\frac{\partial y}{\partial W} = x = 0\), as \(x\) is the output of our previous layer, which we have shown is \(tanh(0) = 0\). This parameter will not be updated from zero. This also means that no gradient will flow through to the earlier layers in the network!
Additionally, as the local gradient wrt. the final layer bias = 1, this means that this parameter will be updated from 0 as training progresses.
import torchimport torch.nn.functional as Fimport matplotlib.pyplot as plt # for making figures%matplotlib inline# read in all the wordswords =open('names.txt', 'r').read().splitlines()words[:8]len(words)# build the vocabulary of characters and mappings to/from integerschars =sorted(list(set(''.join(words))))stoi = {s:i+1for i,s inenumerate(chars)}stoi['.'] =0itos = {i:s for s,i in stoi.items()}vocab_size =len(itos)print(itos)print(vocab_size)# build the datasetblock_size =3# context length: how many characters do we take to predict the next one?def build_dataset(words): X, Y = [], []for w in words: context = [0] * block_sizefor ch in w +'.': ix = stoi[ch] X.append(context) Y.append(ix) context = context[1:] + [ix] # crop and append X = torch.tensor(X) Y = torch.tensor(Y)print(X.shape, Y.shape)return X, Yimport randomrandom.seed(42)random.shuffle(words)n1 =int(0.8*len(words))n2 =int(0.9*len(words))Xtr, Ytr = build_dataset(words[:n1]) # 80%Xdev, Ydev = build_dataset(words[n1:n2]) # 10%Xte, Yte = build_dataset(words[n2:]) # 10%
# Let's train a deeper network# The classes we create here are the same API as nn.Module in PyTorchclass Linear:def__init__(self, fan_in, fan_out, bias=True):self.weight = torch.randn((fan_in, fan_out), generator=g) / fan_in**0.5self.bias = torch.zeros(fan_out) if bias elseNonedef__call__(self, x):self.out = x @self.weightifself.bias isnotNone:self.out +=self.biasreturnself.outdef parameters(self):return [self.weight] + ([] ifself.bias isNoneelse [self.bias])class BatchNorm1d:def__init__(self, dim, eps=1e-5, momentum=0.1):self.eps = epsself.momentum = momentumself.training =True# parameters (trained with backprop)self.gamma = torch.ones(dim)self.beta = torch.zeros(dim)# buffers (trained with a running 'momentum update')self.running_mean = torch.zeros(dim)self.running_var = torch.ones(dim)def__call__(self, x):# calculate the forward passifself.training: xmean = x.mean(0, keepdim=True) # batch mean xvar = x.var(0, keepdim=True) # batch varianceelse: xmean =self.running_mean xvar =self.running_var xhat = (x - xmean) / torch.sqrt(xvar +self.eps) # normalize to unit varianceself.out =self.gamma * xhat +self.beta# update the buffersifself.training:with torch.no_grad():self.running_mean = (1-self.momentum) *self.running_mean +self.momentum * xmeanself.running_var = (1-self.momentum) *self.running_var +self.momentum * xvarreturnself.outdef parameters(self):return [self.gamma, self.beta]class Tanh:def__call__(self, x):self.out = torch.tanh(x)returnself.outdef parameters(self):return []n_embd =10# the dimensionality of the character embedding vectorsn_hidden =100# the number of neurons in the hidden layer of the MLPg = torch.Generator().manual_seed(2147483647) # for reproducibilityC = torch.randn((vocab_size, n_embd), generator=g)# layers = [# Linear(n_embd * block_size, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),# Linear( n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),# Linear( n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),# Linear( n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),# Linear( n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),# Linear( n_hidden, vocab_size, bias=False), BatchNorm1d(vocab_size),# ]layers = [ Linear(n_embd * block_size, n_hidden), Tanh(), Linear( n_hidden, n_hidden), Tanh(), Linear( n_hidden, n_hidden), Tanh(), Linear( n_hidden, n_hidden), Tanh(), Linear( n_hidden, n_hidden), Tanh(), Linear( n_hidden, vocab_size),]with torch.no_grad():# last layer: make less confident layers[-1].bias *=0#0.1 layers[-1].weight *=0#0.1# all other layers: apply gainfor layer in layers[:-1]:ifisinstance(layer, Linear): layer.weight *=0#1.0 #5/3parameters = [C] + [p for layer in layers for p in layer.parameters()]print(sum(p.nelement() for p in parameters)) # number of parameters in totalfor p in parameters: p.requires_grad =True
46497
# same optimization as last timemax_steps =200000batch_size =32lossi = []ud = []wdist = []for i inrange(max_steps):# minibatch construct ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g) Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y# forward pass emb = C[Xb] # embed the characters into vectors x = emb.view(emb.shape[0], -1) # concatenate the vectorsfor layer in layers: x = layer(x) loss = F.cross_entropy(x, Yb) # loss function# backward passfor layer in layers: layer.out.retain_grad() # AFTER_DEBUG: would take out retain_graphfor p in parameters: p.grad =None loss.backward()# update lr =0.1if i <150000else0.01# step learning rate decayfor p in parameters: p.data +=-lr * p.grad# track statsif i %10000==0: # print every once in a whileprint(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}') lossi.append(loss.log10().item())with torch.no_grad(): ud.append([((lr*p.grad).std() / p.data.std()).log10().item() for p in parameters])# wdist.append([for l in layers])if i >=1000:break# AFTER_DEBUG: would take out obviously to run full optimization
0/ 200000: 3.2958
# visualize histogramsplt.figure(figsize=(20, 4)) # width and height of the plotlegends = []for i, layer inenumerate(layers[:-1]): # note: exclude the output layerifisinstance(layer, Tanh): t = layer.outprint('layer %d (%10s): mean %+.2f, std %.2f, saturated: %.2f%%'% (i, layer.__class__.__name__, t.mean(), t.std(), (t.abs() >0.97).float().mean()*100)) hy, hx = torch.histogram(t, density=True) plt.plot(hx[:-1].detach(), hy.detach()) legends.append(f'layer {i} ({layer.__class__.__name__})')plt.legend(legends);plt.title('activation distribution');
layer 1 ( Tanh): mean +0.00, std 0.00, saturated: 0.00%
layer 3 ( Tanh): mean +0.00, std 0.00, saturated: 0.00%
layer 5 ( Tanh): mean +0.00, std 0.00, saturated: 0.00%
layer 7 ( Tanh): mean +0.00, std 0.00, saturated: 0.00%
layer 9 ( Tanh): mean +0.00, std 0.00, saturated: 0.00%
# visualize histogramsplt.figure(figsize=(20, 4)) # width and height of the plotlegends = []for i, layer inenumerate(layers): # note: exclude the output layerifisinstance(layer, Linear): t = layer.weightprint('layer %d (%10s): mean %+.2f, std %.2f'% (i, layer.__class__.__name__, t.mean(), t.std())) hy, hx = torch.histogram(t, density=True) plt.plot(hx[:-1].detach(), hy.detach()) legends.append(f'layer {i} ({layer.__class__.__name__})')plt.legend(legends);plt.title('weights distribution');
layer 0 ( Linear): mean +0.00, std 0.00
layer 2 ( Linear): mean +0.00, std 0.00
layer 4 ( Linear): mean +0.00, std 0.00
layer 6 ( Linear): mean +0.00, std 0.00
layer 8 ( Linear): mean +0.00, std 0.00
layer 10 ( Linear): mean +0.00, std 0.00
# visualize histogramsplt.figure(figsize=(20, 4)) # width and height of the plotlegends = []for i, layer inenumerate(layers): ifisinstance(layer, Linear): t = layer.biasprint('layer %d (%10s): mean %+.2f, std %.2f'% (i, layer.__class__.__name__, t.mean(), t.std())) hy, hx = torch.histogram(t, density=True) plt.plot(hx[:-1].detach(), hy.detach()) legends.append(f'layer {i} ({layer.__class__.__name__})')plt.legend(legends);plt.title('bias distribution');
layer 0 ( Linear): mean +0.00, std 0.00
layer 2 ( Linear): mean +0.00, std 0.00
layer 4 ( Linear): mean +0.00, std 0.00
layer 6 ( Linear): mean +0.00, std 0.00
layer 8 ( Linear): mean +0.00, std 0.00
layer 10 ( Linear): mean -0.00, std 0.92
# we can see that this is the only parameter with non-zero valueslayers[-1].bias
# visualize histogramsplt.figure(figsize=(20, 4)) # width and height of the plotlegends = []for i, layer inenumerate(layers[:-1]): # note: exclude the output layerifisinstance(layer, Tanh): t = layer.out.gradprint('layer %d (%10s): mean %+f, std %e'% (i, layer.__class__.__name__, t.mean(), t.std())) hy, hx = torch.histogram(t, density=True) plt.plot(hx[:-1].detach(), hy.detach()); legends.append(f'layer {i} ({layer.__class__.__name__}) ')plt.legend(legends);plt.title('gradient distribution');
layer 1 ( Tanh): mean +0.000000, std 0.000000e+00
layer 3 ( Tanh): mean +0.000000, std 0.000000e+00
layer 5 ( Tanh): mean +0.000000, std 0.000000e+00
layer 7 ( Tanh): mean +0.000000, std 0.000000e+00
layer 9 ( Tanh): mean +0.000000, std 0.000000e+00
# visualize histogramsplt.figure(figsize=(20, 4)) # width and height of the plotlegends = []for i,p inenumerate(parameters): t = p.gradif p.ndim ==2:print('weight %10s | mean %+f | std %e | grad:data ratio %e'% (tuple(p.shape), t.mean(), t.std(), t.std() / p.std())) hy, hx = torch.histogram(t, density=True) plt.plot(hx[:-1].detach(), hy.detach()) legends.append(f'{i}{tuple(p.shape)}')plt.legend(legends)plt.title('weights gradient distribution');
weight (27, 10) | mean +0.000000 | std 0.000000e+00 | grad:data ratio 0.000000e+00
weight (30, 100) | mean +0.000000 | std 0.000000e+00 | grad:data ratio nan
weight (100, 100) | mean +0.000000 | std 0.000000e+00 | grad:data ratio nan
weight (100, 100) | mean +0.000000 | std 0.000000e+00 | grad:data ratio nan
weight (100, 100) | mean +0.000000 | std 0.000000e+00 | grad:data ratio nan
weight (100, 100) | mean +0.000000 | std 0.000000e+00 | grad:data ratio nan
weight (100, 27) | mean +0.000000 | std 0.000000e+00 | grad:data ratio nan
plt.figure(figsize=(20, 4))legends = []for i,p inenumerate(parameters): plt.plot([ud[j][i] for j inrange(len(ud))]) legends.append('param %d'% i)plt.plot([0, len(ud)], [-3, -3], 'k') # these ratios should be ~1e-3, indicate on plotplt.legend(legends);plt.title('ratio of stdev of weight updates to weight values, log10 scale');
We can also see that this bias is merely adopting the frequency distribution of letters in the dataset 🤯
from string import ascii_lowercaseout = []context = [0] * block_sizeemb = C[torch.tensor([context])] # (1,block_size,n_embd)x = emb.view(emb.shape[0], -1) # concatenate the vectorsfor layer in layers: x = layer(x)logits = xprobs = F.softmax(logits, dim=1)letters_int =list(range(probs.shape[1]))fig, ax = plt.subplots(1,1)plt.bar(letters_int, probs.tolist()[0]);ax.set_xticks(letters_int)ax.set_xticklabels(["."] +list(ascii_lowercase))plt.title("probabilities from the model from an input of '...'");
from collections import Counterwords_ending = [word +"."for word in words]counts = Counter("".join(words_ending)) normed_counts =sorted([(k, v /len("".join(words_ending))) for k, v in counts.items()])fig, ax = plt.subplots(1,1)plt.bar(letters_int, [y for x, y in normed_counts]);ax.set_xticks(letters_int)ax.set_xticklabels(["."] +list(ascii_lowercase));plt.title("actual frequency of letters in the training data");
The model also produces rubbish!
# sample from the modelg = torch.Generator().manual_seed(2147483647+10)for _ inrange(20): out = [] context = [0] * block_size # initialize with all ...whileTrue:# forward pass the neural net emb = C[torch.tensor([context])] # (1,block_size,n_embd) x = emb.view(emb.shape[0], -1) # concatenate the vectorsfor layer in layers: x = layer(x) logits = x probs = F.softmax(logits, dim=1)# sample from the distribution ix = torch.multinomial(probs, num_samples=1, generator=g).item()# shift the context window and track the samples context = context[1:] + [ix] out.append(ix)# if we sample the special '.' token, breakif ix ==0:breakprint(''.join(itos[i] for i in out)) # decode and print the generated word
ernaaimyazhseelmnd.
ryal.
.
etmrsjejdrleg.
adeeedieliihemy.
oealeieiseananarneltzimhlkaa.
n.
sadbvrgahimies.
.
n.
ir.
eelklxnteuofnu.
amnedar.
yidle.
ehs.
laajhysknyaa.
hya.
nalyaisun.
zajelveunen.
.
E02: BatchNorm, unlike other normalization layers like LayerNorm/GroupNorm etc. has the big advantage that after training, the batchnorm gamma/beta can be “folded into” the weights of the preceeding Linear layers, effectively erasing the need to forward it at test time. Set up a small 3-layer MLP with batchnorms, train the network, then “fold” the batchnorm gamma/beta into the preceeding Linear layer’s W,b by creating a new W2, b2 and erasing the batch norm. Verify that this gives the same forward pass during inference. i.e. we see that the batchnorm is there just for stabilizing the training, and can be thrown out after training is done! pretty cool.
First we can train the network with BatchNorm and run inference with it
n_embd =10# the dimensionality of the character embedding vectorsn_hidden =100# the number of neurons in the hidden layer of the MLPg = torch.Generator().manual_seed(2147483647) # for reproducibilityC = torch.randn((vocab_size, n_embd), generator=g)layers = [ Linear(n_embd * block_size, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(), Linear( n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(), Linear( n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(), Linear( n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(), Linear( n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(), Linear( n_hidden, vocab_size, bias=False), BatchNorm1d(vocab_size),]with torch.no_grad():# last layer: make less confident layers[-1].gamma *=0.1# all other layers: apply gainfor layer in layers[:-1]:ifisinstance(layer, Linear): layer.weight *=1.0parameters = [C] + [p for layer in layers for p in layer.parameters()]for p in parameters: p.requires_grad =True
# same optimization as last timemax_steps =200000batch_size =32for i inrange(max_steps):# minibatch construct ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g) Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y# forward pass emb = C[Xb] # embed the characters into vectors x = emb.view(emb.shape[0], -1) # concatenate the vectorsfor layer in layers: x = layer(x) loss = F.cross_entropy(x, Yb) # loss function# backward passfor layer in layers: layer.out.retain_grad() # AFTER_DEBUG: would take out retain_graphfor p in parameters: p.grad =None loss.backward()# update lr =0.1if i <150000else0.01# step learning rate decayfor p in parameters: p.data +=-lr * p.gradif i >=1000:break# AFTER_DEBUG: would take out obviously to run full optimization
# setting the batchnorm layers to inference mode is key! Otherwise we # overwrite the running mean and variancefor layer in layers: layer.training =False# sample from the modelg = torch.Generator().manual_seed(2147483647+10)for _ inrange(1): out = [] context = [0] * block_size # initialize with all ...whileTrue:# forward pass the neural net emb = C[torch.tensor([context])] # (1,block_size,n_embd) x = emb.view(emb.shape[0], -1) # concatenate the vectorsfor layer in layers: x = layer(x) logits = x probs = F.softmax(logits, dim=1)# sample from the distribution ix = torch.multinomial(probs, num_samples=1, generator=g).item()# shift the context window and track the samples context = context[1:] + [ix] out.append(ix)# if we sample the special '.' token, breakif ix ==0:breakprint(''.join(itos[i] for i in out)) # decode and print the generated word
mria.
We can do some algebra to work out how to combine Linear and Batchnorm layers:
# now we combine each batchnorm layer with its prior linear layerfor i, layer inenumerate(layers):if i +1==len(layers):# skip last iterationbreakifisinstance(layer, Linear): bn = layers[i +1] w = layer.weight b = layer.bias if layer.bias else0 layer.weight = bn.gamma * w / torch.sqrt(bn.running_var + bn.eps) layer.bias = bn.gamma * (b - bn.running_mean) / torch.sqrt(bn.running_var + bn.eps) + bn.beta
# remove the BatchNorm layerslayers = [layer for layer in layers ifnotisinstance(layer, BatchNorm1d)]
# sample from the modelg = torch.Generator().manual_seed(2147483647+10)for _ inrange(1): out = [] context = [0] * block_size # initialize with all ...whileTrue:# forward pass the neural net emb = C[torch.tensor([context])] # (1,block_size,n_embd) x = emb.view(emb.shape[0], -1) # concatenate the vectorsfor layer in layers: x = layer(x) logits = x probs = F.softmax(logits, dim=1)# sample from the distribution ix = torch.multinomial(probs, num_samples=1, generator=g).item()# shift the context window and track the samples context = context[1:] + [ix] out.append(ix)# if we sample the special '.' token, breakif ix ==0:breakprint(''.join(itos[i] for i in out)) # decode and print the generated word