diff embeddings/process.py @ 468:a07948f780b9

Moved embeddings out of sandbox
author Joseph Turian <turian@iro.umontreal.ca>
date Tue, 21 Oct 2008 16:24:44 -0400
parents sandbox/embeddings/process.py@f3711bcc467e
children 4335309f4924
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/embeddings/process.py	Tue Oct 21 16:24:44 2008 -0400
@@ -0,0 +1,114 @@
+"""
+Read in the weights file
+"""
+
+import string
+import sys
+
+from parameters import *
+
+__words = None
+__word_to_embedding = None
+__read = False
+
+def word_to_embedding(w):
+    read_embeddings()
+    return __word_to_embedding[w]
+
+def read_embeddings():
+    global __words
+    global __word_to_embedding
+    global __read 
+    if __read: return
+
+    __words = [string.strip(w) for w in open(VOCABFILE).readlines()]
+    assert len(__words) == NUMBER_OF_WORDS
+
+    import numpy, math
+    from common.str import percent
+
+    __word_to_embedding = {}
+
+    sys.stderr.write("Reading %s...\n" % WEIGHTSFILE)
+    f = open(WEIGHTSFILE)
+    f.readline()
+    vals = [float(v) for v in string.split(f.readline())]
+    assert len(vals) == NUMBER_OF_WORDS * DIMENSIONS
+    for i in range(NUMBER_OF_WORDS):
+        l = vals[DIMENSIONS*i:DIMENSIONS*(i+1)]
+        w = __words[i]
+        __word_to_embedding[w] = l
+    __read = True
+    sys.stderr.write("...done reading %s\n" % WEIGHTSFILE)
+
+import re
+numberre = re.compile("[0-9]")
+slashre = re.compile("\\\/")
+
+def preprocess(l):
+    """
+    Convert a sequence so that it can be embedded directly.
+    Returned the preprocessed sequence.
+    @note: Preprocessing is appropriate for Penn Treebank style documents.
+    """
+    read_embeddings()
+    lnew = []
+    for origw in l:
+        if origw == "-LRB-": w = "("
+        elif origw == "-RRB-": w = ")"
+        elif origw == "-LCB-": w = "{"
+        elif origw == "-RCB-": w = "}"
+        elif origw == "-LSB-": w = "["
+        elif origw == "-RSB-": w = "]"
+        else:
+            w = origw
+            w = string.lower(w)
+            w = slashre.sub("/", w)
+            w = numberre.sub("NUMBER", w)
+        if w not in __word_to_embedding:
+            sys.stderr.write("Word not in vocabulary, using %s: %s (original %s)\n" % (UNKNOWN, w, origw))
+            w = UNKNOWN
+        assert w in __word_to_embedding
+        lnew.append(w)
+    return lnew
+
+#def convert_string(s, strict=False):
+#    """
+#    Convert a string to a sequence of embeddings.
+#    @param strict: If strict, then words *must* be in the vocabulary.
+#    @todo: DEPRECATED Remove this function.
+#    """
+#    read_embeddings()
+#    e = []
+#    for origw in string.split(string.lower(s)):
+#        w = numberre.sub("NUMBER", origw)
+#        if w in __word_to_embedding:
+#            e.append(__word_to_embedding[w])
+#        else:
+#            sys.stderr.write("Word not in vocabulary, using %s: %s (original %s)\n" % (UNKNOWN, w, origw))
+#            assert not strict
+#            e.append(__word_to_embedding[UNKNOWN])
+#    return e
+
+#def test():
+#    """
+#    Debugging code.
+#    """
+#    read_embeddings()
+#    for w in __word_to_embedding:
+#        assert len(__word_to_embedding[w]) == 50
+#    import numpy
+#    for w1 in __words:
+#        e1 = numpy.asarray(__word_to_embedding[w1])
+#        lst = []
+#        print w1, numpy.dot(e1, e1)
+#        for w2 in __word_to_embedding:
+#            if w1 >= w2: continue
+#            e2 = numpy.asarray(__word_to_embedding[w2])
+#            d = (e1 - e2)
+#            l2 = numpy.dot(d, d)
+#            lst.append((l2, w1, w2))
+#        lst.sort()
+#        print lst[:10]
+#
+#test()