comparison test_dataset.py @ 94:9c8f3c9c247b

corrected the use of .all()
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Mon, 05 May 2008 17:44:49 -0400
parents eee739fefdff
children 352910e0dbf5
comparison
equal deleted inserted replaced
91:eee739fefdff 94:9c8f3c9c247b
45 assert len(ds)==10 45 assert len(ds)==10
46 #assert ds==a? should this work? 46 #assert ds==a? should this work?
47 47
48 #not in doc!!! 48 #not in doc!!!
49 for example in range(len(ds)): 49 for example in range(len(ds)):
50 assert ds[example]['x'].all()==a[example][:2].all() 50 assert (ds[example]['x']==a[example][:3]).all()
51 assert ds[example]['y']==a[example][3] 51 assert ds[example]['y']==a[example][3]
52 assert ds[example]['z'].all()==a[example][0:3:2].all() 52 assert (ds[example]['z']==a[example][[0,2]]).all()
53 53
54 # - for example in dataset: 54 # - for example in dataset:
55 i=0 55 i=0
56 for example in ds: 56 for example in ds:
57 assert example['x'].all()==a[i][:2].all() 57 assert (example['x']==a[i][:3]).all()
58 assert example['y']==a[i][3] 58 assert example['y']==a[i][3]
59 assert example['z'].all()==a[i][0:3:2].all() 59 assert (example['z']==a[i][0:3:2]).all()
60 assert numpy.append(example['x'],example['y']).all()==a[i].all() 60 assert (numpy.append(example['x'],example['y'])==a[i]).all()
61 i+=1 61 i+=1
62 assert i==len(ds) 62 assert i==len(ds)
63 # - for val1,val2,... in dataset: 63 # - for val1,val2,... in dataset:
64 i=0 64 i=0
65 for x,y,z in ds: 65 for x,y,z in ds:
66 assert x.all()==a[i][:2].all() 66 assert (x==a[i][:3]).all()
67 assert y==a[i][3] 67 assert y==a[i][3]
68 assert z.all()==a[i][0:3:2].all() 68 assert (z==a[i][0:3:2]).all()
69 assert numpy.append(example['x'],example['y']).all()==a[i].all() 69 assert (numpy.append(x,y)==a[i]).all()
70 i+=1 70 i+=1
71 assert i==len(ds) 71 assert i==len(ds)
72 # - for example in dataset(field1, field2,field3, ...): 72 # - for example in dataset(field1, field2,field3, ...):
73 i=0 73 i=0
74 for example in ds('x','y','z'): 74 for example in ds('x','y','z'):
75 assert example['x'].all()==a[i][:2].all() 75 assert (example['x']==a[i][:3]).all()
76 assert example['y']==a[i][3] 76 assert example['y']==a[i][3]
77 assert example['z'].all()==a[i][0:3:2].all() 77 assert (example['z']==a[i][0:3:2]).all()
78 assert numpy.append(example['x'],example['y']).all()==a[i].all() 78 assert (numpy.append(example['x'],example['y'])==a[i]).all()
79 i+=1 79 i+=1
80 assert i==len(ds) 80 assert i==len(ds)
81 81
82 # - for val1,val2,val3 in dataset(field1, field2,field3): 82 # - for val1,val2,val3 in dataset(field1, field2,field3):
83 83
84 # - for example in dataset(field1, field2,field3, ...): 84 # - for example in dataset(field1, field2,field3, ...):
85 85
86 def test_ds_iterator(iterator1,iterator2,iterator3): 86 def test_ds_iterator(iterator1,iterator2,iterator3):
87 i=0 87 i=0
88 for x,y in iterator1: 88 for x,y in iterator1:
89 assert x.all()==a[i][:2].all() 89 assert (x==a[i][:3]).all()
90 assert y==a[i][3] 90 assert y==a[i][3]
91 assert numpy.append(x,y).all()==a[i].all() 91 assert (numpy.append(x,y)==a[i]).all()
92 i+=1 92 i+=1
93 assert i==len(ds) 93 assert i==len(ds)
94 i=0 94 i=0
95 for y,z in iterator2: 95 for y,z in iterator2:
96 assert y==a[i][3] 96 assert y==a[i][3]
97 assert z.all()==a[i][0:3:2].all() 97 assert (z==a[i][0:3:2]).all()
98 i+=1 98 i+=1
99 assert i==len(ds) 99 assert i==len(ds)
100 i=0 100 i=0
101 for x,y,z in iterator3: 101 for x,y,z in iterator3:
102 assert x.all()==a[i][:2].all() 102 assert (x==a[i][:3]).all()
103 assert y==a[i][3] 103 assert y==a[i][3]
104 assert z.all()==a[i][0:3:2].all() 104 assert (z==a[i][0:3:2]).all()
105 assert numpy.append(x,y).all()==a[i].all() 105 assert (numpy.append(x,y)==a[i]).all()
106 i+=1 106 i+=1
107 assert i==len(ds) 107 assert i==len(ds)
108 108
109 #not in doc!!! - for val1,val2,val3 in dataset(field1, field2,field3): 109 #not in doc!!! - for val1,val2,val3 in dataset(field1, field2,field3):
110 test_ds_iterator(ds('x','y'),ds('y','z'),ds('x','y','z')) 110 test_ds_iterator(ds('x','y'),ds('y','z'),ds('x','y','z'))
112 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): 112 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
113 for minibatch in ds.minibatches(['x','z'], minibatch_size=3): 113 for minibatch in ds.minibatches(['x','z'], minibatch_size=3):
114 assert len(minibatch)==2 114 assert len(minibatch)==2
115 assert len(minibatch[0])==3 115 assert len(minibatch[0])==3
116 assert len(minibatch[1])==3 116 assert len(minibatch[1])==3
117 assert minibatch[0][:,0:3:2].all()==minibatch[1].all() 117 assert (minibatch[0][:,0:3:2]==minibatch[1]).all()
118 i=0 118 i=0
119 for minibatch in ds.minibatches(['x','y'], minibatch_size=3): 119 for minibatch in ds.minibatches(['x','y'], minibatch_size=3):
120 assert len(minibatch)==2 120 assert len(minibatch)==2
121 assert len(minibatch[0])==3 121 assert len(minibatch[0])==3
122 assert len(minibatch[1])==3 122 assert len(minibatch[1])==3
123 for id in range(3): 123 for id in range(3):
124 assert numpy.append(minibatch[0][id],minibatch[1][id]).all()==a[i].all() 124 assert (numpy.append(minibatch[0][id],minibatch[1][id])==a[i]).all()
125 i+=1 125 i+=1
126 126
127 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): 127 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
128 for x,z in ds.minibatches(['x','z'], minibatch_size=3): 128 for x,z in ds.minibatches(['x','z'], minibatch_size=3):
129 assert len(x)==3 129 assert len(x)==3
130 assert len(z)==3 130 assert len(z)==3
131 assert x[:,0:3:2].all()==z.all() 131 assert (x[:,0:3:2]==z).all()
132 i=0 132 i=0
133 for x,y in ds.minibatches(['x','y'], minibatch_size=3): 133 for x,y in ds.minibatches(['x','y'], minibatch_size=3):
134 assert len(x)==3 134 assert len(x)==3
135 assert len(y)==3 135 assert len(y)==3
136 for id in range(3): 136 for id in range(3):
137 assert numpy.append(minibatch[0][id],minibatch[1][id]).all()==a[i].all() 137 assert (numpy.append(x[id],y[id])==a[i]).all()
138 i+=1 138 i+=1
139 # - for x,y,z in dataset: # fail x,y,z order not fixed as it is a dict. 139 # - for x,y,z in dataset: # fail x,y,z order not fixed as it is a dict.
140 # for x,y,z in ds: 140 # for x,y,z in ds:
141 # assert x.all()==a[i][:2].all() 141 # assert (x==a[i][:2]).all()
142 # assert y==a[i][3] 142 # assert y==a[i][3]
143 # assert z.all()==a[i][0:3:2].all() 143 # assert (z==a[i][0:3:2]).all()
144 # assert numpy.append(x,y).all()==a[i].all() 144 # assert (numpy.append(x,y)==a[i]).all()
145 # i+=1 145 # i+=1
146 146
147 # for minibatch in ds.minibatches(['z','y'], minibatch_size=3): 147 # for minibatch in ds.minibatches(['z','y'], minibatch_size=3):
148 # print minibatch 148 # print minibatch
149 # minibatch_iterator = ds.minibatches(fieldnames=['z','y'],n_batches=1,minibatch_size=3,offset=4) 149 # minibatch_iterator = ds.minibatches(fieldnames=['z','y'],n_batches=1,minibatch_size=3,offset=4)
171 171
172 def test_ds(orig,ds,index): 172 def test_ds(orig,ds,index):
173 i=0 173 i=0
174 assert len(ds)==len(index) 174 assert len(ds)==len(index)
175 for x,z,y in ds('x','z','y'): 175 for x,z,y in ds('x','z','y'):
176 assert orig[index[i]]['x'].all()==a[index[i]][:3].all() 176 assert (orig[index[i]]['x']==a[index[i]][:3]).all()
177 assert orig[index[i]]['x'].all()==x.all() 177 assert (orig[index[i]]['x']==x).all()
178 assert orig[index[i]]['y']==a[index[i]][3] 178 assert orig[index[i]]['y']==a[index[i]][3]
179 assert orig[index[i]]['y']==y 179 assert orig[index[i]]['y']==y
180 assert orig[index[i]]['z'].all()==a[index[i]][0:3:2].all() 180 assert (orig[index[i]]['z']==a[index[i]][0:3:2]).all()
181 assert orig[index[i]]['z'].all()==z.all() 181 assert (orig[index[i]]['z']==z).all()
182 i+=1 182 i+=1
183 del i 183 del i
184 ds[0] 184 ds[0]
185 if len(ds)>2: 185 if len(ds)>2:
186 ds[:1] 186 ds[:1]