comparison pylearn/dbdict/test_api0.py @ 538:798607a058bd

added missing files
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 12 Nov 2008 22:00:20 -0500
parents
children
comparison
equal deleted inserted replaced
537:b054271b2504 538:798607a058bd
1 from api0 import *
2 import threading, time, commands, os, sys, math, random, datetime
3
4 import psycopg2, psycopg2.extensions
5 from sqlalchemy import create_engine, desc
6 from sqlalchemy.orm import sessionmaker
7 from sqlalchemy import Table, Column, MetaData, ForeignKey
8 from sqlalchemy import Integer, String, Float, DateTime, Text, Binary
9 from sqlalchemy.orm import mapper, relation, backref, eagerload
10 from sqlalchemy.sql import operators, select
11
12 import unittest
13
14 class T(unittest.TestCase):
15
16 def test_bad_dict_table(self):
17 """Make sure our crude version of schema checking kinda works"""
18 engine = create_engine('sqlite:///:memory:', echo=False)
19 Session = sessionmaker(bind=engine, autoflush=True, transactional=True)
20
21 table_prefix='bergstrj_scratch_test_'
22 metadata = MetaData()
23 t_trial = Table(table_prefix+'trial', metadata,
24 Column('id', Integer, primary_key=True),
25 Column('desc', String(256)), #comment: why running this trial?
26 Column('priority', Float(53)), #aka Double
27 Column('start', DateTime),
28 Column('finish', DateTime),
29 Column('host', String(256)))
30 metadata.create_all(engine)
31
32 try:
33 h = DbHandle(None, t_trial, None, None)
34 except ValueError, e:
35 if e[0] == DbHandle.e_bad_table:
36 return
37 self.fail()
38
39
40 def go(self):
41 """Create tables and session_maker"""
42 engine = create_engine('sqlite:///:memory:', echo=False)
43 Session = sessionmaker(autoflush=True, transactional=True)
44
45 table_prefix='bergstrj_scratch_test_'
46 metadata = MetaData()
47
48 t_trial = Table(table_prefix+'trial', metadata,
49 Column('id', Integer, primary_key=True),
50 Column('create', DateTime),
51 Column('write', DateTime),
52 Column('read', DateTime))
53
54 t_keyval = Table(table_prefix+'keyval', metadata,
55 Column('id', Integer, primary_key=True),
56 Column('name', String(32), nullable=False), #name of attribute
57 Column('ntype', Integer),
58 Column('fval', Float(53)), #aka Double
59 Column('sval', Text), #aka Double
60 Column('bval', Binary)) #TODO: store text (strings of unbounded length)
61
62 t_trial_keyval = Table(table_prefix+'trial_keyval', metadata,
63 Column('dict_id', Integer, ForeignKey('%s.id' % t_trial),
64 primary_key=True),
65 Column('pair_id', Integer, ForeignKey('%s.id' % t_keyval),
66 primary_key=True))
67
68 metadata.bind = engine
69 metadata.create_all() # does nothing when tables already exist
70
71 self.engine = engine
72 return Session, t_trial, t_keyval, t_trial_keyval
73
74
75 def test_insert_save(self):
76
77 Session, t_dict, t_pair, t_link = self.go()
78
79 db = DbHandle(*self.go())
80
81 def jobs():
82 dvalid, dtest = 'dvalid', 'dtest file'
83 desc = 'debugging'
84 for lr in [0.001]:
85 for scale in [0.0001 * math.sqrt(10.0)**i for i in range(4)]:
86 for rng_seed in [4, 5, 6]:
87 for priority in [None, 1]:
88 yield dict(locals())
89
90 jlist = list(jobs())
91 assert len(jlist) == 1*4*3*2
92 for i, dct in enumerate(jobs()):
93 t = db.insert(**dct)
94
95 #make sure that they really got inserted into the db
96 orig_keycount = db._session.query(db._KeyVal).count()
97 self.failUnless(orig_keycount > 0, orig_keycount)
98
99 orig_dctcount = Session().query(db._Dict).count()
100 self.failUnless(orig_dctcount ==len(jlist), orig_dctcount)
101
102 orig_keycount = Session().query(db._KeyVal).count()
103 self.failUnless(orig_keycount > 0, orig_keycount)
104
105 #queries
106 q0list = list(db.query().all())
107 q1list = list(db.query())
108 q2list = list(db)
109
110 self.failUnless(q0list == q1list, (q0list,q1list))
111 self.failUnless(q0list == q2list, (q0list,q1list))
112
113 self.failUnless(len(q0list) == len(jlist))
114
115 for i, (j, q) in enumerate(zip(jlist, q0list)):
116 jitems = list(j.items())
117 qitems = list(q.items())
118 jitems.sort()
119 qitems.sort()
120 if jitems != qitems:
121 print i
122 print jitems
123 print qitems
124 self.failUnless(jitems == qitems, (jitems, qitems))
125
126 def test_query_0(self):
127 Session, t_dict, t_pair, t_link = self.go()
128
129 db = DbHandle(*self.go())
130
131 def jobs():
132 dvalid, dtest = 'dvalid', 'dtest file'
133 desc = 'debugging'
134 for lr in [0.001]:
135 for scale in [0.0001 * math.sqrt(10.0)**i for i in range(4)]:
136 for rng_seed in [4, 5, 6]:
137 for priority in [None, 1]:
138 yield dict(locals())
139
140 jlist = list(jobs())
141 assert len(jlist) == 1*4*3*2
142 for i, dct in enumerate(jobs()):
143 t = db.insert(**dct)
144
145 qlist = list(db.query(rng_seed=5))
146 self.failUnless(len(qlist) == len(jlist)/3)
147
148 jlist5 = [j for j in jlist if j['rng_seed'] == 5]
149
150 for i, (j, q) in enumerate(zip(jlist5, qlist)):
151 jitems = list(j.items())
152 qitems = list(q.items())
153 jitems.sort()
154 qitems.sort()
155 if jitems != qitems:
156 print i
157 print jitems
158 print qitems
159 self.failUnless(jitems == qitems, (jitems, qitems))
160
161 def test_delete_keywise(self):
162 Session, t_dict, t_pair, t_link = self.go()
163
164 db = DbHandle(*self.go())
165
166 def jobs():
167 dvalid, dtest = 'dvalid', 'dtest file'
168 desc = 'debugging'
169 for lr in [0.001]:
170 for scale in [0.0001 * math.sqrt(10.0)**i for i in range(4)]:
171 for rng_seed in [4, 5, 6]:
172 for priority in [None, 1]:
173 yield dict(locals())
174
175 jlist = list(jobs())
176 assert len(jlist) == 1*4*3*2
177 for i, dct in enumerate(jobs()):
178 t = db.insert(**dct)
179
180 orig_keycount = Session().query(db._KeyVal).count()
181
182 del_count = Session().query(db._KeyVal).filter_by(name='rng_seed',
183 fval=5.0).count()
184 self.failUnless(del_count == 8, del_count)
185
186 #delete all the rng_seed = 5 entries
187 qlist_before = list(db.query(rng_seed=5))
188 for q in qlist_before:
189 del q['rng_seed']
190
191 #check that it's gone from our objects
192 for q in qlist_before:
193 self.failUnless('rng_seed' not in q) #via __contains__
194 self.failUnless('rng_seed' not in q.keys()) #via keys()
195 exc=None
196 try:
197 r = q['rng_seed'] # via __getitem__
198 print 'r,', r
199 except KeyError, e:
200 pass
201
202 #check that it's gone from dictionaries in the database
203 qlist_after = list(db.query(rng_seed=5))
204 self.failUnless(qlist_after == [])
205
206 #check that exactly 8 keys were removed
207 new_keycount = Session().query(db._KeyVal).count()
208 self.failUnless(orig_keycount == new_keycount + 8, (orig_keycount,
209 new_keycount))
210
211 #check that no keys have rng_seed == 5
212 gone_count = Session().query(db._KeyVal).filter_by(name='rng_seed',
213 fval=5.0).count()
214 self.failUnless(gone_count == 0, gone_count)
215
216
217 def test_delete_dictwise(self):
218 Session, t_dict, t_pair, t_link = self.go()
219
220 db = DbHandle(*self.go())
221
222 def jobs():
223 dvalid, dtest = 'dvalid', 'dtest file'
224 desc = 'debugging'
225 for lr in [0.001]:
226 for scale in [0.0001 * math.sqrt(10.0)**i for i in range(4)]:
227 for rng_seed in [4, 5, 6]:
228 for priority in [None, 1]:
229 yield dict(locals())
230
231 jlist = list(jobs())
232 assert len(jlist) == 1*4*3*2
233 for i, dct in enumerate(jobs()):
234 t = db.insert(**dct)
235
236 orig_keycount = Session().query(db._KeyVal).count()
237 orig_dctcount = Session().query(db._Dict).count()
238 self.failUnless(orig_dctcount == len(jlist))
239
240 #delete all the rng_seed = 5 dictionaries
241 qlist_before = list(db.query(rng_seed=5))
242 for q in qlist_before:
243 q.delete()
244
245 #check that the right number has been removed
246 post_dctcount = Session().query(db._Dict).count()
247 self.failUnless(post_dctcount == len(jlist)-8)
248
249 #check that the remaining ones are correct
250 for a, b, in zip(
251 [j for j in jlist if j['rng_seed'] != 5],
252 Session().query(db._Dict).all()):
253 self.failUnless(a == b)
254
255 #check that the keys have all been removed
256 n_keys_per_dict = 8
257 new_keycount = Session().query(db._KeyVal).count()
258 self.failUnless(orig_keycount - 8 * n_keys_per_dict == new_keycount, (orig_keycount,
259 new_keycount))
260
261
262 def test_setitem_0(self):
263 Session, t_dict, t_pair, t_link = self.go()
264
265 db = DbHandle(*self.go())
266
267 b0 = 6.0
268 b1 = 9.0
269
270 job = dict(a=0, b=b0, c='hello')
271
272 dbjob = db.insert(**job)
273
274 dbjob['b'] = b1
275
276 #check that the change is in db
277 qjob = Session().query(db._Dict).filter(db._Dict._attrs.any(name='b',
278 fval=b1)).first()
279 self.failIf(qjob is dbjob)
280 self.failUnless(qjob == dbjob)
281
282 #check that the b:b0 key is gone
283 count = Session().query(db._KeyVal).filter_by(name='b', fval=b0).count()
284 self.failUnless(count == 0, count)
285
286 #check that the b:b1 key is there
287 count = Session().query(db._KeyVal).filter_by(name='b', fval=b1).count()
288 self.failUnless(count == 1, count)
289
290 def test_setitem_1(self):
291 """replace with different sql type"""
292 Session, t_dict, t_pair, t_link = self.go()
293
294 db = DbHandle(*self.go())
295
296 b0 = 6.0
297 b1 = 'asdf' # a different dtype
298
299 job = dict(a=0, b=b0, c='hello')
300
301 dbjob = db.insert(**job)
302
303 dbjob['b'] = b1
304
305 #check that the change is in db
306 qjob = Session().query(db._Dict).filter(db._Dict._attrs.any(name='b',
307 sval=b1)).first()
308 self.failIf(qjob is dbjob)
309 self.failUnless(qjob == dbjob)
310
311 #check that the b:b0 key is gone
312 count = Session().query(db._KeyVal).filter_by(name='b', fval=b0).count()
313 self.failUnless(count == 0, count)
314
315 #check that the b:b1 key is there
316 count = Session().query(db._KeyVal).filter_by(name='b', sval=b1,
317 fval=None).count()
318 self.failUnless(count == 1, count)
319
320 def test_setitem_2(self):
321 """replace with different number type"""
322 Session, t_dict, t_pair, t_link = self.go()
323
324 db = DbHandle(*self.go())
325
326 b0 = 6.0
327 b1 = 7
328
329 job = dict(a=0, b=b0, c='hello')
330
331 dbjob = db.insert(**job)
332
333 dbjob['b'] = b1
334
335 #check that the change is in db
336 qjob = Session().query(db._Dict).filter(db._Dict._attrs.any(name='b',
337 fval=b1)).first()
338 self.failIf(qjob is dbjob)
339 self.failUnless(qjob == dbjob)
340
341 #check that the b:b0 key is gone
342 count = Session().query(db._KeyVal).filter_by(name='b', fval=b0,ntype=1).count()
343 self.failUnless(count == 0, count)
344
345 #check that the b:b1 key is there
346 count = Session().query(db._KeyVal).filter_by(name='b', fval=b1,ntype=0).count()
347 self.failUnless(count == 1, count)
348
349
350 if __name__ == '__main__':
351 unittest.main()