from fastai import *
from fastai.vision.all import * 
from torch import optim
def load_data(folder):
    dataList = []
    labelList = []
    for num in range(10):
        data_path = (path/folder/f'{num}').ls().sorted() #getting path
        
        stackedData = torch.stack([tensor(Image.open(o)) for o in data_path]) #Open each image and stack them
        stackedData = stackedData.float()/255.0 #squishing between 0-1
        
        dataList.append(stackedData) #adding to dataList
        labelList.extend([num]*len(data_path))#extending labelList
    
    #Convert so that each image data is in each row
    train_x = torch.cat(dataList).view(-1, 28*28) 
    train_y = tensor(labelList)
    
    return train_x, train_y
path = untar_data(URLs.MNIST)
path

x_train, y_train = load_data("training")
x_valid, y_valid = load_data("testing")
dls_train = DataLoader(list(zip(x_train, y_train)), 64, shuffle=True, drop_last=True) 
dls_test = DataLoader(list(zip(x_valid, y_valid)), 128, shuffle=False) 
model = nn.Sequential(nn.Linear(28*28,50), 
                      nn.ReLU(), 
                      nn.Linear(50,10))
opt = optim.Adam(model.parameters())
def accuracy(preds, yb):
    return (torch.argmax(preds, dim=1) == yb).float().mean()


def fit(epochs):
    for i in range(epochs):
        model.train()
        
        for xb,yb in dls_train:
            pred = model(xb)
            loss = F.cross_entropy(pred, yb)
            loss.backward()
            
            opt.step()
            opt.zero_grad()
            
        
        model.eval()
        with torch.no_grad():
            loss,acc = 0.0,0.0
            
            for xb,yb in dls_test:
                pred = model(xb)
                
                loss += F.cross_entropy(pred, yb)
                acc += accuracy(pred, yb)
                
        nv = len(dls_test)
        print(i, loss/nv, acc/nv)
fit(5)
0 tensor(0.1057) tensor(0.9697)
1 tensor(0.0993) tensor(0.9717)
2 tensor(0.0972) tensor(0.9716)
3 tensor(0.0962) tensor(0.9718)
4 tensor(0.0932) tensor(0.9720)