{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# How to evaluate embeddings using linear algebra and analogies " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The dimensions of the word and phrase vectors do not have an explicit meaning. However, the embeddings encode similar usage as proximity in the latent space in a way that carries over to semantic relationships. This results in the interesting properties that analogies can be expressed by adding and subtracting word vectors.\n", "\n", "Just as words can be used in different contexts, they can be related to other words in different ways, and these relationships correspond to different directions in the latent space. Accordingly, there are several types of analogies that the embeddings should reflect if the training data permits.\n", "\n", "The word2vec authors provide a list of several thousand relationships spanning aspects of geography, grammar and syntax, and family relationships to evaluate the quality of embedding vectors (see directory [analogies](data/analogies))." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2020-06-21T02:02:00.366112Z", "start_time": "2020-06-21T02:01:59.675597Z" }, "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "%matplotlib inline\n", "from pathlib import Path\n", "import pandas as pd\n", "import numpy as np\n", "\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "" } }, "source": [ "### Settings" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2020-06-21T02:02:00.377200Z", "start_time": "2020-06-21T02:02:00.367549Z" }, "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "sns.set_style('white')\n", "pd.set_option('float_format', '{:,.2f}'.format)\n", "np.random.seed(42)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2020-06-21T02:02:00.382459Z", "start_time": "2020-06-21T02:02:00.378877Z" } }, "outputs": [], "source": [ "analogy_path = Path('data', 'analogies-en.txt')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2020-06-21T02:02:00.390451Z", "start_time": "2020-06-21T02:02:00.383658Z" } }, "outputs": [], "source": [ "def format_time(t):\n", " m, s = divmod(t, 60)\n", " h, m = divmod(m, 60)\n", " return f'{h:02.0f}:{m:02.0f}:{s:02.0f}'" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Evaluation: Analogies" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2020-06-21T02:02:00.476793Z", "start_time": "2020-06-21T02:02:00.391463Z" }, "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "df = pd.read_csv(analogy_path, header=None, names=['category'], squeeze=True)\n", "categories = df[df.str.startswith(':')]\n", "analogies = df[~df.str.startswith(':')].str.split(expand=True)\n", "analogies.columns = list('abcd')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2020-06-21T02:02:00.511818Z", "start_time": "2020-06-21T02:02:00.478472Z" }, "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
categoryabcd
1: capital-common-countriesathensgreecebaghdadiraq
2: capital-common-countriesathensgreecebangkokthailand
3: capital-common-countriesathensgreecebeijingchina
4: capital-common-countriesathensgreeceberlingermany
5: capital-common-countriesathensgreecebernswitzerland
\n", "
" ], "text/plain": [ " category a b c d\n", "1 : capital-common-countries athens greece baghdad iraq\n", "2 : capital-common-countries athens greece bangkok thailand\n", "3 : capital-common-countries athens greece beijing china\n", "4 : capital-common-countries athens greece berlin germany\n", "5 : capital-common-countries athens greece bern switzerland" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.concat([categories, analogies], axis=1)\n", "df.category = df.category.ffill()\n", "df = df[df['a'].notnull()]\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2020-06-21T02:02:00.529477Z", "start_time": "2020-06-21T02:02:00.513126Z" }, "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "analogy_cnt = df.groupby('category').size().sort_values(ascending=False).to_frame('n')\n", "analogy_example = df.groupby('category').first()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2020-06-21T02:02:00.540665Z", "start_time": "2020-06-21T02:02:00.530425Z" }, "slideshow": { "slide_type": "slide" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
nabcd
category
: capital-world8556abujanigeriaaccraghana
: city-in-state4242chicagoillinoishoustontexas
: gram6-nationality-adjective1640albaniaalbanianargentinaargentinean
: gram7-past-tense1560dancingdanceddecreasingdecreased
: gram8-plural1332bananabananasbirdbirds
: gram3-comparative1332badworsebigbigger
: gram4-superlative1122badworstbigbiggest
: gram5-present-participle1056codecodingdancedancing
: gram1-adjective-to-adverb992amazingamazinglyapparentapparently
: gram9-plural-verbs870decreasedecreasesdescribedescribes
: currency866algeriadinarangolakwanza
: gram2-opposite812acceptableunacceptableawareunaware
: family506boygirlbrothersister
: capital-common-countries506athensgreecebaghdadiraq
\n", "
" ], "text/plain": [ " n a b c \\\n", "category \n", ": capital-world 8556 abuja nigeria accra \n", ": city-in-state 4242 chicago illinois houston \n", ": gram6-nationality-adjective 1640 albania albanian argentina \n", ": gram7-past-tense 1560 dancing danced decreasing \n", ": gram8-plural 1332 banana bananas bird \n", ": gram3-comparative 1332 bad worse big \n", ": gram4-superlative 1122 bad worst big \n", ": gram5-present-participle 1056 code coding dance \n", ": gram1-adjective-to-adverb 992 amazing amazingly apparent \n", ": gram9-plural-verbs 870 decrease decreases describe \n", ": currency 866 algeria dinar angola \n", ": gram2-opposite 812 acceptable unacceptable aware \n", ": family 506 boy girl brother \n", ": capital-common-countries 506 athens greece baghdad \n", "\n", " d \n", "category \n", ": capital-world ghana \n", ": city-in-state texas \n", ": gram6-nationality-adjective argentinean \n", ": gram7-past-tense decreased \n", ": gram8-plural birds \n", ": gram3-comparative bigger \n", ": gram4-superlative biggest \n", ": gram5-present-participle dancing \n", ": gram1-adjective-to-adverb apparently \n", ": gram9-plural-verbs describes \n", ": currency kwanza \n", ": gram2-opposite unaware \n", ": family sister \n", ": capital-common-countries iraq " ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "analogy_cnt.join(analogy_example)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2020-06-21T02:03:02.440216Z", "start_time": "2020-06-21T02:03:02.125152Z" } }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "analogy_cnt.join(analogy_example)['n'].sort_values().plot.barh(title='# Analogies by Category',\n", " figsize=(14, 6))\n", "sns.despine()\n", "plt.tight_layout()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": true, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "47px", "left": "38px", "right": "1340px", "top": "66.5px", "width": "362px" }, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 4 }