Mercurial > pylearn
view pylearn/dbdict/test_api0.py @ 538:798607a058bd
added missing files
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 12 Nov 2008 22:00:20 -0500 |
parents | |
children |
line wrap: on
line source
from api0 import * import threading, time, commands, os, sys, math, random, datetime import psycopg2, psycopg2.extensions from sqlalchemy import create_engine, desc from sqlalchemy.orm import sessionmaker from sqlalchemy import Table, Column, MetaData, ForeignKey from sqlalchemy import Integer, String, Float, DateTime, Text, Binary from sqlalchemy.orm import mapper, relation, backref, eagerload from sqlalchemy.sql import operators, select import unittest class T(unittest.TestCase): def test_bad_dict_table(self): """Make sure our crude version of schema checking kinda works""" engine = create_engine('sqlite:///:memory:', echo=False) Session = sessionmaker(bind=engine, autoflush=True, transactional=True) table_prefix='bergstrj_scratch_test_' metadata = MetaData() t_trial = Table(table_prefix+'trial', metadata, Column('id', Integer, primary_key=True), Column('desc', String(256)), #comment: why running this trial? Column('priority', Float(53)), #aka Double Column('start', DateTime), Column('finish', DateTime), Column('host', String(256))) metadata.create_all(engine) try: h = DbHandle(None, t_trial, None, None) except ValueError, e: if e[0] == DbHandle.e_bad_table: return self.fail() def go(self): """Create tables and session_maker""" engine = create_engine('sqlite:///:memory:', echo=False) Session = sessionmaker(autoflush=True, transactional=True) table_prefix='bergstrj_scratch_test_' metadata = MetaData() t_trial = Table(table_prefix+'trial', metadata, Column('id', Integer, primary_key=True), Column('create', DateTime), Column('write', DateTime), Column('read', DateTime)) t_keyval = Table(table_prefix+'keyval', metadata, Column('id', Integer, primary_key=True), Column('name', String(32), nullable=False), #name of attribute Column('ntype', Integer), Column('fval', Float(53)), #aka Double Column('sval', Text), #aka Double Column('bval', Binary)) #TODO: store text (strings of unbounded length) t_trial_keyval = Table(table_prefix+'trial_keyval', metadata, Column('dict_id', Integer, ForeignKey('%s.id' % t_trial), primary_key=True), Column('pair_id', Integer, ForeignKey('%s.id' % t_keyval), primary_key=True)) metadata.bind = engine metadata.create_all() # does nothing when tables already exist self.engine = engine return Session, t_trial, t_keyval, t_trial_keyval def test_insert_save(self): Session, t_dict, t_pair, t_link = self.go() db = DbHandle(*self.go()) def jobs(): dvalid, dtest = 'dvalid', 'dtest file' desc = 'debugging' for lr in [0.001]: for scale in [0.0001 * math.sqrt(10.0)**i for i in range(4)]: for rng_seed in [4, 5, 6]: for priority in [None, 1]: yield dict(locals()) jlist = list(jobs()) assert len(jlist) == 1*4*3*2 for i, dct in enumerate(jobs()): t = db.insert(**dct) #make sure that they really got inserted into the db orig_keycount = db._session.query(db._KeyVal).count() self.failUnless(orig_keycount > 0, orig_keycount) orig_dctcount = Session().query(db._Dict).count() self.failUnless(orig_dctcount ==len(jlist), orig_dctcount) orig_keycount = Session().query(db._KeyVal).count() self.failUnless(orig_keycount > 0, orig_keycount) #queries q0list = list(db.query().all()) q1list = list(db.query()) q2list = list(db) self.failUnless(q0list == q1list, (q0list,q1list)) self.failUnless(q0list == q2list, (q0list,q1list)) self.failUnless(len(q0list) == len(jlist)) for i, (j, q) in enumerate(zip(jlist, q0list)): jitems = list(j.items()) qitems = list(q.items()) jitems.sort() qitems.sort() if jitems != qitems: print i print jitems print qitems self.failUnless(jitems == qitems, (jitems, qitems)) def test_query_0(self): Session, t_dict, t_pair, t_link = self.go() db = DbHandle(*self.go()) def jobs(): dvalid, dtest = 'dvalid', 'dtest file' desc = 'debugging' for lr in [0.001]: for scale in [0.0001 * math.sqrt(10.0)**i for i in range(4)]: for rng_seed in [4, 5, 6]: for priority in [None, 1]: yield dict(locals()) jlist = list(jobs()) assert len(jlist) == 1*4*3*2 for i, dct in enumerate(jobs()): t = db.insert(**dct) qlist = list(db.query(rng_seed=5)) self.failUnless(len(qlist) == len(jlist)/3) jlist5 = [j for j in jlist if j['rng_seed'] == 5] for i, (j, q) in enumerate(zip(jlist5, qlist)): jitems = list(j.items()) qitems = list(q.items()) jitems.sort() qitems.sort() if jitems != qitems: print i print jitems print qitems self.failUnless(jitems == qitems, (jitems, qitems)) def test_delete_keywise(self): Session, t_dict, t_pair, t_link = self.go() db = DbHandle(*self.go()) def jobs(): dvalid, dtest = 'dvalid', 'dtest file' desc = 'debugging' for lr in [0.001]: for scale in [0.0001 * math.sqrt(10.0)**i for i in range(4)]: for rng_seed in [4, 5, 6]: for priority in [None, 1]: yield dict(locals()) jlist = list(jobs()) assert len(jlist) == 1*4*3*2 for i, dct in enumerate(jobs()): t = db.insert(**dct) orig_keycount = Session().query(db._KeyVal).count() del_count = Session().query(db._KeyVal).filter_by(name='rng_seed', fval=5.0).count() self.failUnless(del_count == 8, del_count) #delete all the rng_seed = 5 entries qlist_before = list(db.query(rng_seed=5)) for q in qlist_before: del q['rng_seed'] #check that it's gone from our objects for q in qlist_before: self.failUnless('rng_seed' not in q) #via __contains__ self.failUnless('rng_seed' not in q.keys()) #via keys() exc=None try: r = q['rng_seed'] # via __getitem__ print 'r,', r except KeyError, e: pass #check that it's gone from dictionaries in the database qlist_after = list(db.query(rng_seed=5)) self.failUnless(qlist_after == []) #check that exactly 8 keys were removed new_keycount = Session().query(db._KeyVal).count() self.failUnless(orig_keycount == new_keycount + 8, (orig_keycount, new_keycount)) #check that no keys have rng_seed == 5 gone_count = Session().query(db._KeyVal).filter_by(name='rng_seed', fval=5.0).count() self.failUnless(gone_count == 0, gone_count) def test_delete_dictwise(self): Session, t_dict, t_pair, t_link = self.go() db = DbHandle(*self.go()) def jobs(): dvalid, dtest = 'dvalid', 'dtest file' desc = 'debugging' for lr in [0.001]: for scale in [0.0001 * math.sqrt(10.0)**i for i in range(4)]: for rng_seed in [4, 5, 6]: for priority in [None, 1]: yield dict(locals()) jlist = list(jobs()) assert len(jlist) == 1*4*3*2 for i, dct in enumerate(jobs()): t = db.insert(**dct) orig_keycount = Session().query(db._KeyVal).count() orig_dctcount = Session().query(db._Dict).count() self.failUnless(orig_dctcount == len(jlist)) #delete all the rng_seed = 5 dictionaries qlist_before = list(db.query(rng_seed=5)) for q in qlist_before: q.delete() #check that the right number has been removed post_dctcount = Session().query(db._Dict).count() self.failUnless(post_dctcount == len(jlist)-8) #check that the remaining ones are correct for a, b, in zip( [j for j in jlist if j['rng_seed'] != 5], Session().query(db._Dict).all()): self.failUnless(a == b) #check that the keys have all been removed n_keys_per_dict = 8 new_keycount = Session().query(db._KeyVal).count() self.failUnless(orig_keycount - 8 * n_keys_per_dict == new_keycount, (orig_keycount, new_keycount)) def test_setitem_0(self): Session, t_dict, t_pair, t_link = self.go() db = DbHandle(*self.go()) b0 = 6.0 b1 = 9.0 job = dict(a=0, b=b0, c='hello') dbjob = db.insert(**job) dbjob['b'] = b1 #check that the change is in db qjob = Session().query(db._Dict).filter(db._Dict._attrs.any(name='b', fval=b1)).first() self.failIf(qjob is dbjob) self.failUnless(qjob == dbjob) #check that the b:b0 key is gone count = Session().query(db._KeyVal).filter_by(name='b', fval=b0).count() self.failUnless(count == 0, count) #check that the b:b1 key is there count = Session().query(db._KeyVal).filter_by(name='b', fval=b1).count() self.failUnless(count == 1, count) def test_setitem_1(self): """replace with different sql type""" Session, t_dict, t_pair, t_link = self.go() db = DbHandle(*self.go()) b0 = 6.0 b1 = 'asdf' # a different dtype job = dict(a=0, b=b0, c='hello') dbjob = db.insert(**job) dbjob['b'] = b1 #check that the change is in db qjob = Session().query(db._Dict).filter(db._Dict._attrs.any(name='b', sval=b1)).first() self.failIf(qjob is dbjob) self.failUnless(qjob == dbjob) #check that the b:b0 key is gone count = Session().query(db._KeyVal).filter_by(name='b', fval=b0).count() self.failUnless(count == 0, count) #check that the b:b1 key is there count = Session().query(db._KeyVal).filter_by(name='b', sval=b1, fval=None).count() self.failUnless(count == 1, count) def test_setitem_2(self): """replace with different number type""" Session, t_dict, t_pair, t_link = self.go() db = DbHandle(*self.go()) b0 = 6.0 b1 = 7 job = dict(a=0, b=b0, c='hello') dbjob = db.insert(**job) dbjob['b'] = b1 #check that the change is in db qjob = Session().query(db._Dict).filter(db._Dict._attrs.any(name='b', fval=b1)).first() self.failIf(qjob is dbjob) self.failUnless(qjob == dbjob) #check that the b:b0 key is gone count = Session().query(db._KeyVal).filter_by(name='b', fval=b0,ntype=1).count() self.failUnless(count == 0, count) #check that the b:b1 key is there count = Session().query(db._KeyVal).filter_by(name='b', fval=b1,ntype=0).count() self.failUnless(count == 1, count) if __name__ == '__main__': unittest.main()