MABe 2022: Mouse-Triplets - Video Data

Unsupervised model - SimCLR - Mouse Video Data

Unsupervised model training using contrastive learning with modified SimCLR


Unsupervised model training using contrastive learning with modified SimCLR


How to use this notebook 📝

  1. Copy the notebook. This is a shared template and any edits you make here will not be saved. You should copy it into your own drive folder. For this, click the "File" menu (top-left), then "Save a Copy in Drive". You can edit your copy however you like.
  2. Link it to your AIcrowd account. In order to submit your predictions to AIcrowd, you need to provide your account's API key.

Problem Statement

Join the communty!
chat on Discord

Setup AIcrowd Utilities 🛠

In [ ]:
!pip install -U aicrowd-cli
%load_ext aicrowd.magic
Collecting aicrowd-cli
  Downloading aicrowd_cli-0.1.15-py3-none-any.whl (51 kB)
     |████████████████████████████████| 51 kB 3.6 MB/s 
Collecting requests-toolbelt<1,>=0.9.1
  Downloading requests_toolbelt-0.9.1-py2.py3-none-any.whl (54 kB)
     |████████████████████████████████| 54 kB 3.1 MB/s 
Requirement already satisfied: click<8,>=7.1.2 in /usr/local/lib/python3.7/dist-packages (from aicrowd-cli) (7.1.2)
Requirement already satisfied: tqdm<5,>=4.56.0 in /usr/local/lib/python3.7/dist-packages (from aicrowd-cli) (4.63.0)
Collecting GitPython==3.1.18
  Downloading GitPython-3.1.18-py3-none-any.whl (170 kB)
     |████████████████████████████████| 170 kB 30.5 MB/s 
Collecting requests<3,>=2.25.1
  Downloading requests-2.27.1-py2.py3-none-any.whl (63 kB)
     |████████████████████████████████| 63 kB 2.2 MB/s 
Collecting rich<11,>=10.0.0
  Downloading rich-10.16.2-py3-none-any.whl (214 kB)
     |████████████████████████████████| 214 kB 67.9 MB/s 
Collecting toml<1,>=0.10.2
  Downloading toml-0.10.2-py2.py3-none-any.whl (16 kB)
Collecting pyzmq==22.1.0
  Downloading pyzmq-22.1.0-cp37-cp37m-manylinux1_x86_64.whl (1.1 MB)
     |████████████████████████████████| 1.1 MB 60.1 MB/s 
Requirement already satisfied: semver<3,>=2.13.0 in /usr/local/lib/python3.7/dist-packages (from aicrowd-cli) (2.13.0)
Collecting python-slugify<6,>=5.0.0
  Downloading python_slugify-5.0.2-py2.py3-none-any.whl (6.7 kB)
Collecting gitdb<5,>=4.0.1
  Downloading gitdb-4.0.9-py3-none-any.whl (63 kB)
     |████████████████████████████████| 63 kB 2.2 MB/s 
Requirement already satisfied: typing-extensions>= in /usr/local/lib/python3.7/dist-packages (from GitPython==3.1.18->aicrowd-cli) (
Collecting smmap<6,>=3.0.1
  Downloading smmap-5.0.0-py3-none-any.whl (24 kB)
Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.7/dist-packages (from python-slugify<6,>=5.0.0->aicrowd-cli) (1.3)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.25.1->aicrowd-cli) (1.24.3)
Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.25.1->aicrowd-cli) (2.0.12)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.25.1->aicrowd-cli) (2.10)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.25.1->aicrowd-cli) (2021.10.8)
Collecting commonmark<0.10.0,>=0.9.0
  Downloading commonmark-0.9.1-py2.py3-none-any.whl (51 kB)
     |████████████████████████████████| 51 kB 1.4 MB/s 
Collecting colorama<0.5.0,>=0.4.0
  Downloading colorama-0.4.4-py2.py3-none-any.whl (16 kB)
Requirement already satisfied: pygments<3.0.0,>=2.6.0 in /usr/local/lib/python3.7/dist-packages (from rich<11,>=10.0.0->aicrowd-cli) (2.6.1)
Installing collected packages: smmap, requests, gitdb, commonmark, colorama, toml, rich, requests-toolbelt, pyzmq, python-slugify, GitPython, aicrowd-cli
  Attempting uninstall: requests
    Found existing installation: requests 2.23.0
    Uninstalling requests-2.23.0:
      Successfully uninstalled requests-2.23.0
  Attempting uninstall: pyzmq
    Found existing installation: pyzmq 22.3.0
    Uninstalling pyzmq-22.3.0:
      Successfully uninstalled pyzmq-22.3.0
  Attempting uninstall: python-slugify
    Found existing installation: python-slugify 6.1.1
    Uninstalling python-slugify-6.1.1:
      Successfully uninstalled python-slugify-6.1.1
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires requests~=2.23.0, but you have requests 2.27.1 which is incompatible.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.
Successfully installed GitPython-3.1.18 aicrowd-cli-0.1.15 colorama-0.4.4 commonmark-0.9.1 gitdb-4.0.9 python-slugify-5.0.2 pyzmq-22.1.0 requests-2.27.1 requests-toolbelt-0.9.1 rich-10.16.2 smmap-5.0.0 toml-0.10.2

Login to AIcrowd

In [ ]:
%aicrowd login
Please login here: https://api.aicrowd.com/auth/1DsVcwtAMiWpRGRf7kpwP7Xsckgr-WmCns_6Az5yXWU
API Key valid
Gitlab access token valid
Saved details successfully!

Install packages 🗃

Please add all pacakages installations in this section

In [ ]:
!pip install torch==1.10.2 torchvision==0.11.3 simclr
Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (1.10.0+cu111)
Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (0.11.1+cu111)
Collecting simclr
  Downloading simclr-1.0.2-py3-none-any.whl (21 kB)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch) (
Requirement already satisfied: pillow!=8.3.0,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision) (7.1.2)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torchvision) (1.21.5)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from simclr) (3.13)
Installing collected packages: simclr
Successfully installed simclr-1.0.2

Import necessary modules and packages 📚

In [ ]:
import os
import cv2
import numpy as np
from tqdm.auto import tqdm

import torch
import torchvision
import torchvision.transforms as T

from simclr import SimCLR
from simclr.modules import NT_Xent
from simclr.modules import LARS

Download and prepare the dataset 🔍

In [ ]:
aicrowd_challenge_name = "mabe-2022-mouse-triplets-video-data"
if not os.path.exists('data'):
datafolder = 'data/'

## If data is already downloaded and stored on google drive, skip the download and point to the prepared directory
# datafolder = '/content/drive/MyDrive/mabe-2022-mouse-video/data/'

video_folder = f'{datafolder}video_clips/'
In [ ]:
## The download might take a while, recommend to move to Google Drive if you want to run multiple times.
%aicrowd ds dl -c {aicrowd_challenge_name} -o data *.npy* # Download all files
# We'll download the 224x224 videos since they're fast on the dataloader, but you can use the full sized videos if you want
%aicrowd ds dl -c {aicrowd_challenge_name} -o data *resized_224* # Download all file
# %aicrowd ds dl -c {aicrowd_challenge_name} -o data *videos.zip* # Download the 512x512 videos
In [ ]:
!unzip -q data/submission_videos_resized_224.zip  -d {video_folder}
!unzip -q data/userTrain_videos_resized_224.zip -d {video_folder}

## Careful when running the below commands - For copying to Google Drive
# !rm data/submission_videos.zip data/userTrain_videos.zip 
# !cp -r data/ '/content/drive/MyDrive/mabe-2022-mouse-video/data/'

Train Unsupervised Baseline - SIMCLR 🏋️

  • We use Contrastive learning for the baseline for the MABe video datasets. The code uses SIMCLR (A Simple Framework for Contrastive Learning of Visual Representations) - https://arxiv.org/abs/2002.05709 - A popular and "simple" contrastive learning algorithm.

  • Some changes are made to SIMCLR use some specific ideas about the dataset. Namely frame stacking and cropping around the animal subjects.

  • We use a ResNet50 model with the pytorch simclr package (unofficial) https://github.com/sthalles/SimCLR to do unsupervised learning on the video data.

  • We also stack past and future frames with frame skipping to incorporate temporal information into each of the embeddings.

Dataloader 📚

The dataloader reads video files, seeks the required past and future frames with frame skipping.

Additionaly, we also crop the images to keep the animals in focus.

In [ ]:
class MouseVideoDataset(torch.utils.data.Dataset):
    Reads frames from video files
    def __init__(self, 
                 frame_size=(224, 224),
        Initializing the dataset with images and labels
        self.datafolder = datafolder
        self.transform = transform
        self.frame_number_map = frame_number_map
        self.num_prev_frames = num_prev_frames
        self.num_next_frames = num_next_frames
        self.frame_skip = frame_skip
        self.frame_size = frame_size
        self.keypoints = keypoints


    def set_transform(self, transform):
        self.transform = transform

    def _setup_frame_map(self):
        self._video_names = np.array(list(self.frame_number_map.keys()))
        # IMPORTANT: the frame number map should be sorted for self.get_video_name to work
        frame_nums = np.array([self.frame_number_map[k] for k in self._video_names])
        self._frame_numbers = frame_nums[:, 0] - 1 # start values
        assert np.all(np.diff(self._frame_numbers) > 0), "Frame number map is not sorted"

        self.length = frame_nums[-1, 1] # last value is the total number of frames

    def get_frame_info(self, global_index):
        """ Returns corresponding video name and frame number"""
        video_idx = np.searchsorted(self._frame_numbers, global_index) - 1
        frame_index = global_index - (self._frame_numbers[video_idx] + 1)
        return self._video_names[video_idx], frame_index
    def __len__(self):
        return self.length
    def __getitem__(self, idx):
        video_name, frame_index = self.get_frame_info(idx)

        video_path = os.path.join(self.datafolder, video_name + '.avi')
        nf = self.num_next_frames + self.num_prev_frames + 1
        frames_array = np.zeros((*self.frame_size, nf), dtype=np.float32)
        if not os.path.exists(video_path):
#             raise FileNotFoundError(video_path)
            if self.transform is not None:
                frames_array = self.transform(frames_array)
            return { "idx": idx,
                     "image": frames_array,
        cap = cv2.VideoCapture(video_path)
        num_video_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        for arridx, fnum in enumerate(range(frame_index - self.num_prev_frames * self.frame_skip,
                                            frame_index + self.num_next_frames * self.frame_skip + 1,
                                            self.frame_skip + 1)):
            if fnum < 0 or fnum >= num_video_frames:
            cap.set(cv2.CAP_PROP_POS_FRAMES, fnum)
            success, frame = cap.read()
            # print(fnum, frame_index, success)
            if success:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
                frames_array[:, :, arridx] = frame

        if video_name in self.keypoints['sequences']:
            bbox = self.keypoints['sequences'][video_name]['bbox']
            if bbox.shape[0] > frame_index:
                bbox = bbox[frame_index]
                frames_array = frames_array[bbox[0]:bbox[2], bbox[1]:bbox[3]] # Crop the image so random crop is more useful
        if self.transform is not None:
            frames_array = self.transform(frames_array)

        return { "idx": idx,
                 "image": frames_array,

Utilites - Optimizer, Transforms and Augmentations 🔧

In [ ]:
def load_optimizer(optimizer, epochs, weight_decay, batch_size, model):

    scheduler = None
    if optimizer == "Adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)  # TODO: LARS
    elif optimizer == "LARS":
        # optimized using LARS with linear learning rate scaling
        # (i.e. LearningRate = 0.3 × BatchSize/256) and weight decay of 10−6.
        learning_rate = 0.3 * batch_size / 256
        optimizer = LARS(
            exclude_from_weight_decay=["batch_normalization", "bias"],

        # "decay the learning rate with the cosine decay schedule without restarts"
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, epochs, eta_min=0, last_epoch=-1
        raise NotImplementedError

    return optimizer, scheduler

def save_model(epoch, model_path, model, optimizer):
    out = os.path.join(model_path, "checkpoint_{}.tar".format(epoch))

    # To save a DataParallel model generically, save the model.module.state_dict().
    # This way, you have the flexibility to load the model any way you want to any device you want.
    if isinstance(model, torch.nn.DataParallel):
        torch.save(model.module.state_dict(), out)
        torch.save(model.state_dict(), out)
In [ ]:
class TransformsSimCLR:
    def __init__(self, size, pretrained=True, n_channel=3, validation=False) -> None:
        self.train_transforms = T.Compose([
            T.RandomResizedCrop(size=size, scale=(0.25, 1.0)),
            # Taking the means of the normal distributions of the 3 channels
            # since we are moving to grayscale
            T.Normalize(mean=np.mean([0.485, 0.456, 0.406]).repeat(n_channel),
                            (np.array([0.229, 0.224, 0.225])**2).sum()/9).repeat(n_channel)
                        ) if pretrained is True else T.Lambda(lambda x: x)

        self.validation_transforms = T.Compose([
            # Taking the means of the normal distributions of the 3 channels
            # since we are moving to grayscale
            T.Normalize(mean=np.mean([0.485, 0.456, 0.406]).repeat(n_channel),
                            (np.array([0.229, 0.224, 0.225])**2).sum()/9).repeat(n_channel)
                        ) if pretrained is True else T.Lambda(lambda x: x)
        self.validation = validation

    def __call__(self, x):
        if not self.validation:
            return self.train_transforms(x), self.train_transforms(x)
            return self.validation_transforms(x)

Bounding box creation 📦

Since most of the frame is empty, it is important that the mouse triplets are cropped correctly when doing SimCLR augments. We use the keypoints to create rough bounding box coordinates around them.

Note that these bounding boxes are made in a simple fixed pixel size cropping mechanism, feel free to change the bounding box generation system.

In [ ]:
######## Prepare bounding boxes from keypoints ##########

# Preparing some bounding box information to be used for cropping frames during training
keypoints = np.load(os.path.join(datafolder, 'submission_keypoints.npy'), allow_pickle=True).item()

padbbox = 50
crop_size = 512
for sk in tqdm(keypoints['sequences'].keys()):
    kp = keypoints['sequences'][sk]['keypoints']
    bboxes = []
    for frame_idx in range(len(kp)):
        allcoords = np.int32(kp[frame_idx].reshape(-1, 2))
        minvals = max(np.min(allcoords[:, 0]) - padbbox, 0), max(np.min(allcoords[:, 1]) - padbbox, 0)
        maxvals = min(np.max(allcoords[:, 0]) + padbbox, crop_size), min(np.max(allcoords[:, 1]) + padbbox, crop_size)
        bbox = (*minvals, *maxvals)
        bbox = np.array(bbox)
        bbox = np.int32(bbox * 224 / 512)
    keypoints['sequences'][sk]['bbox'] = np.array(bboxes)

# Can save it you want and load later
# keypoints = np.save(os.path.join(datafolder, 'submission_keypoints_bbox.npy'), keypoints)

# keypoints = np.load(os.path.join(datafolder, 'submission_keypoints_bbox.npy'), allow_pickle=True).item()

Training ☑️

Below are hyperparamers you can play around with. The runs are pretty slow, so you can reduce the epochs and steps per epochs to find the parameters you want to use.

Note that we do not go over the entire dataset for each "epoch", because the whole dataset is huge.

This code will only use the submission videos for unsupervised training, but you can change it to use all the videos.

In [ ]:
################### CONFIG #########################
batch_size = 32
epochs = 10

# Stack frames with frame skip from the video sequences
IMG_SIZE = 224
n_channel = LEFT_WINDOW + RIGHT_WINDOW + 1

# Check batch size that fits when changing this
embedding_size = 128

# Full Dataset is huge, he're we limiting to steps per epoch
steps_per_epoch = 1000

videos_folder = os.path.join(datafolder, 'video_clips') # TODO: Change this to combined folder
frame_number_map = np.load(os.path.join(datafolder, 'frame_number_map.npy'), allow_pickle=True).item()

checkpoint_folder = "mouse_video_checkpoints/" # Can change this to a Google drive folder
if not os.path.exists(checkpoint_folder): 
In [ ]:
train_dataset = MouseVideoDataset(datafolder=videos_folder, 
                      frame_size=(224, 224),
                      transform=TransformsSimCLR(size=(IMG_SIZE, IMG_SIZE), 

train_loader = torch.utils.data.DataLoader(
In [ ]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

##################### MODEL ############################
def get_simclr_model():
    resnet_encoder = torchvision.models.resnet50(pretrained=IS_PRETRAINED)

    ## Experimental setup for multiplying the grayscale channel
    ## https://stackoverflow.com/a/54777347
    weight = resnet_encoder.conv1.weight.clone()
    resnet_encoder.conv1 = torch.nn.Conv2d(n_channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
    # normalize back by n_channels after tiling
    resnet_encoder.conv1.weight.data = weight.sum(dim=1, keepdim=True).tile(1, n_channel, 1, 1)/n_channel

    n_features = resnet_encoder.fc.in_features
    model = SimCLR(resnet_encoder, embedding_size, n_features)
    model = model.to(device)
    return model

model = get_simclr_model()

##################### UTILS ############################

optimizer_type = 'Adam'
weight_decay = 1e-6

optimizer, scheduler = load_optimizer(optimizer_type, epochs, weight_decay, batch_size, model)

world_size = 1
temperature = 0.5
criterion = NT_Xent(batch_size, temperature, world_size)
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
In [ ]:
# Basic Training loop
def train(epoch, train_loader, model, criterion, optimizer):
    loss_epoch = 0
    # tqdm_iter = tqdm(train_loader, total=len(train_loader)) # Total train loader is huge, he're we limiting to steps per epoch
    tqdm_iter = tqdm(train_loader, total=steps_per_epoch)

    tqdm_iter.set_description(f"Epoch {epoch}")
    for step, batch in enumerate(tqdm_iter):
        x_i = batch['image'][0].cuda(non_blocking=True)
        x_j = batch['image'][1].cuda(non_blocking=True)

        # positive pair, with encoding
        h_i, h_j, z_i, z_j = model(x_i, x_j)

        loss = criterion(z_i, z_j)



        loss_epoch += loss.item()
        if step >= steps_per_epoch:

    return loss_epoch
In [ ]:
# Baseline submission on the leaderboard is trained with 100 epochs, you can train according to your needs
for epoch in range(epochs):
    lr = optimizer.param_groups[0]['lr']
    loss_epoch = train(epoch, train_loader, model, criterion, optimizer)

    if scheduler:

    if (epoch % 3) == 0:
        save_model(epoch, checkpoint_folder, model, optimizer)

save_model(epochs, checkpoint_folder, model, optimizer)
In [ ]:
# Cleanup RAM
del model, optimizer
del train_loader, train_dataset

Predict Embeddings 🔮

Here we'll predict the outputs from the frames, this may take a long time.

In [ ]:
# Load latest model
model = get_simclr_model()
checkpoint_path = os.path.join(checkpoint_folder, 'checkpoint_100.tar')
model = model.to(device)
In [ ]:
prediction_dataset = MouseVideoDataset(datafolder=videos_folder, 
                                          frame_size=(224, 224),
                                          transform=TransformsSimCLR(size=(IMG_SIZE, IMG_SIZE), 


prediction_loader = torch.utils.data.DataLoader(
In [ ]:
sample_submission = np.load(datafolder + 'sample_submission.npy')
submission = np.empty((sample_submission.shape[0], embedding_size), dtype=np.float32)
idx = 0
# This may take quite long, since predicting on all frames
for data in tqdm(prediction_loader, total=len(prediction_loader)):
    with torch.no_grad():
        images = data['image'].to(device)
        output = model.projector(model.encoder(images))
        output = output.cpu().numpy()
        submission[idx:idx+len(output)] = output
        idx += len(output)

Submission 🚀

In [ ]:
print("Embedding shape:", submission.shape)

Validate the submission ✅

The submssion should follow these constraints:

  • It should be a numpy array
  • Embeddings is an 2D numpy array of dtype float32
  • The embedding size should't exceed 128
  • The frame number map matches the clip lengths
  • You can use the helper function below to check these
In [ ]:
def validate_submission(submission, frame_number_map):
    if not isinstance(submission, np.ndarray):
        print("Embeddings should be a numpy array")
        return False
    elif not len(submission.shape) == 2:
        print("Embeddings should be 2D array")
        return False
    elif not submission.shape[1] <= 128:
        print("Embeddings too large, max allowed is 128")
        return False
    elif not isinstance(submission[0, 0], np.float32):
        print(f"Embeddings are not float32")
        return False

    total_clip_length = frame_number_map[list(frame_number_map.keys())[-1]][1]
    if not len(submission) == total_clip_length:
        print(f"Emebddings length doesn't match submission clips total length")
        return False

    if not np.isfinite(submission).all():
        print(f"Emebddings contains NaN or infinity")
        return False

    print("All checks passed")
    return True
In [ ]:
validate_submission(submission, frame_number_map)
In [ ]:
np.save('submission_mouse_simclr.npy', submission)
In [ ]:
## Uploads may take time, you can also rund aicrowd-cli on your local machines with he prepared submission file
%aicrowd submission create --description "Mouse SimCLR Baseline" -c {aicrowd_challenge_name} -f submission_mouse_simclr.npy


22 days ago

Comment deleted by nilesh_arnaiya.

You must login before you can post a comment.