diff pylearn/dbdict/test_api0.py @ 538:798607a058bd

added missing files
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 12 Nov 2008 22:00:20 -0500
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/dbdict/test_api0.py	Wed Nov 12 22:00:20 2008 -0500
@@ -0,0 +1,351 @@
+from api0 import *
+import threading, time, commands, os, sys, math, random, datetime
+
+import psycopg2, psycopg2.extensions
+from sqlalchemy import create_engine, desc
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy import Table, Column, MetaData, ForeignKey    
+from sqlalchemy import Integer, String, Float, DateTime, Text, Binary
+from sqlalchemy.orm import mapper, relation, backref, eagerload
+from sqlalchemy.sql import operators, select
+
+import unittest
+
+class T(unittest.TestCase):
+
+    def test_bad_dict_table(self):
+        """Make sure our crude version of schema checking kinda works"""
+        engine = create_engine('sqlite:///:memory:', echo=False)
+        Session = sessionmaker(bind=engine, autoflush=True, transactional=True)
+
+        table_prefix='bergstrj_scratch_test_'
+        metadata = MetaData()
+        t_trial = Table(table_prefix+'trial', metadata,
+                Column('id', Integer, primary_key=True),
+                Column('desc', String(256)), #comment: why running this trial?
+                Column('priority', Float(53)), #aka Double
+                Column('start', DateTime),
+                Column('finish', DateTime),
+                Column('host', String(256)))
+        metadata.create_all(engine)
+
+        try:
+            h = DbHandle(None, t_trial, None, None)
+        except ValueError, e:
+            if e[0] == DbHandle.e_bad_table:
+                return
+        self.fail()
+
+
+    def go(self):
+        """Create tables and session_maker"""
+        engine = create_engine('sqlite:///:memory:', echo=False)
+        Session = sessionmaker(autoflush=True, transactional=True)
+
+        table_prefix='bergstrj_scratch_test_'
+        metadata = MetaData()
+
+        t_trial = Table(table_prefix+'trial', metadata,
+                Column('id', Integer, primary_key=True),
+                Column('create', DateTime),
+                Column('write', DateTime),
+                Column('read', DateTime))
+
+        t_keyval = Table(table_prefix+'keyval', metadata,
+                Column('id', Integer, primary_key=True),
+                Column('name', String(32), nullable=False), #name of attribute
+                Column('ntype', Integer),
+                Column('fval', Float(53)), #aka Double
+                Column('sval', Text), #aka Double
+                Column('bval', Binary)) #TODO: store text (strings of unbounded length)
+
+        t_trial_keyval = Table(table_prefix+'trial_keyval', metadata,
+                Column('dict_id', Integer, ForeignKey('%s.id' % t_trial),
+                    primary_key=True),
+                Column('pair_id', Integer, ForeignKey('%s.id' % t_keyval),
+                    primary_key=True))
+
+        metadata.bind = engine
+        metadata.create_all() # does nothing when tables already exist
+        
+        self.engine = engine
+        return Session, t_trial, t_keyval, t_trial_keyval
+
+
+    def test_insert_save(self):
+
+        Session, t_dict, t_pair, t_link = self.go()
+
+        db = DbHandle(*self.go())
+
+        def jobs():
+            dvalid, dtest = 'dvalid', 'dtest file'
+            desc = 'debugging'
+            for lr in [0.001]:
+                for scale in [0.0001 * math.sqrt(10.0)**i for i in range(4)]:
+                    for rng_seed in [4, 5, 6]:
+                        for priority in [None, 1]:
+                            yield dict(locals())
+
+        jlist = list(jobs())
+        assert len(jlist) == 1*4*3*2
+        for i, dct in enumerate(jobs()):
+            t = db.insert(**dct)
+
+        #make sure that they really got inserted into the db
+        orig_keycount = db._session.query(db._KeyVal).count()
+        self.failUnless(orig_keycount > 0, orig_keycount)
+
+        orig_dctcount = Session().query(db._Dict).count()
+        self.failUnless(orig_dctcount ==len(jlist), orig_dctcount)
+
+        orig_keycount = Session().query(db._KeyVal).count()
+        self.failUnless(orig_keycount > 0, orig_keycount)
+
+        #queries
+        q0list = list(db.query().all())
+        q1list = list(db.query())
+        q2list = list(db)
+
+        self.failUnless(q0list == q1list, (q0list,q1list))
+        self.failUnless(q0list == q2list, (q0list,q1list))
+
+        self.failUnless(len(q0list) == len(jlist))
+
+        for i, (j, q) in enumerate(zip(jlist, q0list)):
+            jitems = list(j.items())
+            qitems = list(q.items())
+            jitems.sort()
+            qitems.sort()
+            if jitems != qitems:
+                print i
+                print jitems
+                print qitems
+            self.failUnless(jitems == qitems, (jitems, qitems))
+
+    def test_query_0(self):
+        Session, t_dict, t_pair, t_link = self.go()
+
+        db = DbHandle(*self.go())
+
+        def jobs():
+            dvalid, dtest = 'dvalid', 'dtest file'
+            desc = 'debugging'
+            for lr in [0.001]:
+                for scale in [0.0001 * math.sqrt(10.0)**i for i in range(4)]:
+                    for rng_seed in [4, 5, 6]:
+                        for priority in [None, 1]:
+                            yield dict(locals())
+
+        jlist = list(jobs())
+        assert len(jlist) == 1*4*3*2
+        for i, dct in enumerate(jobs()):
+            t = db.insert(**dct)
+
+        qlist = list(db.query(rng_seed=5))
+        self.failUnless(len(qlist) == len(jlist)/3)
+
+        jlist5 = [j for j in jlist if j['rng_seed'] == 5]
+
+        for i, (j, q) in enumerate(zip(jlist5, qlist)):
+            jitems = list(j.items())
+            qitems = list(q.items())
+            jitems.sort()
+            qitems.sort()
+            if jitems != qitems:
+                print i
+                print jitems
+                print qitems
+            self.failUnless(jitems == qitems, (jitems, qitems))
+
+    def test_delete_keywise(self):
+        Session, t_dict, t_pair, t_link = self.go()
+
+        db = DbHandle(*self.go())
+
+        def jobs():
+            dvalid, dtest = 'dvalid', 'dtest file'
+            desc = 'debugging'
+            for lr in [0.001]:
+                for scale in [0.0001 * math.sqrt(10.0)**i for i in range(4)]:
+                    for rng_seed in [4, 5, 6]:
+                        for priority in [None, 1]:
+                            yield dict(locals())
+
+        jlist = list(jobs())
+        assert len(jlist) == 1*4*3*2
+        for i, dct in enumerate(jobs()):
+            t = db.insert(**dct)
+
+        orig_keycount = Session().query(db._KeyVal).count()
+
+        del_count = Session().query(db._KeyVal).filter_by(name='rng_seed',
+                fval=5.0).count()
+        self.failUnless(del_count == 8, del_count)
+
+        #delete all the rng_seed = 5 entries
+        qlist_before = list(db.query(rng_seed=5))
+        for q in qlist_before:
+            del q['rng_seed']
+
+        #check that it's gone from our objects
+        for q in qlist_before:
+            self.failUnless('rng_seed' not in q)  #via __contains__
+            self.failUnless('rng_seed' not in q.keys()) #via keys()
+            exc=None
+            try:
+                r = q['rng_seed'] # via __getitem__
+                print 'r,', r
+            except KeyError, e:
+                pass
+
+        #check that it's gone from dictionaries in the database
+        qlist_after = list(db.query(rng_seed=5))
+        self.failUnless(qlist_after == [])
+
+        #check that exactly 8 keys were removed
+        new_keycount = Session().query(db._KeyVal).count()
+        self.failUnless(orig_keycount == new_keycount + 8, (orig_keycount,
+            new_keycount))
+
+        #check that no keys have rng_seed == 5
+        gone_count = Session().query(db._KeyVal).filter_by(name='rng_seed',
+                fval=5.0).count()
+        self.failUnless(gone_count == 0, gone_count)
+
+
+    def test_delete_dictwise(self):
+        Session, t_dict, t_pair, t_link = self.go()
+
+        db = DbHandle(*self.go())
+
+        def jobs():
+            dvalid, dtest = 'dvalid', 'dtest file'
+            desc = 'debugging'
+            for lr in [0.001]:
+                for scale in [0.0001 * math.sqrt(10.0)**i for i in range(4)]:
+                    for rng_seed in [4, 5, 6]:
+                        for priority in [None, 1]:
+                            yield dict(locals())
+
+        jlist = list(jobs())
+        assert len(jlist) == 1*4*3*2
+        for i, dct in enumerate(jobs()):
+            t = db.insert(**dct)
+
+        orig_keycount = Session().query(db._KeyVal).count()
+        orig_dctcount = Session().query(db._Dict).count()
+        self.failUnless(orig_dctcount == len(jlist))
+
+        #delete all the rng_seed = 5 dictionaries
+        qlist_before = list(db.query(rng_seed=5))
+        for q in qlist_before:
+            q.delete()
+
+        #check that the right number has been removed
+        post_dctcount = Session().query(db._Dict).count()
+        self.failUnless(post_dctcount == len(jlist)-8)
+
+        #check that the remaining ones are correct
+        for a, b, in zip(
+                [j for j in jlist if j['rng_seed'] != 5],
+                Session().query(db._Dict).all()):
+            self.failUnless(a == b)
+
+        #check that the keys have all been removed
+        n_keys_per_dict = 8
+        new_keycount = Session().query(db._KeyVal).count()
+        self.failUnless(orig_keycount - 8 * n_keys_per_dict == new_keycount, (orig_keycount,
+            new_keycount))
+
+
+    def test_setitem_0(self):
+        Session, t_dict, t_pair, t_link = self.go()
+
+        db = DbHandle(*self.go())
+
+        b0 = 6.0
+        b1 = 9.0
+
+        job = dict(a=0, b=b0, c='hello')
+
+        dbjob = db.insert(**job)
+
+        dbjob['b'] = b1
+
+        #check that the change is in db
+        qjob = Session().query(db._Dict).filter(db._Dict._attrs.any(name='b',
+            fval=b1)).first()
+        self.failIf(qjob is dbjob)
+        self.failUnless(qjob == dbjob)
+
+        #check that the b:b0 key is gone
+        count = Session().query(db._KeyVal).filter_by(name='b', fval=b0).count()
+        self.failUnless(count == 0, count)
+
+        #check that the b:b1 key is there
+        count = Session().query(db._KeyVal).filter_by(name='b', fval=b1).count()
+        self.failUnless(count == 1, count)
+
+    def test_setitem_1(self):
+        """replace with different sql type"""
+        Session, t_dict, t_pair, t_link = self.go()
+
+        db = DbHandle(*self.go())
+
+        b0 = 6.0
+        b1 = 'asdf' # a different dtype
+
+        job = dict(a=0, b=b0, c='hello')
+
+        dbjob = db.insert(**job)
+
+        dbjob['b'] = b1
+
+        #check that the change is in db
+        qjob = Session().query(db._Dict).filter(db._Dict._attrs.any(name='b',
+            sval=b1)).first()
+        self.failIf(qjob is dbjob)
+        self.failUnless(qjob == dbjob)
+
+        #check that the b:b0 key is gone
+        count = Session().query(db._KeyVal).filter_by(name='b', fval=b0).count()
+        self.failUnless(count == 0, count)
+
+        #check that the b:b1 key is there
+        count = Session().query(db._KeyVal).filter_by(name='b', sval=b1,
+                fval=None).count()
+        self.failUnless(count == 1, count)
+
+    def test_setitem_2(self):
+        """replace with different number type"""
+        Session, t_dict, t_pair, t_link = self.go()
+
+        db = DbHandle(*self.go())
+
+        b0 = 6.0
+        b1 = 7
+
+        job = dict(a=0, b=b0, c='hello')
+
+        dbjob = db.insert(**job)
+
+        dbjob['b'] = b1
+
+        #check that the change is in db
+        qjob = Session().query(db._Dict).filter(db._Dict._attrs.any(name='b',
+            fval=b1)).first()
+        self.failIf(qjob is dbjob)
+        self.failUnless(qjob == dbjob)
+
+        #check that the b:b0 key is gone
+        count = Session().query(db._KeyVal).filter_by(name='b', fval=b0,ntype=1).count()
+        self.failUnless(count == 0, count)
+
+        #check that the b:b1 key is there
+        count = Session().query(db._KeyVal).filter_by(name='b', fval=b1,ntype=0).count()
+        self.failUnless(count == 1, count)
+        
+
+if __name__ == '__main__':
+    unittest.main()