X-Git-Url: https://git.njae.me.uk/?a=blobdiff_plain;f=word-chains%2Fword-chain-solution.ipynb;h=8a87922bb6b57a44ab864fc9365df8122424c93f;hb=4c5672cbef4ef0df1d6964c55a42fe80634766a9;hp=0052abce5b05c22c7669b120bca349d24c2d1937;hpb=a0fc23654a764a8fdec3fa9858b0453c58bc5f34;p=ou-summer-of-code-2017.git

diff --git a/word-chains/word-chain-solution.ipynb b/word-chains/word-chain-solution.ipynb
index 0052abc..8a87922 100644
--- a/word-chains/word-chain-solution.ipynb
+++ b/word-chains/word-chain-solution.ipynb
@@ -106,9 +106,9 @@
    },
    "outputs": [],
    "source": [
-    "def extend(chain):\n",
-    "    return [chain + [s] for s in neighbours[chain[-1]]\n",
-    "           if s not in chain]"
+    "# def extend(chain):\n",
+    "#     return [chain + [s] for s in neighbours[chain[-1]]\n",
+    "#            if s not in chain]"
    ]
   },
   {
@@ -119,8 +119,13 @@
    },
    "outputs": [],
    "source": [
-    "def bfs_search(start, target, debug=False):\n",
-    "    return bfs([[start]], target, debug=debug)"
+    "def extend(chain, closed=None):\n",
+    "    if closed:\n",
+    "        nbrs = set(neighbours[chain[-1]]) - closed\n",
+    "    else:\n",
+    "        nbrs = neighbours[chain[-1]]\n",
+    "    return [chain + [s] for s in nbrs\n",
+    "           if s not in chain]"
    ]
   },
   {
@@ -131,7 +136,22 @@
    },
    "outputs": [],
    "source": [
-    "def bfs(agenda, goal, debug=False):\n",
+    "def extend_raw(chain):\n",
+    "    nbrs = [w for w in adjacents(chain[-1]) if w in words]\n",
+    "    return [chain + [s] for s in nbrs\n",
+    "           if s not in chain]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "def bfs_search(start, goal, debug=False):\n",
+    "    agenda = [[start]]\n",
     "    finished = False\n",
     "    while not finished and agenda:\n",
     "        current = agenda[0]\n",
@@ -150,14 +170,30 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 19,
    "metadata": {
     "collapsed": true
    },
    "outputs": [],
    "source": [
-    "def dfs_search(start, target, debug=False):\n",
-    "    return dfs([[start]], target, debug=debug)"
+    "def bfs_search_closed(start, goal, debug=False):\n",
+    "    agenda = [[start]]\n",
+    "    closed = set()\n",
+    "    finished = False\n",
+    "    while not finished and agenda:\n",
+    "        current = agenda[0]\n",
+    "        if debug:\n",
+    "            print(current)\n",
+    "        if current[-1] == goal:\n",
+    "            finished = True\n",
+    "        else:\n",
+    "            closed.add(current[-1])\n",
+    "            successors = extend(current, closed)\n",
+    "            agenda = agenda[1:] + successors\n",
+    "    if agenda:\n",
+    "        return current\n",
+    "    else:\n",
+    "        return None   "
    ]
   },
   {
@@ -168,7 +204,8 @@
    },
    "outputs": [],
    "source": [
-    "def dfs(agenda, goal, debug=False):\n",
+    "def dfs_search(start, goal, debug=False):\n",
+    "    agenda = [[start]]\n",
     "    finished = False\n",
     "    while not finished and agenda:\n",
     "        current = agenda[0]\n",
@@ -187,27 +224,44 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 57,
+   "execution_count": 11,
    "metadata": {
     "collapsed": true
    },
    "outputs": [],
    "source": [
-    "def astar_search(start, target, debug=False):\n",
-    "    agenda = [(distance(start, target), [start])]\n",
+    "def astar_search(start, goal, debug=False):\n",
+    "    agenda = [(distance(start, goal), [start])]\n",
     "    heapq.heapify(agenda)\n",
-    "    return astar(agenda, target, debug=debug)"
+    "    finished = False\n",
+    "    while not finished and agenda:\n",
+    "        _, current = heapq.heappop(agenda)\n",
+    "        if debug:\n",
+    "            print(current)\n",
+    "        if current[-1] == goal:\n",
+    "            finished = True\n",
+    "        else:\n",
+    "            successors = extend(current)\n",
+    "            for s in successors:\n",
+    "                heapq.heappush(agenda, (len(current) + distance(s[-1], goal) - 1, s))\n",
+    "    if agenda:\n",
+    "        return current\n",
+    "    else:\n",
+    "        return None        "
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 55,
+   "execution_count": 12,
    "metadata": {
     "collapsed": true
    },
    "outputs": [],
    "source": [
-    "def astar(agenda, goal, debug=False):\n",
+    "# Uses direct lookup of successors, rather than using cached neighbours in the dict\n",
+    "def astar_search_raw(start, goal, debug=False):\n",
+    "    agenda = [(distance(start, goal), [start])]\n",
+    "    heapq.heapify(agenda)\n",
     "    finished = False\n",
     "    while not finished and agenda:\n",
     "        _, current = heapq.heappop(agenda)\n",
@@ -216,7 +270,7 @@
     "        if current[-1] == goal:\n",
     "            finished = True\n",
     "        else:\n",
-    "            successors = extend(current)\n",
+    "            successors = extend_raw(current) # Difference here\n",
     "            for s in successors:\n",
     "                heapq.heappush(agenda, (len(current) + distance(s[-1], goal) - 1, s))\n",
     "    if agenda:\n",
@@ -227,7 +281,37 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 58,
+   "execution_count": 13,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "def astar_search_closed(start, goal, debug=False):\n",
+    "    agenda = [(distance(start, goal), [start])]\n",
+    "    heapq.heapify(agenda)\n",
+    "    closed = set()\n",
+    "    finished = False\n",
+    "    while not finished and agenda:\n",
+    "        _, current = heapq.heappop(agenda)\n",
+    "        if debug:\n",
+    "            print(current)\n",
+    "        if current[-1] == goal:\n",
+    "            finished = True\n",
+    "        else:\n",
+    "            closed.add(current[-1])\n",
+    "            successors = extend(current, closed)\n",
+    "            for s in successors:\n",
+    "                heapq.heappush(agenda, (len(current) + distance(s[-1], goal) - 1, s))\n",
+    "    if agenda:\n",
+    "        return current\n",
+    "    else:\n",
+    "        return None   "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
    "metadata": {},
    "outputs": [
     {
@@ -236,7 +320,7 @@
        "['vice', 'dice', 'dire', 'dare', 'ware', 'wars']"
       ]
      },
-     "execution_count": 58,
+     "execution_count": 14,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -247,7 +331,27 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 60,
+   "execution_count": 15,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "['vice', 'dice', 'dire', 'dare', 'ware', 'wars']"
+      ]
+     },
+     "execution_count": 15,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "astar_search_raw('vice', 'wars')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
    "metadata": {},
    "outputs": [
     {
@@ -256,7 +360,7 @@
        "6"
       ]
      },
-     "execution_count": 60,
+     "execution_count": 16,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -267,7 +371,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 17,
    "metadata": {},
    "outputs": [
     {
@@ -276,7 +380,7 @@
        "6"
       ]
      },
-     "execution_count": 15,
+     "execution_count": 17,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -287,7 +391,27 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 16,
+   "execution_count": 20,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "6"
+      ]
+     },
+     "execution_count": 20,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "len(bfs_search_closed('vice', 'wars'))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
    "metadata": {},
    "outputs": [
     {
@@ -296,7 +420,7 @@
        "793"
       ]
      },
-     "execution_count": 16,
+     "execution_count": 21,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -307,14 +431,14 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 17,
+   "execution_count": 22,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "10000 loops, best of 3: 154 µs per loop\n"
+      "10000 loops, best of 3: 158 µs per loop\n"
      ]
     }
    ],
@@ -325,7 +449,43 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 18,
+   "execution_count": 23,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "100 loops, best of 3: 15.6 ms per loop\n"
+     ]
+    }
+   ],
+   "source": [
+    "%%timeit\n",
+    "astar_search_raw('vice', 'wars')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 24,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "10000 loops, best of 3: 168 µs per loop\n"
+     ]
+    }
+   ],
+   "source": [
+    "%%timeit\n",
+    "astar_search_closed('vice', 'wars')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
    "metadata": {},
    "outputs": [
     {
@@ -343,14 +503,32 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 19,
+   "execution_count": 26,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "1 loop, best of 3: 597 ms per loop\n"
+     ]
+    }
+   ],
+   "source": [
+    "%%timeit\n",
+    "bfs_search_closed('vice', 'wars')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 27,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "10 loops, best of 3: 86.3 ms per loop\n"
+      "10 loops, best of 3: 85.5 ms per loop\n"
      ]
     }
    ],
@@ -377,7 +555,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 37,
+   "execution_count": 28,
    "metadata": {
     "collapsed": true
    },
@@ -399,7 +577,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 38,
+   "execution_count": 29,
    "metadata": {
     "scrolled": true
    },
@@ -411,7 +589,7 @@
        " '`bash`, `cash`, `dash`, `gash`, `hash`, `lash`, `mash`, `rasp`, `rush`, `sash`, `wash`')"
       ]
      },
-     "execution_count": 38,
+     "execution_count": 29,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -422,7 +600,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 39,
+   "execution_count": 30,
    "metadata": {
     "scrolled": true
    },
@@ -434,7 +612,7 @@
        " '`base`, `bash`, `bask`, `bass`, `bast`, `bath`, `bosh`, `bush`, `case`, `cash`, `cask`, `cast`, `dash`, `dish`, `gash`, `gasp`, `gosh`, `gush`, `hash`, `hasp`, `hath`, `hush`, `lash`, `lass`, `last`, `lath`, `lush`, `mash`, `mask`, `mass`, `mast`, `math`, `mesh`, `mush`, `push`, `ramp`, `rasp`, `ruse`, `rush`, `rusk`, `rust`, `sash`, `sass`, `tush`, `wash`, `wasp`, `wish`')"
       ]
      },
-     "execution_count": 39,
+     "execution_count": 30,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -445,7 +623,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 40,
+   "execution_count": 31,
    "metadata": {
     "scrolled": true
    },
@@ -456,7 +634,7 @@
        "180"
       ]
      },
-     "execution_count": 40,
+     "execution_count": 31,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -467,7 +645,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 48,
+   "execution_count": 32,
    "metadata": {
     "scrolled": true
    },
@@ -478,7 +656,7 @@
        "2195"
       ]
      },
-     "execution_count": 48,
+     "execution_count": 32,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -489,7 +667,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 47,
+   "execution_count": 33,
    "metadata": {
     "scrolled": true
    },
@@ -500,7 +678,7 @@
        "2192"
       ]
      },
-     "execution_count": 47,
+     "execution_count": 33,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -511,7 +689,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 46,
+   "execution_count": 34,
    "metadata": {
     "scrolled": true
    },
@@ -520,7 +698,7 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "100 loops, best of 3: 5.97 ms per loop\n"
+      "100 loops, best of 3: 5.82 ms per loop\n"
      ]
     }
    ],
@@ -531,7 +709,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 44,
+   "execution_count": 35,
    "metadata": {
     "scrolled": true
    },
@@ -540,7 +718,7 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "100 loops, best of 3: 3.1 ms per loop\n"
+      "100 loops, best of 3: 2.75 ms per loop\n"
      ]
     }
    ],