Mercurial > pylearn
comparison pylearn/dbdict/test_api0.py @ 538:798607a058bd
added missing files
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 12 Nov 2008 22:00:20 -0500 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
537:b054271b2504 | 538:798607a058bd |
---|---|
1 from api0 import * | |
2 import threading, time, commands, os, sys, math, random, datetime | |
3 | |
4 import psycopg2, psycopg2.extensions | |
5 from sqlalchemy import create_engine, desc | |
6 from sqlalchemy.orm import sessionmaker | |
7 from sqlalchemy import Table, Column, MetaData, ForeignKey | |
8 from sqlalchemy import Integer, String, Float, DateTime, Text, Binary | |
9 from sqlalchemy.orm import mapper, relation, backref, eagerload | |
10 from sqlalchemy.sql import operators, select | |
11 | |
12 import unittest | |
13 | |
14 class T(unittest.TestCase): | |
15 | |
16 def test_bad_dict_table(self): | |
17 """Make sure our crude version of schema checking kinda works""" | |
18 engine = create_engine('sqlite:///:memory:', echo=False) | |
19 Session = sessionmaker(bind=engine, autoflush=True, transactional=True) | |
20 | |
21 table_prefix='bergstrj_scratch_test_' | |
22 metadata = MetaData() | |
23 t_trial = Table(table_prefix+'trial', metadata, | |
24 Column('id', Integer, primary_key=True), | |
25 Column('desc', String(256)), #comment: why running this trial? | |
26 Column('priority', Float(53)), #aka Double | |
27 Column('start', DateTime), | |
28 Column('finish', DateTime), | |
29 Column('host', String(256))) | |
30 metadata.create_all(engine) | |
31 | |
32 try: | |
33 h = DbHandle(None, t_trial, None, None) | |
34 except ValueError, e: | |
35 if e[0] == DbHandle.e_bad_table: | |
36 return | |
37 self.fail() | |
38 | |
39 | |
40 def go(self): | |
41 """Create tables and session_maker""" | |
42 engine = create_engine('sqlite:///:memory:', echo=False) | |
43 Session = sessionmaker(autoflush=True, transactional=True) | |
44 | |
45 table_prefix='bergstrj_scratch_test_' | |
46 metadata = MetaData() | |
47 | |
48 t_trial = Table(table_prefix+'trial', metadata, | |
49 Column('id', Integer, primary_key=True), | |
50 Column('create', DateTime), | |
51 Column('write', DateTime), | |
52 Column('read', DateTime)) | |
53 | |
54 t_keyval = Table(table_prefix+'keyval', metadata, | |
55 Column('id', Integer, primary_key=True), | |
56 Column('name', String(32), nullable=False), #name of attribute | |
57 Column('ntype', Integer), | |
58 Column('fval', Float(53)), #aka Double | |
59 Column('sval', Text), #aka Double | |
60 Column('bval', Binary)) #TODO: store text (strings of unbounded length) | |
61 | |
62 t_trial_keyval = Table(table_prefix+'trial_keyval', metadata, | |
63 Column('dict_id', Integer, ForeignKey('%s.id' % t_trial), | |
64 primary_key=True), | |
65 Column('pair_id', Integer, ForeignKey('%s.id' % t_keyval), | |
66 primary_key=True)) | |
67 | |
68 metadata.bind = engine | |
69 metadata.create_all() # does nothing when tables already exist | |
70 | |
71 self.engine = engine | |
72 return Session, t_trial, t_keyval, t_trial_keyval | |
73 | |
74 | |
75 def test_insert_save(self): | |
76 | |
77 Session, t_dict, t_pair, t_link = self.go() | |
78 | |
79 db = DbHandle(*self.go()) | |
80 | |
81 def jobs(): | |
82 dvalid, dtest = 'dvalid', 'dtest file' | |
83 desc = 'debugging' | |
84 for lr in [0.001]: | |
85 for scale in [0.0001 * math.sqrt(10.0)**i for i in range(4)]: | |
86 for rng_seed in [4, 5, 6]: | |
87 for priority in [None, 1]: | |
88 yield dict(locals()) | |
89 | |
90 jlist = list(jobs()) | |
91 assert len(jlist) == 1*4*3*2 | |
92 for i, dct in enumerate(jobs()): | |
93 t = db.insert(**dct) | |
94 | |
95 #make sure that they really got inserted into the db | |
96 orig_keycount = db._session.query(db._KeyVal).count() | |
97 self.failUnless(orig_keycount > 0, orig_keycount) | |
98 | |
99 orig_dctcount = Session().query(db._Dict).count() | |
100 self.failUnless(orig_dctcount ==len(jlist), orig_dctcount) | |
101 | |
102 orig_keycount = Session().query(db._KeyVal).count() | |
103 self.failUnless(orig_keycount > 0, orig_keycount) | |
104 | |
105 #queries | |
106 q0list = list(db.query().all()) | |
107 q1list = list(db.query()) | |
108 q2list = list(db) | |
109 | |
110 self.failUnless(q0list == q1list, (q0list,q1list)) | |
111 self.failUnless(q0list == q2list, (q0list,q1list)) | |
112 | |
113 self.failUnless(len(q0list) == len(jlist)) | |
114 | |
115 for i, (j, q) in enumerate(zip(jlist, q0list)): | |
116 jitems = list(j.items()) | |
117 qitems = list(q.items()) | |
118 jitems.sort() | |
119 qitems.sort() | |
120 if jitems != qitems: | |
121 print i | |
122 print jitems | |
123 print qitems | |
124 self.failUnless(jitems == qitems, (jitems, qitems)) | |
125 | |
126 def test_query_0(self): | |
127 Session, t_dict, t_pair, t_link = self.go() | |
128 | |
129 db = DbHandle(*self.go()) | |
130 | |
131 def jobs(): | |
132 dvalid, dtest = 'dvalid', 'dtest file' | |
133 desc = 'debugging' | |
134 for lr in [0.001]: | |
135 for scale in [0.0001 * math.sqrt(10.0)**i for i in range(4)]: | |
136 for rng_seed in [4, 5, 6]: | |
137 for priority in [None, 1]: | |
138 yield dict(locals()) | |
139 | |
140 jlist = list(jobs()) | |
141 assert len(jlist) == 1*4*3*2 | |
142 for i, dct in enumerate(jobs()): | |
143 t = db.insert(**dct) | |
144 | |
145 qlist = list(db.query(rng_seed=5)) | |
146 self.failUnless(len(qlist) == len(jlist)/3) | |
147 | |
148 jlist5 = [j for j in jlist if j['rng_seed'] == 5] | |
149 | |
150 for i, (j, q) in enumerate(zip(jlist5, qlist)): | |
151 jitems = list(j.items()) | |
152 qitems = list(q.items()) | |
153 jitems.sort() | |
154 qitems.sort() | |
155 if jitems != qitems: | |
156 print i | |
157 print jitems | |
158 print qitems | |
159 self.failUnless(jitems == qitems, (jitems, qitems)) | |
160 | |
161 def test_delete_keywise(self): | |
162 Session, t_dict, t_pair, t_link = self.go() | |
163 | |
164 db = DbHandle(*self.go()) | |
165 | |
166 def jobs(): | |
167 dvalid, dtest = 'dvalid', 'dtest file' | |
168 desc = 'debugging' | |
169 for lr in [0.001]: | |
170 for scale in [0.0001 * math.sqrt(10.0)**i for i in range(4)]: | |
171 for rng_seed in [4, 5, 6]: | |
172 for priority in [None, 1]: | |
173 yield dict(locals()) | |
174 | |
175 jlist = list(jobs()) | |
176 assert len(jlist) == 1*4*3*2 | |
177 for i, dct in enumerate(jobs()): | |
178 t = db.insert(**dct) | |
179 | |
180 orig_keycount = Session().query(db._KeyVal).count() | |
181 | |
182 del_count = Session().query(db._KeyVal).filter_by(name='rng_seed', | |
183 fval=5.0).count() | |
184 self.failUnless(del_count == 8, del_count) | |
185 | |
186 #delete all the rng_seed = 5 entries | |
187 qlist_before = list(db.query(rng_seed=5)) | |
188 for q in qlist_before: | |
189 del q['rng_seed'] | |
190 | |
191 #check that it's gone from our objects | |
192 for q in qlist_before: | |
193 self.failUnless('rng_seed' not in q) #via __contains__ | |
194 self.failUnless('rng_seed' not in q.keys()) #via keys() | |
195 exc=None | |
196 try: | |
197 r = q['rng_seed'] # via __getitem__ | |
198 print 'r,', r | |
199 except KeyError, e: | |
200 pass | |
201 | |
202 #check that it's gone from dictionaries in the database | |
203 qlist_after = list(db.query(rng_seed=5)) | |
204 self.failUnless(qlist_after == []) | |
205 | |
206 #check that exactly 8 keys were removed | |
207 new_keycount = Session().query(db._KeyVal).count() | |
208 self.failUnless(orig_keycount == new_keycount + 8, (orig_keycount, | |
209 new_keycount)) | |
210 | |
211 #check that no keys have rng_seed == 5 | |
212 gone_count = Session().query(db._KeyVal).filter_by(name='rng_seed', | |
213 fval=5.0).count() | |
214 self.failUnless(gone_count == 0, gone_count) | |
215 | |
216 | |
217 def test_delete_dictwise(self): | |
218 Session, t_dict, t_pair, t_link = self.go() | |
219 | |
220 db = DbHandle(*self.go()) | |
221 | |
222 def jobs(): | |
223 dvalid, dtest = 'dvalid', 'dtest file' | |
224 desc = 'debugging' | |
225 for lr in [0.001]: | |
226 for scale in [0.0001 * math.sqrt(10.0)**i for i in range(4)]: | |
227 for rng_seed in [4, 5, 6]: | |
228 for priority in [None, 1]: | |
229 yield dict(locals()) | |
230 | |
231 jlist = list(jobs()) | |
232 assert len(jlist) == 1*4*3*2 | |
233 for i, dct in enumerate(jobs()): | |
234 t = db.insert(**dct) | |
235 | |
236 orig_keycount = Session().query(db._KeyVal).count() | |
237 orig_dctcount = Session().query(db._Dict).count() | |
238 self.failUnless(orig_dctcount == len(jlist)) | |
239 | |
240 #delete all the rng_seed = 5 dictionaries | |
241 qlist_before = list(db.query(rng_seed=5)) | |
242 for q in qlist_before: | |
243 q.delete() | |
244 | |
245 #check that the right number has been removed | |
246 post_dctcount = Session().query(db._Dict).count() | |
247 self.failUnless(post_dctcount == len(jlist)-8) | |
248 | |
249 #check that the remaining ones are correct | |
250 for a, b, in zip( | |
251 [j for j in jlist if j['rng_seed'] != 5], | |
252 Session().query(db._Dict).all()): | |
253 self.failUnless(a == b) | |
254 | |
255 #check that the keys have all been removed | |
256 n_keys_per_dict = 8 | |
257 new_keycount = Session().query(db._KeyVal).count() | |
258 self.failUnless(orig_keycount - 8 * n_keys_per_dict == new_keycount, (orig_keycount, | |
259 new_keycount)) | |
260 | |
261 | |
262 def test_setitem_0(self): | |
263 Session, t_dict, t_pair, t_link = self.go() | |
264 | |
265 db = DbHandle(*self.go()) | |
266 | |
267 b0 = 6.0 | |
268 b1 = 9.0 | |
269 | |
270 job = dict(a=0, b=b0, c='hello') | |
271 | |
272 dbjob = db.insert(**job) | |
273 | |
274 dbjob['b'] = b1 | |
275 | |
276 #check that the change is in db | |
277 qjob = Session().query(db._Dict).filter(db._Dict._attrs.any(name='b', | |
278 fval=b1)).first() | |
279 self.failIf(qjob is dbjob) | |
280 self.failUnless(qjob == dbjob) | |
281 | |
282 #check that the b:b0 key is gone | |
283 count = Session().query(db._KeyVal).filter_by(name='b', fval=b0).count() | |
284 self.failUnless(count == 0, count) | |
285 | |
286 #check that the b:b1 key is there | |
287 count = Session().query(db._KeyVal).filter_by(name='b', fval=b1).count() | |
288 self.failUnless(count == 1, count) | |
289 | |
290 def test_setitem_1(self): | |
291 """replace with different sql type""" | |
292 Session, t_dict, t_pair, t_link = self.go() | |
293 | |
294 db = DbHandle(*self.go()) | |
295 | |
296 b0 = 6.0 | |
297 b1 = 'asdf' # a different dtype | |
298 | |
299 job = dict(a=0, b=b0, c='hello') | |
300 | |
301 dbjob = db.insert(**job) | |
302 | |
303 dbjob['b'] = b1 | |
304 | |
305 #check that the change is in db | |
306 qjob = Session().query(db._Dict).filter(db._Dict._attrs.any(name='b', | |
307 sval=b1)).first() | |
308 self.failIf(qjob is dbjob) | |
309 self.failUnless(qjob == dbjob) | |
310 | |
311 #check that the b:b0 key is gone | |
312 count = Session().query(db._KeyVal).filter_by(name='b', fval=b0).count() | |
313 self.failUnless(count == 0, count) | |
314 | |
315 #check that the b:b1 key is there | |
316 count = Session().query(db._KeyVal).filter_by(name='b', sval=b1, | |
317 fval=None).count() | |
318 self.failUnless(count == 1, count) | |
319 | |
320 def test_setitem_2(self): | |
321 """replace with different number type""" | |
322 Session, t_dict, t_pair, t_link = self.go() | |
323 | |
324 db = DbHandle(*self.go()) | |
325 | |
326 b0 = 6.0 | |
327 b1 = 7 | |
328 | |
329 job = dict(a=0, b=b0, c='hello') | |
330 | |
331 dbjob = db.insert(**job) | |
332 | |
333 dbjob['b'] = b1 | |
334 | |
335 #check that the change is in db | |
336 qjob = Session().query(db._Dict).filter(db._Dict._attrs.any(name='b', | |
337 fval=b1)).first() | |
338 self.failIf(qjob is dbjob) | |
339 self.failUnless(qjob == dbjob) | |
340 | |
341 #check that the b:b0 key is gone | |
342 count = Session().query(db._KeyVal).filter_by(name='b', fval=b0,ntype=1).count() | |
343 self.failUnless(count == 0, count) | |
344 | |
345 #check that the b:b1 key is there | |
346 count = Session().query(db._KeyVal).filter_by(name='b', fval=b1,ntype=0).count() | |
347 self.failUnless(count == 1, count) | |
348 | |
349 | |
350 if __name__ == '__main__': | |
351 unittest.main() |