Mercurial > pylearn
comparison test_dataset.py @ 94:9c8f3c9c247b
corrected the use of .all()
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Mon, 05 May 2008 17:44:49 -0400 |
parents | eee739fefdff |
children | 352910e0dbf5 |
comparison
equal
deleted
inserted
replaced
91:eee739fefdff | 94:9c8f3c9c247b |
---|---|
45 assert len(ds)==10 | 45 assert len(ds)==10 |
46 #assert ds==a? should this work? | 46 #assert ds==a? should this work? |
47 | 47 |
48 #not in doc!!! | 48 #not in doc!!! |
49 for example in range(len(ds)): | 49 for example in range(len(ds)): |
50 assert ds[example]['x'].all()==a[example][:2].all() | 50 assert (ds[example]['x']==a[example][:3]).all() |
51 assert ds[example]['y']==a[example][3] | 51 assert ds[example]['y']==a[example][3] |
52 assert ds[example]['z'].all()==a[example][0:3:2].all() | 52 assert (ds[example]['z']==a[example][[0,2]]).all() |
53 | 53 |
54 # - for example in dataset: | 54 # - for example in dataset: |
55 i=0 | 55 i=0 |
56 for example in ds: | 56 for example in ds: |
57 assert example['x'].all()==a[i][:2].all() | 57 assert (example['x']==a[i][:3]).all() |
58 assert example['y']==a[i][3] | 58 assert example['y']==a[i][3] |
59 assert example['z'].all()==a[i][0:3:2].all() | 59 assert (example['z']==a[i][0:3:2]).all() |
60 assert numpy.append(example['x'],example['y']).all()==a[i].all() | 60 assert (numpy.append(example['x'],example['y'])==a[i]).all() |
61 i+=1 | 61 i+=1 |
62 assert i==len(ds) | 62 assert i==len(ds) |
63 # - for val1,val2,... in dataset: | 63 # - for val1,val2,... in dataset: |
64 i=0 | 64 i=0 |
65 for x,y,z in ds: | 65 for x,y,z in ds: |
66 assert x.all()==a[i][:2].all() | 66 assert (x==a[i][:3]).all() |
67 assert y==a[i][3] | 67 assert y==a[i][3] |
68 assert z.all()==a[i][0:3:2].all() | 68 assert (z==a[i][0:3:2]).all() |
69 assert numpy.append(example['x'],example['y']).all()==a[i].all() | 69 assert (numpy.append(x,y)==a[i]).all() |
70 i+=1 | 70 i+=1 |
71 assert i==len(ds) | 71 assert i==len(ds) |
72 # - for example in dataset(field1, field2,field3, ...): | 72 # - for example in dataset(field1, field2,field3, ...): |
73 i=0 | 73 i=0 |
74 for example in ds('x','y','z'): | 74 for example in ds('x','y','z'): |
75 assert example['x'].all()==a[i][:2].all() | 75 assert (example['x']==a[i][:3]).all() |
76 assert example['y']==a[i][3] | 76 assert example['y']==a[i][3] |
77 assert example['z'].all()==a[i][0:3:2].all() | 77 assert (example['z']==a[i][0:3:2]).all() |
78 assert numpy.append(example['x'],example['y']).all()==a[i].all() | 78 assert (numpy.append(example['x'],example['y'])==a[i]).all() |
79 i+=1 | 79 i+=1 |
80 assert i==len(ds) | 80 assert i==len(ds) |
81 | 81 |
82 # - for val1,val2,val3 in dataset(field1, field2,field3): | 82 # - for val1,val2,val3 in dataset(field1, field2,field3): |
83 | 83 |
84 # - for example in dataset(field1, field2,field3, ...): | 84 # - for example in dataset(field1, field2,field3, ...): |
85 | 85 |
86 def test_ds_iterator(iterator1,iterator2,iterator3): | 86 def test_ds_iterator(iterator1,iterator2,iterator3): |
87 i=0 | 87 i=0 |
88 for x,y in iterator1: | 88 for x,y in iterator1: |
89 assert x.all()==a[i][:2].all() | 89 assert (x==a[i][:3]).all() |
90 assert y==a[i][3] | 90 assert y==a[i][3] |
91 assert numpy.append(x,y).all()==a[i].all() | 91 assert (numpy.append(x,y)==a[i]).all() |
92 i+=1 | 92 i+=1 |
93 assert i==len(ds) | 93 assert i==len(ds) |
94 i=0 | 94 i=0 |
95 for y,z in iterator2: | 95 for y,z in iterator2: |
96 assert y==a[i][3] | 96 assert y==a[i][3] |
97 assert z.all()==a[i][0:3:2].all() | 97 assert (z==a[i][0:3:2]).all() |
98 i+=1 | 98 i+=1 |
99 assert i==len(ds) | 99 assert i==len(ds) |
100 i=0 | 100 i=0 |
101 for x,y,z in iterator3: | 101 for x,y,z in iterator3: |
102 assert x.all()==a[i][:2].all() | 102 assert (x==a[i][:3]).all() |
103 assert y==a[i][3] | 103 assert y==a[i][3] |
104 assert z.all()==a[i][0:3:2].all() | 104 assert (z==a[i][0:3:2]).all() |
105 assert numpy.append(x,y).all()==a[i].all() | 105 assert (numpy.append(x,y)==a[i]).all() |
106 i+=1 | 106 i+=1 |
107 assert i==len(ds) | 107 assert i==len(ds) |
108 | 108 |
109 #not in doc!!! - for val1,val2,val3 in dataset(field1, field2,field3): | 109 #not in doc!!! - for val1,val2,val3 in dataset(field1, field2,field3): |
110 test_ds_iterator(ds('x','y'),ds('y','z'),ds('x','y','z')) | 110 test_ds_iterator(ds('x','y'),ds('y','z'),ds('x','y','z')) |
112 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): | 112 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): |
113 for minibatch in ds.minibatches(['x','z'], minibatch_size=3): | 113 for minibatch in ds.minibatches(['x','z'], minibatch_size=3): |
114 assert len(minibatch)==2 | 114 assert len(minibatch)==2 |
115 assert len(minibatch[0])==3 | 115 assert len(minibatch[0])==3 |
116 assert len(minibatch[1])==3 | 116 assert len(minibatch[1])==3 |
117 assert minibatch[0][:,0:3:2].all()==minibatch[1].all() | 117 assert (minibatch[0][:,0:3:2]==minibatch[1]).all() |
118 i=0 | 118 i=0 |
119 for minibatch in ds.minibatches(['x','y'], minibatch_size=3): | 119 for minibatch in ds.minibatches(['x','y'], minibatch_size=3): |
120 assert len(minibatch)==2 | 120 assert len(minibatch)==2 |
121 assert len(minibatch[0])==3 | 121 assert len(minibatch[0])==3 |
122 assert len(minibatch[1])==3 | 122 assert len(minibatch[1])==3 |
123 for id in range(3): | 123 for id in range(3): |
124 assert numpy.append(minibatch[0][id],minibatch[1][id]).all()==a[i].all() | 124 assert (numpy.append(minibatch[0][id],minibatch[1][id])==a[i]).all() |
125 i+=1 | 125 i+=1 |
126 | 126 |
127 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): | 127 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): |
128 for x,z in ds.minibatches(['x','z'], minibatch_size=3): | 128 for x,z in ds.minibatches(['x','z'], minibatch_size=3): |
129 assert len(x)==3 | 129 assert len(x)==3 |
130 assert len(z)==3 | 130 assert len(z)==3 |
131 assert x[:,0:3:2].all()==z.all() | 131 assert (x[:,0:3:2]==z).all() |
132 i=0 | 132 i=0 |
133 for x,y in ds.minibatches(['x','y'], minibatch_size=3): | 133 for x,y in ds.minibatches(['x','y'], minibatch_size=3): |
134 assert len(x)==3 | 134 assert len(x)==3 |
135 assert len(y)==3 | 135 assert len(y)==3 |
136 for id in range(3): | 136 for id in range(3): |
137 assert numpy.append(minibatch[0][id],minibatch[1][id]).all()==a[i].all() | 137 assert (numpy.append(x[id],y[id])==a[i]).all() |
138 i+=1 | 138 i+=1 |
139 # - for x,y,z in dataset: # fail x,y,z order not fixed as it is a dict. | 139 # - for x,y,z in dataset: # fail x,y,z order not fixed as it is a dict. |
140 # for x,y,z in ds: | 140 # for x,y,z in ds: |
141 # assert x.all()==a[i][:2].all() | 141 # assert (x==a[i][:2]).all() |
142 # assert y==a[i][3] | 142 # assert y==a[i][3] |
143 # assert z.all()==a[i][0:3:2].all() | 143 # assert (z==a[i][0:3:2]).all() |
144 # assert numpy.append(x,y).all()==a[i].all() | 144 # assert (numpy.append(x,y)==a[i]).all() |
145 # i+=1 | 145 # i+=1 |
146 | 146 |
147 # for minibatch in ds.minibatches(['z','y'], minibatch_size=3): | 147 # for minibatch in ds.minibatches(['z','y'], minibatch_size=3): |
148 # print minibatch | 148 # print minibatch |
149 # minibatch_iterator = ds.minibatches(fieldnames=['z','y'],n_batches=1,minibatch_size=3,offset=4) | 149 # minibatch_iterator = ds.minibatches(fieldnames=['z','y'],n_batches=1,minibatch_size=3,offset=4) |
171 | 171 |
172 def test_ds(orig,ds,index): | 172 def test_ds(orig,ds,index): |
173 i=0 | 173 i=0 |
174 assert len(ds)==len(index) | 174 assert len(ds)==len(index) |
175 for x,z,y in ds('x','z','y'): | 175 for x,z,y in ds('x','z','y'): |
176 assert orig[index[i]]['x'].all()==a[index[i]][:3].all() | 176 assert (orig[index[i]]['x']==a[index[i]][:3]).all() |
177 assert orig[index[i]]['x'].all()==x.all() | 177 assert (orig[index[i]]['x']==x).all() |
178 assert orig[index[i]]['y']==a[index[i]][3] | 178 assert orig[index[i]]['y']==a[index[i]][3] |
179 assert orig[index[i]]['y']==y | 179 assert orig[index[i]]['y']==y |
180 assert orig[index[i]]['z'].all()==a[index[i]][0:3:2].all() | 180 assert (orig[index[i]]['z']==a[index[i]][0:3:2]).all() |
181 assert orig[index[i]]['z'].all()==z.all() | 181 assert (orig[index[i]]['z']==z).all() |
182 i+=1 | 182 i+=1 |
183 del i | 183 del i |
184 ds[0] | 184 ds[0] |
185 if len(ds)>2: | 185 if len(ds)>2: |
186 ds[:1] | 186 ds[:1] |