Tree Segmentation
[Getting Started Notebook] Trees Segmentation
A Getting Started notebook for Trees Segmentation Puzzle of BlitzX.
Starter Code for Trees Segmentation
What we are going to Learn¶
- Getting started Image Segmentation using PyTorch.
- Using models provided by segmentation_models.pytorch for the image segmentation.
- Training & Testing a Unet model with PyTorch
Note : Create a copy of the notebook and use the copy for submission. Go to File > Save a Copy in Drive to create a new copy
Setting up Environment¶
Downloading Dataset¶
So we will first need to download the python library by AIcrowd that will allow us to download the dataset by just inputting the API key.
In [ ]:
!pip install aicrowd-cli
%load_ext aicrowd.magic
In [ ]:
%aicrowd login
In [ ]:
# Downloading the Dataset
!rm -rf data
!mkdir data
%aicrowd ds dl -c tree-segmentation -o data
In [ ]:
!unzip data/train.zip -d data/train > /dev/null
!unzip data/test.zip -d data/test > /dev/null
Downloading & Importing Libraries¶
Here we are going to use segmentation_models.pytorch which is a really popular library providing a tons of different segmentation models for pytorch including basic unets to DeepLabV3!
Along with that, we will be also using library pytorch-argus to help in training the model.
In [ ]:
!pip install git+https://github.com/qubvel/segmentation_models.pytorch pytorch-argus
In [ ]:
# Pytorch
import torch
from torch import nn
import segmentation_models_pytorch as smp
import argus
from torch.utils.data import Dataset, DataLoader
# Reading Dataset, vis and miscellaneous
from PIL import Image
import matplotlib.pyplot as plt
import os
import numpy as np
from tqdm.notebook import tqdm
import cv2
from natsort import natsorted
Training phase ⚙️¶
Creating the Dataloader¶
Here, we are simply create a class for pytorch to load the dataset and then to put into the model
In [ ]:
class TreeSegmentationDataset(Dataset):
def __init__(self, img_directory=None, label_directory=None, train=True):
self.img_directory = img_directory
self.label_directory = label_directory
# If the image direcotry is valid
if img_directory != None:
self.img_list = natsorted(os.listdir(img_directory))
self.label_list = natsorted(os.listdir(label_directory))
self.train = train
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
# Reading the image
img = Image.open(os.path.join(self.img_directory, self.img_list[idx]))
if self.train == True:
# Readiding the mak image
mask = Image.open(os.path.join(self.label_directory, self.label_list[idx]))
img = np.array(img, dtype=np.float32)
mask = np.array(mask, dtype=np.float32)
# Change image channel ordering
img = np.moveaxis(img, -1, 0)
return img, mask
# If reading test dataset, only return image
else:
img = np.array(img, dtype=np.float32)
img = np.moveaxis(img, -1, 0)
return img
In [ ]:
# Creating the training dataset
train_dataset = TreeSegmentationDataset(img_directory="data/train/image", label_directory="data/train/segmentation")
train_loader = DataLoader(train_dataset, batch_size=4, num_workers=1, shuffle=False, drop_last=True)
In [ ]:
# Reading the image and corrosponding segmentation
image_batch, segmentation_batch = next(iter(train_loader))
image_batch.shape, segmentation_batch.shape
Out[ ]:
Visualizing Dataset¶
In [ ]:
plt.rcParams["figure.figsize"] = (30,5)
# Going through each image and segmentation
for image, segmentation in zip(image_batch, segmentation_batch):
# Change the channel ordering
image = np.moveaxis(image.numpy()/255, 0, -1)
# Showing the image
plt.figure()
plt.subplot(1,2,1)
plt.imshow(image, 'gray', interpolation='none')
plt.subplot(1,2,2)
plt.imshow(image, 'gray', interpolation='none')
plt.imshow(segmentation, 'jet', interpolation='none', alpha=0.7)
plt.show()