X-Git-Url: https://git.njae.me.uk/?a=blobdiff_plain;f=segment.py;h=dd0b2a8347ee800c4addf996f369ea0293b47bb7;hb=2f33e16ccc84ddb0023f3621dd6ad545c1bb3251;hp=e4b019f4c8248d8647f938fea8295c132308de0b;hpb=792bef4fa890a8c834ddd83ab9a573d0e2a75dc9;p=cipher-tools.git

diff --git a/segment.py b/segment.py
index e4b019f..dd0b2a8 100644
--- a/segment.py
+++ b/segment.py
@@ -1,20 +1,12 @@
-# import re, string, random, glob, operator, heapq
 import string
 import collections
 from math import log10
 import itertools
+import sys
+from functools import lru_cache
+sys.setrecursionlimit(1000000)
 
-def memo(f):
-    "Memoize function f."
-    table = {}
-    def fmemo(*args):
-        if args not in table:
-            table[args] = f(*args)
-        return table[args]
-    fmemo.memo = table
-    return fmemo
-
-@memo
+@lru_cache()
 def segment(text):
     """Return a list of words that is the best segmentation of text.
     """
@@ -31,7 +23,7 @@ def splits(text, L=20):
 def Pwords(words): 
     """The Naive Bayes log probability of a sequence of words.
     """
-    return sum(Pw[w] for w in words)
+    return sum(Pw[w.lower()] for w in words)
 
 class Pdist(dict):
     """A probability distribution estimated from counts in datafile.
@@ -58,7 +50,5 @@ def avoid_long_words(key, N):
     """
     return -log10((N * 10**(len(key) - 2)))
 
-# N = 1024908267229 ## Number of tokens
-
 Pw  = Pdist(datafile('count_1w.txt'), avoid_long_words)