# HG changeset patch # User James Bergstra # Date 1212440998 14400 # Node ID c702abb7f87557ae97a5576f46d6850623a57d72 # Parent 3156a9976183b20c5faa6e4a092fdbc3a573193f# Parent c8f19a9eb10fac6eec3cf7c77f4e2476a1c25c24 merged diff -r 3156a9976183 -r c702abb7f875 dataset.py --- a/dataset.py Mon Jun 02 17:08:17 2008 -0400 +++ b/dataset.py Mon Jun 02 17:09:58 2008 -0400 @@ -47,14 +47,14 @@ columns/attributes are called fields. The field value for a particular example can be an arbitrary python object, which depends on the particular dataset. - We call a DataSet a 'stream' when its length is unbounded (otherwise its __len__ method + We call a DataSet a 'stream' when its length is unbounded (in which case its __len__ method should return sys.maxint). A DataSet is a generator of iterators; these iterators can run through the examples or the fields in a variety of ways. A DataSet need not necessarily have a finite or known length, so this class can be used to interface to a 'stream' which feeds on-line learning (however, as noted below, some operations are not - feasible or not recommanded on streams). + feasible or not recommended on streams). To iterate over examples, there are several possibilities: - for example in dataset: @@ -81,7 +81,7 @@ - for field_examples in dataset.fields(): for example_value in field_examples: ... - but when the dataset is a stream (unbounded length), it is not recommanded to do + but when the dataset is a stream (unbounded length), it is not recommended to do such things because the underlying dataset may refuse to access the different fields in an unsynchronized ways. Hence the fields() method is illegal for streams, by default. The result of fields() is a L{DataSetFields} object, which iterates over fields, @@ -599,7 +599,7 @@ * for field_examples in dataset.fields(): for example_value in field_examples: ... - but when the dataset is a stream (unbounded length), it is not recommanded to do + but when the dataset is a stream (unbounded length), it is not recommended to do such things because the underlying dataset may refuse to access the different fields in an unsynchronized ways. Hence the fields() method is illegal for streams, by default. The result of fields() is a DataSetFields object, which iterates over fields, @@ -1016,12 +1016,13 @@ def __getitem__(self,key): """More efficient implementation than the default __getitem__""" fieldnames=self.fields_columns.keys() + values=self.fields_columns.values() if type(key) is int: return Example(fieldnames, - [self.data[key,self.fields_columns[f]] for f in fieldnames]) + [self.data[key,col] for col in values]) if type(key) is slice: return MinibatchDataSet(Example(fieldnames, - [self.data[key,self.fields_columns[f]] for f in fieldnames])) + [self.data[key,col] for col in values])) if type(key) is list: for i in range(len(key)): if self.hasFields(key[i]): @@ -1030,9 +1031,10 @@ #we must separate differently for list as numpy # doesn't support self.data[[i1,...],[i2,...]] # when their is more then two i1 and i2 - [self.data[key,:][:,self.fields_columns[f]] - if isinstance(self.fields_columns[f],list) else - self.data[key,self.fields_columns[f]] for f in fieldnames]), + [self.data[key,:][:,col] + if isinstance(col,list) else + self.data[key,col] for col in values]), + self.valuesVStack,self.valuesHStack) @@ -1054,6 +1056,8 @@ assert offset>=0 and offset=self.dataset.data.shape[0]: raise StopIteration sub_data = self.dataset.data[self.current] - self.minibatch._values = [sub_data[self.dataset.fields_columns[f]] for f in self.minibatch._names] + self.minibatch._values = [sub_data[c] for c in self.columns] + self.current+=self.minibatch_size return self.minibatch diff -r 3156a9976183 -r c702abb7f875 test_dataset.py --- a/test_dataset.py Mon Jun 02 17:08:17 2008 -0400 +++ b/test_dataset.py Mon Jun 02 17:09:58 2008 -0400 @@ -194,38 +194,41 @@ m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4) assert isinstance(m,DataSet.MinibatchWrapAroundIterator) for x,y in m: - assert len(x)==3 - assert len(y)==3 - for id in range(3): + assert len(x)==m.minibatch_size + assert len(y)==m.minibatch_size + for id in range(m.minibatch_size): assert (numpy.append(x[id],y[id])==array[i+4]).all() i+=1 - assert i==3 + assert i==m.n_batches*m.minibatch_size del x,y,i,id,m i=0 m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4) assert isinstance(m,DataSet.MinibatchWrapAroundIterator) for x,y in m: - assert len(x)==3 - assert len(y)==3 - for id in range(3): + assert len(x)==m.minibatch_size + assert len(y)==m.minibatch_size + for id in range(m.minibatch_size): assert (numpy.append(x[id],y[id])==array[i+4]).all() i+=1 - assert i==6 + assert i==m.n_batches*m.minibatch_size del x,y,i,id,m i=0 m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4) assert isinstance(m,DataSet.MinibatchWrapAroundIterator) for x,y in m: - assert len(x)==3 - assert len(y)==3 - for id in range(3): + assert len(x)==m.minibatch_size + assert len(y)==m.minibatch_size + for id in range(m.minibatch_size): assert (numpy.append(x[id],y[id])==array[(i+4)%array.shape[0]]).all() i+=1 assert i==m.n_batches*m.minibatch_size del x,y,i,id + #@todo: we can't do minibatch bigger then the size of the dataset??? + assert have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0) + assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array),offset=0) def test_ds_iterator(array,iterator1,iterator2,iterator3): l=len(iterator1) @@ -494,10 +497,7 @@ print "test_speed" import time a2 = numpy.random.rand(100000,400) - ds = ArrayDataSet(a2,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested - ds = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested ds = ArrayDataSet(a2,{'all':slice(0,a2.shape[1],1)}) - #assert ds==a? should this work? mat = numpy.random.rand(400,100) @print_timing def f_array1(a):