Loading

Research Paper Classification

Solution for submission 148639

A detailed solution for submission 148639 submitted for challenge Research Paper Classification

salim_shaikh
In [1]:
# Install huggingface library
!pip install --upgrade transformers

!pip install --upgrade simpletransformers

# AI crowd CLI for data download
!pip install aicrowd-cli
Collecting transformers
  Downloading https://files.pythonhosted.org/packages/b5/d5/c6c23ad75491467a9a84e526ef2364e523d45e2b0fae28a7cbe8689e7e84/transformers-4.8.1-py3-none-any.whl (2.5MB)
     |████████████████████████████████| 2.5MB 4.3MB/s 
Requirement already satisfied, skipping upgrade: pyyaml in /usr/local/lib/python3.7/dist-packages (from transformers) (3.13)
Collecting huggingface-hub==0.0.12
  Downloading https://files.pythonhosted.org/packages/2f/ee/97e253668fda9b17e968b3f97b2f8e53aa0127e8807d24a547687423fe0b/huggingface_hub-0.0.12-py3-none-any.whl
Collecting sacremoses
  Downloading https://files.pythonhosted.org/packages/75/ee/67241dc87f266093c533a2d4d3d69438e57d7a90abb216fa076e7d475d4a/sacremoses-0.0.45-py3-none-any.whl (895kB)
     |████████████████████████████████| 901kB 52.8MB/s 
Requirement already satisfied, skipping upgrade: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.41.1)
Requirement already satisfied, skipping upgrade: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.0.12)
Requirement already satisfied, skipping upgrade: importlib-metadata; python_version < "3.8" in /usr/local/lib/python3.7/dist-packages (from transformers) (4.5.0)
Collecting tokenizers<0.11,>=0.10.1
  Downloading https://files.pythonhosted.org/packages/d4/e2/df3543e8ffdab68f5acc73f613de9c2b155ac47f162e725dcac87c521c11/tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3MB)
     |████████████████████████████████| 3.3MB 20.8MB/s 
Requirement already satisfied, skipping upgrade: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)
Requirement already satisfied, skipping upgrade: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)
Requirement already satisfied, skipping upgrade: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.19.5)
Requirement already satisfied, skipping upgrade: packaging in /usr/local/lib/python3.7/dist-packages (from transformers) (20.9)
Requirement already satisfied, skipping upgrade: typing-extensions in /usr/local/lib/python3.7/dist-packages (from huggingface-hub==0.0.12->transformers) (3.7.4.3)
Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.15.0)
Requirement already satisfied, skipping upgrade: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.0.1)
Requirement already satisfied, skipping upgrade: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (7.1.2)
Requirement already satisfied, skipping upgrade: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < "3.8"->transformers) (3.4.1)
Requirement already satisfied, skipping upgrade: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)
Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)
Requirement already satisfied, skipping upgrade: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)
Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2021.5.30)
Requirement already satisfied, skipping upgrade: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->transformers) (2.4.7)
Installing collected packages: huggingface-hub, sacremoses, tokenizers, transformers
Successfully installed huggingface-hub-0.0.12 sacremoses-0.0.45 tokenizers-0.10.3 transformers-4.8.1
Collecting simpletransformers
  Downloading https://files.pythonhosted.org/packages/24/fc/3da256b01385dcecd52f79c11cc493a4cfbef8e6d1a6a62d98e8536c2993/simpletransformers-0.61.9-py3-none-any.whl (220kB)
     |████████████████████████████████| 225kB 4.2MB/s 
Requirement already satisfied, skipping upgrade: scikit-learn in /usr/local/lib/python3.7/dist-packages (from simpletransformers) (0.22.2.post1)
Collecting wandb>=0.10.32
  Downloading https://files.pythonhosted.org/packages/e0/b4/9d92953d8cddc8450c859be12e3dbdd4c7754fb8def94c28b3b351c6ee4e/wandb-0.10.32-py2.py3-none-any.whl (1.8MB)
     |████████████████████████████████| 1.8MB 56.6MB/s 
Collecting seqeval
  Downloading https://files.pythonhosted.org/packages/9d/2d/233c79d5b4e5ab1dbf111242299153f3caddddbb691219f363ad55ce783d/seqeval-1.2.2.tar.gz (43kB)
     |████████████████████████████████| 51kB 8.4MB/s 
Requirement already satisfied, skipping upgrade: requests in /usr/local/lib/python3.7/dist-packages (from simpletransformers) (2.23.0)
Requirement already satisfied, skipping upgrade: transformers>=4.2.0 in /usr/local/lib/python3.7/dist-packages (from simpletransformers) (4.8.1)
Requirement already satisfied, skipping upgrade: pandas in /usr/local/lib/python3.7/dist-packages (from simpletransformers) (1.1.5)
Collecting streamlit
  Downloading https://files.pythonhosted.org/packages/d7/0c/469ee9160ad7bc064eb498fa95aefd4e96b593ce0d53fb07ff217badff47/streamlit-0.83.0-py2.py3-none-any.whl (7.7MB)
     |████████████████████████████████| 7.8MB 50.1MB/s 
Collecting tqdm>=4.47.0
  Downloading https://files.pythonhosted.org/packages/b4/20/9f1e974bb4761128fc0d0a32813eaa92827309b1756c4b892d28adfb4415/tqdm-4.61.1-py2.py3-none-any.whl (75kB)
     |████████████████████████████████| 81kB 12.9MB/s 
Collecting sentencepiece
  Downloading https://files.pythonhosted.org/packages/ac/aa/1437691b0c7c83086ebb79ce2da16e00bef024f24fec2a5161c35476f499/sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2MB)
     |████████████████████████████████| 1.2MB 67.7MB/s 
Requirement already satisfied, skipping upgrade: scipy in /usr/local/lib/python3.7/dist-packages (from simpletransformers) (1.4.1)
Collecting datasets
  Downloading https://files.pythonhosted.org/packages/08/a2/d4e1024c891506e1cee8f9d719d20831bac31cb5b7416983c4d2f65a6287/datasets-1.8.0-py3-none-any.whl (237kB)
     |████████████████████████████████| 245kB 65.2MB/s 
Requirement already satisfied, skipping upgrade: numpy in /usr/local/lib/python3.7/dist-packages (from simpletransformers) (1.19.5)
Requirement already satisfied, skipping upgrade: tokenizers in /usr/local/lib/python3.7/dist-packages (from simpletransformers) (0.10.3)
Requirement already satisfied, skipping upgrade: regex in /usr/local/lib/python3.7/dist-packages (from simpletransformers) (2019.12.20)
Collecting tensorboardx
  Downloading https://files.pythonhosted.org/packages/42/36/2b147652c40c3a858efa0afbf7b8236fae968e88ff530511a4cfa299a506/tensorboardX-2.3-py2.py3-none-any.whl (124kB)
     |████████████████████████████████| 133kB 60.3MB/s 
Requirement already satisfied, skipping upgrade: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->simpletransformers) (1.0.1)
Requirement already satisfied, skipping upgrade: Click!=8.0.0,>=7.0 in /usr/local/lib/python3.7/dist-packages (from wandb>=0.10.32->simpletransformers) (7.1.2)
Collecting GitPython>=1.0.0
  Downloading https://files.pythonhosted.org/packages/bc/91/b38c4fabb6e5092ab23492ded4f318ab7299b19263272b703478038c0fbc/GitPython-3.1.18-py3-none-any.whl (170kB)
     |████████████████████████████████| 174kB 63.1MB/s 
Collecting docker-pycreds>=0.4.0
  Downloading https://files.pythonhosted.org/packages/f5/e8/f6bd1eee09314e7e6dee49cbe2c5e22314ccdb38db16c9fc72d2fa80d054/docker_pycreds-0.4.0-py2.py3-none-any.whl
Requirement already satisfied, skipping upgrade: promise<3,>=2.0 in /usr/local/lib/python3.7/dist-packages (from wandb>=0.10.32->simpletransformers) (2.3)
Requirement already satisfied, skipping upgrade: protobuf>=3.12.0 in /usr/local/lib/python3.7/dist-packages (from wandb>=0.10.32->simpletransformers) (3.12.4)
Collecting pathtools
  Downloading https://files.pythonhosted.org/packages/e7/7f/470d6fcdf23f9f3518f6b0b76be9df16dcc8630ad409947f8be2eb0ed13a/pathtools-0.1.2.tar.gz
Collecting configparser>=3.8.1
  Downloading https://files.pythonhosted.org/packages/fd/01/ff260a18caaf4457eb028c96eeb405c4a230ca06c8ec9c1379f813caa52e/configparser-5.0.2-py3-none-any.whl
Requirement already satisfied, skipping upgrade: python-dateutil>=2.6.1 in /usr/local/lib/python3.7/dist-packages (from wandb>=0.10.32->simpletransformers) (2.8.1)
Collecting shortuuid>=0.5.0
  Downloading https://files.pythonhosted.org/packages/25/a6/2ecc1daa6a304e7f1b216f0896b26156b78e7c38e1211e9b798b4716c53d/shortuuid-1.0.1-py3-none-any.whl
Requirement already satisfied, skipping upgrade: PyYAML in /usr/local/lib/python3.7/dist-packages (from wandb>=0.10.32->simpletransformers) (3.13)
Collecting sentry-sdk>=0.4.0
  Downloading https://files.pythonhosted.org/packages/1c/4a/a54b254f67d8f4052338d54ebe90126f200693440a93ef76d254d581e3ec/sentry_sdk-1.1.0-py2.py3-none-any.whl (131kB)
     |████████████████████████████████| 133kB 55.3MB/s 
Requirement already satisfied, skipping upgrade: six>=1.13.0 in /usr/local/lib/python3.7/dist-packages (from wandb>=0.10.32->simpletransformers) (1.15.0)
Requirement already satisfied, skipping upgrade: psutil>=5.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb>=0.10.32->simpletransformers) (5.4.8)
Collecting subprocess32>=3.5.3
  Downloading https://files.pythonhosted.org/packages/32/c8/564be4d12629b912ea431f1a50eb8b3b9d00f1a0b1ceff17f266be190007/subprocess32-3.5.4.tar.gz (97kB)
     |████████████████████████████████| 102kB 15.5MB/s 
Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->simpletransformers) (1.24.3)
Requirement already satisfied, skipping upgrade: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->simpletransformers) (2.10)
Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->simpletransformers) (2021.5.30)
Requirement already satisfied, skipping upgrade: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->simpletransformers) (3.0.4)
Requirement already satisfied, skipping upgrade: importlib-metadata; python_version < "3.8" in /usr/local/lib/python3.7/dist-packages (from transformers>=4.2.0->simpletransformers) (4.5.0)
Requirement already satisfied, skipping upgrade: packaging in /usr/local/lib/python3.7/dist-packages (from transformers>=4.2.0->simpletransformers) (20.9)
Requirement already satisfied, skipping upgrade: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers>=4.2.0->simpletransformers) (0.0.45)
Requirement already satisfied, skipping upgrade: filelock in /usr/local/lib/python3.7/dist-packages (from transformers>=4.2.0->simpletransformers) (3.0.12)
Requirement already satisfied, skipping upgrade: huggingface-hub==0.0.12 in /usr/local/lib/python3.7/dist-packages (from transformers>=4.2.0->simpletransformers) (0.0.12)
Requirement already satisfied, skipping upgrade: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->simpletransformers) (2018.9)
Collecting blinker
  Downloading https://files.pythonhosted.org/packages/1b/51/e2a9f3b757eb802f61dc1f2b09c8c99f6eb01cf06416c0671253536517b6/blinker-1.4.tar.gz (111kB)
     |████████████████████████████████| 112kB 68.8MB/s 
Collecting validators
  Downloading https://files.pythonhosted.org/packages/db/2f/7fed3ee94ad665ad2c1de87f858f10a7785251ff75b4fd47987888d07ef1/validators-0.18.2-py3-none-any.whl
Requirement already satisfied, skipping upgrade: altair>=3.2.0 in /usr/local/lib/python3.7/dist-packages (from streamlit->simpletransformers) (4.1.0)
Requirement already satisfied, skipping upgrade: toml in /usr/local/lib/python3.7/dist-packages (from streamlit->simpletransformers) (0.10.2)
Collecting watchdog; platform_system != "Darwin"
  Downloading https://files.pythonhosted.org/packages/6b/d1/b0a1e69af06d2f6f47a11238ca115667d858cbb30baf6b6df03f1b874163/watchdog-2.1.3-py3-none-manylinux2014_x86_64.whl (75kB)
     |████████████████████████████████| 81kB 12.8MB/s 
Requirement already satisfied, skipping upgrade: astor in /usr/local/lib/python3.7/dist-packages (from streamlit->simpletransformers) (0.8.1)
Collecting base58
  Downloading https://files.pythonhosted.org/packages/b8/a1/d9f565e9910c09fd325dc638765e8843a19fa696275c16cc08cf3b0a3c25/base58-2.1.0-py3-none-any.whl
Requirement already satisfied, skipping upgrade: cachetools>=4.0 in /usr/local/lib/python3.7/dist-packages (from streamlit->simpletransformers) (4.2.2)
Requirement already satisfied, skipping upgrade: tzlocal in /usr/local/lib/python3.7/dist-packages (from streamlit->simpletransformers) (1.5.1)
Requirement already satisfied, skipping upgrade: tornado>=5.0 in /usr/local/lib/python3.7/dist-packages (from streamlit->simpletransformers) (5.1.1)
Requirement already satisfied, skipping upgrade: pyarrow; python_version < "3.9" in /usr/local/lib/python3.7/dist-packages (from streamlit->simpletransformers) (3.0.0)
Collecting pydeck>=0.1.dev5
  Downloading https://files.pythonhosted.org/packages/d6/bc/f0e44828e4290367c869591d50d3671a4d0ee94926da6cb734b7b200308c/pydeck-0.6.2-py2.py3-none-any.whl (4.2MB)
     |████████████████████████████████| 4.2MB 71.9MB/s 
Requirement already satisfied, skipping upgrade: pillow>=6.2.0 in /usr/local/lib/python3.7/dist-packages (from streamlit->simpletransformers) (7.1.2)
Requirement already satisfied, skipping upgrade: dill in /usr/local/lib/python3.7/dist-packages (from datasets->simpletransformers) (0.3.4)
Requirement already satisfied, skipping upgrade: multiprocess in /usr/local/lib/python3.7/dist-packages (from datasets->simpletransformers) (0.70.12.2)
Collecting fsspec
  Downloading https://files.pythonhosted.org/packages/0e/3a/666e63625a19883ae8e1674099e631f9737bd5478c4790e5ad49c5ac5261/fsspec-2021.6.1-py3-none-any.whl (115kB)
     |████████████████████████████████| 122kB 59.8MB/s 
Collecting xxhash
  Downloading https://files.pythonhosted.org/packages/7d/4f/0a862cad26aa2ed7a7cd87178cbbfa824fc1383e472d63596a0d018374e7/xxhash-2.0.2-cp37-cp37m-manylinux2010_x86_64.whl (243kB)
     |████████████████████████████████| 245kB 69.9MB/s 
Collecting gitdb<5,>=4.0.1
  Downloading https://files.pythonhosted.org/packages/ea/e8/f414d1a4f0bbc668ed441f74f44c116d9816833a48bf81d22b697090dba8/gitdb-4.0.7-py3-none-any.whl (63kB)
     |████████████████████████████████| 71kB 12.6MB/s 
Requirement already satisfied, skipping upgrade: typing-extensions>=3.7.4.0; python_version < "3.8" in /usr/local/lib/python3.7/dist-packages (from GitPython>=1.0.0->wandb>=0.10.32->simpletransformers) (3.7.4.3)
Requirement already satisfied, skipping upgrade: setuptools in /usr/local/lib/python3.7/dist-packages (from protobuf>=3.12.0->wandb>=0.10.32->simpletransformers) (57.0.0)
Requirement already satisfied, skipping upgrade: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < "3.8"->transformers>=4.2.0->simpletransformers) (3.4.1)
Requirement already satisfied, skipping upgrade: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->transformers>=4.2.0->simpletransformers) (2.4.7)
Requirement already satisfied, skipping upgrade: decorator>=3.4.0 in /usr/local/lib/python3.7/dist-packages (from validators->streamlit->simpletransformers) (4.4.2)
Requirement already satisfied, skipping upgrade: toolz in /usr/local/lib/python3.7/dist-packages (from altair>=3.2.0->streamlit->simpletransformers) (0.11.1)
Requirement already satisfied, skipping upgrade: jsonschema in /usr/local/lib/python3.7/dist-packages (from altair>=3.2.0->streamlit->simpletransformers) (2.6.0)
Requirement already satisfied, skipping upgrade: entrypoints in /usr/local/lib/python3.7/dist-packages (from altair>=3.2.0->streamlit->simpletransformers) (0.3)
Requirement already satisfied, skipping upgrade: jinja2 in /usr/local/lib/python3.7/dist-packages (from altair>=3.2.0->streamlit->simpletransformers) (2.11.3)
Requirement already satisfied, skipping upgrade: traitlets>=4.3.2 in /usr/local/lib/python3.7/dist-packages (from pydeck>=0.1.dev5->streamlit->simpletransformers) (5.0.5)
Collecting ipykernel>=5.1.2; python_version >= "3.4"
  Downloading https://files.pythonhosted.org/packages/90/6d/6c8fe4b658f77947d4244ce81f60230c4c8d1dc1a21ae83e63b269339178/ipykernel-5.5.5-py3-none-any.whl (120kB)
     |████████████████████████████████| 122kB 71.2MB/s 
Requirement already satisfied, skipping upgrade: ipywidgets>=7.0.0 in /usr/local/lib/python3.7/dist-packages (from pydeck>=0.1.dev5->streamlit->simpletransformers) (7.6.3)
Collecting smmap<5,>=3.0.1
  Downloading https://files.pythonhosted.org/packages/68/ee/d540eb5e5996eb81c26ceffac6ee49041d473bc5125f2aa995cf51ec1cf1/smmap-4.0.0-py2.py3-none-any.whl
Requirement already satisfied, skipping upgrade: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->altair>=3.2.0->streamlit->simpletransformers) (2.0.1)
Requirement already satisfied, skipping upgrade: ipython-genutils in /usr/local/lib/python3.7/dist-packages (from traitlets>=4.3.2->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.2.0)
Requirement already satisfied, skipping upgrade: jupyter-client in /usr/local/lib/python3.7/dist-packages (from ipykernel>=5.1.2; python_version >= "3.4"->pydeck>=0.1.dev5->streamlit->simpletransformers) (5.3.5)
Requirement already satisfied, skipping upgrade: ipython>=5.0.0 in /usr/local/lib/python3.7/dist-packages (from ipykernel>=5.1.2; python_version >= "3.4"->pydeck>=0.1.dev5->streamlit->simpletransformers) (5.5.0)
Requirement already satisfied, skipping upgrade: nbformat>=4.2.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (5.1.3)
Requirement already satisfied, skipping upgrade: jupyterlab-widgets>=1.0.0; python_version >= "3.6" in /usr/local/lib/python3.7/dist-packages (from ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (1.0.0)
Requirement already satisfied, skipping upgrade: widgetsnbextension~=3.5.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (3.5.1)
Requirement already satisfied, skipping upgrade: jupyter-core>=4.6.0 in /usr/local/lib/python3.7/dist-packages (from jupyter-client->ipykernel>=5.1.2; python_version >= "3.4"->pydeck>=0.1.dev5->streamlit->simpletransformers) (4.7.1)
Requirement already satisfied, skipping upgrade: pyzmq>=13 in /usr/local/lib/python3.7/dist-packages (from jupyter-client->ipykernel>=5.1.2; python_version >= "3.4"->pydeck>=0.1.dev5->streamlit->simpletransformers) (22.1.0)
Requirement already satisfied, skipping upgrade: simplegeneric>0.8 in /usr/local/lib/python3.7/dist-packages (from ipython>=5.0.0->ipykernel>=5.1.2; python_version >= "3.4"->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.8.1)
Requirement already satisfied, skipping upgrade: pygments in /usr/local/lib/python3.7/dist-packages (from ipython>=5.0.0->ipykernel>=5.1.2; python_version >= "3.4"->pydeck>=0.1.dev5->streamlit->simpletransformers) (2.6.1)
Requirement already satisfied, skipping upgrade: pexpect; sys_platform != "win32" in /usr/local/lib/python3.7/dist-packages (from ipython>=5.0.0->ipykernel>=5.1.2; python_version >= "3.4"->pydeck>=0.1.dev5->streamlit->simpletransformers) (4.8.0)
Requirement already satisfied, skipping upgrade: prompt-toolkit<2.0.0,>=1.0.4 in /usr/local/lib/python3.7/dist-packages (from ipython>=5.0.0->ipykernel>=5.1.2; python_version >= "3.4"->pydeck>=0.1.dev5->streamlit->simpletransformers) (1.0.18)
Requirement already satisfied, skipping upgrade: pickleshare in /usr/local/lib/python3.7/dist-packages (from ipython>=5.0.0->ipykernel>=5.1.2; python_version >= "3.4"->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.7.5)
Requirement already satisfied, skipping upgrade: notebook>=4.4.1 in /usr/local/lib/python3.7/dist-packages (from widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (5.3.1)
Requirement already satisfied, skipping upgrade: ptyprocess>=0.5 in /usr/local/lib/python3.7/dist-packages (from pexpect; sys_platform != "win32"->ipython>=5.0.0->ipykernel>=5.1.2; python_version >= "3.4"->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.7.0)
Requirement already satisfied, skipping upgrade: wcwidth in /usr/local/lib/python3.7/dist-packages (from prompt-toolkit<2.0.0,>=1.0.4->ipython>=5.0.0->ipykernel>=5.1.2; python_version >= "3.4"->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.2.5)
Requirement already satisfied, skipping upgrade: Send2Trash in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (1.5.0)
Requirement already satisfied, skipping upgrade: terminado>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.10.1)
Requirement already satisfied, skipping upgrade: nbconvert in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (5.6.1)
Requirement already satisfied, skipping upgrade: pandocfilters>=1.4.1 in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (1.4.3)
Requirement already satisfied, skipping upgrade: defusedxml in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.7.1)
Requirement already satisfied, skipping upgrade: testpath in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.5.0)
Requirement already satisfied, skipping upgrade: bleach in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (3.3.0)
Requirement already satisfied, skipping upgrade: mistune<2,>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.8.4)
Requirement already satisfied, skipping upgrade: webencodings in /usr/local/lib/python3.7/dist-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.5.1)
Building wheels for collected packages: seqeval, pathtools, subprocess32, blinker
  Building wheel for seqeval (setup.py) ... done
  Created wheel for seqeval: filename=seqeval-1.2.2-cp37-none-any.whl size=16184 sha256=2f9ed69a34e462c21b6f72ee6ff8c253f3461418bce8ff0fa984c79c7204aa2c
  Stored in directory: /root/.cache/pip/wheels/52/df/1b/45d75646c37428f7e626214704a0e35bd3cfc32eda37e59e5f
  Building wheel for pathtools (setup.py) ... done
  Created wheel for pathtools: filename=pathtools-0.1.2-cp37-none-any.whl size=8807 sha256=e18e8aaf178b99e01ecd8cd0f60fa39a0abc0d7182edad30f7c1758fe4faba7c
  Stored in directory: /root/.cache/pip/wheels/0b/04/79/c3b0c3a0266a3cb4376da31e5bfe8bba0c489246968a68e843
  Building wheel for subprocess32 (setup.py) ... done
  Created wheel for subprocess32: filename=subprocess32-3.5.4-cp37-none-any.whl size=6502 sha256=08632eea5a1c2ffa1e050b82d5190895a3012bed3e08294c2c1e1cc4ab70a9a4
  Stored in directory: /root/.cache/pip/wheels/68/39/1a/5e402bdfdf004af1786c8b853fd92f8c4a04f22aad179654d1
  Building wheel for blinker (setup.py) ... done
  Created wheel for blinker: filename=blinker-1.4-cp37-none-any.whl size=13476 sha256=4eca13ba94a0da361785fc56ec8ebb30c884da57bf2fc379df1aa604f894129a
  Stored in directory: /root/.cache/pip/wheels/92/a0/00/8690a57883956a301d91cf4ec999cc0b258b01e3f548f86e89
Successfully built seqeval pathtools subprocess32 blinker
ERROR: google-colab 1.0.0 has requirement ipykernel~=4.10, but you'll have ipykernel 5.5.5 which is incompatible.
ERROR: datasets 1.8.0 has requirement tqdm<4.50.0,>=4.27, but you'll have tqdm 4.61.1 which is incompatible.
Installing collected packages: smmap, gitdb, GitPython, docker-pycreds, pathtools, configparser, shortuuid, sentry-sdk, subprocess32, wandb, seqeval, blinker, validators, watchdog, base58, ipykernel, pydeck, streamlit, tqdm, sentencepiece, fsspec, xxhash, datasets, tensorboardx, simpletransformers
  Found existing installation: ipykernel 4.10.1
    Uninstalling ipykernel-4.10.1:
      Successfully uninstalled ipykernel-4.10.1
  Found existing installation: tqdm 4.41.1
    Uninstalling tqdm-4.41.1:
      Successfully uninstalled tqdm-4.41.1
Successfully installed GitPython-3.1.18 base58-2.1.0 blinker-1.4 configparser-5.0.2 datasets-1.8.0 docker-pycreds-0.4.0 fsspec-2021.6.1 gitdb-4.0.7 ipykernel-5.5.5 pathtools-0.1.2 pydeck-0.6.2 sentencepiece-0.1.96 sentry-sdk-1.1.0 seqeval-1.2.2 shortuuid-1.0.1 simpletransformers-0.61.9 smmap-4.0.0 streamlit-0.83.0 subprocess32-3.5.4 tensorboardx-2.3 tqdm-4.61.1 validators-0.18.2 wandb-0.10.32 watchdog-2.1.3 xxhash-2.0.2
Collecting aicrowd-cli
  Downloading https://files.pythonhosted.org/packages/1f/57/59b5a00c6e90c9cc028b3da9dff90e242ad2847e735b1a0e81a21c616e27/aicrowd_cli-0.1.7-py3-none-any.whl (49kB)
     |████████████████████████████████| 51kB 2.9MB/s 
Requirement already satisfied: tqdm<5,>=4.56.0 in /usr/local/lib/python3.7/dist-packages (from aicrowd-cli) (4.61.1)
Requirement already satisfied: gitpython<4,>=3.1.12 in /usr/local/lib/python3.7/dist-packages (from aicrowd-cli) (3.1.18)
Collecting requests<3,>=2.25.1
  Downloading https://files.pythonhosted.org/packages/29/c1/24814557f1d22c56d50280771a17307e6bf87b70727d975fd6b2ce6b014a/requests-2.25.1-py2.py3-none-any.whl (61kB)
     |████████████████████████████████| 61kB 8.0MB/s 
Collecting requests-toolbelt<1,>=0.9.1
  Downloading https://files.pythonhosted.org/packages/60/ef/7681134338fc097acef8d9b2f8abe0458e4d87559c689a8c306d0957ece5/requests_toolbelt-0.9.1-py2.py3-none-any.whl (54kB)
     |████████████████████████████████| 61kB 10.1MB/s 
Requirement already satisfied: click<8,>=7.1.2 in /usr/local/lib/python3.7/dist-packages (from aicrowd-cli) (7.1.2)
Collecting rich<11,>=10.0.0
  Downloading https://files.pythonhosted.org/packages/69/a1/660d718e61d4c64fb8f1ef7b4aaf6db7a48a2b720cfac2991f06561d9a6c/rich-10.4.0-py3-none-any.whl (206kB)
     |████████████████████████████████| 215kB 23.3MB/s 
Requirement already satisfied: toml<1,>=0.10.2 in /usr/local/lib/python3.7/dist-packages (from aicrowd-cli) (0.10.2)
Requirement already satisfied: typing-extensions>=3.7.4.0; python_version < "3.8" in /usr/local/lib/python3.7/dist-packages (from gitpython<4,>=3.1.12->aicrowd-cli) (3.7.4.3)
Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.7/dist-packages (from gitpython<4,>=3.1.12->aicrowd-cli) (4.0.7)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.25.1->aicrowd-cli) (2021.5.30)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.25.1->aicrowd-cli) (2.10)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.25.1->aicrowd-cli) (1.24.3)
Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.25.1->aicrowd-cli) (3.0.4)
Collecting colorama<0.5.0,>=0.4.0
  Downloading https://files.pythonhosted.org/packages/44/98/5b86278fbbf250d239ae0ecb724f8572af1c91f4a11edf4d36a206189440/colorama-0.4.4-py2.py3-none-any.whl
Collecting commonmark<0.10.0,>=0.9.0
  Downloading https://files.pythonhosted.org/packages/b1/92/dfd892312d822f36c55366118b95d914e5f16de11044a27cf10a7d71bbbf/commonmark-0.9.1-py2.py3-none-any.whl (51kB)
     |████████████████████████████████| 51kB 9.1MB/s 
Requirement already satisfied: pygments<3.0.0,>=2.6.0 in /usr/local/lib/python3.7/dist-packages (from rich<11,>=10.0.0->aicrowd-cli) (2.6.1)
Requirement already satisfied: smmap<5,>=3.0.1 in /usr/local/lib/python3.7/dist-packages (from gitdb<5,>=4.0.1->gitpython<4,>=3.1.12->aicrowd-cli) (4.0.0)
ERROR: google-colab 1.0.0 has requirement ipykernel~=4.10, but you'll have ipykernel 5.5.5 which is incompatible.
ERROR: google-colab 1.0.0 has requirement requests~=2.23.0, but you'll have requests 2.25.1 which is incompatible.
ERROR: datasets 1.8.0 has requirement tqdm<4.50.0,>=4.27, but you'll have tqdm 4.61.1 which is incompatible.
ERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible.
Installing collected packages: requests, requests-toolbelt, colorama, commonmark, rich, aicrowd-cli
  Found existing installation: requests 2.23.0
    Uninstalling requests-2.23.0:
      Successfully uninstalled requests-2.23.0
Successfully installed aicrowd-cli-0.1.7 colorama-0.4.4 commonmark-0.9.1 requests-2.25.1 requests-toolbelt-0.9.1 rich-10.4.0
In [2]:
# Downloading the Dataset
!mkdir data
API Key valid
Saved API Key successfully!
val.csv:   0% 0.00/883k [00:00<?, ?B/s]
train.csv:   0% 0.00/8.77M [00:00<?, ?B/s]

val.csv: 100% 883k/883k [00:00<00:00, 1.22MB/s]


test.csv: 100% 3.01M/3.01M [00:00<00:00, 3.09MB/s]

train.csv: 100% 8.77M/8.77M [00:01<00:00, 7.09MB/s]
In [1]:
import os
import re
import random
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import BertTokenizer
from transformers import BertModel
from tqdm import tqdm

%matplotlib inline
In [7]:
train_dataset = pd.read_csv("data/train.csv")
validation_dataset = pd.read_csv("data/val.csv")[1:]
train_dataset=pd.concat([train_dataset,validation_dataset],axis=0)
test_dataset = pd.read_csv("data/test.csv")
X_train = train_dataset.text.values
y_train = train_dataset.label.values
X_val = validation_dataset.text.values
y_val = validation_dataset.label.values
In [8]:
if torch.cuda.is_available():       
    device = torch.device("cuda")
    print(f'There are {torch.cuda.device_count()} GPU(s) available.')
    print('Device name:', torch.cuda.get_device_name(0))
There are 1 GPU(s) available.
Device name: Tesla P100-PCIE-16GB
In [9]:
def text_preprocessing(text):
    """
    - Remove entity mentions (eg. '@united')
    - Correct errors (eg. '&amp;' to '&')
    @param    text (str): a string to be processed.
    @return   text (Str): the processed string.
    """
    # Remove '@name'
    text = text.lower()

    return text
In [10]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, )

# Create a function to tokenize a set of texts
def preprocessing_for_bert(data):
    """Perform required preprocessing steps for pretrained BERT.
    @param    data (np.array): Array of texts to be processed.
    @return   input_ids (torch.Tensor): Tensor of token ids to be fed to a model.
    @return   attention_masks (torch.Tensor): Tensor of indices specifying which
                  tokens should be attended to by the model.
    """
    # Create empty lists to store outputs
    input_ids = []
    attention_masks = []

    # For every sentence...
    for sent in data:
        # `encode_plus` will:
        #    (1) Tokenize the sentence
        #    (2) Add the `[CLS]` and `[SEP]` token to the start and end
        #    (3) Truncate/Pad sentence to max length
        #    (4) Map tokens to their IDs
        #    (5) Create attention mask
        #    (6) Return a dictionary of outputs
        encoded_sent = tokenizer.encode_plus(
            text=text_preprocessing(sent),  # Preprocess sentence
            add_special_tokens=True,        # Add `[CLS]` and `[SEP]`
            max_length=MAX_LEN,                  # Max length to truncate/pad
            pad_to_max_length=True,         # Pad sentence to max length
            #return_tensors='pt',           # Return PyTorch tensor
            return_attention_mask=True,      # Return attention mask
            truncation=True)
        
        # Add the outputs to the lists
        input_ids.append(encoded_sent.get('input_ids'))
        attention_masks.append(encoded_sent.get('attention_mask'))

    # Convert lists to tensors
    input_ids = torch.tensor(input_ids)
    attention_masks = torch.tensor(attention_masks)

    return input_ids, attention_masks
In [11]:
# Concatenate train data and test data
all_data = np.concatenate([X_train, X_val])

# Encode our concatenated data
encoded_data = [tokenizer.encode(sent, add_special_tokens=True) for sent in all_data]

# Find the maximum length
max_len = max([len(sent) for sent in encoded_data])
print('Max length: ', max_len)
Max length:  112
In [12]:
# Specify `MAX_LEN`
MAX_LEN = 100

# Print sentence 0 and its encoded token ids
token_ids = list(preprocessing_for_bert([X_train[0]])[0].squeeze().numpy())
print('Original: ', X_train[0])
print('Token IDs: ', token_ids)
Original:  we propose deep network models and learning algorithms for learning binary hash codes given image representations under both unsupervised and supervised manners . the novelty of our network design is that we constrain one hidden layer to directly output the binary codes . resulting optimizations involving these binary, independence, and balance constraints are difficult to solve .
Token IDs:  [101, 2057, 16599, 2784, 2897, 4275, 1998, 4083, 13792, 2005, 4083, 12441, 23325, 9537, 2445, 3746, 15066, 2104, 2119, 4895, 6342, 4842, 11365, 2098, 1998, 13588, 14632, 1012, 1996, 21160, 1997, 2256, 2897, 2640, 2003, 2008, 2057, 9530, 20528, 2378, 2028, 5023, 6741, 2000, 3495, 6434, 1996, 12441, 9537, 1012, 4525, 20600, 2015, 5994, 2122, 12441, 1010, 4336, 1010, 1998, 5703, 14679, 2024, 3697, 2000, 9611, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
/usr/local/lib/python3.7/dist-packages/transformers/tokenization_utils_base.py:2132: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).
  FutureWarning,
In [13]:
# Run function `preprocessing_for_bert` on the train set and the validation set
print('Tokenizing data...')
train_inputs, train_masks = preprocessing_for_bert(X_train)
val_inputs, val_masks = preprocessing_for_bert(X_val)

# Convert other data types to torch.Tensor
train_labels = torch.tensor(y_train)
val_labels = torch.tensor(y_val)
Tokenizing data...
/usr/local/lib/python3.7/dist-packages/transformers/tokenization_utils_base.py:2132: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).
  FutureWarning,
In [14]:
%%time

# Create the BertClassfier class
class BertClassifier(nn.Module):
    """Bert Model for Classification Tasks.
    """
    def __init__(self, freeze_bert=False):
        """
        @param    bert: a BertModel object
        @param    classifier: a torch.nn.Module classifier
        @param    freeze_bert (bool): Set `False` to fine-tune the BERT model
        """
        super(BertClassifier, self).__init__()
        # Specify hidden size of BERT, hidden size of our classifier, and number of labels
        D_in, H, D_out = 768, 16, 4

        # Instantiate BERT model
        self.bert = BertModel.from_pretrained('bert-base-uncased')

        # Instantiate an one-layer feed-forward classifier
        self.classifier = nn.Sequential(
            nn.Linear(D_in, H),
            nn.ReLU(),
            #nn.Dropout(0.5),
            nn.Linear(H, D_out)
        )

        # Freeze the BERT model
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False
        
    def forward(self, input_ids, attention_mask):
        """
        Feed input to BERT and the classifier to compute logits.
        @param    input_ids (torch.Tensor): an input tensor with shape (batch_size,
                      max_length)
        @param    attention_mask (torch.Tensor): a tensor that hold attention mask
                      information with shape (batch_size, max_length)
        @return   logits (torch.Tensor): an output tensor with shape (batch_size,
                      num_labels)
        """
        # Feed input to BERT
        outputs = self.bert(input_ids=input_ids,
                            attention_mask=attention_mask)
        
        # Extract the last hidden state of the token `[CLS]` for classification task
        last_hidden_state_cls = outputs[0][:, 0, :]

        # Feed input to classifier to compute logits
        logits = self.classifier(last_hidden_state_cls)

        return logits
    
def initialize_model(epochs=4):
    """Initialize the Bert Classifier, the optimizer and the learning rate scheduler.
    """
    # Instantiate Bert Classifier
    bert_classifier = BertClassifier(freeze_bert=False)

    # Tell PyTorch to run the model on GPU
    bert_classifier.to(device)

    # Create the optimizer
    optimizer = AdamW(bert_classifier.parameters(),
                      lr=3e-5,    # Default learning rate
                      eps=1e-8    # Default epsilon value
                      )

    # Total number of training steps
    total_steps = len(train_dataloader) * epochs

    # Set up the learning rate scheduler
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=0, # Default value
                                                num_training_steps=total_steps)
    return bert_classifier, optimizer, scheduler
CPU times: user 25 µs, sys: 0 ns, total: 25 µs
Wall time: 27.9 µs
In [15]:
# Specify loss function
loss_fn = nn.CrossEntropyLoss()

def set_seed(seed_value=42):
    """Set seed for reproducibility.
    """
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)

def train(model, train_dataloader, val_dataloader=None, epochs=4, evaluation=False):
    """Train the BertClassifier model.
    """
    # Start training loop
    print("Start training...\n")
    for epoch_i in range(epochs):
        # =======================================
        #               Training
        # =======================================
        # Print the header of the result table
        print(f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10} | {'Val Acc':^9} | {'Elapsed':^9}")
        print("-"*70)

        # Measure the elapsed time of each epoch
        t0_epoch, t0_batch = time.time(), time.time()

        # Reset tracking variables at the beginning of each epoch
        total_loss, batch_loss, batch_counts = 0, 0, 0

        # Put the model into the training mode
        model.train()

        # For each batch of training data...
        for step, batch in enumerate(train_dataloader):
            batch_counts +=1
            # Load batch to GPU
            b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)

            # Zero out any previously calculated gradients
            model.zero_grad()

            # Perform a forward pass. This will return logits.
            logits = model(b_input_ids, b_attn_mask)

            # Compute loss and accumulate the loss values
            loss = loss_fn(logits, b_labels)
            batch_loss += loss.item()
            total_loss += loss.item()

            # Perform a backward pass to calculate gradients
            loss.backward()

            # Clip the norm of the gradients to 1.0 to prevent "exploding gradients"
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            # Update parameters and the learning rate
            optimizer.step()
            scheduler.step()

            # Print the loss values and time elapsed for every 20 batches
            if (step % 20 == 0 and step != 0) or (step == len(train_dataloader) - 1):
                # Calculate time elapsed for 20 batches
                time_elapsed = time.time() - t0_batch

                # Print training results
                print(f"{epoch_i + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {'-':^10} | {'-':^9} | {time_elapsed:^9.2f}")

                # Reset batch tracking variables
                batch_loss, batch_counts = 0, 0
                t0_batch = time.time()

        # Calculate the average loss over the entire training data
        avg_train_loss = total_loss / len(train_dataloader)

        print("-"*70)
        # =======================================
        #               Evaluation
        # =======================================
        if evaluation == True:
            # After the completion of each training epoch, measure the model's performance
            # on our validation set.
            val_loss, val_accuracy = evaluate(model, val_dataloader)

            # Print performance over the entire training data
            time_elapsed = time.time() - t0_epoch
            
            print(f"{epoch_i + 1:^7} | {'-':^7} | {avg_train_loss:^12.6f} | {val_loss:^10.6f} | {val_accuracy:^9.2f} | {time_elapsed:^9.2f}")
            print("-"*70)
        print("\n")
    
    print("Training complete!")


def evaluate(model, val_dataloader):
    """After the completion of each training epoch, measure the model's performance
    on our validation set.
    """
    # Put the model into the evaluation mode. The dropout layers are disabled during
    # the test time.
    model.eval()

    # Tracking variables
    val_accuracy = []
    val_loss = []

    # For each batch in our validation set...
    for batch in val_dataloader:
        # Load batch to GPU
        b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)

        # Compute logits
        with torch.no_grad():
            logits = model(b_input_ids, b_attn_mask)

        # Compute loss
        loss = loss_fn(logits, b_labels)
        val_loss.append(loss.item())

        # Get the predictions
        preds = torch.argmax(logits, dim=1).flatten()

        # Calculate the accuracy rate
        accuracy = (preds == b_labels).cpu().numpy().mean() * 100
        val_accuracy.append(accuracy)

    # Compute the average accuracy and loss over the validation set.
    val_loss = np.mean(val_loss)
    val_accuracy = np.mean(val_accuracy)

    return val_loss, val_accuracy

def bert_predict(model, test_dataloader):
    """Perform a forward pass on the trained BERT model to predict probabilities
    on the test set.
    """
    # Put the model into the evaluation mode. The dropout layers are disabled during
    # the test time.
    model.eval()

    all_logits = []

    # For each batch in our test set...
    for batch in test_dataloader:
        # Load batch to GPU
        b_input_ids, b_attn_mask = tuple(t.to(device) for t in batch)[:2]

        # Compute logits
        with torch.no_grad():
            logits = model(b_input_ids, b_attn_mask)
        all_logits.append(logits)
    
    # Concatenate logits from each batch
    all_logits = torch.cat(all_logits, dim=0)

    # Apply softmax to calculate probabilities
    probs = F.softmax(all_logits, dim=1).cpu().numpy()

    return probs
In [16]:
# For fine-tuning BERT, the authors recommend a batch size of 16 or 32.
batch_size = 8

# Create the DataLoader for our training set
train_data = TensorDataset(train_inputs, train_masks, train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

# Create the DataLoader for our validation set
val_data = TensorDataset(val_inputs, val_masks, val_labels)
val_sampler = SequentialSampler(val_data)
val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=batch_size)

# Concatenate the train set and the validation set
full_train_data = torch.utils.data.ConcatDataset([train_data, val_data])
full_train_sampler = RandomSampler(full_train_data)
full_train_dataloader = DataLoader(full_train_data, sampler=full_train_sampler, batch_size=8)

# Train the Bert Classifier on the entire training data
set_seed(42)
bert_classifier, optimizer, scheduler = initialize_model(epochs=3)
train(bert_classifier, train_dataloader, epochs=3)
# train(bert_classifier, full_train_dataloader, epochs=1)
evaluate(bert_classifier, val_dataloader)
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Start training...

 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------
   1    |   20    |   1.212518   |     -      |     -     |   2.46   
   1    |   40    |   1.026087   |     -      |     -     |   2.21   
   1    |   60    |   0.838594   |     -      |     -     |   2.22   
   1    |   80    |   0.785460   |     -      |     -     |   2.22   
   1    |   100   |   0.781694   |     -      |     -     |   2.22   
   1    |   120   |   0.577379   |     -      |     -     |   2.22   
   1    |   140   |   0.659050   |     -      |     -     |   2.22   
   1    |   160   |   0.566129   |     -      |     -     |   2.21   
   1    |   180   |   0.680715   |     -      |     -     |   2.22   
   1    |   200   |   0.697724   |     -      |     -     |   2.21   
   1    |   220   |   0.589925   |     -      |     -     |   2.21   
   1    |   240   |   0.629778   |     -      |     -     |   2.22   
   1    |   260   |   0.632385   |     -      |     -     |   2.22   
   1    |   280   |   0.520268   |     -      |     -     |   2.22   
   1    |   300   |   0.619638   |     -      |     -     |   2.21   
   1    |   320   |   0.460602   |     -      |     -     |   2.22   
   1    |   340   |   0.505591   |     -      |     -     |   2.22   
   1    |   360   |   0.695262   |     -      |     -     |   2.22   
   1    |   380   |   0.409775   |     -      |     -     |   2.21   
   1    |   400   |   0.612260   |     -      |     -     |   2.21   
   1    |   420   |   0.428344   |     -      |     -     |   2.21   
   1    |   440   |   0.618245   |     -      |     -     |   2.21   
   1    |   460   |   0.637948   |     -      |     -     |   2.21   
   1    |   480   |   0.425736   |     -      |     -     |   2.21   
   1    |   500   |   0.672390   |     -      |     -     |   2.22   
   1    |   520   |   0.557015   |     -      |     -     |   2.21   
   1    |   540   |   0.374424   |     -      |     -     |   2.22   
   1    |   560   |   0.492142   |     -      |     -     |   2.21   
   1    |   580   |   0.523808   |     -      |     -     |   2.21   
   1    |   600   |   0.512978   |     -      |     -     |   2.22   
   1    |   620   |   0.627485   |     -      |     -     |   2.22   
   1    |   640   |   0.479528   |     -      |     -     |   2.21   
   1    |   660   |   0.473359   |     -      |     -     |   2.21   
   1    |   680   |   0.601399   |     -      |     -     |   2.22   
   1    |   700   |   0.562857   |     -      |     -     |   2.22   
   1    |   720   |   0.464541   |     -      |     -     |   2.22   
   1    |   740   |   0.584000   |     -      |     -     |   2.22   
   1    |   760   |   0.440258   |     -      |     -     |   2.22   
   1    |   780   |   0.461804   |     -      |     -     |   2.22   
   1    |   800   |   0.435884   |     -      |     -     |   2.22   
   1    |   820   |   0.409103   |     -      |     -     |   2.21   
   1    |   840   |   0.508785   |     -      |     -     |   2.22   
   1    |   860   |   0.389992   |     -      |     -     |   2.22   
   1    |   880   |   0.530578   |     -      |     -     |   2.21   
   1    |   900   |   0.473513   |     -      |     -     |   2.21   
   1    |   920   |   0.434002   |     -      |     -     |   2.22   
   1    |   940   |   0.367459   |     -      |     -     |   2.21   
   1    |   960   |   0.484282   |     -      |     -     |   2.22   
   1    |   980   |   0.469113   |     -      |     -     |   2.21   
   1    |  1000   |   0.449016   |     -      |     -     |   2.22   
   1    |  1020   |   0.399781   |     -      |     -     |   2.21   
   1    |  1040   |   0.527049   |     -      |     -     |   2.22   
   1    |  1060   |   0.415336   |     -      |     -     |   2.21   
   1    |  1080   |   0.478728   |     -      |     -     |   2.23   
   1    |  1100   |   0.564300   |     -      |     -     |   2.21   
   1    |  1120   |   0.407980   |     -      |     -     |   2.21   
   1    |  1140   |   0.493881   |     -      |     -     |   2.21   
   1    |  1160   |   0.506344   |     -      |     -     |   2.22   
   1    |  1180   |   0.394473   |     -      |     -     |   2.22   
   1    |  1200   |   0.597872   |     -      |     -     |   2.21   
   1    |  1220   |   0.495467   |     -      |     -     |   2.22   
   1    |  1240   |   0.336521   |     -      |     -     |   2.21   
   1    |  1260   |   0.444600   |     -      |     -     |   2.21   
   1    |  1280   |   0.427214   |     -      |     -     |   2.22   
   1    |  1300   |   0.714254   |     -      |     -     |   2.22   
   1    |  1320   |   0.514492   |     -      |     -     |   2.22   
   1    |  1340   |   0.448602   |     -      |     -     |   2.22   
   1    |  1360   |   0.484858   |     -      |     -     |   2.22   
   1    |  1380   |   0.526166   |     -      |     -     |   2.22   
   1    |  1400   |   0.478206   |     -      |     -     |   2.22   
   1    |  1420   |   0.404973   |     -      |     -     |   2.22   
   1    |  1440   |   0.507713   |     -      |     -     |   2.22   
   1    |  1460   |   0.539945   |     -      |     -     |   2.22   
   1    |  1480   |   0.564581   |     -      |     -     |   2.22   
   1    |  1500   |   0.510978   |     -      |     -     |   2.21   
   1    |  1520   |   0.504789   |     -      |     -     |   2.22   
   1    |  1540   |   0.385453   |     -      |     -     |   2.21   
   1    |  1560   |   0.578877   |     -      |     -     |   2.22   
   1    |  1580   |   0.297245   |     -      |     -     |   2.21   
   1    |  1600   |   0.664166   |     -      |     -     |   2.22   
   1    |  1620   |   0.311165   |     -      |     -     |   2.21   
   1    |  1640   |   0.537473   |     -      |     -     |   2.22   
   1    |  1660   |   0.417328   |     -      |     -     |   2.22   
   1    |  1680   |   0.491010   |     -      |     -     |   2.22   
   1    |  1700   |   0.468689   |     -      |     -     |   2.22   
   1    |  1720   |   0.528900   |     -      |     -     |   2.22   
   1    |  1740   |   0.464334   |     -      |     -     |   2.21   
   1    |  1760   |   0.390161   |     -      |     -     |   2.22   
   1    |  1780   |   0.517798   |     -      |     -     |   2.22   
   1    |  1800   |   0.473682   |     -      |     -     |   2.22   
   1    |  1820   |   0.363510   |     -      |     -     |   2.22   
   1    |  1840   |   0.415706   |     -      |     -     |   2.22   
   1    |  1860   |   0.445254   |     -      |     -     |   2.22   
   1    |  1880   |   0.487103   |     -      |     -     |   2.22   
   1    |  1900   |   0.288645   |     -      |     -     |   2.22   
   1    |  1920   |   0.317319   |     -      |     -     |   2.21   
   1    |  1940   |   0.370192   |     -      |     -     |   2.21   
   1    |  1960   |   0.512695   |     -      |     -     |   2.22   
   1    |  1980   |   0.513689   |     -      |     -     |   2.22   
   1    |  2000   |   0.355159   |     -      |     -     |   2.21   
   1    |  2020   |   0.340034   |     -      |     -     |   2.22   
   1    |  2040   |   0.368006   |     -      |     -     |   2.22   
   1    |  2060   |   0.461745   |     -      |     -     |   2.22   
   1    |  2080   |   0.501412   |     -      |     -     |   2.22   
   1    |  2100   |   0.448082   |     -      |     -     |   2.22   
   1    |  2120   |   0.478640   |     -      |     -     |   2.22   
   1    |  2140   |   0.455679   |     -      |     -     |   2.22   
   1    |  2160   |   0.388774   |     -      |     -     |   2.22   
   1    |  2180   |   0.474274   |     -      |     -     |   2.22   
   1    |  2200   |   0.380503   |     -      |     -     |   2.22   
   1    |  2220   |   0.503582   |     -      |     -     |   2.22   
   1    |  2240   |   0.401817   |     -      |     -     |   2.22   
   1    |  2260   |   0.528856   |     -      |     -     |   2.23   
   1    |  2280   |   0.229861   |     -      |     -     |   2.22   
   1    |  2300   |   0.432785   |     -      |     -     |   2.22   
   1    |  2320   |   0.391586   |     -      |     -     |   2.22   
   1    |  2340   |   0.442966   |     -      |     -     |   2.22   
   1    |  2360   |   0.369986   |     -      |     -     |   2.22   
   1    |  2380   |   0.480832   |     -      |     -     |   2.22   
   1    |  2400   |   0.486392   |     -      |     -     |   2.22   
   1    |  2420   |   0.582519   |     -      |     -     |   2.22   
   1    |  2440   |   0.528408   |     -      |     -     |   2.22   
   1    |  2460   |   0.433433   |     -      |     -     |   2.22   
   1    |  2480   |   0.389820   |     -      |     -     |   2.21   
   1    |  2500   |   0.403192   |     -      |     -     |   2.22   
   1    |  2520   |   0.535581   |     -      |     -     |   2.21   
   1    |  2540   |   0.584072   |     -      |     -     |   2.22   
   1    |  2560   |   0.425891   |     -      |     -     |   2.22   
   1    |  2580   |   0.397607   |     -      |     -     |   2.22   
   1    |  2600   |   0.434766   |     -      |     -     |   2.22   
   1    |  2620   |   0.391318   |     -      |     -     |   2.22   
   1    |  2640   |   0.427866   |     -      |     -     |   2.22   
   1    |  2660   |   0.345405   |     -      |     -     |   2.22   
   1    |  2680   |   0.407680   |     -      |     -     |   2.21   
   1    |  2700   |   0.483474   |     -      |     -     |   2.22   
   1    |  2720   |   0.354000   |     -      |     -     |   2.22   
   1    |  2740   |   0.447533   |     -      |     -     |   2.22   
   1    |  2760   |   0.448948   |     -      |     -     |   2.22   
   1    |  2780   |   0.496571   |     -      |     -     |   2.22   
   1    |  2800   |   0.457702   |     -      |     -     |   2.22   
   1    |  2820   |   0.295824   |     -      |     -     |   2.21   
   1    |  2840   |   0.420868   |     -      |     -     |   2.22   
   1    |  2860   |   0.452107   |     -      |     -     |   2.22   
   1    |  2880   |   0.606943   |     -      |     -     |   2.22   
   1    |  2900   |   0.450085   |     -      |     -     |   2.23   
   1    |  2920   |   0.353490   |     -      |     -     |   2.22   
   1    |  2940   |   0.364078   |     -      |     -     |   2.22   
   1    |  2960   |   0.394412   |     -      |     -     |   2.22   
   1    |  2980   |   0.454151   |     -      |     -     |   2.22   
   1    |  3000   |   0.418116   |     -      |     -     |   2.22   
   1    |  3020   |   0.420169   |     -      |     -     |   2.22   
   1    |  3040   |   0.256342   |     -      |     -     |   2.22   
   1    |  3060   |   0.271901   |     -      |     -     |   2.22   
   1    |  3080   |   0.422501   |     -      |     -     |   2.21   
   1    |  3100   |   0.430202   |     -      |     -     |   2.22   
   1    |  3120   |   0.472013   |     -      |     -     |   2.22   
   1    |  3140   |   0.462826   |     -      |     -     |   2.22   
   1    |  3160   |   0.430959   |     -      |     -     |   2.23   
   1    |  3180   |   0.341863   |     -      |     -     |   2.22   
   1    |  3200   |   0.456700   |     -      |     -     |   2.22   
   1    |  3220   |   0.425651   |     -      |     -     |   2.23   
   1    |  3240   |   0.348138   |     -      |     -     |   2.22   
   1    |  3260   |   0.497292   |     -      |     -     |   2.22   
   1    |  3280   |   0.498105   |     -      |     -     |   2.22   
   1    |  3300   |   0.329272   |     -      |     -     |   2.22   
   1    |  3320   |   0.464986   |     -      |     -     |   2.22   
   1    |  3340   |   0.273993   |     -      |     -     |   2.22   
   1    |  3360   |   0.430756   |     -      |     -     |   2.22   
   1    |  3380   |   0.318651   |     -      |     -     |   2.22   
   1    |  3400   |   0.454394   |     -      |     -     |   2.22   
   1    |  3420   |   0.483313   |     -      |     -     |   2.22   
   1    |  3440   |   0.379902   |     -      |     -     |   2.22   
   1    |  3460   |   0.328204   |     -      |     -     |   2.21   
   1    |  3480   |   0.294298   |     -      |     -     |   2.22   
   1    |  3500   |   0.484010   |     -      |     -     |   2.21   
   1    |  3520   |   0.485009   |     -      |     -     |   2.22   
   1    |  3540   |   0.371409   |     -      |     -     |   2.23   
   1    |  3560   |   0.456556   |     -      |     -     |   2.22   
   1    |  3580   |   0.357437   |     -      |     -     |   2.22   
   1    |  3600   |   0.360371   |     -      |     -     |   2.22   
   1    |  3620   |   0.500041   |     -      |     -     |   2.22   
   1    |  3640   |   0.430853   |     -      |     -     |   2.22   
   1    |  3660   |   0.336077   |     -      |     -     |   2.22   
   1    |  3680   |   0.386160   |     -      |     -     |   2.21   
   1    |  3700   |   0.381758   |     -      |     -     |   2.22   
   1    |  3720   |   0.328207   |     -      |     -     |   2.21   
   1    |  3740   |   0.350814   |     -      |     -     |   2.22   
   1    |  3760   |   0.298917   |     -      |     -     |   2.22   
   1    |  3780   |   0.551071   |     -      |     -     |   2.22   
   1    |  3800   |   0.400124   |     -      |     -     |   2.22   
   1    |  3820   |   0.558554   |     -      |     -     |   2.22   
   1    |  3840   |   0.375092   |     -      |     -     |   2.22   
   1    |  3860   |   0.366992   |     -      |     -     |   2.22   
   1    |  3880   |   0.459604   |     -      |     -     |   2.22   
   1    |  3900   |   0.395427   |     -      |     -     |   2.22   
   1    |  3920   |   0.497407   |     -      |     -     |   2.22   
   1    |  3940   |   0.332747   |     -      |     -     |   2.22   
   1    |  3960   |   0.424219   |     -      |     -     |   2.22   
   1    |  3980   |   0.302402   |     -      |     -     |   2.23   
   1    |  4000   |   0.412850   |     -      |     -     |   2.22   
   1    |  4020   |   0.496752   |     -      |     -     |   2.22   
   1    |  4040   |   0.412804   |     -      |     -     |   2.22   
   1    |  4060   |   0.365924   |     -      |     -     |   2.22   
   1    |  4080   |   0.431395   |     -      |     -     |   2.23   
   1    |  4100   |   0.482764   |     -      |     -     |   2.22   
   1    |  4120   |   0.402700   |     -      |     -     |   2.22   
   1    |  4140   |   0.381043   |     -      |     -     |   2.22   
   1    |  4160   |   0.207558   |     -      |     -     |   2.22   
   1    |  4180   |   0.388669   |     -      |     -     |   2.22   
   1    |  4200   |   0.452013   |     -      |     -     |   2.22   
   1    |  4220   |   0.410671   |     -      |     -     |   2.23   
   1    |  4240   |   0.264566   |     -      |     -     |   2.22   
   1    |  4260   |   0.487760   |     -      |     -     |   2.23   
   1    |  4274   |   0.524053   |     -      |     -     |   1.54   
----------------------------------------------------------------------


 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------
   2    |   20    |   0.178235   |     -      |     -     |   2.33   
   2    |   40    |   0.236248   |     -      |     -     |   2.22   
   2    |   60    |   0.417528   |     -      |     -     |   2.22   
   2    |   80    |   0.403732   |     -      |     -     |   2.22   
   2    |   100   |   0.281601   |     -      |     -     |   2.21   
   2    |   120   |   0.465623   |     -      |     -     |   2.22   
   2    |   140   |   0.207516   |     -      |     -     |   2.21   
   2    |   160   |   0.314487   |     -      |     -     |   2.22   
   2    |   180   |   0.365850   |     -      |     -     |   2.22   
   2    |   200   |   0.258800   |     -      |     -     |   2.21   
   2    |   220   |   0.393041   |     -      |     -     |   2.22   
   2    |   240   |   0.228221   |     -      |     -     |   2.22   
   2    |   260   |   0.260176   |     -      |     -     |   2.22   
   2    |   280   |   0.383128   |     -      |     -     |   2.21   
   2    |   300   |   0.344241   |     -      |     -     |   2.22   
   2    |   320   |   0.169637   |     -      |     -     |   2.21   
   2    |   340   |   0.358171   |     -      |     -     |   2.22   
   2    |   360   |   0.354750   |     -      |     -     |   2.22   
   2    |   380   |   0.283153   |     -      |     -     |   2.21   
   2    |   400   |   0.345911   |     -      |     -     |   2.22   
   2    |   420   |   0.465970   |     -      |     -     |   2.21   
   2    |   440   |   0.287812   |     -      |     -     |   2.22   
   2    |   460   |   0.322058   |     -      |     -     |   2.22   
   2    |   480   |   0.417622   |     -      |     -     |   2.22   
   2    |   500   |   0.236024   |     -      |     -     |   2.22   
   2    |   520   |   0.252402   |     -      |     -     |   2.21   
   2    |   540   |   0.325842   |     -      |     -     |   2.21   
   2    |   560   |   0.346707   |     -      |     -     |   2.21   
   2    |   580   |   0.273158   |     -      |     -     |   2.21   
   2    |   600   |   0.447459   |     -      |     -     |   2.22   
   2    |   620   |   0.278262   |     -      |     -     |   2.22   
   2    |   640   |   0.257653   |     -      |     -     |   2.22   
   2    |   660   |   0.265684   |     -      |     -     |   2.22   
   2    |   680   |   0.481832   |     -      |     -     |   2.22   
   2    |   700   |   0.341643   |     -      |     -     |   2.22   
   2    |   720   |   0.333686   |     -      |     -     |   2.22   
   2    |   740   |   0.386808   |     -      |     -     |   2.22   
   2    |   760   |   0.338864   |     -      |     -     |   2.22   
   2    |   780   |   0.273587   |     -      |     -     |   2.22   
   2    |   800   |   0.437048   |     -      |     -     |   2.22   
   2    |   820   |   0.403763   |     -      |     -     |   2.22   
   2    |   840   |   0.271199   |     -      |     -     |   2.22   
   2    |   860   |   0.431815   |     -      |     -     |   2.23   
   2    |   880   |   0.408145   |     -      |     -     |   2.22   
   2    |   900   |   0.339019   |     -      |     -     |   2.22   
   2    |   920   |   0.259007   |     -      |     -     |   2.22   
   2    |   940   |   0.253530   |     -      |     -     |   2.22   
   2    |   960   |   0.341351   |     -      |     -     |   2.22   
   2    |   980   |   0.383076   |     -      |     -     |   2.21   
   2    |  1000   |   0.244079   |     -      |     -     |   2.21   
   2    |  1020   |   0.288099   |     -      |     -     |   2.21   
   2    |  1040   |   0.443806   |     -      |     -     |   2.22   
   2    |  1060   |   0.466626   |     -      |     -     |   2.23   
   2    |  1080   |   0.248390   |     -      |     -     |   2.22   
   2    |  1100   |   0.219219   |     -      |     -     |   2.22   
   2    |  1120   |   0.267054   |     -      |     -     |   2.22   
   2    |  1140   |   0.375691   |     -      |     -     |   2.22   
   2    |  1160   |   0.287939   |     -      |     -     |   2.22   
   2    |  1180   |   0.405043   |     -      |     -     |   2.23   
   2    |  1200   |   0.355432   |     -      |     -     |   2.23   
   2    |  1220   |   0.299337   |     -      |     -     |   2.22   
   2    |  1240   |   0.322129   |     -      |     -     |   2.22   
   2    |  1260   |   0.321425   |     -      |     -     |   2.22   
   2    |  1280   |   0.282665   |     -      |     -     |   2.22   
   2    |  1300   |   0.348964   |     -      |     -     |   2.22   
   2    |  1320   |   0.317224   |     -      |     -     |   2.22   
   2    |  1340   |   0.313979   |     -      |     -     |   2.22   
   2    |  1360   |   0.318665   |     -      |     -     |   2.23   
   2    |  1380   |   0.325975   |     -      |     -     |   2.23   
   2    |  1400   |   0.305999   |     -      |     -     |   2.22   
   2    |  1420   |   0.266321   |     -      |     -     |   2.22   
   2    |  1440   |   0.270573   |     -      |     -     |   2.21   
   2    |  1460   |   0.366601   |     -      |     -     |   2.22   
   2    |  1480   |   0.314479   |     -      |     -     |   2.23   
   2    |  1500   |   0.237390   |     -      |     -     |   2.22   
   2    |  1520   |   0.528444   |     -      |     -     |   2.23   
   2    |  1540   |   0.297206   |     -      |     -     |   2.22   
   2    |  1560   |   0.408237   |     -      |     -     |   2.22   
   2    |  1580   |   0.377525   |     -      |     -     |   2.23   
   2    |  1600   |   0.222596   |     -      |     -     |   2.22   
   2    |  1620   |   0.316801   |     -      |     -     |   2.22   
   2    |  1640   |   0.229774   |     -      |     -     |   2.21   
   2    |  1660   |   0.172352   |     -      |     -     |   2.21   
   2    |  1680   |   0.388276   |     -      |     -     |   2.22   
   2    |  1700   |   0.321816   |     -      |     -     |   2.21   
   2    |  1720   |   0.274526   |     -      |     -     |   2.22   
   2    |  1740   |   0.212466   |     -      |     -     |   2.22   
   2    |  1760   |   0.303641   |     -      |     -     |   2.22   
   2    |  1780   |   0.148262   |     -      |     -     |   2.21   
   2    |  1800   |   0.367426   |     -      |     -     |   2.22   
   2    |  1820   |   0.328495   |     -      |     -     |   2.21   
   2    |  1840   |   0.276786   |     -      |     -     |   2.22   
   2    |  1860   |   0.231517   |     -      |     -     |   2.21   
   2    |  1880   |   0.300054   |     -      |     -     |   2.22   
   2    |  1900   |   0.205822   |     -      |     -     |   2.22   
   2    |  1920   |   0.198258   |     -      |     -     |   2.22   
   2    |  1940   |   0.376904   |     -      |     -     |   2.22   
   2    |  1960   |   0.263306   |     -      |     -     |   2.21   
   2    |  1980   |   0.225420   |     -      |     -     |   2.20   
   2    |  2000   |   0.220761   |     -      |     -     |   2.21   
   2    |  2020   |   0.356617   |     -      |     -     |   2.22   
   2    |  2040   |   0.249233   |     -      |     -     |   2.21   
   2    |  2060   |   0.226963   |     -      |     -     |   2.21   
   2    |  2080   |   0.439219   |     -      |     -     |   2.22   
   2    |  2100   |   0.383887   |     -      |     -     |   2.22   
   2    |  2120   |   0.424001   |     -      |     -     |   2.22   
   2    |  2140   |   0.212143   |     -      |     -     |   2.22   
   2    |  2160   |   0.385160   |     -      |     -     |   2.23   
   2    |  2180   |   0.398621   |     -      |     -     |   2.22   
   2    |  2200   |   0.266952   |     -      |     -     |   2.22   
   2    |  2220   |   0.125873   |     -      |     -     |   2.21   
   2    |  2240   |   0.586491   |     -      |     -     |   2.22   
   2    |  2260   |   0.202121   |     -      |     -     |   2.22   
   2    |  2280   |   0.282162   |     -      |     -     |   2.22   
   2    |  2300   |   0.291807   |     -      |     -     |   2.22   
   2    |  2320   |   0.353804   |     -      |     -     |   2.22   
   2    |  2340   |   0.353146   |     -      |     -     |   2.23   
   2    |  2360   |   0.185794   |     -      |     -     |   2.22   
   2    |  2380   |   0.327677   |     -      |     -     |   2.21   
   2    |  2400   |   0.232972   |     -      |     -     |   2.21   
   2    |  2420   |   0.282569   |     -      |     -     |   2.22   
   2    |  2440   |   0.288749   |     -      |     -     |   2.23   
   2    |  2460   |   0.345081   |     -      |     -     |   2.22   
   2    |  2480   |   0.337567   |     -      |     -     |   2.22   
   2    |  2500   |   0.392647   |     -      |     -     |   2.22   
   2    |  2520   |   0.447326   |     -      |     -     |   2.23   
   2    |  2540   |   0.271528   |     -      |     -     |   2.22   
   2    |  2560   |   0.338073   |     -      |     -     |   2.22   
   2    |  2580   |   0.303100   |     -      |     -     |   2.22   
   2    |  2600   |   0.254633   |     -      |     -     |   2.22   
   2    |  2620   |   0.228749   |     -      |     -     |   2.21   
   2    |  2640   |   0.380493   |     -      |     -     |   2.23   
   2    |  2660   |   0.296292   |     -      |     -     |   2.21   
   2    |  2680   |   0.312422   |     -      |     -     |   2.21   
   2    |  2700   |   0.169324   |     -      |     -     |   2.22   
   2    |  2720   |   0.322657   |     -      |     -     |   2.22   
   2    |  2740   |   0.316887   |     -      |     -     |   2.23   
   2    |  2760   |   0.230352   |     -      |     -     |   2.22   
   2    |  2780   |   0.382647   |     -      |     -     |   2.22   
   2    |  2800   |   0.447534   |     -      |     -     |   2.23   
   2    |  2820   |   0.258827   |     -      |     -     |   2.22   
   2    |  2840   |   0.373634   |     -      |     -     |   2.22   
   2    |  2860   |   0.344052   |     -      |     -     |   2.23   
   2    |  2880   |   0.260472   |     -      |     -     |   2.23   
   2    |  2900   |   0.368974   |     -      |     -     |   2.23   
   2    |  2920   |   0.222077   |     -      |     -     |   2.22   
   2    |  2940   |   0.251587   |     -      |     -     |   2.23   
   2    |  2960   |   0.334251   |     -      |     -     |   2.22   
   2    |  2980   |   0.251267   |     -      |     -     |   2.21   
   2    |  3000   |   0.446619   |     -      |     -     |   2.23   
   2    |  3020   |   0.454647   |     -      |     -     |   2.22   
   2    |  3040   |   0.360200   |     -      |     -     |   2.22   
   2    |  3060   |   0.268775   |     -      |     -     |   2.23   
   2    |  3080   |   0.399917   |     -      |     -     |   2.23   
   2    |  3100   |   0.288121   |     -      |     -     |   2.22   
   2    |  3120   |   0.228023   |     -      |     -     |   2.23   
   2    |  3140   |   0.297294   |     -      |     -     |   2.23   
   2    |  3160   |   0.251520   |     -      |     -     |   2.22   
   2    |  3180   |   0.294835   |     -      |     -     |   2.22   
   2    |  3200   |   0.409996   |     -      |     -     |   2.22   
   2    |  3220   |   0.140834   |     -      |     -     |   2.22   
   2    |  3240   |   0.397301   |     -      |     -     |   2.22   
   2    |  3260   |   0.244492   |     -      |     -     |   2.22   
   2    |  3280   |   0.266942   |     -      |     -     |   2.22   
   2    |  3300   |   0.238524   |     -      |     -     |   2.21   
   2    |  3320   |   0.565675   |     -      |     -     |   2.23   
   2    |  3340   |   0.310869   |     -      |     -     |   2.22   
   2    |  3360   |   0.247724   |     -      |     -     |   2.22   
   2    |  3380   |   0.347890   |     -      |     -     |   2.23   
   2    |  3400   |   0.205551   |     -      |     -     |   2.22   
   2    |  3420   |   0.476822   |     -      |     -     |   2.23   
   2    |  3440   |   0.304471   |     -      |     -     |   2.23   
   2    |  3460   |   0.330230   |     -      |     -     |   2.22   
   2    |  3480   |   0.221099   |     -      |     -     |   2.22   
   2    |  3500   |   0.409206   |     -      |     -     |   2.23   
   2    |  3520   |   0.417929   |     -      |     -     |   2.23   
   2    |  3540   |   0.200517   |     -      |     -     |   2.21   
   2    |  3560   |   0.364602   |     -      |     -     |   2.23   
   2    |  3580   |   0.328577   |     -      |     -     |   2.22   
   2    |  3600   |   0.302296   |     -      |     -     |   2.22   
   2    |  3620   |   0.394715   |     -      |     -     |   2.23   
   2    |  3640   |   0.297546   |     -      |     -     |   2.22   
   2    |  3660   |   0.475172   |     -      |     -     |   2.22   
   2    |  3680   |   0.297590   |     -      |     -     |   2.22   
   2    |  3700   |   0.308858   |     -      |     -     |   2.22   
   2    |  3720   |   0.330256   |     -      |     -     |   2.23   
   2    |  3740   |   0.376832   |     -      |     -     |   2.23   
   2    |  3760   |   0.376163   |     -      |     -     |   2.23   
   2    |  3780   |   0.400901   |     -      |     -     |   2.23   
   2    |  3800   |   0.351140   |     -      |     -     |   2.23   
   2    |  3820   |   0.465799   |     -      |     -     |   2.23   
   2    |  3840   |   0.261218   |     -      |     -     |   2.22   
   2    |  3860   |   0.267973   |     -      |     -     |   2.22   
   2    |  3880   |   0.299458   |     -      |     -     |   2.22   
   2    |  3900   |   0.267373   |     -      |     -     |   2.22   
   2    |  3920   |   0.283054   |     -      |     -     |   2.22   
   2    |  3940   |   0.372673   |     -      |     -     |   2.23   
   2    |  3960   |   0.181575   |     -      |     -     |   2.22   
   2    |  3980   |   0.354681   |     -      |     -     |   2.23   
   2    |  4000   |   0.395429   |     -      |     -     |   2.22   
   2    |  4020   |   0.271367   |     -      |     -     |   2.23   
   2    |  4040   |   0.325781   |     -      |     -     |   2.23   
   2    |  4060   |   0.430173   |     -      |     -     |   2.23   
   2    |  4080   |   0.255085   |     -      |     -     |   2.23   
   2    |  4100   |   0.368935   |     -      |     -     |   2.22   
   2    |  4120   |   0.297810   |     -      |     -     |   2.23   
   2    |  4140   |   0.332633   |     -      |     -     |   2.22   
   2    |  4160   |   0.246938   |     -      |     -     |   2.22   
   2    |  4180   |   0.209579   |     -      |     -     |   2.22   
   2    |  4200   |   0.231430   |     -      |     -     |   2.22   
   2    |  4220   |   0.176111   |     -      |     -     |   2.21   
   2    |  4240   |   0.221054   |     -      |     -     |   2.22   
   2    |  4260   |   0.231887   |     -      |     -     |   2.21   
   2    |  4274   |   0.406195   |     -      |     -     |   1.54   
----------------------------------------------------------------------


 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------
   3    |   20    |   0.287907   |     -      |     -     |   2.34   
   3    |   40    |   0.169149   |     -      |     -     |   2.22   
   3    |   60    |   0.146852   |     -      |     -     |   2.22   
   3    |   80    |   0.340404   |     -      |     -     |   2.22   
   3    |   100   |   0.187099   |     -      |     -     |   2.22   
   3    |   120   |   0.242595   |     -      |     -     |   2.22   
   3    |   140   |   0.280586   |     -      |     -     |   2.22   
   3    |   160   |   0.133286   |     -      |     -     |   2.22   
   3    |   180   |   0.306578   |     -      |     -     |   2.22   
   3    |   200   |   0.149105   |     -      |     -     |   2.21   
   3    |   220   |   0.164732   |     -      |     -     |   2.21   
   3    |   240   |   0.176946   |     -      |     -     |   2.21   
   3    |   260   |   0.214384   |     -      |     -     |   2.22   
   3    |   280   |   0.193741   |     -      |     -     |   2.21   
   3    |   300   |   0.121526   |     -      |     -     |   2.21   
   3    |   320   |   0.316858   |     -      |     -     |   2.22   
   3    |   340   |   0.219560   |     -      |     -     |   2.22   
   3    |   360   |   0.292677   |     -      |     -     |   2.23   
   3    |   380   |   0.108325   |     -      |     -     |   2.21   
   3    |   400   |   0.252165   |     -      |     -     |   2.21   
   3    |   420   |   0.192938   |     -      |     -     |   2.22   
   3    |   440   |   0.192702   |     -      |     -     |   2.21   
   3    |   460   |   0.198925   |     -      |     -     |   2.21   
   3    |   480   |   0.128961   |     -      |     -     |   2.21   
   3    |   500   |   0.261096   |     -      |     -     |   2.21   
   3    |   520   |   0.162804   |     -      |     -     |   2.21   
   3    |   540   |   0.186448   |     -      |     -     |   2.21   
   3    |   560   |   0.118215   |     -      |     -     |   2.21   
   3    |   580   |   0.224330   |     -      |     -     |   2.22   
   3    |   600   |   0.368840   |     -      |     -     |   2.22   
   3    |   620   |   0.228781   |     -      |     -     |   2.21   
   3    |   640   |   0.142742   |     -      |     -     |   2.22   
   3    |   660   |   0.098649   |     -      |     -     |   2.21   
   3    |   680   |   0.319665   |     -      |     -     |   2.21   
   3    |   700   |   0.254994   |     -      |     -     |   2.21   
   3    |   720   |   0.297118   |     -      |     -     |   2.22   
   3    |   740   |   0.174241   |     -      |     -     |   2.22   
   3    |   760   |   0.150964   |     -      |     -     |   2.22   
   3    |   780   |   0.242742   |     -      |     -     |   2.22   
   3    |   800   |   0.233357   |     -      |     -     |   2.22   
   3    |   820   |   0.179140   |     -      |     -     |   2.21   
   3    |   840   |   0.256957   |     -      |     -     |   2.22   
   3    |   860   |   0.194072   |     -      |     -     |   2.21   
   3    |   880   |   0.169564   |     -      |     -     |   2.21   
   3    |   900   |   0.140018   |     -      |     -     |   2.21   
   3    |   920   |   0.260728   |     -      |     -     |   2.22   
   3    |   940   |   0.196077   |     -      |     -     |   2.21   
   3    |   960   |   0.138152   |     -      |     -     |   2.22   
   3    |   980   |   0.339784   |     -      |     -     |   2.22   
   3    |  1000   |   0.198123   |     -      |     -     |   2.22   
   3    |  1020   |   0.184109   |     -      |     -     |   2.21   
   3    |  1040   |   0.192801   |     -      |     -     |   2.21   
   3    |  1060   |   0.260482   |     -      |     -     |   2.22   
   3    |  1080   |   0.135853   |     -      |     -     |   2.22   
   3    |  1100   |   0.198053   |     -      |     -     |   2.21   
   3    |  1120   |   0.184740   |     -      |     -     |   2.21   
   3    |  1140   |   0.337751   |     -      |     -     |   2.23   
   3    |  1160   |   0.164092   |     -      |     -     |   2.22   
   3    |  1180   |   0.241024   |     -      |     -     |   2.21   
   3    |  1200   |   0.289141   |     -      |     -     |   2.22   
   3    |  1220   |   0.215991   |     -      |     -     |   2.22   
   3    |  1240   |   0.217216   |     -      |     -     |   2.23   
   3    |  1260   |   0.128413   |     -      |     -     |   2.22   
   3    |  1280   |   0.160155   |     -      |     -     |   2.21   
   3    |  1300   |   0.216679   |     -      |     -     |   2.22   
   3    |  1320   |   0.167887   |     -      |     -     |   2.22   
   3    |  1340   |   0.155127   |     -      |     -     |   2.21   
   3    |  1360   |   0.203865   |     -      |     -     |   2.21   
   3    |  1380   |   0.246050   |     -      |     -     |   2.21   
   3    |  1400   |   0.142925   |     -      |     -     |   2.21   
   3    |  1420   |   0.180820   |     -      |     -     |   2.22   
   3    |  1440   |   0.219821   |     -      |     -     |   2.22   
   3    |  1460   |   0.170561   |     -      |     -     |   2.21   
   3    |  1480   |   0.262509   |     -      |     -     |   2.22   
   3    |  1500   |   0.255177   |     -      |     -     |   2.21   
   3    |  1520   |   0.199802   |     -      |     -     |   2.22   
   3    |  1540   |   0.302800   |     -      |     -     |   2.22   
   3    |  1560   |   0.272904   |     -      |     -     |   2.21   
   3    |  1580   |   0.302736   |     -      |     -     |   2.21   
   3    |  1600   |   0.183896   |     -      |     -     |   2.21   
   3    |  1620   |   0.200647   |     -      |     -     |   2.21   
   3    |  1640   |   0.181364   |     -      |     -     |   2.22   
   3    |  1660   |   0.097935   |     -      |     -     |   2.21   
   3    |  1680   |   0.146785   |     -      |     -     |   2.21   
   3    |  1700   |   0.277353   |     -      |     -     |   2.22   
   3    |  1720   |   0.168417   |     -      |     -     |   2.21   
   3    |  1740   |   0.189893   |     -      |     -     |   2.21   
   3    |  1760   |   0.401158   |     -      |     -     |   2.23   
   3    |  1780   |   0.187433   |     -      |     -     |   2.22   
   3    |  1800   |   0.201111   |     -      |     -     |   2.21   
   3    |  1820   |   0.239526   |     -      |     -     |   2.22   
   3    |  1840   |   0.260277   |     -      |     -     |   2.20   
   3    |  1860   |   0.105640   |     -      |     -     |   2.21   
   3    |  1880   |   0.217080   |     -      |     -     |   2.22   
   3    |  1900   |   0.175066   |     -      |     -     |   2.22   
   3    |  1920   |   0.290558   |     -      |     -     |   2.22   
   3    |  1940   |   0.195370   |     -      |     -     |   2.22   
   3    |  1960   |   0.289704   |     -      |     -     |   2.21   
   3    |  1980   |   0.205946   |     -      |     -     |   2.21   
   3    |  2000   |   0.378616   |     -      |     -     |   2.22   
   3    |  2020   |   0.339383   |     -      |     -     |   2.22   
   3    |  2040   |   0.096656   |     -      |     -     |   2.21   
   3    |  2060   |   0.208559   |     -      |     -     |   2.21   
   3    |  2080   |   0.223843   |     -      |     -     |   2.21   
   3    |  2100   |   0.169819   |     -      |     -     |   2.22   
   3    |  2120   |   0.124758   |     -      |     -     |   2.22   
   3    |  2140   |   0.271942   |     -      |     -     |   2.22   
   3    |  2160   |   0.324865   |     -      |     -     |   2.21   
   3    |  2180   |   0.174191   |     -      |     -     |   2.22   
   3    |  2200   |   0.237353   |     -      |     -     |   2.22   
   3    |  2220   |   0.261197   |     -      |     -     |   2.21   
   3    |  2240   |   0.204816   |     -      |     -     |   2.21   
   3    |  2260   |   0.136572   |     -      |     -     |   2.23   
   3    |  2280   |   0.170786   |     -      |     -     |   2.22   
   3    |  2300   |   0.193347   |     -      |     -     |   2.22   
   3    |  2320   |   0.115297   |     -      |     -     |   2.22   
   3    |  2340   |   0.283210   |     -      |     -     |   2.21   
   3    |  2360   |   0.136321   |     -      |     -     |   2.21   
   3    |  2380   |   0.139913   |     -      |     -     |   2.21   
   3    |  2400   |   0.284366   |     -      |     -     |   2.22   
   3    |  2420   |   0.272027   |     -      |     -     |   2.22   
   3    |  2440   |   0.190968   |     -      |     -     |   2.21   
   3    |  2460   |   0.128395   |     -      |     -     |   2.21   
   3    |  2480   |   0.135913   |     -      |     -     |   2.20   
   3    |  2500   |   0.096405   |     -      |     -     |   2.21   
   3    |  2520   |   0.332776   |     -      |     -     |   2.22   
   3    |  2540   |   0.131247   |     -      |     -     |   2.21   
   3    |  2560   |   0.256691   |     -      |     -     |   2.21   
   3    |  2580   |   0.186974   |     -      |     -     |   2.21   
   3    |  2600   |   0.333957   |     -      |     -     |   2.22   
   3    |  2620   |   0.299003   |     -      |     -     |   2.22   
   3    |  2640   |   0.261602   |     -      |     -     |   2.22   
   3    |  2660   |   0.250151   |     -      |     -     |   2.22   
   3    |  2680   |   0.210117   |     -      |     -     |   2.21   
   3    |  2700   |   0.374809   |     -      |     -     |   2.22   
   3    |  2720   |   0.180602   |     -      |     -     |   2.21   
   3    |  2740   |   0.262619   |     -      |     -     |   2.23   
   3    |  2760   |   0.190584   |     -      |     -     |   2.22   
   3    |  2780   |   0.298435   |     -      |     -     |   2.22   
   3    |  2800   |   0.206946   |     -      |     -     |   2.23   
   3    |  2820   |   0.157499   |     -      |     -     |   2.22   
   3    |  2840   |   0.108108   |     -      |     -     |   2.21   
   3    |  2860   |   0.152643   |     -      |     -     |   2.20   
   3    |  2880   |   0.133671   |     -      |     -     |   2.21   
   3    |  2900   |   0.162220   |     -      |     -     |   2.21   
   3    |  2920   |   0.141410   |     -      |     -     |   2.21   
   3    |  2940   |   0.358039   |     -      |     -     |   2.22   
   3    |  2960   |   0.137600   |     -      |     -     |   2.21   
   3    |  2980   |   0.109006   |     -      |     -     |   2.21   
   3    |  3000   |   0.214440   |     -      |     -     |   2.22   
   3    |  3020   |   0.172844   |     -      |     -     |   2.21   
   3    |  3040   |   0.221561   |     -      |     -     |   2.22   
   3    |  3060   |   0.290055   |     -      |     -     |   2.21   
   3    |  3080   |   0.128516   |     -      |     -     |   2.22   
   3    |  3100   |   0.179072   |     -      |     -     |   2.21   
   3    |  3120   |   0.159905   |     -      |     -     |   2.21   
   3    |  3140   |   0.061597   |     -      |     -     |   2.21   
   3    |  3160   |   0.230245   |     -      |     -     |   2.22   
   3    |  3180   |   0.298810   |     -      |     -     |   2.22   
   3    |  3200   |   0.147486   |     -      |     -     |   2.22   
   3    |  3220   |   0.148169   |     -      |     -     |   2.21   
   3    |  3240   |   0.198698   |     -      |     -     |   2.22   
   3    |  3260   |   0.344509   |     -      |     -     |   2.22   
   3    |  3280   |   0.145770   |     -      |     -     |   2.21   
   3    |  3300   |   0.208897   |     -      |     -     |   2.22   
   3    |  3320   |   0.262981   |     -      |     -     |   2.22   
   3    |  3340   |   0.131428   |     -      |     -     |   2.21   
   3    |  3360   |   0.163636   |     -      |     -     |   2.20   
   3    |  3380   |   0.253057   |     -      |     -     |   2.21   
   3    |  3400   |   0.088063   |     -      |     -     |   2.21   
   3    |  3420   |   0.153364   |     -      |     -     |   2.22   
   3    |  3440   |   0.154884   |     -      |     -     |   2.21   
   3    |  3460   |   0.144953   |     -      |     -     |   2.22   
   3    |  3480   |   0.236155   |     -      |     -     |   2.21   
   3    |  3500   |   0.050307   |     -      |     -     |   2.21   
   3    |  3520   |   0.318051   |     -      |     -     |   2.22   
   3    |  3540   |   0.151670   |     -      |     -     |   2.22   
   3    |  3560   |   0.101136   |     -      |     -     |   2.21   
   3    |  3580   |   0.154310   |     -      |     -     |   2.21   
   3    |  3600   |   0.255000   |     -      |     -     |   2.22   
   3    |  3620   |   0.254025   |     -      |     -     |   2.21   
   3    |  3640   |   0.150957   |     -      |     -     |   2.21   
   3    |  3660   |   0.252347   |     -      |     -     |   2.21   
   3    |  3680   |   0.163677   |     -      |     -     |   2.21   
   3    |  3700   |   0.189354   |     -      |     -     |   2.21   
   3    |  3720   |   0.200923   |     -      |     -     |   2.21   
   3    |  3740   |   0.173451   |     -      |     -     |   2.21   
   3    |  3760   |   0.264323   |     -      |     -     |   2.21   
   3    |  3780   |   0.185939   |     -      |     -     |   2.21   
   3    |  3800   |   0.139420   |     -      |     -     |   2.21   
   3    |  3820   |   0.248304   |     -      |     -     |   2.21   
   3    |  3840   |   0.264812   |     -      |     -     |   2.21   
   3    |  3860   |   0.208504   |     -      |     -     |   2.21   
   3    |  3880   |   0.249622   |     -      |     -     |   2.21   
   3    |  3900   |   0.377659   |     -      |     -     |   2.22   
   3    |  3920   |   0.116282   |     -      |     -     |   2.21   
   3    |  3940   |   0.161975   |     -      |     -     |   2.21   
   3    |  3960   |   0.235030   |     -      |     -     |   2.21   
   3    |  3980   |   0.256480   |     -      |     -     |   2.22   
   3    |  4000   |   0.196103   |     -      |     -     |   2.21   
   3    |  4020   |   0.251700   |     -      |     -     |   2.22   
   3    |  4040   |   0.221106   |     -      |     -     |   2.22   
   3    |  4060   |   0.237219   |     -      |     -     |   2.21   
   3    |  4080   |   0.080999   |     -      |     -     |   2.21   
   3    |  4100   |   0.142421   |     -      |     -     |   2.21   
   3    |  4120   |   0.078593   |     -      |     -     |   2.20   
   3    |  4140   |   0.237669   |     -      |     -     |   2.21   
   3    |  4160   |   0.251462   |     -      |     -     |   2.21   
   3    |  4180   |   0.219768   |     -      |     -     |   2.22   
   3    |  4200   |   0.086318   |     -      |     -     |   2.21   
   3    |  4220   |   0.172586   |     -      |     -     |   2.22   
   3    |  4240   |   0.216208   |     -      |     -     |   2.22   
   3    |  4260   |   0.123733   |     -      |     -     |   2.21   
   3    |  4274   |   0.190683   |     -      |     -     |   1.54   
----------------------------------------------------------------------


Training complete!
Out[16]:
(0.10784337061668123, 97.30029585798816)
In [17]:
# Run `preprocessing_for_bert` on the test set
test_data = pd.read_csv('data/test.csv')
print('Tokenizing data...')
test_inputs, test_masks = preprocessing_for_bert(test_data.text)

# Create the DataLoader for our test set
test_dataset = TensorDataset(test_inputs, test_masks)
test_sampler = SequentialSampler(test_dataset)
test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=8)
Tokenizing data...
/usr/local/lib/python3.7/dist-packages/transformers/tokenization_utils_base.py:2132: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).
  FutureWarning,
In [18]:
# Compute predicted probabilities on the test set
probs = bert_predict(bert_classifier, test_dataloader)

# Get predictions from the probabilities
# Since it is multi class, we take argmax for label
preds = np.argmax(probs, axis=1)
test_data['label'] = preds
test_data['label'].value_counts(normalize=True)
Out[18]:
3    0.633241
0    0.135463
1    0.130648
2    0.100648
Name: label, dtype: float64
In [19]:
!mkdir assets

# Saving the sample submission in assets directory
test_data.to_csv(os.path.join("assets", "submission.csv"), index=False)
In [ ]:

Mounting Google Drive 💾
Your Google Drive will be mounted to access the colab notebook
Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.activity.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fexperimentsandconfigs%20https%3a%2f%2fwww.googleapis.com%2fauth%2fphotos.native&response_type=code

Enter your authorization code:
In [17]:


Comments

You must login before you can post a comment.

Execute