Mercurial > pylearn
changeset 12:ff4e551490f1
Added LookupList type in lookup_list.py and used it to keep order
of field names in Example in ArrayDataSet. Example is now just = LookupList.
author | bengioy@esprit.iro.umontreal.ca |
---|---|
date | Wed, 26 Mar 2008 18:21:57 -0400 |
parents | be128b9127c8 |
children | 633453635d51 759d17112b23 |
files | dataset.py lookup_list.py |
diffstat | 2 files changed, 93 insertions(+), 58 deletions(-) [+] |
line wrap: on
line diff
--- a/dataset.py Wed Mar 26 15:01:30 2008 -0400 +++ b/dataset.py Wed Mar 26 18:21:57 2008 -0400 @@ -1,42 +1,7 @@ -class Example(object): - """ - An example is something that is like a tuple but whose elements can be named, to that - following syntactic constructions work as one would expect: - example.x = [1, 2, 3] # set a field - x, y, z = example - x = example[0] - x = example["x"] - """ - def __init__(self,names,values): - assert len(values)==len(names) - self.__dict__['values']=values - self.__dict__['fields']={} - for i in xrange(len(values)): - self.fields[names[i]]=i - - def __getitem__(self,i): - if isinstance(i,int): - return self.values[i] - else: - return self.values[self.fields[i]] - - def __setitem__(self,i,value): - if isinstance(i,int): - self.values[i]=value - else: - self.values[self.fields[i]]=value - - def __getattr__(self,name): - return self.values[self.fields[name]] - - def __setattr__(self,name,value): - self.values[self.fields[name]]=value - - def __len__(self): - return len(self.values) - - +from lookup_list import LookupList +Example = LookupList + class DataSet(object): """ This is a virtual base class or interface for datasets. @@ -192,15 +157,15 @@ by the numpy.array(dataset) call. """ - def __init__(self,dataset=None,data=None,fields={}): + def __init__(self,dataset=None,data=None,fields=None): """ There are two ways to construct an ArrayDataSet: (1) from an existing dataset (which may result in a copy of the data in a numpy array), or (2) from a numpy.array (the data argument), along with an optional description - of the fields (dictionary of column slices indexed by field names). + of the fields (a LookupList of column slices indexed by field names). """ if dataset!=None: - assert data==None and fields=={} + assert data==None and fields==None # Make ONE big minibatch with all the examples, to separate the fields. n_examples=len(dataset) batch = dataset.minibatches(n_examples).next() @@ -210,6 +175,7 @@ fieldnames = batch.fields.keys() total_width = 0 type = None + fields = LookupList() for i in xrange(n_fields): field = array(batch[i]) assert field.shape[0]==n_examples @@ -227,19 +193,19 @@ self.data=data self.fields=fields self.width = data.shape[1] - for fieldname in fields: - fieldslice=fields[fieldname] - # make sure fieldslice.start and fieldslice.step are defined - start=fieldslice.start - step=fieldslice.step - if not start: - start=0 - if not step: - step=1 - if not fieldslice.start or not fieldslice.step: - fields[fieldname] = fieldslice = slice(start,fieldslice.stop,step) - # and coherent with the data array - assert fieldslice.start>=0 and fieldslice.stop<=self.width + if fields: + for fieldname,fieldslice in fields.items(): + # make sure fieldslice.start and fieldslice.step are defined + start=fieldslice.start + step=fieldslice.step + if not start: + start=0 + if not step: + step=1 + if not fieldslice.start or not fieldslice.step: + fields[fieldname] = fieldslice = slice(start,fieldslice.stop,step) + # and coherent with the data array + assert fieldslice.start>=0 and fieldslice.stop<=self.width def __getattr__(self,fieldname): """ @@ -258,9 +224,9 @@ for field_slice in self.fields.values(): min_col=min(min_col,field_slice.start) max_col=max(max_col,field_slice.stop) - new_fields={} - for field in self.fields: - new_fields[field[0]]=slice(field[1].start-min_col,field[1].stop-min_col,field[1].step) + new_fields=LookupList() + for fieldname,fieldslice in self.fields.items(): + new_fields[fieldname]=slice(fieldslice.start-min_col,fieldslice.stop-min_col,fieldslice.step) return ArrayDataSet(data=self.data[:,min_col:max_col],fields=new_fields) def fieldNames(self): @@ -278,7 +244,7 @@ """ if self.fields: fieldnames,fieldslices=zip(*self.fields.items()) - return Example(fieldnames,[self.data[i,fieldslice] for fieldslice in fieldslices]) + return Example(self.fields.keys(),[self.data[i,fieldslice] for fieldslice in self.fields.values()]) else: return self.data[i]
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/lookup_list.py Wed Mar 26 18:21:57 2008 -0400 @@ -0,0 +1,69 @@ + +class LookupList(object): + """ + A LookupList is a sequence whose elements can be named (and unlike + a dictionary the order of the elements depends not on their key but + on the order given by the user through construction) so that + following syntactic constructions work as one would expect: + example = Example(['x','y','z'],[1,2,3]) + example.x = [1, 2, 3] # set or change a field + x, y, z = example + x = example[0] + x = example["x"] + print example.keys() # returns ['x','y','z'] + print example.values() # returns [[1,2,3],2,3] + """ + def __init__(self,names=[],values=[]): + assert len(values)==len(names) + self.__dict__['_values']=values + self.__dict__['_name2index']={} + self.__dict__['_names']=names + for i in xrange(len(values)): + self._name2index[names[i]]=i + + def keys(self): + return _names + + def values(self): + return _values + + def items(self): + return zip(self._names,self._values) + + def __getitem__(self,key): + """ + The key in example[key] can either be an integer to index the fields + or the name of the field. + """ + if isinstance(key,int): + return self._values[key] + else: # if not an int, key must be a name + return self._values[self._name2index[key]] + + def __setitem__(self,key,value): + if isinstance(key,int): + self._values[key]=value + else: # if not an int, key must be a name + if key in self._name2index: + self._values[self._name2index[key]]=value + else: + self._name2index[key]=len(self) + self._values.append(value) + self._names.append(key) + + def __getattr__(self,name): + return self._values[self._name2index[name]] + + def __setattr__(self,name,value): + if name in self._name2index: + self._values[self._name2index[name]]=value + else: + self._name2index[name]=len(self) + self._values.append(value) + self._names.append(name) + + def __len__(self): + return len(self._values) + + def __repr__(self): + return "{%s}" % ", ".join([str(k) + "=" + repr(v) for k,v in self.items()])