Mercurial > pylearn
changeset 151:39bb21348fdf
Automated merge with ssh://p-omega1@lgcm.iro.umontreal.ca/tlearn
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Mon, 12 May 2008 15:51:43 -0400 |
parents | 625d2b21ee48 (current diff) 9abd19af822e (diff) |
children | 3f627e844cba 71107b0ac860 ae5651a3696b |
files | dataset.py |
diffstat | 2 files changed, 34 insertions(+), 4 deletions(-) [+] |
line wrap: on
line diff
--- a/dataset.py Mon May 12 15:50:34 2008 -0400 +++ b/dataset.py Mon May 12 15:51:43 2008 -0400 @@ -1077,7 +1077,7 @@ for example in self.dataset.source_dataset[cache_len:upper]: self.dataset.cached_examples.append(example) all_fields_minibatch = Example(self.dataset.fieldNames(), - self.dataset.cached_examples[self.current:self.current+minibatch_size]) + *self.dataset.cached_examples[self.current:self.current+minibatch_size]) if self.dataset.fieldNames()==fieldnames: return all_fields_minibatch return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames])
--- a/test_dataset.py Mon May 12 15:50:34 2008 -0400 +++ b/test_dataset.py Mon May 12 15:51:43 2008 -0400 @@ -383,12 +383,41 @@ assert example+example2==example3 assert have_raised("var['x']+var['x']",x=example) +def test_CacheDataSet(): + print "test_CacheDataSet" + a2 = numpy.random.rand(10,4) + ds1 = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested + ds2 = CachedDataSet(ds1) + ds3 = CachedDataSet(ds1,cache_all_upon_construction=True) + assert len(ds2)==10 + + test_iterate_over_examples(a2, ds2) + test_getitem(a2, ds2) + +# - for val1,val2,val3 in dataset(field1, field2,field3): + test_ds_iterator(a2,ds2('x','y'),ds2('y','z'),ds2('x','y','z')) + + + assert len(ds2.fields())==3 + for field in ds2.fields(): + for field_value in field: # iterate over the values associated to that field for all the ds examples + pass + for field in ds2('x','z').fields(): + pass + for field in ds2.fields('x','y'): + pass + for field_examples in ds2.fields(): + for example_value in field_examples: + pass + + assert ds2 == ds2.fields().examples() +# for ((x,y),a_v) in (ds('x','y'),a): #???don't work # haven't found a variant that work.# will not work +# assert numpy.append(x,y)==z + + def test_ApplyFunctionDataSet(): print "test_ApplyFunctionDataSet" raise NotImplementedError() -def test_CacheDataSet(): - print "test_CacheDataSet" - raise NotImplementedError() def test_FieldsSubsetDataSet(): print "test_FieldsSubsetDataSet" raise NotImplementedError() @@ -411,4 +440,5 @@ test1() test_LookupList() test_ArrayDataSet() +test_CacheDataSet() #test pmat.py