comparison test_dataset.py @ 84:aa9e786ee849

added function have_raised that evaluate the string in parameter and return true if the function have raised an exception code cleanup and renaming added a few test
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Mon, 05 May 2008 11:49:40 -0400
parents 158653a9bc7c
children fdf72ea4f2bc
comparison
equal deleted inserted replaced
83:c0f211213a58 84:aa9e786ee849
1 #!/bin/env python 1 #!/bin/env python
2 from dataset import * 2 from dataset import *
3 from math import * 3 from math import *
4 import numpy 4 import numpy
5 5
6 def have_raised(to_eval):
7 have_thrown = False
8 try:
9 eval(to_eval)
10 except :
11 have_thrown = True
12 return have_thrown
13
6 def test1(): 14 def test1():
15 print "test1"
7 global a,ds 16 global a,ds
8 a = numpy.random.rand(10,4) 17 a = numpy.random.rand(10,4)
9 print a 18 print a
10 ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]}) 19 ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})
11 print "len(ds)=",len(ds) 20 print "len(ds)=",len(ds)
27 #don't test stream 36 #don't test stream
28 #tested only with float value 37 #tested only with float value
29 #test with y too 38 #test with y too
30 #test missing value 39 #test missing value
31 40
41 print "test_ArrayDataSet"
32 a = numpy.random.rand(10,4) 42 a = numpy.random.rand(10,4)
33 print a
34 ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]}) 43 ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})
35 assert len(ds)==10 44 assert len(ds)==10
36 #assert ds==a? should this work? 45 #assert ds==a? should this work?
37 for i in range(len(ds)): 46 for i in range(len(ds)):
38 assert ds[i]['x'].all()==a[i][:2].all() 47 assert ds[i]['x'].all()==a[i][:2].all()
39 assert ds[i]['y']==a[i][3] 48 assert ds[i]['y']==a[i][3]
40 assert ds[i]['z'].all()==a[i][0:3:2].all() 49 assert ds[i]['z'].all()==a[i][0:3:2].all()
41 print "x=",ds["x"]
42 i=0 50 i=0
43 for x in ds('x','y'): 51 for x in ds('x','y'):
44 assert numpy.append(x['x'],x['y']).all()==a[i].all() 52 assert numpy.append(x['x'],x['y']).all()==a[i].all()
45 i+=1 53 i+=1
46 54
59 # minibatch = minibatch_iterator.__iter__().next() 67 # minibatch = minibatch_iterator.__iter__().next()
60 # print "minibatch=",minibatch 68 # print "minibatch=",minibatch
61 # for var in minibatch: 69 # for var in minibatch:
62 # print "var=",var 70 # print "var=",var
63 # print "take a slice and look at field y",ds[1:6:2]["y"] 71 # print "take a slice and look at field y",ds[1:6:2]["y"]
64 have_thrown = False 72 assert have_raised("ds['h']") # h is not defined...
65 try: 73 assert have_raised("ds[['h']]") # h is not defined...
66 ds['h'] # h is not defined...
67 except :
68 have_thrown = True
69 assert have_thrown == True
70
71 have_thrown = False
72 try:
73 ds[['h']] # h is not defined...
74 except :
75 have_thrown = True
76 assert have_thrown == True
77 74
78 assert len(ds.fields())==3 75 assert len(ds.fields())==3
79 for field in ds.fields(): 76 for field in ds.fields():
80 for field_value in field: # iterate over the values associated to that field for all the ds examples 77 for field_value in field: # iterate over the values associated to that field for all the ds examples
81 pass 78 pass
91 88
92 89
93 #ds[:n] returns a dataset with the n first examples. 90 #ds[:n] returns a dataset with the n first examples.
94 assert len(ds[:3])==3 91 assert len(ds[:3])==3
95 i=0 92 i=0
93 for x,z,y in ds[:3]('x','z','y'):
94 assert ds[i]['x'].all()==a[i][:3].all()
95 assert ds[i]['x'].all()==x.all()
96 assert ds[i]['y']==a[i][3]
97 assert ds[i]['y']==y
98 assert ds[i]['z'].all()==a[i][0:3:2].all()
99 assert ds[i]['z'].all()==z.all()
100 i+=1
101 i=0
96 for x,z in ds[:3]('x','z'): 102 for x,z in ds[:3]('x','z'):
97 assert ds[i]['z'].all()==a[i][0:3:2].all() 103 assert ds[i]['z'].all()==a[i][0:3:2].all()
98 i+=1 104 i+=1
99 105
100 #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. 106 #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s.
101 ds[1:7:2][1] #fail??? 107 ds[1:7:2][1]
102 assert len(ds[1:7:2])==3 # should be number example 1,3 and 5 108 assert len(ds[1:7:2])==3 # should be number example 1,3 and 5
103 i=0 109 i=0
104 for x,z in ds[1:7:2]('x','z'): 110 index=[1,3,5]
105 assert ds[i]['z'].all()==a[i][0:3:2].all() 111 for z,y,x in ds[1:7:2]('z','y','x'):
112 assert ds[index[i]]['x'].all()==a[index[i]][:3].all()
113 assert ds[index[i]]['x'].all()==x.all()
114 assert ds[index[i]]['y']==a[index[i]][3]
115 assert ds[index[i]]['y']==y
116 assert ds[index[i]]['z'].all()==a[index[i]][0:3:2].all()
117 assert ds[index[i]]['z'].all()==z.all()
106 i+=1 118 i+=1
107 ds2=ds[1:7:2] 119
108 for i in range(len(ds2)):
109 print ds2[i]
110 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. 120 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in.
111 i=0 121 i=0
112 for x in ds[[1,2]]: 122 for x in ds[[1,2]]:
113 assert numpy.append(x['x'],x['y']).all()==a[i].all() 123 assert numpy.append(x['x'],x['y']).all()==a[i].all()
114 i+=1 124 i+=1
127 # for (x,y) in (ds('x','y'),a): #don't work # haven't found a variant that work. 137 # for (x,y) in (ds('x','y'),a): #don't work # haven't found a variant that work.
128 # assert numpy.append(x,y)==z 138 # assert numpy.append(x,y)==z
129 139
130 def test_LookupList(): 140 def test_LookupList():
131 #test only the example in the doc??? 141 #test only the example in the doc???
132 example = LookupList(['x','y','z'],[1,2,3]) 142 print "test_LookupList"
133 example['x'] = [1, 2, 3] # set or change a field 143 example = LookupList(['x','y','z'],[1,2,3])
134 x, y, z = example 144 example['x'] = [1, 2, 3] # set or change a field
135 x = example[0] 145 x, y, z = example
136 x = example["x"] 146 x = example[0]
137 assert example.keys()==['x','y','z'] 147 x = example["x"]
138 assert example.values()==[[1,2,3],2,3] 148 assert example.keys()==['x','y','z']
139 assert example.items()==[('x',[1,2,3]),('y',2),('z',3)] 149 assert example.values()==[[1,2,3],2,3]
140 example.append_keyval('u',0) # adds item with name 'u' and value 0 150 assert example.items()==[('x',[1,2,3]),('y',2),('z',3)]
141 assert len(example)==4 # number of items = 4 here 151 example.append_keyval('u',0) # adds item with name 'u' and value 0
142 example2 = LookupList(['v','w'], ['a','b']) 152 assert len(example)==4 # number of items = 4 here
143 example3 = LookupList(['x','y','z','u','v','w'], [[1, 2, 3],2,3,0,'a','b']) 153 example2 = LookupList(['v','w'], ['a','b'])
144 print example3 154 example3 = LookupList(['x','y','z','u','v','w'], [[1, 2, 3],2,3,0,'a','b'])
145 print example+example2 155 assert example+example2==example3
146 print example+example2 156 assert have_raised("example+example")
147 assert example+example2==example3
148 have_throw=False
149 try:
150 example+example
151 except:
152 have_throw=True
153 assert have_throw
154 157
155 def ApplyFunctionDataSet(): 158 def test_ApplyFunctionDataSet():
159 print "test_ApplyFunctionDataSet"
156 raise NotImplementedError() 160 raise NotImplementedError()
157 def CacheDataSet(): 161 def test_CacheDataSet():
162 print "test_CacheDataSet"
158 raise NotImplementedError() 163 raise NotImplementedError()
159 def FieldsSubsetDataSet(): 164 def test_FieldsSubsetDataSet():
165 print "test_FieldsSubsetDataSet"
160 raise NotImplementedError() 166 raise NotImplementedError()
161 def DataSetFields(): 167 def test_DataSetFields():
168 print "test_DataSetFields"
162 raise NotImplementedError() 169 raise NotImplementedError()
163 def MinibatchDataSet(): 170 def test_MinibatchDataSet():
171 print "test_MinibatchDataSet"
164 raise NotImplementedError() 172 raise NotImplementedError()
165 def HStackedDataSet(): 173 def test_HStackedDataSet():
174 print "test_HStackedDataSet"
166 raise NotImplementedError() 175 raise NotImplementedError()
167 def VStackedDataSet(): 176 def test_VStackedDataSet():
177 print "test_VStackedDataSet"
168 raise NotImplementedError() 178 raise NotImplementedError()
169 def ArrayFieldsDataSet(): 179 def test_ArrayFieldsDataSet():
180 print "test_ArrayFieldsDataSet"
170 raise NotImplementedError() 181 raise NotImplementedError()
171 182
172 test1() 183 test1()
173 test_LookupList() 184 test_LookupList()
174 test_ArrayDataSet() 185 test_ArrayDataSet()