# there no change change in the first several cells from last lecture
makemore: becoming a backprop ninja
swole doge style
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline
# read in all the words
= open('names.txt', 'r').read().splitlines()
words print(len(words))
print(max(len(w) for w in words))
print(words[:8])
32033
15
['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']
# build the vocabulary of characters and mappings to/from integers
= sorted(list(set(''.join(words))))
chars = {s:i+1 for i,s in enumerate(chars)}
stoi '.'] = 0
stoi[= {i:s for s,i in stoi.items()}
itos = len(itos)
vocab_size print(itos)
print(vocab_size)
{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27
# build the dataset
= 3 # context length: how many characters do we take to predict the next one?
block_size
def build_dataset(words):
= [], []
X, Y
for w in words:
= [0] * block_size
context for ch in w + '.':
= stoi[ch]
ix
X.append(context)
Y.append(ix)= context[1:] + [ix] # crop and append
context
= torch.tensor(X)
X = torch.tensor(Y)
Y print(X.shape, Y.shape)
return X, Y
import random
42)
random.seed(
random.shuffle(words)= int(0.8*len(words))
n1 = int(0.9*len(words))
n2
= build_dataset(words[:n1]) # 80%
Xtr, Ytr = build_dataset(words[n1:n2]) # 10%
Xdev, Ydev = build_dataset(words[n2:]) # 10% Xte, Yte
torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])
# ok biolerplate done, now we get to the action:
# utility function we will use later when comparing manual gradients to PyTorch gradients
def cmp(s, dt, t):
= torch.all(dt == t.grad).item()
ex = torch.allclose(dt, t.grad)
app = (dt - t.grad).abs().max().item()
maxdiff print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff} | shapes: {dt.shape} vs. {t.shape}')
= 10 # the dimensionality of the character embedding vectors
n_embd = 64 # the number of neurons in the hidden layer of the MLP
n_hidden
= torch.Generator().manual_seed(2147483647) # for reproducibility
g = torch.randn((vocab_size, n_embd), generator=g)
C # Layer 1
= torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)
W1 = torch.randn(n_hidden, generator=g) * 0.1 # using b1 just for fun, it's useless because of BN
b1 # Layer 2
= torch.randn((n_hidden, vocab_size), generator=g) * 0.1
W2 = torch.randn(vocab_size, generator=g) * 0.1
b2 # BatchNorm parameters
= torch.randn((1, n_hidden))*0.1 + 1.0
bngain = torch.randn((1, n_hidden))*0.1
bnbias
# Note: I am initializating many of these parameters in non-standard ways
# because sometimes initializating with e.g. all zeros could mask an incorrect
# implementation of the backward pass.
= [C, W1, b1, W2, b2, bngain, bnbias]
parameters = dict(
parameters_named =C,
C=W1,
W1=b1,
b1=W2,
W2=b2,
b2=bngain,
bngain=bnbias,
bnbias
)print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
= True p.requires_grad
4137
= 32
batch_size = batch_size # a shorter variable also, for convenience
n # construct a minibatch
= torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
ix = Xtr[ix], Ytr[ix] # batch X,Y Xb, Yb
# forward pass, "chunkated" into smaller steps that are possible to backward one at a time
= C[Xb] # embed the characters into vectors
emb = emb.view(emb.shape[0], -1) # concatenate the vectors
embcat # Linear layer 1
= embcat @ W1 + b1 # hidden layer pre-activation
hprebn # BatchNorm layer
= 1/n*hprebn.sum(0, keepdim=True)
bnmeani = hprebn - bnmeani
bndiff = bndiff**2
bndiff2 = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)
bnvar = (bnvar + 1e-5)**-0.5
bnvar_inv = bndiff * bnvar_inv
bnraw = bngain * bnraw + bnbias
hpreact # Non-linearity
= torch.tanh(hpreact) # hidden layer
h # Linear layer 2
= h @ W2 + b2 # output layer
logits # cross entropy loss (same as F.cross_entropy(logits, Yb))
= logits.max(1, keepdim=True).values
logit_maxes = logits - logit_maxes # subtract max for numerical stability
norm_logits = norm_logits.exp()
counts = counts.sum(1, keepdims=True)
counts_sum = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
counts_sum_inv = counts * counts_sum_inv
probs = probs.log()
logprobs = -logprobs[range(n), Yb].mean()
loss
# PyTorch backward pass
for p in parameters:
= None
p.grad for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way
norm_logits, logit_maxes, logits, h, hpreact, bnraw,
bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
embcat, emb]:
t.retain_grad()
loss.backward() loss
tensor(3.3390, grad_fn=<NegBackward0>)
from torchviz import make_dot
=parameters_named) make_dot(loss, params
# Exercise 1: backprop through the whole thing manually,
# backpropagating through exactly all of the variables
# as they are defined in the forward pass above, one by one
= torch.zeros_like(logprobs)
dlogprobs # get all rows, and index into the correct column for the labels
range(n), Yb] = -1/n
dlogprobs[= dlogprobs * 1 / probs
dprobs = (dprobs * counts).sum(1, keepdim=True)
dcounts_sum_inv = dcounts_sum_inv * -counts_sum**-2
dcounts_sum = counts_sum_inv * dprobs
dcounts # gradients flow through other vertex!
+= dcounts_sum * 1
dcounts
= dcounts * norm_logits.exp()
dnorm_logits = (-dnorm_logits).sum(1, keepdim=True)
dlogit_maxes
= dnorm_logits.clone()
dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes
dlogits
= dlogits @ W2.T
dh
= h.T @ dlogits
dW2 = dlogits.sum(0)
db2
= dh * (1 - h**2)
dhpreact
= (dhpreact * bnraw).sum(0, keepdim=True)
dbngain = dhpreact.sum(0, keepdim=True)
dbnbias = dhpreact * bngain
dbnraw
= (dbnraw * bndiff).sum(0, keepdim=True)
dbnvar_inv = bnvar_inv * dbnraw
dbndiff = dbnvar_inv * -.5 * (bnvar + 1e-5)**-1.5
dbnvar
= torch.ones_like(bndiff2) * 1 / (n-1) * dbnvar
dbndiff2 += (2*bndiff) * dbndiff2
dbndiff = dbndiff.clone()
dhprebn
= (-dbndiff).sum(0, keepdim=True)
dbnmeani += dbnmeani * torch.ones_like(hprebn) * 1 / n
dhprebn
= dhprebn @ W1.T
dembcat = embcat.T @ dhprebn
dW1 = dhprebn.sum(0)
db1 = torch.ones_like(emb) * dembcat.view(n, 3, -1)
demb
= torch.zeros_like(C)
dC for k in range(Xb.shape[0]):
for j in range(Xb.shape[1]):
= Xb[k,j]
ix += demb[k,j]
dC[ix]
cmp('logprobs', dlogprobs, logprobs)
cmp('probs', dprobs, probs)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
cmp('counts_sum', dcounts_sum, counts_sum)
cmp('counts', dcounts, counts)
cmp('norm_logits', dnorm_logits, norm_logits)
cmp('logit_maxes', dlogit_maxes, logit_maxes)
cmp('logits', dlogits, logits)
cmp('h', dh, h)
cmp('W2', dW2, W2)
cmp('b2', db2, b2)
cmp('hpreact', dhpreact, hpreact)
cmp('bngain', dbngain, bngain)
cmp('bnbias', dbnbias, bnbias)
cmp('bnraw', dbnraw, bnraw)
cmp('bnvar_inv', dbnvar_inv, bnvar_inv)
cmp('bnvar', dbnvar, bnvar)
cmp('bndiff2', dbndiff2, bndiff2)
cmp('bndiff', dbndiff, bndiff)
cmp('bnmeani', dbnmeani, bnmeani)
cmp('hprebn', dhprebn, hprebn)
cmp('embcat', dembcat, embcat)
cmp('W1', dW1, W1)
cmp('b1', db1, b1)
cmp('emb', demb, emb)
cmp('C', dC, C)
logprobs | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([32, 27]) vs. torch.Size([32, 27])
probs | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([32, 27]) vs. torch.Size([32, 27])
counts_sum_inv | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([32, 1]) vs. torch.Size([32, 1])
counts_sum | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([32, 1]) vs. torch.Size([32, 1])
counts | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([32, 27]) vs. torch.Size([32, 27])
norm_logits | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([32, 27]) vs. torch.Size([32, 27])
logit_maxes | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([32, 1]) vs. torch.Size([32, 1])
logits | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([32, 27]) vs. torch.Size([32, 27])
h | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([32, 64]) vs. torch.Size([32, 64])
W2 | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([64, 27]) vs. torch.Size([64, 27])
b2 | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([27]) vs. torch.Size([27])
hpreact | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([32, 64]) vs. torch.Size([32, 64])
bngain | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([1, 64]) vs. torch.Size([1, 64])
bnbias | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([1, 64]) vs. torch.Size([1, 64])
bnraw | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([32, 64]) vs. torch.Size([32, 64])
bnvar_inv | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([1, 64]) vs. torch.Size([1, 64])
bnvar | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([1, 64]) vs. torch.Size([1, 64])
bndiff2 | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([32, 64]) vs. torch.Size([32, 64])
bndiff | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([32, 64]) vs. torch.Size([32, 64])
bnmeani | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([1, 64]) vs. torch.Size([1, 64])
hprebn | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([32, 64]) vs. torch.Size([32, 64])
embcat | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([32, 30]) vs. torch.Size([32, 30])
W1 | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([30, 64]) vs. torch.Size([30, 64])
b1 | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([64]) vs. torch.Size([64])
emb | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([32, 3, 10]) vs. torch.Size([32, 3, 10])
C | exact: True | approximate: True | maxdiff: 0.0 | shapes: torch.Size([27, 10]) vs. torch.Size([27, 10])
# Exercise 2: backprop through cross_entropy but all in one go
# to complete this challenge look at the mathematical expression of the loss,
# take the derivative, simplify the expression, and just write it out
# forward pass
# before:
# logit_maxes = logits.max(1, keepdim=True).values
# norm_logits = logits - logit_maxes # subtract max for numerical stability
# counts = norm_logits.exp()
# counts_sum = counts.sum(1, keepdims=True)
# counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
# probs = counts * counts_sum_inv
# logprobs = probs.log()
# loss = -logprobs[range(n), Yb].mean()
# now:
= F.cross_entropy(logits, Yb)
loss_fast print(loss_fast.item(), 'diff:', (loss_fast - loss).item())
3.3377387523651123 diff: 2.384185791015625e-07
# backward pass
= F.softmax(logits, 1)
dlogits range(n), Yb] -= 1
dlogits[/= n
dlogits
cmp('logits', dlogits, logits) # I can only get approximate to be true, my maxdiff is 6e-9
logits | exact: False | approximate: True | maxdiff: 5.122274160385132e-09
logits.shape, Yb.shape
(torch.Size([32, 27]), torch.Size([32]))
1)[0] F.softmax(logits,
tensor([0.0719, 0.0881, 0.0193, 0.0493, 0.0169, 0.0864, 0.0226, 0.0356, 0.0165,
0.0314, 0.0364, 0.0383, 0.0424, 0.0279, 0.0317, 0.0142, 0.0085, 0.0195,
0.0152, 0.0555, 0.0450, 0.0236, 0.0250, 0.0662, 0.0616, 0.0269, 0.0239],
grad_fn=<SelectBackward0>)
0] * n dlogits[
tensor([ 0.0719, 0.0881, 0.0193, 0.0493, 0.0169, 0.0864, 0.0226, 0.0356,
-0.9835, 0.0314, 0.0364, 0.0383, 0.0424, 0.0279, 0.0317, 0.0142,
0.0085, 0.0195, 0.0152, 0.0555, 0.0450, 0.0236, 0.0250, 0.0662,
0.0616, 0.0269, 0.0239], grad_fn=<MulBackward0>)
0].sum() dlogits[
tensor(1.3970e-09, grad_fn=<SumBackward0>)
=(4, 4))
plt.figure(figsize='gray') plt.imshow(dlogits.detach(), cmap
# Exercise 3: backprop through batchnorm but all in one go
# to complete this challenge look at the mathematical expression of the output of batchnorm,
# take the derivative w.r.t. its input, simplify the expression, and just write it out
# forward pass
# before:
# bnmeani = 1/n*hprebn.sum(0, keepdim=True)
# bndiff = hprebn - bnmeani
# bndiff2 = bndiff**2
# bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)
# bnvar_inv = (bnvar + 1e-5)**-0.5
# bnraw = bndiff * bnvar_inv
# hpreact = bngain * bnraw + bnbias
# now:
= bngain * (hprebn - hprebn.mean(0, keepdim=True)) / torch.sqrt(hprebn.var(0, keepdim=True, unbiased=True) + 1e-5) + bnbias
hpreact_fast print('max diff:', (hpreact_fast - hpreact).abs().max())
max diff: tensor(4.7684e-07, grad_fn=<MaxBackward1>)
# backward pass
# before we had:
# dbnraw = bngain * dhpreact
# dbndiff = bnvar_inv * dbnraw
# dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)
# dbnvar = (-0.5*(bnvar + 1e-5)**-1.5) * dbnvar_inv
# dbndiff2 = (1.0/(n-1))*torch.ones_like(bndiff2) * dbnvar
# dbndiff += (2*bndiff) * dbndiff2
# dhprebn = dbndiff.clone()
# dbnmeani = (-dbndiff).sum(0)
# dhprebn += 1.0/n * (torch.ones_like(hprebn) * dbnmeani)
# calculate dhprebn given dhpreact (i.e. backprop through the batchnorm)
# (you'll also need to use some of the variables from the forward pass up above)
cmp('hprebn', dhprebn, hprebn) # I can only get approximate to be true, my maxdiff is 9e-10
hprebn | exact: False | approximate: True | maxdiff: 9.313225746154785e-10
sum(0).shape dhprebn.shape, bngain.shape, bnvar_inv.shape, dbnraw.shape, dbnraw.
(torch.Size([32, 64]),
torch.Size([1, 64]),
torch.Size([1, 64]),
torch.Size([32, 64]),
torch.Size([64]))
# Exercise 4: putting it all together!
# Train the MLP neural net with your own backward pass
# init
= 10 # the dimensionality of the character embedding vectors
n_embd = 200 # the number of neurons in the hidden layer of the MLP
n_hidden
= torch.Generator().manual_seed(2147483647) # for reproducibility
g = torch.randn((vocab_size, n_embd), generator=g)
C # Layer 1
= torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)
W1 = torch.randn(n_hidden, generator=g) * 0.1
b1 # Layer 2
= torch.randn((n_hidden, vocab_size), generator=g) * 0.1
W2 = torch.randn(vocab_size, generator=g) * 0.1
b2 # BatchNorm parameters
= torch.randn((1, n_hidden))*0.1 + 1.0
bngain = torch.randn((1, n_hidden))*0.1
bnbias
= [C, W1, b1, W2, b2, bngain, bnbias]
parameters print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
= True
p.requires_grad
# same optimization as last time
= 200000
max_steps = 32
batch_size = batch_size # convenience
n = []
lossi
# use this context manager for efficiency once your backward pass is written (TODO)
with torch.no_grad():
# kick off optimization
for i in range(max_steps):
# minibatch construct
= torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
ix = Xtr[ix], Ytr[ix] # batch X,Y
Xb, Yb
# forward pass
= C[Xb] # embed the characters into vectors
emb = emb.view(emb.shape[0], -1) # concatenate the vectors
embcat # Linear layer
= embcat @ W1 + b1 # hidden layer pre-activation
hprebn # BatchNorm layer
# -------------------------------------------------------------
= hprebn.mean(0, keepdim=True)
bnmean = hprebn.var(0, keepdim=True, unbiased=True)
bnvar = (bnvar + 1e-5)**-0.5
bnvar_inv = (hprebn - bnmean) * bnvar_inv
bnraw = bngain * bnraw + bnbias
hpreact # -------------------------------------------------------------
# Non-linearity
= torch.tanh(hpreact) # hidden layer
h = h @ W2 + b2 # output layer
logits = F.cross_entropy(logits, Yb) # loss function
loss
# backward pass
for p in parameters:
= None
p.grad #loss.backward() # use this for correctness comparisons, delete it later!
# manual backprop! #swole_doge_meme
# -----------------
= F.softmax(logits, 1)
dlogits range(n), Yb] -= 1
dlogits[/= n
dlogits # 2nd layer backprop
= dlogits @ W2.T
dh = h.T @ dlogits
dW2 = dlogits.sum(0)
db2 # tanh
= (1.0 - h**2) * dh
dhpreact # batchnorm backprop
= (bnraw * dhpreact).sum(0, keepdim=True)
dbngain = dhpreact.sum(0, keepdim=True)
dbnbias = bngain*bnvar_inv/n * (n*dhpreact - dhpreact.sum(0) - n/(n-1)*bnraw*(dhpreact*bnraw).sum(0))
dhprebn # 1st layer
= dhprebn @ W1.T
dembcat = embcat.T @ dhprebn
dW1 = dhprebn.sum(0)
db1 # embedding
= dembcat.view(emb.shape)
demb = torch.zeros_like(C)
dC for k in range(Xb.shape[0]):
for j in range(Xb.shape[1]):
= Xb[k,j]
ix += demb[k,j]
dC[ix] = [dC, dW1, db1, dW2, db2, dbngain, dbnbias]
grads # -----------------
# update
= 0.1 if i < 100000 else 0.01 # step learning rate decay
lr for p, grad in zip(parameters, grads):
#p.data += -lr * p.grad # old way of cheems doge (using PyTorch grad from .backward())
+= -lr * grad # new way of swole doge TODO: enable
p.data
# track stats
if i % 10000 == 0: # print every once in a while
print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')
lossi.append(loss.log10().item())
# if i >= 100: # TODO: delete early breaking when you're ready to train the full net
# break
12297
0/ 200000: 3.7805
10000/ 200000: 2.1775
20000/ 200000: 2.3957
30000/ 200000: 2.5032
40000/ 200000: 2.0065
50000/ 200000: 2.3873
60000/ 200000: 2.3378
70000/ 200000: 2.0640
80000/ 200000: 2.3497
90000/ 200000: 2.1093
100000/ 200000: 1.9132
110000/ 200000: 2.2229
120000/ 200000: 1.9912
130000/ 200000: 2.4441
140000/ 200000: 2.3198
150000/ 200000: 2.1857
160000/ 200000: 2.0296
170000/ 200000: 1.8391
180000/ 200000: 2.0436
190000/ 200000: 1.9200
# useful for checking your gradients
# for p,g in zip(parameters, grads):
# cmp(str(tuple(p.shape)), g, p)
# calibrate the batch norm at the end of training
with torch.no_grad():
# pass the training set through
= C[Xtr]
emb = emb.view(emb.shape[0], -1)
embcat = embcat @ W1 + b1
hpreact # measure the mean/std over the entire training set
= hpreact.mean(0, keepdim=True)
bnmean = hpreact.var(0, keepdim=True, unbiased=True) bnvar
# evaluate train and val loss
@torch.no_grad() # this decorator disables gradient tracking
def split_loss(split):
= {
x,y 'train': (Xtr, Ytr),
'val': (Xdev, Ydev),
'test': (Xte, Yte),
}[split]= C[x] # (N, block_size, n_embd)
emb = emb.view(emb.shape[0], -1) # concat into (N, block_size * n_embd)
embcat = embcat @ W1 + b1
hpreact = bngain * (hpreact - bnmean) * (bnvar + 1e-5)**-0.5 + bnbias
hpreact = torch.tanh(hpreact) # (N, n_hidden)
h = h @ W2 + b2 # (N, vocab_size)
logits = F.cross_entropy(logits, y)
loss print(split, loss.item())
'train')
split_loss('val') split_loss(
train 2.070523500442505
val 2.109893560409546
# I achieved:
# train 2.0718822479248047
# val 2.1162495613098145
# sample from the model
= torch.Generator().manual_seed(2147483647 + 10)
g
for _ in range(20):
= []
out = [0] * block_size # initialize with all ...
context while True:
# ------------
# forward pass:
# Embedding
= C[torch.tensor([context])] # (1,block_size,d)
emb = emb.view(emb.shape[0], -1) # concat into (N, block_size * n_embd)
embcat = embcat @ W1 + b1
hpreact = bngain * (hpreact - bnmean) * (bnvar + 1e-5)**-0.5 + bnbias
hpreact = torch.tanh(hpreact) # (N, n_hidden)
h = h @ W2 + b2 # (N, vocab_size)
logits # ------------
# Sample
= F.softmax(logits, dim=1)
probs = torch.multinomial(probs, num_samples=1, generator=g).item()
ix = context[1:] + [ix]
context
out.append(ix)if ix == 0:
break
print(''.join(itos[i] for i in out))
carmahzamille.
khi.
mreigeet.
khalaysie.
mahnen.
delynn.
jareen.
nellara.
chaiiv.
kaleigh.
ham.
joce.
quinn.
shoison.
jadiquintero.
dearyxi.
jace.
pinsley.
dae.
iia.