# HG changeset patch # User bengioy@grenat.iro.umontreal.ca # Date 1209153631 14400 # Node ID 9b68774fcc6b0ed663bd0b1dea143137c806d3d4 # Parent 283e95c15b47c6f4fded1a89abc13002725a1fdb Testing basic functionality and removing obvious bugs diff -r 283e95c15b47 -r 9b68774fcc6b dataset.py --- a/dataset.py Fri Apr 25 12:04:55 2008 -0400 +++ b/dataset.py Fri Apr 25 16:00:31 2008 -0400 @@ -1,13 +1,13 @@ from lookup_list import LookupList Example = LookupList -from misc import * -import copy -import string +from misc import unique_elements_list_intersection +from string import join +from sys import maxint class AbstractFunction (Exception): """Derived class must override this function""" class NotImplementedYet (NotImplementedError): """Work in progress, this should eventually be implemented""" -class UnboundedDataSet (Exception): """Trying to obtain length of unbounded dataset (a stream)""" +#class UnboundedDataSet (Exception): """Trying to obtain length of unbounded dataset (a stream)""" class DataSet(object): """A virtual base class for datasets. @@ -124,7 +124,7 @@ def __init__(self,description=None,field_types=None): if description is None: # by default return "(,,...)" - description = type(self).__name__ + " ( " + string.join([x.__name__ for x in type(self).__bases__]) + " )" + description = type(self).__name__ + " ( " + join([x.__name__ for x in type(self).__bases__]) + " )" self.description=description self.field_types=field_types @@ -143,7 +143,7 @@ return self def next(self): size1_minibatch = self.minibatch_iterator.next() - return Example(size1_minibatch.keys,[value[0] for value in size1_minibatch.values()]) + return Example(size1_minibatch.keys(),[value[0] for value in size1_minibatch.values()]) def next_index(self): return self.minibatch_iterator.next_index() @@ -197,11 +197,11 @@ return self.next_row def next(self): - if self.n_batches and self.n_batches_done==self.n_batches: + if self.n_batches and self.n_batches_done==self.n_batches raise StopIteration - upper = self.next_row+minibatch_size + upper = self.next_row+self.minibatch_size if upper <=self.L: - minibatch = self.minibatch_iterator.next() + minibatch = self.iterator.next() else: if not self.n_batches: raise StopIteration @@ -214,8 +214,8 @@ for name in self.fieldnames]) self.next_row=upper self.n_batches_done+=1 - if upper >= L: - self.next_row -= L + if upper >= self.L: + self.next_row -= self.L return minibatch @@ -275,7 +275,7 @@ any other object that supports integer indexing and slicing. """ - return MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset) + return DataSet.MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset) def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): """ @@ -322,14 +322,14 @@ """ Return a dataset that sees only the fields whose name are specified. """ - assert self.hasFields(fieldnames) - return self.fields(fieldnames).examples() + assert self.hasFields(*fieldnames) + return self.fields(*fieldnames).examples() def fields(self,*fieldnames): """ Return a DataSetFields object associated with this dataset. """ - return DataSetFields(self,fieldnames) + return DataSetFields(self,*fieldnames) def __getitem__(self,i): """ @@ -371,7 +371,7 @@ rows = i if rows is not None: fields_values = zip(*[self[row] for row in rows]) - return MinibatchDataSet( + return DataSet.MinibatchDataSet( Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) for fieldname,field_values in zip(self.fieldNames(),fields_values)])) @@ -459,7 +459,41 @@ return datasets[0] return VStackedDataSet(datasets) +class FieldsSubsetDataSet(DataSet): + """ + A sub-class of DataSet that selects a subset of the fields. + """ + def __init__(self,src,fieldnames): + self.src=src + self.fieldnames=fieldnames + assert src.hasFields(*fieldnames) + self.valuesHStack = src.valuesHStack + self.valuesVStack = src.valuesVStack + def __len__(self): return len(self.src) + + def fieldNames(self): + return self.fieldnames + + def __iter__(self): + class Iterator(object): + def __init__(self,ds): + self.ds=ds + self.src_iter=ds.src.__iter__() + def __iter__(self): return self + def next(self): + example = self.src_iter.next() + return Example(self.ds.fieldnames, + [example[field] for field in self.ds.fieldnames]) + return Iterator(self) + + def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): + assert self.hasFields(*fieldnames) + return self.src.minibatches_nowrap(fieldnames,minibatch_size,n_batches,offset) + def __getitem__(self,i): + return FieldsSubsetDataSet(self.src[i],self.fieldnames) + + class DataSetFields(LookupList): """ Although a DataSet iterates over examples (like rows of a matrix), an associated @@ -488,13 +522,18 @@ the examples. """ def __init__(self,dataset,*fieldnames): - self.dataset=dataset if not fieldnames: fieldnames=dataset.fieldNames() + elif fieldnames is not dataset.fieldNames(): + dataset = FieldsSubsetDataSet(dataset,fieldnames) assert dataset.hasFields(*fieldnames) - LookupList.__init__(self,dataset.fieldNames(), - dataset.minibatches(fieldnames if len(fieldnames)>0 else self.fieldNames(), - minibatch_size=len(dataset)).next()) + self.dataset=dataset + minibatch_iterator = dataset.minibatches(fieldnames, + minibatch_size=len(dataset), + n_batches=1) + minibatch=minibatch_iterator.next() + LookupList.__init__(self,fieldnames,minibatch) + def examples(self): return self.dataset @@ -813,16 +852,16 @@ """ Construct an ArrayDataSet from the underlying numpy array (data) and - a map from fieldnames to field columns. The columns of a field are specified + a map (fields_columns) from fieldnames to field columns. The columns of a field are specified using the standard arguments for indexing/slicing: integer for a column index, slice for an interval of columns (with possible stride), or iterable of column indices. """ - def __init__(self, data_array, fields_names_columns): + def __init__(self, data_array, fields_columns): self.data=data_array - self.fields=fields_names_columns + self.fields_columns=fields_columns # check consistency and complete slices definitions - for fieldname, fieldcolumns in self.fields.items(): + for fieldname, fieldcolumns in self.fields_columns.items(): if type(fieldcolumns) is int: assert fieldcolumns>=0 and fieldcolumns=0 and i=0 and offset=0 and offset