Goal

Lets create a bear classifier that can classify Teddy, Black, and Grizzly bears.

First lets get the training data

from pathlib import Path
root = Path().cwd()/"images"

#rmtree(root) #Deletes all previous images

from jmd_imagescraper.core import *
duckduckgo_search(root, "Grizzly", "Grizzly bears", max_results=300)
duckduckgo_search(root, "Black", "Black bears", max_results=300)
duckduckgo_search(root, "Teddy", "Teddy bears", max_results=300)
#duckduckgo_search(root, "Random", "Random images", max_results=300)

Lets view the data

from jmd_imagescraper.imagecleaner import *

display_image_cleaner(root)

Creating datablock

bears = DataBlock(
    blocks=(ImageBlock, CategoryBlock),  #Independent are images, dependent is the categories
    get_items=get_image_files, 
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parent_label,
    item_tfms=RandomResizedCrop(224, min_scale=0.5),
    batch_tfms=aug_transforms())

dls = bears.dataloaders(root)
dls.valid.show_batch(max_n=4, nrows=1) #Viewing data

Creating model and training

learner = cnn_learner(dls, resnet18, metrics = accuracy)
learner.fine_tune(2)
epoch train_loss valid_loss accuracy time
0 0.955830 0.151302 0.961111 00:09
epoch train_loss valid_loss accuracy time
0 0.198822 0.065566 0.988889 00:11
1 0.122963 0.038051 0.994444 00:11

Now lets export this model

We should export this model, in case we want to use it again.

learner.export(fname='bear.pkl')
path = Path()
path.ls(file_exts='.pkl')
(#2) [Path('bear.pkl'),Path('export.pkl')]

Lets test it with our images

learn_inf = load_learner(path/'bear.pkl')
uploader = widgets.FileUpload()
uploader
img = PILImage.create(uploader.data[0])
img.to_thumb(192)
learn_inf.predict(img) 
('Teddy', TensorImage(2), TensorImage([7.3293e-08, 9.1892e-06, 9.9999e-01]))

It got it correct!

Now lets try an image with more than 1 category

uploader = widgets.FileUpload()
uploader
img = PILImage.create(uploader.data[0])
img.to_thumb(192)
learn_inf.predict(img)
('Grizzly', TensorImage(1), TensorImage([2.9330e-01, 7.0659e-01, 1.1200e-04]))

It only got 1 correct - This is because our model isn't made to do multi-labels

Lets make it so our model can do multi-label

Remember to make a multi-label classifier out DataBlock needs to be adjusted: MultiCategoryBlock

def parent_label_multi(o):
    return [Path(o).parent.name]

bears2 = DataBlock(
    blocks=(ImageBlock, MultiCategoryBlock),  #Independent are images, dependent is the multiple labels
    get_items=get_image_files, 
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y= parent_label_multi,
    item_tfms=Resize(128))
    

dls2 = bears2.dataloaders(root)
dls2.show_batch()
dls2.vocab
['Black', 'Grizzly', 'Teddy']
learn2 = cnn_learner(dls2, resnet18, metrics=partial(accuracy_multi, thresh=0.2)) #lets just pick a tresh of .2
learn2.fine_tune(2)
epoch train_loss valid_loss accuracy_multi time
0 0.760363 0.213807 0.840741 00:04
epoch train_loss valid_loss accuracy_multi time
0 0.299894 0.143817 0.892593 00:04
1 0.211261 0.125826 0.911111 00:04

Lets test it with the same image again

uploader = widgets.FileUpload()
uploader
img = PILImage.create(uploader.data[0])
img.to_thumb(200)
learn2.predict(img)
((#2) ['Black','Grizzly'],
 TensorImage([ True,  True, False]),
 TensorImage([6.0354e-01, 9.9354e-01, 1.6160e-04]))

It got both of them right!

Lets test it with a random image

uploader = widgets.FileUpload()
uploader
img = PILImage.create(uploader.data[0])
img.to_thumb(200)
learn2.predict(img) #It works
((#0) [],
 TensorImage([False, False, False]),
 TensorImage([0.0265, 0.4757, 0.3557]))

Notice that there is no prediction! This is the other benifit of a multi-label classifier!

Now lets make this into an application/website

btn_upload = widgets.FileUpload()
out_pl = widgets.Output()
lbl_pred = widgets.Label()
btn_run = widgets.Button(description='Classify')

Method for what to do on click

def on_click_classify(change):
    img = PILImage.create(btn_upload.data[-1])
    out_pl.clear_output()
    with out_pl: display(img.to_thumb(128,128))
    pred,pred_idx,probs = learn_inf.predict(img)
    lbl_pred.value = f'Prediction: {pred}; Probability: {probs[pred_idx]:.04f}'
btn_run.on_click(on_click_classify)

VBox used to put all widgets togather

VBox([widgets.Label('Select your bear!'), 
      btn_upload, btn_run, out_pl, lbl_pred])