]> git.angelumana.com Git - tweet_classification/.git/commitdiff
Delete BERT_neighbors_classification.ipynb
authoribidyouadu <60790401+ibidyouadu@users.noreply.github.com>
Mon, 11 May 2020 03:22:31 +0000 (23:22 -0400)
committerGitHub <noreply@github.com>
Mon, 11 May 2020 03:22:31 +0000 (23:22 -0400)
BERT_neighbors_classification.ipynb [deleted file]

diff --git a/BERT_neighbors_classification.ipynb b/BERT_neighbors_classification.ipynb
deleted file mode 100644 (file)
index 93d6c6f..0000000
+++ /dev/null
@@ -1,719 +0,0 @@
-{
- "cells": [
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "The first few cells will import some libraries and a script. For the most part this is identical to [xhlulu's Kaggle code.](https://www.kaggle.com/xhlulu/disaster-nlp-keras-bert-using-tfhub) Only difference is the first cell, which you only need to run if using google colab, and the importing of re and nltk in cell 3. "
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/",
-     "height": 196
-    },
-    "colab_type": "code",
-    "id": "z2EOfe64FcUh",
-    "outputId": "fb8d8809-4d89-40b4-833f-7836feff4577"
-   },
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Collecting sentencepiece\n",
-      "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/98/2c/8df20f3ac6c22ac224fff307ebc102818206c53fc454ecd37d8ac2060df5/sentencepiece-0.1.86-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)\n",
-      "\r",
-      "\u001b[K     |▎                               | 10kB 21.6MB/s eta 0:00:01\r",
-      "\u001b[K     |▋                               | 20kB 6.7MB/s eta 0:00:01\r",
-      "\u001b[K     |█                               | 30kB 7.7MB/s eta 0:00:01\r",
-      "\u001b[K     |█▎                              | 40kB 8.5MB/s eta 0:00:01\r",
-      "\u001b[K     |█▋                              | 51kB 7.7MB/s eta 0:00:01\r",
-      "\u001b[K     |██                              | 61kB 8.5MB/s eta 0:00:01\r",
-      "\u001b[K     |██▏                             | 71kB 8.9MB/s eta 0:00:01\r",
-      "\u001b[K     |██▌                             | 81kB 9.3MB/s eta 0:00:01\r",
-      "\u001b[K     |██▉                             | 92kB 9.6MB/s eta 0:00:01\r",
-      "\u001b[K     |███▏                            | 102kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███▌                            | 112kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███▉                            | 122kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████                            | 133kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████▍                           | 143kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████▊                           | 153kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████                           | 163kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████▍                          | 174kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████▊                          | 184kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████                          | 194kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████▎                         | 204kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████▋                         | 215kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████                         | 225kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████▎                        | 235kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████▋                        | 245kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████▉                        | 256kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████████▏                       | 266kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████████▌                       | 276kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████████▉                       | 286kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████████▏                      | 296kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████████▌                      | 307kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████████▊                      | 317kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████████                      | 327kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████████▍                     | 337kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████████▊                     | 348kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████████                     | 358kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████████▍                    | 368kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████████▋                    | 378kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████████████                    | 389kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████████████▎                   | 399kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████████████▋                   | 409kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████████████                   | 419kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████████████▎                  | 430kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████████████▌                  | 440kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████████████▉                  | 450kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████████████▏                 | 460kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████████████▌                 | 471kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████████████▉                 | 481kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████████████▏                | 491kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████████████▍                | 501kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████████████▊                | 512kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████████████████                | 522kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████████████████▍               | 532kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████████████████▊               | 542kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████████████████               | 552kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████████████████▎              | 563kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████████████████▋              | 573kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████████████████              | 583kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████████████████▎             | 593kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████████████████▋             | 604kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████████████████             | 614kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████████████████▏            | 624kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████████████████▌            | 634kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████████████████▉            | 645kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████████████████████▏           | 655kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████████████████████▌           | 665kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████████████████████▉           | 675kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████████████████████▏          | 686kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████████████████████▍          | 696kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████████████████████▊          | 706kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████████████████████          | 716kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████████████████████▍         | 727kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████████████████████▊         | 737kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████████████████████         | 747kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████████████████████▎        | 757kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████████████████████▋        | 768kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████████████████████████        | 778kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████████████████████████▎       | 788kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████████████████████████▋       | 798kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████████████████████████       | 808kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████████████████████████▏      | 819kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████████████████████████▌      | 829kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████████████████████████▉      | 839kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████████████████████████▏     | 849kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████████████████████████▌     | 860kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████████████████████████▉     | 870kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████████████████████████     | 880kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████████████████████████▍    | 890kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████████████████████████▊    | 901kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████████████████████████████    | 911kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████████████████████████████▍   | 921kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████████████████████████████▊   | 931kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████████████████████████████   | 942kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████████████████████████████▎  | 952kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |█████████████████████████████▋  | 962kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████████████████████████████  | 972kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████████████████████████████▎ | 983kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████████████████████████████▋ | 993kB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |██████████████████████████████▉ | 1.0MB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████████████████████████████▏| 1.0MB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████████████████████████████▌| 1.0MB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |███████████████████████████████▉| 1.0MB 10.0MB/s eta 0:00:01\r",
-      "\u001b[K     |████████████████████████████████| 1.0MB 10.0MB/s \n",
-      "\u001b[?25hInstalling collected packages: sentencepiece\n",
-      "Successfully installed sentencepiece-0.1.86\n",
-      "[nltk_data] Downloading package stopwords to /root/nltk_data...\n",
-      "[nltk_data]   Unzipping corpora/stopwords.zip.\n",
-      "[nltk_data] Downloading package wordnet to /root/nltk_data...\n",
-      "[nltk_data]   Unzipping corpora/wordnet.zip.\n"
-     ]
-    },
-    {
-     "data": {
-      "text/plain": [
-       "True"
-      ]
-     },
-     "execution_count": 1,
-     "metadata": {
-      "tags": []
-     },
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "# Use the following if you need the libraries. If working in google colab, you will need these lines.\n",
-    "\n",
-    "# !pip install sentencepiece\n",
-    "# import nltk\n",
-    "# nltk.download('wordnet')"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 0,
-   "metadata": {
-    "colab": {},
-    "colab_type": "code",
-    "id": "gR6-jumvFhr9"
-   },
-   "outputs": [],
-   "source": [
-    "!wget --quiet https://raw.githubusercontent.com/tensorflow/models/master/official/nlp/bert/tokenization.py"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 3,
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/",
-     "height": 35
-    },
-    "colab_type": "code",
-    "id": "75PqQCwcFlji",
-    "outputId": "41863890-3b6e-4b60-9b60-e4d5158a73dd"
-   },
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "DONE\n"
-     ]
-    }
-   ],
-   "source": [
-    "import numpy as np\n",
-    "import pandas as pd\n",
-    "import tensorflow as tf\n",
-    "from tensorflow.keras.layers import Dense, Input\n",
-    "from tensorflow.keras.optimizers import Adam\n",
-    "from tensorflow.keras.models import Model\n",
-    "from tensorflow.keras.callbacks import ModelCheckpoint\n",
-    "import tensorflow_hub as hub\n",
-    "import re\n",
-    "from nltk.stem import WordNetLemmatizer\n",
-    "from nltk.corpus import stopwords\n",
-    "\n",
-    "\n",
-    "import tokenization"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "The following four cells are unchanged from the Kaggle code"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 0,
-   "metadata": {
-    "colab": {},
-    "colab_type": "code",
-    "id": "ULdRsYGtFnro"
-   },
-   "outputs": [],
-   "source": [
-    "def bert_encode(texts, tokenizer, max_len=512):\n",
-    "    all_tokens = []\n",
-    "    all_masks = []\n",
-    "    all_segments = []\n",
-    "\n",
-    "\n",
-    "    for text in texts:\n",
-    "        text = tokenizer.tokenize(text)\n",
-    "           \n",
-    "        text = text[:max_len-2]\n",
-    "        input_sequence = [\"[CLS]\"] + text + [\"[SEP]\"]\n",
-    "        pad_len = max_len - len(input_sequence)\n",
-    "        \n",
-    "        tokens = tokenizer.convert_tokens_to_ids(input_sequence)\n",
-    "        tokens += [0] * pad_len\n",
-    "        pad_masks = [1] * len(input_sequence) + [0] * pad_len\n",
-    "        segment_ids = [0] * max_len\n",
-    "        \n",
-    "        all_tokens.append(tokens)\n",
-    "        all_masks.append(pad_masks)\n",
-    "        all_segments.append(segment_ids)\n",
-    "    \n",
-    "    return np.array(all_tokens), np.array(all_masks), np.array(all_segments)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 0,
-   "metadata": {
-    "colab": {},
-    "colab_type": "code",
-    "id": "mo7LXhGLFrCz"
-   },
-   "outputs": [],
-   "source": [
-    "def build_model(bert_layer, max_len=512):\n",
-    "    input_word_ids = Input(shape=(max_len,), dtype=tf.int32, name=\"input_word_ids\")\n",
-    "    input_mask = Input(shape=(max_len,), dtype=tf.int32, name=\"input_mask\")\n",
-    "    segment_ids = Input(shape=(max_len,), dtype=tf.int32, name=\"segment_ids\")\n",
-    "\n",
-    "    _, sequence_output = bert_layer([input_word_ids, input_mask, segment_ids])\n",
-    "    clf_output = sequence_output[:, 0, :]\n",
-    "    out = Dense(1, activation='sigmoid')(clf_output)\n",
-    "    \n",
-    "    model = Model(inputs=[input_word_ids, input_mask, segment_ids], outputs=out)\n",
-    "    model.compile(Adam(lr=2e-6), loss='binary_crossentropy', metrics=['accuracy'])\n",
-    "    \n",
-    "    return model"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 6,
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/",
-     "height": 53
-    },
-    "colab_type": "code",
-    "id": "ukl7WfAiFtqL",
-    "outputId": "94997e5e-b150-4d00-ba59-ba515d56a169"
-   },
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "CPU times: user 9.67 s, sys: 2.01 s, total: 11.7 s\n",
-      "Wall time: 13 s\n"
-     ]
-    }
-   ],
-   "source": [
-    "%%time\n",
-    "module_url = \"https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1\"\n",
-    "bert_layer = hub.KerasLayer(module_url, trainable=True)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 0,
-   "metadata": {
-    "colab": {},
-    "colab_type": "code",
-    "id": "l-N0Z4J_62Nh"
-   },
-   "outputs": [],
-   "source": [
-    "vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()\n",
-    "do_lower_case = bert_layer.resolved_object.do_lower_case.numpy()\n",
-    "tokenizer = tokenization.FullTokenizer(vocab_file, do_lower_case)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "This next cell is what separates this notebook from the Kaggle code. The function used to generated neighborhoods is defined. A small lemmatization function is defined and used in the generator function in order to extract a keyword index. The basic BERT tokenizer in conjunction with some regex filtering is used to clean and tokenize the text, again to find the keyword index. Once the keyword index is found, the neighborhood is generated by finding how many tokens to the left and how many tokens to the right to use as the neighborhood. These tokens are joined together with spaces and returned as output."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 0,
-   "metadata": {
-    "colab": {},
-    "colab_type": "code",
-    "id": "x66XjPOMFvMd"
-   },
-   "outputs": [],
-   "source": [
-    "def lemmatize(x):\n",
-    "  lemmatizer = WordNetLemmatizer()\n",
-    "  return lemmatizer.lemmatize(lemmatizer.lemmatize(x,pos='v'))\n",
-    "\n",
-    "kwds = 'blackout duke dukeenergy electric electricity fpl outage power'.split()\n",
-    "basic_tok = tokenization.BasicTokenizer()\n",
-    "\n",
-    "def neighborhood_generator(text,kwd):\n",
-    "    parse_seq = r'[-_+<=>\\[\\]{}`&;\\/()#!@,.\\n?]|\\x80|\\:\\/\\/.*$'\n",
-    "    split_tokens = text.lower().split()\n",
-    "    stop_words = set(stopwords.words('english'))\n",
-    "    kwd_idx = 0\n",
-    "\n",
-    "    def tok_process(tok): #remove special chars, apply basic BERT tokenizer, lemmatize, and remove stop words\n",
-    "        parse_tok = re.sub(parse_seq, ' ', tok)\n",
-    "        tok_tokens = basic_tok.tokenize(parse_tok) #tokens of the token!\n",
-    "        tok_lemmatized_tokens = [lemmatize(tok_tok) for tok_tok in tok_tokens]\n",
-    "        # tok_lemmatized_tokens = [tok_tok for tok_tok in tok_lemmatized_tokens if tok_tok not in stop_words]\n",
-    "        return tok_lemmatized_tokens\n",
-    "\n",
-    "    if kwd in split_tokens: #see if we can get away without going through the text processing\n",
-    "        kwd_idx = [idx for idx in range(len(split_tokens)) if split_tokens[idx]==kwd][0]\n",
-    "    else:\n",
-    "        for idx in range(len(split_tokens)):\n",
-    "            tok = split_tokens[idx]\n",
-    "            tok_lemmatized_tokens = tok_process(tok)\n",
-    "            if kwd in tok_lemmatized_tokens:\n",
-    "                kwd_idx = idx\n",
-    "                break\n",
-    "\n",
-    "    neighborhood_radius = 3\n",
-    "    before_kwd = split_tokens[:kwd_idx]\n",
-    "    after_kwd = split_tokens[kwd_idx+1:]\n",
-    "\n",
-    "    before_idx = [idx for idx in range(kwd_idx)][::-1]\n",
-    "    after_idx = [idx for idx in range(kwd_idx+1,len(split_tokens))]\n",
-    "    start_tok_idx = 0 #if search extends beyond startin index, use the first token\n",
-    "    end_tok_idx = len(split_tokens) #if search extends beyond last index, use the last index\n",
-    "    before_sig_toks = 0\n",
-    "    after_sig_toks = 0\n",
-    "\n",
-    "    for idx in before_idx:\n",
-    "        tok = split_tokens[idx]\n",
-    "        tok_lemmatized_tokens = tok_process(tok)\n",
-    "        if len(tok_lemmatized_tokens) > 0: # significant token\n",
-    "            before_sig_toks += 1\n",
-    "            if before_sig_toks == neighborhood_radius:\n",
-    "                start_tok_idx = idx\n",
-    "                break\n",
-    "    \n",
-    "    for idx in after_idx:\n",
-    "        tok = split_tokens[idx]\n",
-    "        tok_lemmatized_tokens = tok_process(tok)\n",
-    "        if len(tok_lemmatized_tokens) > 0:\n",
-    "            after_sig_toks += 1\n",
-    "            if after_sig_toks == neighborhood_radius:\n",
-    "                end_tok_idx = idx\n",
-    "                break\n",
-    "                \n",
-    "    neighborhood = split_tokens[start_tok_idx:end_tok_idx+1]\n",
-    "    neighborhood = ' '.join(neighborhood)\n",
-    "    \n",
-    "    return neighborhood"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Iterate through all the tweets to generate the neighborhoods."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 0,
-   "metadata": {
-    "colab": {},
-    "colab_type": "code",
-    "id": "UBn2jMFuBC27"
-   },
-   "outputs": [],
-   "source": [
-    "town = pd.DataFrame(columns=['text']) # a town is a collection of neighborhoods\n",
-    "neighborhood_vals = []\n",
-    "\n",
-    "data = pd.read_csv('irma_power_tweets.csv')\n",
-    "\n",
-    "for idx in data.index:\n",
-    "  row = data.iloc[idx]\n",
-    "  text = row.text\n",
-    "  kwd = row.kwd\n",
-    "  new_row = neighborhood_generator(text,kwd)\n",
-    "  neighborhood_vals.append(new_row)\n",
-    "town.text = neighborhood_vals\n",
-    "town['original_twt'] = data.text\n",
-    "town = town[['original_twt', 'text']]\n",
-    "\n",
-    "from sklearn.model_selection import train_test_split\n",
-    "train, test, train_labels, test_labels = train_test_split(town.text, data.topic_related,\n",
-    "                                                          test_size=0.2,\n",
-    "                                                          random_state=42)\n",
-    "\n",
-    "train_input = bert_encode(train, tokenizer, max_len=160)\n",
-    "test_input = bert_encode(test, tokenizer, max_len=160)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {
-    "colab": {},
-    "colab_type": "code",
-    "id": "gu5PX-saIhSQ"
-   },
-   "source": [
-    "The next three cells are unchanged from the Kaggle code. These blocks build the model and predict the outputs for the data."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 18,
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/",
-     "height": 410
-    },
-    "colab_type": "code",
-    "id": "OYjF80BvIjId",
-    "outputId": "327fa3e6-2fd1-4f53-bca1-8bc2456ca391"
-   },
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Model: \"model_1\"\n",
-      "__________________________________________________________________________________________________\n",
-      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
-      "==================================================================================================\n",
-      "input_word_ids (InputLayer)     [(None, 160)]        0                                            \n",
-      "__________________________________________________________________________________________________\n",
-      "input_mask (InputLayer)         [(None, 160)]        0                                            \n",
-      "__________________________________________________________________________________________________\n",
-      "segment_ids (InputLayer)        [(None, 160)]        0                                            \n",
-      "__________________________________________________________________________________________________\n",
-      "keras_layer (KerasLayer)        [(None, 768), (None, 109482241   input_word_ids[0][0]             \n",
-      "                                                                 input_mask[0][0]                 \n",
-      "                                                                 segment_ids[0][0]                \n",
-      "__________________________________________________________________________________________________\n",
-      "tf_op_layer_strided_slice_1 (Te [(None, 768)]        0           keras_layer[1][1]                \n",
-      "__________________________________________________________________________________________________\n",
-      "dense_1 (Dense)                 (None, 1)            769         tf_op_layer_strided_slice_1[0][0]\n",
-      "==================================================================================================\n",
-      "Total params: 109,483,010\n",
-      "Trainable params: 109,483,009\n",
-      "Non-trainable params: 1\n",
-      "__________________________________________________________________________________________________\n"
-     ]
-    }
-   ],
-   "source": [
-    "model = build_model(bert_layer, max_len=160)\n",
-    "model.summary()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 19,
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/",
-     "height": 124
-    },
-    "colab_type": "code",
-    "id": "g3JP8ENtIry4",
-    "outputId": "6360a6df-3da4-4087-ed43-23fc52a2c20e"
-   },
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Epoch 1/3\n",
-      "94/94 [==============================] - 947s 10s/step - loss: 0.3609 - accuracy: 0.8203 - val_loss: 0.2471 - val_accuracy: 0.9200\n",
-      "Epoch 2/3\n",
-      "94/94 [==============================] - 942s 10s/step - loss: 0.1991 - accuracy: 0.9332 - val_loss: 0.2080 - val_accuracy: 0.9227\n",
-      "Epoch 3/3\n",
-      "94/94 [==============================] - 910s 10s/step - loss: 0.1244 - accuracy: 0.9639 - val_loss: 0.2133 - val_accuracy: 0.9253\n"
-     ]
-    }
-   ],
-   "source": [
-    "train_history = model.fit(\n",
-    "    train_input, train_labels,\n",
-    "    validation_split=0.2,\n",
-    "    epochs=3,\n",
-    "    batch_size=16\n",
-    ")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 0,
-   "metadata": {
-    "colab": {},
-    "colab_type": "code",
-    "id": "S3KrM3vnUaru"
-   },
-   "outputs": [],
-   "source": [
-    "test_pred = model.predict(test_input)\n",
-    "train_pred = model.predict(train_input)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "After producing outputs, we're interested in seeing how well the model has done. So the following block defines a scoring function to output some metrics for both the training and test sets. Then a function for drawing a color map is defined."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 0,
-   "metadata": {
-    "colab": {},
-    "colab_type": "code",
-    "id": "l9C1RyC0mYLF"
-   },
-   "outputs": [],
-   "source": [
-    "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score\n",
-    "from sklearn.metrics import confusion_matrix\n",
-    "import itertools\n",
-    "from matplotlib import pyplot as plt\n",
-    "plt.rcParams['figure.figsize'] = [10,10]\n",
-    "\n",
-    "\n",
-    "def scoring(Ytrain, Ytrain_pred, Ytest, Ypred):\n",
-    "    \n",
-    "\n",
-    "    acc_train = accuracy_score(Ytrain, Ytrain_pred)\n",
-    "    prec_train = precision_score(Ytrain, Ytrain_pred)\n",
-    "    rec_train = recall_score(Ytrain, Ytrain_pred)\n",
-    "    f1_train = f1_score(Ytrain, Ytrain_pred)\n",
-    "    \n",
-    "    acc = accuracy_score(Ytest, Ypred)\n",
-    "    prec = precision_score(Ytest, Ypred)\n",
-    "    rec = recall_score(Ytest, Ypred)\n",
-    "    f1 = f1_score(Ytest, Ypred)\n",
-    "\n",
-    "    scores = {'accuracy': acc,\n",
-    "              'precision': prec,\n",
-    "             'recall': rec,\n",
-    "             'f1 score': f1}\n",
-    "\n",
-    "    scores_train = {'accuracy': acc_train,\n",
-    "              'precision': prec_train,\n",
-    "             'recall': rec_train,\n",
-    "             'f1 score': f1_train}\n",
-    "\n",
-    "    print('Ytrain vs Ytrain_pred')\n",
-    "    for metric in scores_train.keys():\n",
-    "        print(f'{metric}: {scores_train[metric]:.2f}')\n",
-    "    \n",
-    "    print('\\nYtest vs Ypred')\n",
-    "    for metric in scores.keys():\n",
-    "        print(f'{metric}: {scores[metric]:.2f}')\n",
-    "        \n",
-    "\n",
-    "\n",
-    "def plot_confusion_matrix(cm, labels: list, normalize=False, title='Confusion Matrix', cmap=plt.cm.Oranges):\n",
-    "    if normalize:\n",
-    "        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
-    "    plt.figure(figsize=(10,10))\n",
-    "    plt.imshow(cm, interpolation='nearest',cmap=cmap)\n",
-    "    plt.title(title,size=24)\n",
-    "    plt.colorbar(aspect=4)\n",
-    "    tick_marks = np.arange(len(labels))\n",
-    "    plt.xticks(tick_marks, labels, size=14)\n",
-    "    plt.yticks(tick_marks, labels, size=14)\n",
-    "    \n",
-    "    fmt = '.2f' if normalize else 'd' # format of decimal precision to display\n",
-    "    thresh = cm.max()/2 # threshold to change color of text depending on color of cell\n",
-    "    for i,j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): # realistically, cm.shape[0] = cm.shape[1] = len(labels)\n",
-    "        plt.text(j, i, format(cm[i,j], fmt), fontsize=20, # i,j rev bc of diff btwn matrix and list of lists indexing\n",
-    "                 horizontalalignment='center',\n",
-    "                 color='white' if cm[i,j] > thresh else 'black') # if the cell color is too dark, make the text white\n",
-    "    plt.grid(False)\n",
-    "    plt.tight_layout()\n",
-    "    plt.ylabel('True label', size=16)\n",
-    "    plt.xlabel('Predicted label', size=16)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Now let's see the results."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 22,
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/",
-     "height": 759
-    },
-    "colab_type": "code",
-    "id": "YoE6nWXFquEK",
-    "outputId": "8c0d739f-b7fd-40c2-dd16-5e90dae6061f"
-   },
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Ytrain vs Ytrain_pred\n",
-      "accuracy: 0.97\n",
-      "precision: 0.97\n",
-      "recall: 0.99\n",
-      "f1 score: 0.98\n",
-      "\n",
-      "Ytest vs Ypred\n",
-      "accuracy: 0.94\n",
-      "precision: 0.94\n",
-      "recall: 0.98\n",
-      "f1 score: 0.96\n"
-     ]
-    },
-    {
-     "data": {
-      "image/png": "\n",
-      "text/plain": [
-       "<Figure size 720x720 with 2 Axes>"
-      ]
-     },
-     "metadata": {
-      "needs_background": "light",
-      "tags": []
-     },
-     "output_type": "display_data"
-    }
-   ],
-   "source": [
-    "round_train_pred = np.round(train_pred)\n",
-    "round_test_pred = np.round(test_pred)\n",
-    "scoring(train_labels, round_train_pred, test_labels, round_test_pred)\n",
-    "cm = confusion_matrix(test_labels, round_test_pred)\n",
-    "cm_labels = ['Not topic-related', 'Topic-related']\n",
-    "plot_confusion_matrix(cm, cm_labels, normalize=False,title=f'BERT Confusion Matrix\\nn={len(test_labels)}')"
-   ]
-  }
- ],
- "metadata": {
-  "colab": {
-   "machine_shape": "hm",
-   "name": "2nd Copy of BERT final.ipynb",
-   "provenance": []
-  },
-  "kernelspec": {
-   "display_name": "Python 3",
-   "language": "python",
-   "name": "python3"
-  },
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.7.6"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 1
-}