comparison test_dataset.py @ 81:4b0859606d05

Added test for ArrayDataSet and LookUpList
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Mon, 05 May 2008 10:57:33 -0400
parents dde1fb1b63ba
children 158653a9bc7c
comparison
equal deleted inserted replaced
80:40476a7746e8 81:4b0859606d05
24 print "take a slice and look at field y",ds[1:6:2]["y"] 24 print "take a slice and look at field y",ds[1:6:2]["y"]
25 25
26 def test_ArrayDataSet(): 26 def test_ArrayDataSet():
27 #don't test stream 27 #don't test stream
28 #tested only with float value 28 #tested only with float value
29 #test with y too
30 #test missing value
31
29 a = numpy.random.rand(10,4) 32 a = numpy.random.rand(10,4)
30 print a 33 print a
31 ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]}) 34 ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})
32 assert len(ds)==10 35 assert len(ds)==10
33 #assert ds==a? should this work? 36 #assert ds==a? should this work?
34 for i in range(len(ds)): 37 for i in range(len(ds)):
35 assert ds[i]['x'].all()==a[i][:2].all() 38 assert ds[i]['x'].all()==a[i][:2].all()
36 assert ds[i]['y']==a[i][3] 39 assert ds[i]['y']==a[i][3]
37 assert ds[i]['z'].all()==a[i][0:3:2].all() 40 assert ds[i]['z'].all()==a[i][0:3:2].all()
38 print "x=",ds["x"] 41 print "x=",ds["x"]
39 print "x|y"
40 i=0 42 i=0
41 for x in ds('x','y'): 43 for x in ds('x','y'):
42 assert numpy.append(x['x'],x['y']).all()==a[i].all() 44 assert numpy.append(x['x'],x['y']).all()==a[i].all()
43 i+=1 45 i+=1
44 # i=0 46
45 # for x in ds['x','y']: # don't work
46 # assert numpy.append(x['x'],x['y']).all()==a[i].all()
47 # i+=1
48 # for (x,y) in (ds('x','y'),a): #don't work # haven't found a variant that work.
49 # assert numpy.append(x,y)==z
50 i=0 47 i=0
51 for x,y in ds('x','y'): 48 for x,y in ds('x','y'):
52 assert numpy.append(x,y).all()==a[i].all() 49 assert numpy.append(x,y).all()==a[i].all()
53 i+=1 50 i+=1
54 for minibatch in ds.minibatches(['x','z'], minibatch_size=3): 51 for minibatch in ds.minibatches(['x','z'], minibatch_size=3):
68 try: 65 try:
69 ds['h'] # h is not defined... 66 ds['h'] # h is not defined...
70 except : 67 except :
71 have_thrown = True 68 have_thrown = True
72 assert have_thrown == True 69 assert have_thrown == True
70
71 have_thrown = False
72 try:
73 ds[['h']] # h is not defined...
74 except :
75 have_thrown = True
76 assert have_thrown == True
77
73 assert len(ds.fields())==3 78 assert len(ds.fields())==3
74 for field in ds.fields(): 79 for field in ds.fields():
75 for field_value in field: # iterate over the values associated to that field for all the ds examples 80 for field_value in field: # iterate over the values associated to that field for all the ds examples
76 pass 81 pass
77 for field in ds('x','z').fields(): 82 for field in ds('x','z').fields():
83 pass 88 pass
84 89
85 assert ds == ds.fields().examples() 90 assert ds == ds.fields().examples()
86 91
87 92
88 #test missing value 93 #ds[:n] returns a dataset with the n first examples.
89
90 assert len(ds[:3])==3 94 assert len(ds[:3])==3
91 i=0 95 i=0
92 for x,z in ds[:3]('x','z'): 96 for x,z in ds[:3]('x','z'):
93 assert ds[i]['z'].all()==a[i][0:3:2].all() 97 assert ds[i]['z'].all()==a[i][0:3:2].all()
94 i+=1 98 i+=1
99
95 #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. 100 #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s.
96 101 ds[1:7:2][1] #fail???
97 #ds[i]# returns an Example. 102 assert len(ds[1:7:2])==3 # should be number example 1,3 and 5
98 103 i=0
104 for x,z in ds[1:7:2]('x','z'):
105 assert ds[i]['z'].all()==a[i][0:3:2].all()
106 i+=1
107 ds2=ds[1:7:2]
108 for i in range(len(ds2)):
109 print ds2[i]
99 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. 110 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in.
100 111 i=0
112 for x in ds[[1,2]]:
113 assert numpy.append(x['x'],x['y']).all()==a[i].all()
114 i+=1
115 #ds[i1,i2,...]# should we accept????
101 #ds[fieldname]# an iterable over the values of the field fieldname across 116 #ds[fieldname]# an iterable over the values of the field fieldname across
102 #the ds (the iterable is obtained by default by calling valuesVStack 117 #the ds (the iterable is obtained by default by calling valuesVStack
103 #over the values for individual examples). 118 #over the values for individual examples).
104 119
105 #ds.<property># returns the value of a property associated with 120 #ds.<property># returns the value of a property associated with
108 # - 'fieldtypes': a list of types (one per field) 123 # - 'fieldtypes': a list of types (one per field)
109 #* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3]) 124 #* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3])
110 #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3]) 125 #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])
111 126
112 127
128 # for (x,y) in (ds('x','y'),a): #don't work # haven't found a variant that work.
129 # assert numpy.append(x,y)==z
130
131 def test_LookupList():
132 #test only the example in the doc???
133 example = LookupList(['x','y','z'],[1,2,3])
134 example['x'] = [1, 2, 3] # set or change a field
135 x, y, z = example
136 x = example[0]
137 x = example["x"]
138 assert example.keys()==['x','y','z']
139 assert example.values()==[[1,2,3],2,3]
140 assert example.items()==[('x',[1,2,3]),('y',2),('z',3)]
141 example.append_keyval('u',0) # adds item with name 'u' and value 0
142 assert len(example)==4 # number of items = 4 here
143 example2 = LookupList(['v','w'], ['a','b'])
144 example3 = LookupList(['x','y','z','u','v','w'], [[1, 2, 3],2,3,0,'a','b'])
145 print example3
146 print example+example2
147 print example+example2
148 assert example+example2==example3
149 have_throw=False
150 try:
151 example+example
152 except:
153 have_throw=True
154 assert have_throw
155
156 def ApplyFunctionDataSet():
157 raise NotImplementedError()
158 def CacheDataSet():
159 raise NotImplementedError()
160 def FieldsSubsetDataSet():
161 raise NotImplementedError()
162 def DataSetFields():
163 raise NotImplementedError()
164 def MinibatchDataSet():
165 raise NotImplementedError()
166 def HStackedDataSet():
167 raise NotImplementedError()
168 def VStackedDataSet():
169 raise NotImplementedError()
170 def ArrayFieldsDataSet():
171 raise NotImplementedError()
172
173 test_LookupList()
113 test_ArrayDataSet() 174 test_ArrayDataSet()
114 175
176
177