X-Git-Url: https://git.njae.me.uk/?a=blobdiff_plain;f=word-chains%2Fword-chain-solution.ipynb;h=8a87922bb6b57a44ab864fc9365df8122424c93f;hb=4c5672cbef4ef0df1d6964c55a42fe80634766a9;hp=e90cacaabbe306e738ac03b11ccfddfc0d247d3e;hpb=736040348711b246f04e606e18eb85b54d1c475d;p=ou-summer-of-code-2017.git diff --git a/word-chains/word-chain-solution.ipynb b/word-chains/word-chain-solution.ipynb index e90caca..8a87922 100644 --- a/word-chains/word-chain-solution.ipynb +++ b/word-chains/word-chain-solution.ipynb @@ -106,9 +106,9 @@ }, "outputs": [], "source": [ - "def extend(chain):\n", - " return [chain + [s] for s in neighbours[chain[-1]]\n", - " if s not in chain]" + "# def extend(chain):\n", + "# return [chain + [s] for s in neighbours[chain[-1]]\n", + "# if s not in chain]" ] }, { @@ -119,8 +119,11 @@ }, "outputs": [], "source": [ - "def extend_raw(chain):\n", - " nbrs = [w for w in adjacents(chain[-1]) if w in words]\n", + "def extend(chain, closed=None):\n", + " if closed:\n", + " nbrs = set(neighbours[chain[-1]]) - closed\n", + " else:\n", + " nbrs = neighbours[chain[-1]]\n", " return [chain + [s] for s in nbrs\n", " if s not in chain]" ] @@ -133,8 +136,10 @@ }, "outputs": [], "source": [ - "def bfs_search(start, target, debug=False):\n", - " return bfs([[start]], target, debug=debug)" + "def extend_raw(chain):\n", + " nbrs = [w for w in adjacents(chain[-1]) if w in words]\n", + " return [chain + [s] for s in nbrs\n", + " if s not in chain]" ] }, { @@ -145,7 +150,8 @@ }, "outputs": [], "source": [ - "def bfs(agenda, goal, debug=False):\n", + "def bfs_search(start, goal, debug=False):\n", + " agenda = [[start]]\n", " finished = False\n", " while not finished and agenda:\n", " current = agenda[0]\n", @@ -164,25 +170,42 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 19, "metadata": { "collapsed": true }, "outputs": [], "source": [ - "def dfs_search(start, target, debug=False):\n", - " return dfs([[start]], target, debug=debug)" + "def bfs_search_closed(start, goal, debug=False):\n", + " agenda = [[start]]\n", + " closed = set()\n", + " finished = False\n", + " while not finished and agenda:\n", + " current = agenda[0]\n", + " if debug:\n", + " print(current)\n", + " if current[-1] == goal:\n", + " finished = True\n", + " else:\n", + " closed.add(current[-1])\n", + " successors = extend(current, closed)\n", + " agenda = agenda[1:] + successors\n", + " if agenda:\n", + " return current\n", + " else:\n", + " return None " ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ - "def dfs(agenda, goal, debug=False):\n", + "def dfs_search(start, goal, debug=False):\n", + " agenda = [[start]]\n", " finished = False\n", " while not finished and agenda:\n", " current = agenda[0]\n", @@ -201,27 +224,15 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": { "collapsed": true }, "outputs": [], "source": [ - "def astar_search(start, target, debug=False):\n", - " agenda = [(distance(start, target), [start])]\n", + "def astar_search(start, goal, debug=False):\n", + " agenda = [(distance(start, goal), [start])]\n", " heapq.heapify(agenda)\n", - " return astar(agenda, target, debug=debug)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [ - "def astar(agenda, goal, debug=False):\n", " finished = False\n", " while not finished and agenda:\n", " _, current = heapq.heappop(agenda)\n", @@ -241,27 +252,45 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ - "def astar_search_raw(start, target, debug=False):\n", - " agenda = [(distance(start, target), [start])]\n", + "# Uses direct lookup of successors, rather than using cached neighbours in the dict\n", + "def astar_search_raw(start, goal, debug=False):\n", + " agenda = [(distance(start, goal), [start])]\n", " heapq.heapify(agenda)\n", - " return astar_raw(agenda, target, debug=debug)" + " finished = False\n", + " while not finished and agenda:\n", + " _, current = heapq.heappop(agenda)\n", + " if debug:\n", + " print(current)\n", + " if current[-1] == goal:\n", + " finished = True\n", + " else:\n", + " successors = extend_raw(current) # Difference here\n", + " for s in successors:\n", + " heapq.heappush(agenda, (len(current) + distance(s[-1], goal) - 1, s))\n", + " if agenda:\n", + " return current\n", + " else:\n", + " return None " ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ - "def astar_raw(agenda, goal, debug=False):\n", + "def astar_search_closed(start, goal, debug=False):\n", + " agenda = [(distance(start, goal), [start])]\n", + " heapq.heapify(agenda)\n", + " closed = set()\n", " finished = False\n", " while not finished and agenda:\n", " _, current = heapq.heappop(agenda)\n", @@ -270,18 +299,19 @@ " if current[-1] == goal:\n", " finished = True\n", " else:\n", - " successors = extend_raw(current)\n", + " closed.add(current[-1])\n", + " successors = extend(current, closed)\n", " for s in successors:\n", " heapq.heappush(agenda, (len(current) + distance(s[-1], goal) - 1, s))\n", " if agenda:\n", " return current\n", " else:\n", - " return None " + " return None " ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -290,7 +320,7 @@ "['vice', 'dice', 'dire', 'dare', 'ware', 'wars']" ] }, - "execution_count": 16, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -301,7 +331,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -310,7 +340,7 @@ "['vice', 'dice', 'dire', 'dare', 'ware', 'wars']" ] }, - "execution_count": 17, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -321,7 +351,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -330,7 +360,7 @@ "6" ] }, - "execution_count": 18, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -341,7 +371,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -350,7 +380,7 @@ "6" ] }, - "execution_count": 19, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -367,7 +397,7 @@ { "data": { "text/plain": [ - "793" + "6" ] }, "execution_count": 20, @@ -376,19 +406,39 @@ } ], "source": [ - "len(dfs_search('vice', 'wars'))" + "len(bfs_search_closed('vice', 'wars'))" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "793" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(dfs_search('vice', 'wars'))" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "10000 loops, best of 3: 153 µs per loop\n" + "10000 loops, best of 3: 158 µs per loop\n" ] } ], @@ -399,14 +449,14 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "100 loops, best of 3: 15.8 ms per loop\n" + "100 loops, best of 3: 15.6 ms per loop\n" ] } ], @@ -417,14 +467,32 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10000 loops, best of 3: 168 µs per loop\n" + ] + } + ], + "source": [ + "%%timeit\n", + "astar_search_closed('vice', 'wars')" + ] + }, + { + "cell_type": "code", + "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "1 loop, best of 3: 1min 42s per loop\n" + "1 loop, best of 3: 1min 40s per loop\n" ] } ], @@ -435,14 +503,32 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 loop, best of 3: 597 ms per loop\n" + ] + } + ], + "source": [ + "%%timeit\n", + "bfs_search_closed('vice', 'wars')" + ] + }, + { + "cell_type": "code", + "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "10 loops, best of 3: 88 ms per loop\n" + "10 loops, best of 3: 85.5 ms per loop\n" ] } ], @@ -469,7 +555,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 28, "metadata": { "collapsed": true }, @@ -491,7 +577,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 29, "metadata": { "scrolled": true }, @@ -503,7 +589,7 @@ " '`bash`, `cash`, `dash`, `gash`, `hash`, `lash`, `mash`, `rasp`, `rush`, `sash`, `wash`')" ] }, - "execution_count": 26, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -514,7 +600,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 30, "metadata": { "scrolled": true }, @@ -526,7 +612,7 @@ " '`base`, `bash`, `bask`, `bass`, `bast`, `bath`, `bosh`, `bush`, `case`, `cash`, `cask`, `cast`, `dash`, `dish`, `gash`, `gasp`, `gosh`, `gush`, `hash`, `hasp`, `hath`, `hush`, `lash`, `lass`, `last`, `lath`, `lush`, `mash`, `mask`, `mass`, `mast`, `math`, `mesh`, `mush`, `push`, `ramp`, `rasp`, `ruse`, `rush`, `rusk`, `rust`, `sash`, `sass`, `tush`, `wash`, `wasp`, `wish`')" ] }, - "execution_count": 27, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } @@ -537,7 +623,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 31, "metadata": { "scrolled": true }, @@ -548,7 +634,7 @@ "180" ] }, - "execution_count": 28, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } @@ -559,7 +645,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 32, "metadata": { "scrolled": true }, @@ -570,7 +656,7 @@ "2195" ] }, - "execution_count": 29, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } @@ -581,7 +667,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 33, "metadata": { "scrolled": true }, @@ -592,7 +678,7 @@ "2192" ] }, - "execution_count": 30, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } @@ -603,7 +689,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 34, "metadata": { "scrolled": true }, @@ -612,7 +698,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "100 loops, best of 3: 5.96 ms per loop\n" + "100 loops, best of 3: 5.82 ms per loop\n" ] } ], @@ -623,7 +709,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 35, "metadata": { "scrolled": true },