Updated word chain problem
[ou-summer-of-code-2017.git] / word-chains / word-chain-solution.ipynb
index e90cacaabbe306e738ac03b11ccfddfc0d247d3e..8a87922bb6b57a44ab864fc9365df8122424c93f 100644 (file)
    },
    "outputs": [],
    "source": [
-    "def extend(chain):\n",
-    "    return [chain + [s] for s in neighbours[chain[-1]]\n",
-    "           if s not in chain]"
+    "def extend(chain):\n",
+    "    return [chain + [s] for s in neighbours[chain[-1]]\n",
+    "           if s not in chain]"
    ]
   },
   {
    },
    "outputs": [],
    "source": [
-    "def extend_raw(chain):\n",
-    "    nbrs = [w for w in adjacents(chain[-1]) if w in words]\n",
+    "def extend(chain, closed=None):\n",
+    "    if closed:\n",
+    "        nbrs = set(neighbours[chain[-1]]) - closed\n",
+    "    else:\n",
+    "        nbrs = neighbours[chain[-1]]\n",
     "    return [chain + [s] for s in nbrs\n",
     "           if s not in chain]"
    ]
    },
    "outputs": [],
    "source": [
-    "def bfs_search(start, target, debug=False):\n",
-    "    return bfs([[start]], target, debug=debug)"
+    "def extend_raw(chain):\n",
+    "    nbrs = [w for w in adjacents(chain[-1]) if w in words]\n",
+    "    return [chain + [s] for s in nbrs\n",
+    "           if s not in chain]"
    ]
   },
   {
    },
    "outputs": [],
    "source": [
-    "def bfs(agenda, goal, debug=False):\n",
+    "def bfs_search(start, goal, debug=False):\n",
+    "    agenda = [[start]]\n",
     "    finished = False\n",
     "    while not finished and agenda:\n",
     "        current = agenda[0]\n",
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 19,
    "metadata": {
     "collapsed": true
    },
    "outputs": [],
    "source": [
-    "def dfs_search(start, target, debug=False):\n",
-    "    return dfs([[start]], target, debug=debug)"
+    "def bfs_search_closed(start, goal, debug=False):\n",
+    "    agenda = [[start]]\n",
+    "    closed = set()\n",
+    "    finished = False\n",
+    "    while not finished and agenda:\n",
+    "        current = agenda[0]\n",
+    "        if debug:\n",
+    "            print(current)\n",
+    "        if current[-1] == goal:\n",
+    "            finished = True\n",
+    "        else:\n",
+    "            closed.add(current[-1])\n",
+    "            successors = extend(current, closed)\n",
+    "            agenda = agenda[1:] + successors\n",
+    "    if agenda:\n",
+    "        return current\n",
+    "    else:\n",
+    "        return None   "
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 11,
+   "execution_count": 10,
    "metadata": {
     "collapsed": true
    },
    "outputs": [],
    "source": [
-    "def dfs(agenda, goal, debug=False):\n",
+    "def dfs_search(start, goal, debug=False):\n",
+    "    agenda = [[start]]\n",
     "    finished = False\n",
     "    while not finished and agenda:\n",
     "        current = agenda[0]\n",
   },
   {
    "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": 11,
    "metadata": {
     "collapsed": true
    },
    "outputs": [],
    "source": [
-    "def astar_search(start, target, debug=False):\n",
-    "    agenda = [(distance(start, target), [start])]\n",
+    "def astar_search(start, goal, debug=False):\n",
+    "    agenda = [(distance(start, goal), [start])]\n",
     "    heapq.heapify(agenda)\n",
-    "    return astar(agenda, target, debug=debug)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 13,
-   "metadata": {
-    "collapsed": true
-   },
-   "outputs": [],
-   "source": [
-    "def astar(agenda, goal, debug=False):\n",
     "    finished = False\n",
     "    while not finished and agenda:\n",
     "        _, current = heapq.heappop(agenda)\n",
   },
   {
    "cell_type": "code",
-   "execution_count": 14,
+   "execution_count": 12,
    "metadata": {
     "collapsed": true
    },
    "outputs": [],
    "source": [
-    "def astar_search_raw(start, target, debug=False):\n",
-    "    agenda = [(distance(start, target), [start])]\n",
+    "# Uses direct lookup of successors, rather than using cached neighbours in the dict\n",
+    "def astar_search_raw(start, goal, debug=False):\n",
+    "    agenda = [(distance(start, goal), [start])]\n",
     "    heapq.heapify(agenda)\n",
-    "    return astar_raw(agenda, target, debug=debug)"
+    "    finished = False\n",
+    "    while not finished and agenda:\n",
+    "        _, current = heapq.heappop(agenda)\n",
+    "        if debug:\n",
+    "            print(current)\n",
+    "        if current[-1] == goal:\n",
+    "            finished = True\n",
+    "        else:\n",
+    "            successors = extend_raw(current) # Difference here\n",
+    "            for s in successors:\n",
+    "                heapq.heappush(agenda, (len(current) + distance(s[-1], goal) - 1, s))\n",
+    "    if agenda:\n",
+    "        return current\n",
+    "    else:\n",
+    "        return None        "
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 13,
    "metadata": {
     "collapsed": true
    },
    "outputs": [],
    "source": [
-    "def astar_raw(agenda, goal, debug=False):\n",
+    "def astar_search_closed(start, goal, debug=False):\n",
+    "    agenda = [(distance(start, goal), [start])]\n",
+    "    heapq.heapify(agenda)\n",
+    "    closed = set()\n",
     "    finished = False\n",
     "    while not finished and agenda:\n",
     "        _, current = heapq.heappop(agenda)\n",
     "        if current[-1] == goal:\n",
     "            finished = True\n",
     "        else:\n",
-    "            successors = extend_raw(current)\n",
+    "            closed.add(current[-1])\n",
+    "            successors = extend(current, closed)\n",
     "            for s in successors:\n",
     "                heapq.heappush(agenda, (len(current) + distance(s[-1], goal) - 1, s))\n",
     "    if agenda:\n",
     "        return current\n",
     "    else:\n",
-    "        return None        "
+    "        return None   "
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 16,
+   "execution_count": 14,
    "metadata": {},
    "outputs": [
     {
        "['vice', 'dice', 'dire', 'dare', 'ware', 'wars']"
       ]
      },
-     "execution_count": 16,
+     "execution_count": 14,
      "metadata": {},
      "output_type": "execute_result"
     }
   },
   {
    "cell_type": "code",
-   "execution_count": 17,
+   "execution_count": 15,
    "metadata": {},
    "outputs": [
     {
        "['vice', 'dice', 'dire', 'dare', 'ware', 'wars']"
       ]
      },
-     "execution_count": 17,
+     "execution_count": 15,
      "metadata": {},
      "output_type": "execute_result"
     }
   },
   {
    "cell_type": "code",
-   "execution_count": 18,
+   "execution_count": 16,
    "metadata": {},
    "outputs": [
     {
        "6"
       ]
      },
-     "execution_count": 18,
+     "execution_count": 16,
      "metadata": {},
      "output_type": "execute_result"
     }
   },
   {
    "cell_type": "code",
-   "execution_count": 19,
+   "execution_count": 17,
    "metadata": {},
    "outputs": [
     {
        "6"
       ]
      },
-     "execution_count": 19,
+     "execution_count": 17,
      "metadata": {},
      "output_type": "execute_result"
     }
     {
      "data": {
       "text/plain": [
-       "793"
+       "6"
       ]
      },
      "execution_count": 20,
     }
    ],
    "source": [
-    "len(dfs_search('vice', 'wars'))"
+    "len(bfs_search_closed('vice', 'wars'))"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": 21,
    "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "793"
+      ]
+     },
+     "execution_count": 21,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "len(dfs_search('vice', 'wars'))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 22,
+   "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "10000 loops, best of 3: 153 µs per loop\n"
+      "10000 loops, best of 3: 158 µs per loop\n"
      ]
     }
    ],
   },
   {
    "cell_type": "code",
-   "execution_count": 22,
+   "execution_count": 23,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "100 loops, best of 3: 15.8 ms per loop\n"
+      "100 loops, best of 3: 15.6 ms per loop\n"
      ]
     }
    ],
   },
   {
    "cell_type": "code",
-   "execution_count": 23,
+   "execution_count": 24,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "10000 loops, best of 3: 168 µs per loop\n"
+     ]
+    }
+   ],
+   "source": [
+    "%%timeit\n",
+    "astar_search_closed('vice', 'wars')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "1 loop, best of 3: 1min 42s per loop\n"
+      "1 loop, best of 3: 1min 40s per loop\n"
      ]
     }
    ],
   },
   {
    "cell_type": "code",
-   "execution_count": 24,
+   "execution_count": 26,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "1 loop, best of 3: 597 ms per loop\n"
+     ]
+    }
+   ],
+   "source": [
+    "%%timeit\n",
+    "bfs_search_closed('vice', 'wars')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 27,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "10 loops, best of 3: 88 ms per loop\n"
+      "10 loops, best of 3: 85.5 ms per loop\n"
      ]
     }
    ],
   },
   {
    "cell_type": "code",
-   "execution_count": 25,
+   "execution_count": 28,
    "metadata": {
     "collapsed": true
    },
   },
   {
    "cell_type": "code",
-   "execution_count": 26,
+   "execution_count": 29,
    "metadata": {
     "scrolled": true
    },
        " '`bash`, `cash`, `dash`, `gash`, `hash`, `lash`, `mash`, `rasp`, `rush`, `sash`, `wash`')"
       ]
      },
-     "execution_count": 26,
+     "execution_count": 29,
      "metadata": {},
      "output_type": "execute_result"
     }
   },
   {
    "cell_type": "code",
-   "execution_count": 27,
+   "execution_count": 30,
    "metadata": {
     "scrolled": true
    },
        " '`base`, `bash`, `bask`, `bass`, `bast`, `bath`, `bosh`, `bush`, `case`, `cash`, `cask`, `cast`, `dash`, `dish`, `gash`, `gasp`, `gosh`, `gush`, `hash`, `hasp`, `hath`, `hush`, `lash`, `lass`, `last`, `lath`, `lush`, `mash`, `mask`, `mass`, `mast`, `math`, `mesh`, `mush`, `push`, `ramp`, `rasp`, `ruse`, `rush`, `rusk`, `rust`, `sash`, `sass`, `tush`, `wash`, `wasp`, `wish`')"
       ]
      },
-     "execution_count": 27,
+     "execution_count": 30,
      "metadata": {},
      "output_type": "execute_result"
     }
   },
   {
    "cell_type": "code",
-   "execution_count": 28,
+   "execution_count": 31,
    "metadata": {
     "scrolled": true
    },
        "180"
       ]
      },
-     "execution_count": 28,
+     "execution_count": 31,
      "metadata": {},
      "output_type": "execute_result"
     }
   },
   {
    "cell_type": "code",
-   "execution_count": 29,
+   "execution_count": 32,
    "metadata": {
     "scrolled": true
    },
        "2195"
       ]
      },
-     "execution_count": 29,
+     "execution_count": 32,
      "metadata": {},
      "output_type": "execute_result"
     }
   },
   {
    "cell_type": "code",
-   "execution_count": 30,
+   "execution_count": 33,
    "metadata": {
     "scrolled": true
    },
        "2192"
       ]
      },
-     "execution_count": 30,
+     "execution_count": 33,
      "metadata": {},
      "output_type": "execute_result"
     }
   },
   {
    "cell_type": "code",
-   "execution_count": 31,
+   "execution_count": 34,
    "metadata": {
     "scrolled": true
    },
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "100 loops, best of 3: 5.96 ms per loop\n"
+      "100 loops, best of 3: 5.82 ms per loop\n"
      ]
     }
    ],
   },
   {
    "cell_type": "code",
-   "execution_count": 32,
+   "execution_count": 35,
    "metadata": {
     "scrolled": true
    },