# HG changeset patch # User Frederic Bastien # Date 1209999723 14400 # Node ID 158653a9bc7c949bcecbfe90a043cb62e10c51f3 # Parent 3499918faa9db49c131674cd46473d472ff0eb99# Parent 4b0859606d05037463a863e24ab05e2a46ae9750 Automated merge with ssh://p-omega1@lgcm.iro.umontreal.ca/tlearn diff -r 3499918faa9d -r 158653a9bc7c dataset.py --- a/dataset.py Mon May 05 09:35:30 2008 -0400 +++ b/dataset.py Mon May 05 11:02:03 2008 -0400 @@ -281,7 +281,7 @@ The minibatches iterator is expected to return upon each call to next() a DataSetFields object, which is a LookupList (indexed by the field names) whose - elements are iterable over the minibatch examples, and which keeps a pointer to + elements are iterable and indexable over the minibatch examples, and which keeps a pointer to a sub-dataset that can be used to iterate over the individual examples in the minibatch. Hence a minibatch can be converted back to a regular dataset or its fields can be looked at individually (and possibly iterated over). @@ -632,12 +632,13 @@ return self.length def __getitem__(self,i): + if type(i) in (slice,list): + return DataSetFields(MinibatchDataSet( + Example(self._fields.keys(),[field[i] for field in self._fields])),self.fieldNames()) if type(i) is int: - return Example(self._fields.keys(),[field[i] for field in self._fields]) - if type(i) in (slice,list): - return MinibatchDataSet(Example(self._fields.keys(), - [field[i] for field in self._fields]), - self.valuesVStack,self.valuesHStack) + return DataSetFields(MinibatchDataSet( + Example(self._fields.keys(),[[field[i]] for field in self._fields])),self.fieldNames()) + if self.hasFields(i): return self._fields[i] assert i in self.__dict__ # else it means we are trying to access a non-existing property @@ -939,22 +940,29 @@ def __len__(self): return len(self.data) - def __getitem__(self,i): + def __getitem__(self,key): """More efficient implementation than the default __getitem__""" fieldnames=self.fields_columns.keys() - if type(i) is int: + if type(key) is int: return Example(fieldnames, - [self.data[i,self.fields_columns[f]] for f in fieldnames]) - if type(i) in (slice,list): + [self.data[key,self.fields_columns[f]] for f in fieldnames]) + if type(key) is slice: return MinibatchDataSet(Example(fieldnames, - [self.data[i,self.fields_columns[f]] for f in fieldnames]), + [self.data[key,self.fields_columns[f]] for f in fieldnames])) + if type(key) is list: + for i in range(len(key)): + if self.hasFields(key[i]): + key[i]=self.fields_columns[key[i]] + return MinibatchDataSet(Example(fieldnames, + [self.data[key,self.fields_columns[f]] for f in fieldnames]), self.valuesVStack,self.valuesHStack) + # else check for a fieldname - if self.hasFields(i): - return Example([i],[self.data[self.fields_columns[i],:]]) + if self.hasFields(key): + return self.data[self.fields_columns[key],:] # else we are trying to access a property of the dataset - assert i in self.__dict__ # else it means we are trying to access a non-existing property - return self.__dict__[i] + assert key in self.__dict__ # else it means we are trying to access a non-existing property + return self.__dict__[key] def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): diff -r 3499918faa9d -r 158653a9bc7c lookup_list.py --- a/lookup_list.py Mon May 05 09:35:30 2008 -0400 +++ b/lookup_list.py Mon May 05 11:02:03 2008 -0400 @@ -1,5 +1,5 @@ -from copy import copy +from copy import deepcopy class LookupList(object): """ @@ -17,7 +17,9 @@ print example.items() # prints [('x',[1,2,3]),('y',2),('z',3)] example.append_keyval('u',0) # adds item with name 'u' and value 0 print len(example) # number of items = 4 here - print example+example # addition is like for lists, a concatenation of the items. + example2 = LookupList(['v', 'w'], ['a','b']) + print example+example2 # addition is like for lists, a concatenation of the items. + example + example # throw an error as we can't have duplicate name. Note that the element names should be unique. """ def __init__(self,names=[],values=[]): @@ -75,15 +77,23 @@ return "{%s}" % ", ".join([str(k) + "=" + repr(v) for k,v in self.items()]) def __add__(self,rhs): - new_example = copy(self) + new_example = deepcopy(self) for item in rhs.items(): new_example.append_keyval(item[0],item[1]) return new_example def __radd__(self,lhs): - new_example = copy(lhs) + new_example = deepcopy(lhs) for item in self.items(): new_example.append_keyval(item[0],item[1]) return new_example + def __eq__(self, other): + return self._values==other._values and self._name2index==other._name2index and self._names==other._names + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(): + raise NotImplementedError() diff -r 3499918faa9d -r 158653a9bc7c test_dataset.py --- a/test_dataset.py Mon May 05 09:35:30 2008 -0400 +++ b/test_dataset.py Mon May 05 11:02:03 2008 -0400 @@ -26,6 +26,9 @@ def test_ArrayDataSet(): #don't test stream #tested only with float value + #test with y too + #test missing value + a = numpy.random.rand(10,4) print a ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]}) @@ -36,17 +39,11 @@ assert ds[i]['y']==a[i][3] assert ds[i]['z'].all()==a[i][0:3:2].all() print "x=",ds["x"] - print "x|y" i=0 for x in ds('x','y'): assert numpy.append(x['x'],x['y']).all()==a[i].all() i+=1 -# i=0 -# for x in ds['x','y']: # don't work -# assert numpy.append(x['x'],x['y']).all()==a[i].all() -# i+=1 -# for (x,y) in (ds('x','y'),a): #don't work # haven't found a variant that work. -# assert numpy.append(x,y)==z + i=0 for x,y in ds('x','y'): assert numpy.append(x,y).all()==a[i].all() @@ -70,6 +67,14 @@ except : have_thrown = True assert have_thrown == True + + have_thrown = False + try: + ds[['h']] # h is not defined... + except : + have_thrown = True + assert have_thrown == True + assert len(ds.fields())==3 for field in ds.fields(): for field_value in field: # iterate over the values associated to that field for all the ds examples @@ -85,19 +90,29 @@ assert ds == ds.fields().examples() - #test missing value - + #ds[:n] returns a dataset with the n first examples. assert len(ds[:3])==3 i=0 for x,z in ds[:3]('x','z'): assert ds[i]['z'].all()==a[i][0:3:2].all() i+=1 + #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. - - #ds[i]# returns an Example. - + ds[1:7:2][1] #fail??? + assert len(ds[1:7:2])==3 # should be number example 1,3 and 5 + i=0 + for x,z in ds[1:7:2]('x','z'): + assert ds[i]['z'].all()==a[i][0:3:2].all() + i+=1 + ds2=ds[1:7:2] + for i in range(len(ds2)): + print ds2[i] #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. - + i=0 + for x in ds[[1,2]]: + assert numpy.append(x['x'],x['y']).all()==a[i].all() + i+=1 + #ds[i1,i2,...]# should we accept???? #ds[fieldname]# an iterable over the values of the field fieldname across #the ds (the iterable is obtained by default by calling valuesVStack #over the values for individual examples). @@ -109,6 +124,54 @@ #* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3]) #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3]) +# for (x,y) in (ds('x','y'),a): #don't work # haven't found a variant that work. +# assert numpy.append(x,y)==z + +def test_LookupList(): + #test only the example in the doc??? + example = LookupList(['x','y','z'],[1,2,3]) + example['x'] = [1, 2, 3] # set or change a field + x, y, z = example + x = example[0] + x = example["x"] + assert example.keys()==['x','y','z'] + assert example.values()==[[1,2,3],2,3] + assert example.items()==[('x',[1,2,3]),('y',2),('z',3)] + example.append_keyval('u',0) # adds item with name 'u' and value 0 + assert len(example)==4 # number of items = 4 here + example2 = LookupList(['v','w'], ['a','b']) + example3 = LookupList(['x','y','z','u','v','w'], [[1, 2, 3],2,3,0,'a','b']) + print example3 + print example+example2 + print example+example2 + assert example+example2==example3 + have_throw=False + try: + example+example + except: + have_throw=True + assert have_throw + +def ApplyFunctionDataSet(): + raise NotImplementedError() +def CacheDataSet(): + raise NotImplementedError() +def FieldsSubsetDataSet(): + raise NotImplementedError() +def DataSetFields(): + raise NotImplementedError() +def MinibatchDataSet(): + raise NotImplementedError() +def HStackedDataSet(): + raise NotImplementedError() +def VStackedDataSet(): + raise NotImplementedError() +def ArrayFieldsDataSet(): + raise NotImplementedError() + test1() +test_LookupList() test_ArrayDataSet() + +