Loading
Feedback

Global Wheat Challenge 2021

Tutorial with Pytorch, Torchvision and Pytorch Lightning !

A short intro to train your first detector !

By  etienne_david

Hello and welcome to the Global Wheat Challenge 2021 !

If you are new to object detection, or want to get some insights on the dataset and format, please take a look on this short tutorial that covers all aspects of the competition !



Global Wheat Competition 2021 - Starting notebook

  • The goal of the notebook is to help you to train your first model and submit !
  • We will use Pytorch / Torchvision / Pytorch Lightning to go through your first model !
  • Before starting, Please check in Edit / Prefrences if GPU is selected


illustration_gwc_2021.png

Download Aicrowd-cli 📚

It helps to download dataset and make submission directly via the notebook.

In [ ]:
!pip install aicrowd-cli

Download Data

The first step is to download out train test data. We will be training a model on the train data and make predictions on test data. We submit our predictions.

Please enter your API Key from here.

In [ ]:
API_KEY = "" 
!aicrowd login --api-key $API_KEY
In [ ]:
!aicrowd dataset download --challenge global-wheat-challenge-2021
In [ ]:
!unzip train.zip
In [ ]:
!unzip test.zip

Download Necessary Packages

In [ ]:
!pip install albumentations==0.4.6
!pip install pytorch_lightning
In [ ]:
# Common imports
import math
import sys
import time
from tqdm.notebook import tqdm
import numpy as np
from pathlib import Path
import pandas as pd
import random
import cv2
import matplotlib.pyplot as plt

# Torch imports 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torchvision.transforms as transforms
from torchvision.ops.boxes import box_iou
from torchvision.models.detection._utils import Matcher
from torchvision.ops import nms, box_convert
import torch.nn.functional as F

# Albumentations is used for the Data Augmentation
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Pytorch import
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer, seed_everything

Dataloader and Dataset

We will write two extension of the Dataset class:

  • One that will read the train.csv to get images, boxes and domain
  • One that will read the submission.csv to retrieve images to predict and their associated domains.

💻 Labels

  • All boxes are contained in a csv with three columns image_name, BoxesString and domain
  • image_name is the name of the image, without the suffix. All images have a .png extension
  • BoxesString</span></span> is a string containing all predicted boxes with the format [x_min,y_min, x_max,y_max]. To concatenate a list of boxes into a PredString, please concatenate all list of coordinates with one space (" ") and all boxes with one semi-column ";". If there is no box, BoxesString is equal to "no_box".
  • domain</span></span> give the domain for each image 
In [ ]:
class WheatDataset(Dataset):
    """A dataset example for GWC 2021 competition."""

    def __init__(self, csv_file,root_dir , transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional data augmentation to be applied
                on a sample.
        """

        self.root_dir = Path(root_dir)
        annotations = pd.read_csv(csv_file)

        self.image_list = annotations["image_name"].values
        self.domain_list = annotations["domain"].values
        self.boxes = [self.decodeString(item) for item in annotations["BoxesString"]]
        
        self.transform = transform

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        
        imgp = str(self.root_dir / (self.image_list[idx]+".png"))
        domain = self.domain_list[idx] # We don't use the domain information but you could !
        bboxes = self.boxes[idx]
        img = cv2.imread(imgp)
        image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Opencv open images in BGR mode by default

        if self.transform:
            transformed = self.transform(image=image,bboxes=bboxes,class_labels=["wheat_head"]*len(bboxes)) #Albumentations can transform images and boxes
            image = transformed["image"]
            bboxes = transformed["bboxes"]

        if len(bboxes) > 0:
          bboxes = torch.stack([torch.tensor(item) for item in bboxes])
        else:
          bboxes = torch.zeros((0,4))
        return image, bboxes, domain


    def decodeString(self,BoxesString):
      """
      Small method to decode the BoxesString
      """
      if BoxesString == "no_box":
          return np.zeros((0,4))
      else:
          try:
              boxes =  np.array([np.array([int(i) for i in box.split(" ")])
                              for box in BoxesString.split(";")])
              return boxes
          except:
              print(BoxesString)
              print("Submission is not well formatted. empty boxes will be returned")
              return np.zeros((0,4))
In [ ]:
transform = A.Compose([
    A.LongestMaxSize(1024,p=1),
    A.PadIfNeeded(min_height=1024,min_width=1024,p=1,border_mode=1,value=0),
    A.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225],
    ),
    ToTensorV2(),

],bbox_params=A.BboxParams(format='pascal_voc',label_fields=['class_labels'],min_area=20))
In [ ]:
dataset = WheatDataset("train.csv","train",transform=None)

Sanity check of the Dataset

Here we randomly load some images to check our dataloard

In [ ]:
hlines = []
for i in range(10):
  vlines = []
  for j in range(5):
    img , bboxes , metadata = dataset[random.randint(0,len(dataset))]

    for (x,y,xx,yy) in bboxes:
      cv2.rectangle(img,(int(x.item()),int(y.item())),(int(xx.item()),int(yy.item())),(255,255,0),5)

    vlines.append(img)
  hlines.append(cv2.vconcat(vlines))

final_img = cv2.hconcat(hlines)

fig ,ax = plt.subplots(1,1,figsize=(20,20))
plt.imshow(final_img)

Set up the Dataloader

The Dataloader is the utility that will load each images and form batch

In [ ]:
def collate_fn(batch):
    """
    Since each image may have a different number of objects, we need a collate function (to be passed to the DataLoader).

    :param batch: an iterable of N sets from __getitem__()
    :return: a tensor of images, lists of varying-size tensors of bounding boxes, labels, and difficulties
    """

    images = list()
    targets=list()
    metadatas = list()

    for i, t, m in batch:
        images.append(i)
        targets.append(t)
        metadatas.append(m)
    images = torch.stack(images, dim=0)

    return images, targets,metadatas


dataset = WheatDataset("train.csv","train",transform=transform)
train_size = int(len(dataset)*0.9)
val_size = len(dataset)-train_size
train_set, val_set = torch.utils.data.random_split(dataset, [train_size, val_size]) # We sample 10% of the images as a validation dataset
train_dataloader = torch.utils.data.DataLoader(train_set,
                                          batch_size=4,
                                          shuffle=False,
                                               collate_fn=collate_fn)
val_dataloader = torch.utils.data.DataLoader(val_set,
                                          batch_size=4,
                                          shuffle=False,
                                          collate_fn=collate_fn)
In [ ]:
batch = next(iter(train_dataloader)) # We test is the dataloader is working

Train Faster-RCNN with Pytorch_lightning and torchvision

We propose to finetune Faster-RCNN with a ResNet 50 FPN from torchvision thanks to Pytorch Lightning

In [ ]:
seed_everything(25081992)

class FasterRCNN(LightningModule):
    def __init__(self,n_classes):
        super().__init__()
        self.detector = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
        in_features = self.detector.roi_heads.box_predictor.cls_score.in_features
        self.detector.roi_heads.box_predictor = FastRCNNPredictor(in_features, n_classes)
        self.lr = 1e-4

    def forward(self, imgs,targets=None):
      # Torchvision FasterRCNN returns the loss during training 
      # and the boxes during eval
      self.detector.eval()
      return self.detector(imgs)

    def configure_optimizers(self):
      optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
      return optimizer

    def training_step(self, batch, batch_idx):

      imgs = batch[0]

      targets = []
      for boxes in batch[1]:
        target= {}
        target["boxes"] = boxes.cuda()
        target["labels"] = torch.ones(len(target["boxes"])).long().cuda()
        targets.append(target)

      # fasterrcnn takes both images and targets for training, returns
      loss_dict = self.detector(imgs, targets)
      loss = sum(loss for loss in loss_dict.values())
      return {"loss": loss, "log": loss_dict}

    def validation_step(self, batch, batch_idx):
      img, boxes, metadata = batch
      pred_boxes =self.forward(img)

      self.val_loss = torch.mean(torch.stack([self.accuracy(b,pb["boxes"],iou_threshold=0.5) for b,pb in zip(boxes,pred_boxes)]))
      return self.val_loss

    def test_step(self, batch, batch_idx):
      img, boxes, metadata = batch
      pred_boxes = self.forward(img) # in validation, faster rcnn return the boxes
      self.test_loss = torch.mean(torch.stack([self.accuracy(b,pb["boxes"],iou_threshold=0.5) for b,pb in zip(boxes,pred_boxes)]))
      return self.test_loss

    def accuracy(self, src_boxes,pred_boxes ,  iou_threshold = 1.):
      """
      The accuracy method is not the one used in the evaluator but very similar
      """
      total_gt = len(src_boxes)
      total_pred = len(pred_boxes)
      if total_gt > 0 and total_pred > 0:


        # Define the matcher and distance matrix based on iou
        matcher = Matcher(iou_threshold,iou_threshold,allow_low_quality_matches=False) 
        match_quality_matrix = box_iou(src_boxes,pred_boxes)

        results = matcher(match_quality_matrix)
        
        true_positive = torch.count_nonzero(results.unique() != -1)
        matched_elements = results[results > -1]
        
        #in Matcher, a pred element can be matched only twice 
        false_positive = torch.count_nonzero(results == -1) + ( len(matched_elements) - len(matched_elements.unique()))
        false_negative = total_gt - true_positive

            
        return  true_positive / ( true_positive + false_positive + false_negative )

      elif total_gt == 0:
          if total_pred > 0:
              return torch.tensor(0.).cuda()
          else:
              return torch.tensor(1.).cuda()
      elif total_gt > 0 and total_pred == 0:
          return torch.tensor(0.).cuda()
In [ ]:
detector = FasterRCNN(2)

Training time !

In [ ]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
early_stop_callback = EarlyStopping(
   monitor='val_accuracy',
   min_delta=0.00,
   patience=3,
   verbose=False,
   mode='max'
)


# run learning rate finder, results override hparams.learning_rate
trainer = Trainer( gpus=1, progress_bar_refresh_rate=1, max_epochs=1,deterministic=False)

# call tune to find the lr
# trainer.tune(classifier,train_dataloader,val_dataloader) # we already did it once = 1e-4
trainer.fit(detector,train_dataloader,val_dataloader)
In [ ]:
# You can save your model for ensembling with : torch.save(detector,"path/to/model.pth")

Sanity check before submission

We need another Dataset for prediction, that does not transform data nor retrieve labels

In [ ]:
class WheatDatasetPredict(Dataset):
    """A dataset example for GWC 2021 competition."""

    def __init__(self, csv_file,root_dir):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """

        self.root_dir = Path(root_dir)
        annotations = pd.read_csv(csv_file)

        self.image_list = annotations["image_name"].values
        self.domain_list = annotations["domain"].values
        
        self.transform = A.Compose([
          A.Normalize(
          mean=[0.485, 0.456, 0.406],
          std=[0.229, 0.224, 0.225],
          ),
          ToTensorV2(),
          ])

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        
        imgp = str(self.root_dir / (self.image_list[idx]+".png"))
        domain = self.domain_list[idx]
        img = cv2.imread(imgp)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        if self.transform:
            transformed = self.transform(image=img)
            image = transformed["image"]
        return image, img, self.image_list[idx],domain
In [ ]:
detector.freeze()
test_dataset = WheatDatasetPredict("submission.csv","test") #the domain information is included in the submission file


hlines = []
for i in range(4):
  vlines = []
  for j in range(4):
    idx = random.randint(0, len(test_dataset))

    norm_img, img , img_name, domain = test_dataset[idx] #norm_image is used for prediction and img for visualisation


    predictions = detector(norm_img.unsqueeze(dim=0))

    pboxes = predictions[0]["boxes"]
    scores = predictions[0]["scores"]
    pboxes = pboxes[scores > 0.5]


    for (x,y,xx,yy) in pboxes:
      cv2.rectangle(img,(int(x.item()),int(y.item())),(int(xx.item()),int(yy.item())),(0,255,255),5)

    vlines.append(img)
  hlines.append(cv2.vconcat(vlines))

final_img = cv2.hconcat(hlines)

fig ,ax = plt.subplots(1,1,figsize=(20,20))
plt.imshow(final_img)

Writing submission file

In [ ]:
from tqdm import tqdm_notebook as tqdm
detector.detector.eval()
test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=1,
                                          shuffle=False)


def encode_boxes(boxes):

  if len(boxes) >0:
    boxes = [" ".join([str(int(i)) for i in item]) for item in boxes]
    BoxesString = ";".join(boxes)
  else:
    BoxesString = "no_box"
  return BoxesString

results = []
for batch in tqdm(test_dataloader):
  norm_img, img , img_names , metadata = batch

  predictions = detector.detector(norm_img)

  for img_name, pred, domain in zip(img_names,predictions,metadata):
    boxes = pred["boxes"]
    scores = pred["scores"]
    boxes = boxes[scores > 0.5].cpu().numpy()
    PredString = encode_boxes(boxes)
    results.append([img_name,PredString,domain.item()])
In [ ]:
results = pd.DataFrame(results,columns =["image_name","PredString","domain"])
results.to_csv("submission_final.csv")
In [ ]:
results

Making Direct Submission thought Aicrowd CLI

In [ ]:
!aicrowd submission create -c global-wheat-challenge-2021 -f submission_final.csv
In [ ]:

↕️  Read More

Liked by  

Comments

You must login before you can post a comment.