# HG changeset patch # User bengioy@zircon.iro.umontreal.ca # Date 1208901791 14400 # Node ID 438440ba0627c29b8652d31ba656a8808705a04b # Parent 2508c373cf2935f1753f40241630379e648fee92 Rewriting dataset.py completely diff -r 2508c373cf29 -r 438440ba0627 dataset.py --- a/dataset.py Fri Apr 18 01:36:56 2008 -0400 +++ b/dataset.py Tue Apr 22 18:03:11 2008 -0400 @@ -1,23 +1,33 @@ from lookup_list import LookupList Example = LookupList +from misc import * import copy class AbstractFunction (Exception): """Derived class must override this function""" - +class NotImplementedYet (NotImplementedError): """Work in progress, this should eventually be implemented""" + class DataSet(object): """A virtual base class for datasets. + A DataSet can be seen as a generalization of a matrix, meant to be used in conjunction + with learning algorithms (for training and testing them): rows/records are called examples, and + columns/attributes are called fields. The field value for a particular example can be an arbitrary + python object, which depends on the particular dataset. + + We call a DataSet a 'stream' when its length is unbounded (len(dataset)==float("infinity")). + A DataSet is a generator of iterators; these iterators can run through the - examples in a variety of ways. A DataSet need not necessarily have a finite + examples or the fields in a variety of ways. A DataSet need not necessarily have a finite or known length, so this class can be used to interface to a 'stream' which - feeds on-line learning. + feeds on-line learning (however, as noted below, some operations are not + feasible or not recommanded on streams). To iterate over examples, there are several possibilities: - - for example in dataset.zip([field1, field2,field3, ...]) - - for val1,val2,val3 in dataset.zip([field1, field2,field3]) - - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N) - - for example in dataset + * for example in dataset([field1, field2,field3, ...]): + * for val1,val2,val3 in dataset([field1, field2,field3]): + * for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): + * for example in dataset: Each of these is documented below. All of these iterators are expected to provide, in addition to the usual 'next()' method, a 'next_index()' method which returns a non-negative integer pointing to the position of the next @@ -26,40 +36,97 @@ can wrap around the dataset in order to do multiple passes through it, in possibly unregular ways if the minibatch size is not a divisor of the dataset length. - + + To iterate over fields, one can do + * for fields in dataset.fields() + * for fields in dataset(field1,field2,...).fields() to select a subset of fields + * for fields in dataset.fields(field1,field2,...) to select a subset of fields + and each of these fields is iterable over the examples: + * for field_examples in dataset.fields(): + for example_value in field_examples: + ... + but when the dataset is a stream (unbounded length), it is not recommanded to do + such things because the underlying dataset may refuse to access the different fields in + an unsynchronized ways. Hence the fields() method is illegal for streams, by default. + The result of fields() is a DataSetFields object, which iterates over fields, + and whose elements are iterable over examples. A DataSetFields object can + be turned back into a DataSet with its examples() method: + dataset2 = dataset1.fields().examples() + and dataset2 should behave exactly like dataset1 (in fact by default dataset2==dataset1). + Note: Fields are not mutually exclusive, i.e. two fields can overlap in their actual content. - Note: The content of a field can be of any type. + Note: The content of a field can be of any type. Field values can also be 'missing' + (e.g. to handle semi-supervised learning), and in the case of numeric (numpy array) + fields (i.e. an ArrayFieldsDataSet), NaN plays the role of a missing value. + + Dataset elements can be indexed and sub-datasets (with a subset + of examples) can be extracted. These operations are not supported + by default in the case of streams. + + * dataset[:n] returns a dataset with the n first examples. - Note: A dataset can recognize a potentially infinite number of field names (i.e. the field - values can be computed on-demand, when particular field names are used in one of the - iterators). + * dataset[i1:i2:s] returns a dataset with the examples i1,i1+s,...i2-s. + + * dataset[i] returns an Example. + + * dataset[[i1,i2,...in]] returns a dataset with examples i1,i2,...in. + + Datasets can be concatenated either vertically (increasing the length) or + horizontally (augmenting the set of fields), if they are compatible, using + the following operations (with the same basic semantics as numpy.hstack + and numpy.vstack): + + * dataset1 | dataset2 | dataset3 == dataset.hstack([dataset1,dataset2,dataset3]) - Datasets of finite length should be sub-classes of FiniteLengthDataSet. + creates a new dataset whose list of fields is the concatenation of the list of + fields of the argument datasets. This only works if they all have the same length. + + * dataset1 + dataset2 + dataset3 == dataset.vstack([dataset1,dataset2,dataset3]) + + creates a new dataset that concatenates the examples from the argument datasets + (and whose length is the sum of the length of the argument datasets). This only + works if they all have the same fields. - Datasets whose elements can be indexed and whose sub-datasets (with a subset - of examples) can be extracted should be sub-classes of - SliceableDataSet. + According to the same logic, and viewing a DataSetFields object associated to + a DataSet as a kind of transpose of it, fields1 + fields2 concatenates fields of + a DataSetFields fields1 and fields2, and fields1 | fields2 concatenates their + examples. + - Datasets with a finite number of fields should be sub-classes of - FiniteWidthDataSet. + A DataSet sub-class should always redefine the following methods: + * __len__ if it is not a stream + * __getitem__ may not be feasible with some streams + * fieldNames + * minibatches + * valuesHStack + * valuesVStack + For efficiency of implementation, a sub-class might also want to redefine + * hasFields """ + infinity = float("infinity") + def __init__(self): pass - class Iterator(LookupList): - def __init__(self, ll): - LookupList.__init__(self, ll.keys(), ll.values()) - self.ll = ll + class MinibatchToSingleExampleIterator(object): + """ + Converts the result of minibatch iterator with minibatch_size==1 into + single-example values in the result. Therefore the result of + iterating on the dataset itself gives a sequence of single examples + (whereas the result of iterating over minibatches gives in each + Example field an iterable object over the individual examples in + the minibatch). + """ + def __init__(self, minibatch_iterator): + self.minibatch_iterator = minibatch_iterator def __iter__(self): #makes for loop work return self def next(self): - self.ll.next() - self._values = [v[0] for v in self.ll._values] - return self + return self.minibatch_iterator.next()[0] def next_index(self): - return self.ll.next_index() + return self.minibatch_iterator.next_index() def __iter__(self): """Supports the syntax "for i in dataset: ..." @@ -70,28 +137,10 @@ i["fielname"] or i[3] (in the order defined by the elements of the Example returned by this iterator), but the derived class is free to accept any type of identifier, and add extra functionality to the iterator. - """ - return DataSet.Iterator(self.minibatches(None, minibatch_size = 1)) - def zip(self, *fieldnames): + The default implementation calls the minibatches iterator and extracts the first example of each field. """ - Supports two forms of syntax: - - for i in dataset.zip([f1, f2, f3]): ... - - for i1, i2, i3 in dataset.zip([f1, f2, f3]): ... - - Using the first syntax, "i" will be an indexable object, such as a list, - tuple, or Example instance, such that on every iteration, i[0] is the f1 - field of the current example, i[1] is the f2 field, and so on. - - Using the second syntax, i1, i2, i3 will contain the the contents of the - f1, f2, and f3 fields of a single example on each loop iteration. - - The derived class may accept fieldname arguments of any type. - - """ - return DataSet.Iterator(self.minibatches(fieldnames, minibatch_size = 1)) + return DataSet.MinibatchToSingleExampleIterator(self.minibatches(None, minibatch_size = 1)) minibatches_fieldnames = None minibatches_minibatch_size = 1 @@ -101,7 +150,7 @@ minibatch_size = minibatches_minibatch_size, n_batches = minibatches_n_batches): """ - Supports three forms of syntax: + Return an iterator that supports three forms of syntax: for i in dataset.minibatches(None,**kwargs): ... @@ -122,6 +171,13 @@ Using the third syntax, i1, i2, i3 will be list-like containers of the f1, f2, and f3 fields of a batch of examples on each loop iteration. + The minibatches iterator is expected to return upon each call to next() + a DataSetFields object, which is a LookupList (indexed by the field names) whose + elements are iterable over the minibatch examples, and which keeps a pointer to + a sub-dataset that can be used to iterate over the individual examples + in the minibatch. Hence a minibatch can be converted back to a regular + dataset or its fields can be looked at individually (and possibly iterated over). + PARAMETERS - fieldnames (list of any type, default None): The loop variables i1, i2, i3 (in the example above) should contain the @@ -143,134 +199,48 @@ """ raise AbstractFunction() + + def __len__(self): + """ + len(dataset) returns the number of examples in the dataset. + By default, a DataSet is a 'stream', i.e. it has an unbounded (infinite) length. + Sub-classes which implement finite-length datasets should redefine this method. + Some methods only make sense for finite-length datasets, and will perform + assert len(dataset)0 + + def fieldNames(self): + """ + Return the list of field names that are supported by the iterators, + and for which hasFields(fieldname) would return True. """ raise AbstractFunction() - def merge_field_values(self,*field_value_pairs): - """ - Return the value that corresponds to merging the values of several fields, - given as arguments (field_name, field_value) pairs with self.hasField(field_name). - This may be used by implementations of merge_fields. - Raise a ValueError if the operation is not possible. - """ - fieldnames,fieldvalues = zip(*field_value_pairs) - raise ValueError("Unable to merge values of these fields:"+repr(fieldnames)) - - def examples2minibatch(self,examples): - """ - Combine a list of Examples into a minibatch. A minibatch is an Example whose fields - are iterable over the examples of the minibatch. - """ - raise AbstractFunction() - - def rename(self,rename_dict): - """ - Changes a dataset into one that renames fields, using a dictionnary that maps old field - names to new field names. The only fields visible by the returned dataset are those - whose names are keys of the rename_dict. + def __call__(self,*fieldnames): """ - self_class = self.__class__ - class SelfRenamingDataSet(RenamingDataSet,self_class): - pass - self.__class__ = SelfRenamingDataSet - # set the rename_dict and src fields - SelfRenamingDataSet.__init__(self,self,rename_dict) - return self - - def apply_function(self,function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True): + Return a dataset that sees only the fields whose name are specified. """ - Changes a dataset into one that contains as fields the results of applying - the given function (example-wise) to the specified input_fields. The - function should return a sequence whose elements will be stored in - fields whose names are given in the output_fields list. If copy_inputs - is True then the resulting dataset will also contain the fields of self. - If accept_minibatches, then the function may be called - with minibatches as arguments (what is returned by the minibatches - iterator). In any case, the computations may be delayed until the examples - of the resulting dataset are requested. If cache is True, then - once the output fields for some examples have been computed, then - are cached (to avoid recomputation if the same examples are again - requested). - """ - self_class = self.__class__ - class SelfApplyFunctionDataSet(ApplyFunctionDataSet,self_class): - pass - self.__class__ = SelfApplyFunctionDataSet - # set the required additional fields - ApplyFunctionDataSet.__init__(self,self,function, input_fields, output_fields, copy_inputs, accept_minibatches, cache) - return self + assert self.hasFields(fieldnames) + return self.fields(fieldnames).examples() - -class FiniteLengthDataSet(DataSet): - """ - Virtual interface for datasets that have a finite length (number of examples), - and thus recognize a len(dataset) call. - """ - def __init__(self): - DataSet.__init__(self) - - def __len__(self): - """len(dataset) returns the number of examples in the dataset.""" - raise AbstractFunction() - - def __call__(self,fieldname_or_fieldnames): - """ - Extract one or more fields. This may be an expensive operation when the - dataset is large. It is not the recommanded way to access individual values - (use the iterators instead). If the argument is a string fieldname, then the result - is a sequence (iterable object) of values for that field, for the whole dataset. If the - argument is a list of field names, then the result is a 'batch', i.e., an Example with keys - corresponding to the given field names and values being iterable objects over the - individual example values. + def fields(self,*fieldnames): """ - if type(fieldname_or_fieldnames) is string: - minibatch = self.minibatches([fieldname_or_fieldnames],len(self)).next() - return minibatch[fieldname_or_fieldnames] - return self.minibatches(fieldname_or_fieldnames,len(self)).next() - -class SliceableDataSet(DataSet): - """ - Virtual interface, a subclass of DataSet for datasets which are sliceable - and whose individual elements can be accessed, generally respecting the - python semantics for [spec], where spec is either a non-negative integer - (for selecting one example), a python slice(start,stop,step) for selecting a regular - sub-dataset comprising examples start,start+step,start+2*step,...,n (with n 0 - if minibatch_size >= len(dataset): - raise NotImplementedError() - - def __iter__(self): #makes for loop work - return self - - @staticmethod - def matcat(a, b): - a0, a1 = a.shape - b0, b1 = b.shape - assert a1 == b1 - assert a.dtype is b.dtype - rval = numpy.empty( (a0 + b0, a1), dtype=a.dtype) - rval[:a0,:] = a - rval[a0:,:] = b - return rval + def __add__(self,other): + """ + dataset1 + dataset2 is a dataset that concatenates the examples from the argument datasets + (and whose length is the sum of the length of the argument datasets). This only + works if they all have the same fields. + """ + return VStackedDataSet(self,other) - def next_index(self): - n_rows = self.dataset.data.shape[0] - next_i = self.current+self.minibatch_size - if next_i >= n_rows: - next_i -= n_rows - return next_i - - def next(self): - - #check for end-of-loop - self.next_count += 1 - if self.next_count == self.next_max: - raise StopIteration +def hstack(datasets): + """ + hstack(dataset1,dataset2,...) returns dataset1 | datataset2 | ... + which is a dataset whose fields list is the concatenation of the fields + of the individual datasets. + """ + assert len(datasets)>0 + if len(datasets)==1: + return datasets[0] + return HStackedDataSet(datasets) - #determine the first and last elements of the minibatch slice we'll return - n_rows = self.dataset.data.shape[0] - self.current = self.next_index() - upper = self.current + self.minibatch_size - - data = self.dataset.data - - if upper <= n_rows: - #this is the easy case, we only need once slice - dataview = data[self.current:upper] - else: - # the minibatch wraps around the end of the dataset - dataview = data[self.current:] - upper -= n_rows - assert upper > 0 - dataview = self.matcat(dataview, data[:upper]) - - self._values = [dataview[:, self.dataset.fields[f]]\ - for f in self._names] - return self +def vstack(datasets): + """ + vstack(dataset1,dataset2,...) returns dataset1 + datataset2 + ... + which is a dataset which iterates first over the examples of dataset1, then + over those of dataset2, etc. + """ + assert len(datasets)>0 + if len(datasets)==1: + return datasets[0] + return VStackedDataSet(datasets) - def __init__(self, data, fields=None): - """ - There are two ways to construct an ArrayDataSet: (1) from an - existing dataset (which may result in a copy of the data in a numpy array), - or (2) from a numpy.array (the data argument), along with an optional description - of the fields (a LookupList of column slices (or column lists) indexed by field names). - """ - self.data=data - self.fields=fields - rows, cols = data.shape +class DataSetFields(LookupList): + """ + Although a DataSet iterates over examples (like rows of a matrix), an associated + DataSetFields iterates over fields (like columns of a matrix), and can be understood + as a transpose of the associated dataset. - if fields: - for fieldname,fieldslice in fields.items(): - assert type(fieldslice) is int or isinstance(fieldslice,slice) or hasattr(fieldslice,"__iter__") - if hasattr(fieldslice,"__iter__"): # is a sequence - for i in fieldslice: - assert type(i) is int - elif isinstance(fieldslice,slice): - # make sure fieldslice.start and fieldslice.step are defined - start=fieldslice.start - step=fieldslice.step - if not start: - start=0 - if not step: - step=1 - if not fieldslice.start or not fieldslice.step: - fields[fieldname] = fieldslice = slice(start,fieldslice.stop,step) - # and coherent with the data array - assert fieldslice.start >= 0 and fieldslice.stop <= cols + To iterate over fields, one can do + * for fields in dataset.fields() + * for fields in dataset(field1,field2,...).fields() to select a subset of fields + * for fields in dataset.fields(field1,field2,...) to select a subset of fields + and each of these fields is iterable over the examples: + * for field_examples in dataset.fields(): + for example_value in field_examples: + ... + but when the dataset is a stream (unbounded length), it is not recommanded to do + such things because the underlying dataset may refuse to access the different fields in + an unsynchronized ways. Hence the fields() method is illegal for streams, by default. + The result of fields() is a DataSetFields object, which iterates over fields, + and whose elements are iterable over examples. A DataSetFields object can + be turned back into a DataSet with its examples() method: + dataset2 = dataset1.fields().examples() + and dataset2 should behave exactly like dataset1 (in fact by default dataset2==dataset1). + """ + def __init__(self,dataset,*fieldnames): + self.dataset=dataset + assert dataset.hasField(*fieldnames) + LookupList.__init__(self,dataset.fieldNames(), + dataset.minibatches(fieldnames if len(fieldnames)>0 else self.fieldNames(),minibatch_size=len(dataset)).next() + def examples(self): + return self.dataset + + def __or__(self,other): + """ + fields1 | fields2 is a DataSetFields that whose list of examples is the concatenation + of the list of examples of DataSetFields fields1 and fields2. + """ + return (self.examples() + other.examples()).fields() - def minibatches(self, - fieldnames = DataSet.minibatches_fieldnames, - minibatch_size = DataSet.minibatches_minibatch_size, - n_batches = DataSet.minibatches_n_batches): + def __add__(self,other): """ - If the fieldnames list is None, it means that we want to see ALL the fields. + fields1 + fields2 is a DataSetFields that whose list of fields is the concatenation + of the fields of DataSetFields fields1 and fields2. + """ + return (self.examples() | other.examples()).fields() - If the n_batches is None, we want to see all the examples possible - for the given minibatch_size (possibly missing some near the end). +class MinibatchDataSet(DataSet): + """ + Turn a LookupList of same-length fields into an example-iterable dataset. + Each element of the lookup-list should be an iterable and sliceable, all of the same length. + """ + def __init__(self,fields_lookuplist,values_vstack=DataSet().valuesVStack, + values_hstack=DataSet().valuesHStack): """ - # substitute the defaults: - if n_batches is None: n_batches = len(self) / minibatch_size - return ArrayDataSet.Iterator(self, fieldnames, minibatch_size, n_batches) - - def fieldNames(self): - """Return the list of field names that are supported by getattr and hasField.""" - return self.fields.keys() + The user can (and generally should) also provide values_vstack(fieldname,fieldvalues) + and a values_hstack(fieldnames,fieldvalues) functions behaving with the same + semantics as the DataSet methods of the same name (but without the self argument). + """ + self.fields=fields_lookuplist + assert len(fields_lookuplist)>0 + self.length=len(fields_lookuplist[0]) + for field in fields_lookuplist[1:]: + assert self.length==len(field) + self.values_vstack=values_vstack + self.values_hstack=values_hstack def __len__(self): - """len(dataset) returns the number of examples in the dataset.""" - return len(self.data) - - def __getitem__(self,i): - """ - dataset[i] returns the (i+1)-th Example of the dataset. - If there are no fields the result is just a numpy array (for the i-th row of the dataset data matrix). - dataset[i:j] returns the subdataset with examples i,i+1,...,j-1. - dataset[i:j:s] returns the subdataset with examples i,i+2,i+4...,j-2. - dataset[[i1,i2,..,in]] returns the subdataset with examples i1,i2,...,in. - """ - if self.fields: - fieldnames,fieldslices=zip(*self.fields.items()) - return Example(self.fields.keys(),[self.data[i,fieldslice] for fieldslice in self.fields.values()]) - else: - return self.data[i] - - def __getslice__(self,*args): - """ - dataset[i:j] returns the subdataset with examples i,i+1,...,j-1. - dataset[i:j:s] returns the subdataset with examples i,i+2,i+4...,j-2. - """ - return ArrayDataSet(self.data.__getslice__(*args), fields=self.fields) - - def indices_of_unique_columns_used(self): - """ - Return the unique indices of the columns actually used by the fields, and a boolean - that signals (if True) that used columns overlap. If they do then the - indices are not repeated in the result. - """ - columns_used = numpy.zeros((self.data.shape[1]),dtype=bool) - overlapping_columns = False - for field_slice in self.fields.values(): - if sum(columns_used[field_slice])>0: overlapping_columns=True - columns_used[field_slice]=True - return [i for i,used in enumerate(columns_used) if used],overlapping_columns + return self.length - def slice_of_unique_columns_used(self): - """ - Return None if the indices_of_unique_columns_used do not form a slice. If they do, - return that slice. It means that the columns used can be extracted - from the data array without making a copy. If the fields overlap - but their unique columns used form a slice, still return that slice. - """ - columns_used,overlapping_columns = self.indices_of_columns_used() - mappable_to_one_slice = True - if not overlapping_fields: - start=0 - while start0 and not columns_used[stop-1]: - stop-=1 - step=0 - i=start - while i=0 - fieldnames = output_fields - if copy_inputs: fieldnames = src.fieldNames() + output_fields - if accept_minibatches: - # make a single minibatch with all the inputs - inputs = src.minibatches(input_fields,len(src)).next() - # and apply the function to it, and transpose into a list of examples (field values, actually) - self.cached_examples = zip(*Example(output_fields,function(*inputs))) - else: - # compute a list with one tuple per example, with the function outputs - self.cached_examples = [ function(input) for input in src.zip(input_fields) ] - elif cache: - # maybe a fixed-size array kind of structure would be more efficient than a list - # in the case where src is FiniteDataSet. -YB - self.cached_examples = [] + def __getitem__(self,i): + return Example(self.fields.keys(),[field[i] for field in self.fields]) def fieldNames(self): - if self.copy_inputs: - return self.output_fields + self.src.fieldNames() - return self.output_fields - - def minibatches(self, - fieldnames = DataSet.minibatches_fieldnames, - minibatch_size = DataSet.minibatches_minibatch_size, - n_batches = DataSet.minibatches_n_batches): - - class Iterator(LookupList): + return self.fields.keys() + + def hasField(self,*fieldnames): + for fieldname in fieldnames: + if fieldname not in self.fields: + return False + return True - def __init__(self,dataset): - if fieldnames is None: - assert hasattr(dataset,"fieldNames") - fieldnames = dataset.fieldNames() - self.example_index=0 - LookupList.__init__(self, fieldnames, [0]*len(fieldnames)) - self.dataset=dataset - self.src_iterator=self.src.minibatches(list(set.union(set(fieldnames),set(dataset.input_fields))), - minibatch_size,n_batches) - self.fieldnames_not_in_input = [] - if self.copy_inputs: - self.fieldnames_not_in_input = filter(lambda x: not x in dataset.input_fields, fieldnames) - + def minibatches(self, + fieldnames = minibatches_fieldnames, + minibatch_size = minibatches_minibatch_size, + n_batches = minibatches_n_batches): + class Iterator(object): + def __init__(self,ds): + self.ds=ds + self.next_example=0 + self.n_batches_done=0 + assert minibatch_size > 0 + if minibatch_size > ds.length + raise NotImplementedError() def __iter__(self): return self - def next_index(self): - return self.src_iterator.next_index() - + return self.next_example def next(self): - example_index = self.src_iterator.next_index() - src_examples = self.src_iterator.next() - if self.dataset.copy_inputs: - function_inputs = [src_examples[field_name] for field_name in self.dataset.input_fields] + upper = next_example+minibatch_size + if upper<=self.ds.length: + minibatch = Example(self.ds.fields.keys(), + [field[next_example:upper] + for field in self.ds.fields]) + else: # we must concatenate (vstack) the bottom and top parts of our minibatch + minibatch = Example(self.ds.fields.keys(), + [self.ds.valuesVStack(name,[value[next_example:], + value[0:upper-self.ds.length]]) + for name,value in self.ds.fields.items()]) + self.next_example+=minibatch_size + self.n_batches_done+=1 + if n_batches: + if self.n_batches_done==n_batches: + raise StopIteration + if self.next_example>=self.ds.length: + self.next_example-=self.ds.length else: - function_inputs = src_examples - if self.dataset.cached_examples: - cache_len=len(self.cached_examples) - if example_index=self.ds.length: + raise StopIteration + return DataSetFields(MinibatchDataSet(minibatch),fieldnames) - for fieldname in fieldnames: - assert fieldname in self.output_fields or self.src.hasFields(fieldname) return Iterator(self) + def valuesVStack(self,fieldname,fieldvalues): + return self.values_vstack(fieldname,fieldvalues) + def valuesHStack(self,fieldnames,fieldvalues): + return self.values_hstack(fieldnames,fieldvalues) + +class HStackedDataSet(DataSet): + """ + A DataSet that wraps several datasets and shows a view that includes all their fields, + i.e. whose list of fields is the concatenation of their lists of fields. + + If a field name is found in more than one of the datasets, then either an error is + raised or the fields are renamed (either by prefixing the __name__ attribute + of the dataset + ".", if it exists, or by suffixing the dataset index in the argument list). + + TODO: automatically detect a chain of stacked datasets due to A | B | C | D ... + """ + def __init__(self,datasets,accept_nonunique_names=False): + DataSet.__init__(self) + self.datasets=datasets + self.accept_nonunique_names=accept_nonunique_names + self.fieldname2dataset={} + + def rename_field(fieldname,dataset,i): + if hasattr(dataset,"__name__"): + return dataset.__name__ + "." + fieldname + return fieldname+"."+str(i) + + # make sure all datasets have the same length and unique field names + self.length=None + names_to_change=[] + for i in xrange(len(datasets)): + dataset = datasets[i] + length=len(dataset) + if self.length: + assert self.length==length + else: + self.length=length + for fieldname in dataset.fieldNames(): + if fieldname in self.fieldname2dataset: # name conflict! + if accept_nonunique_names: + fieldname=rename_field(fieldname,dataset,i) + names2change.append((fieldname,i)) + else: + raise ValueError("Incompatible datasets: non-unique field name = "+fieldname) + self.fieldname2dataset[fieldname]=i + for fieldname,i in names_to_change: + del self.fieldname2dataset[fieldname] + self.fieldname2dataset[rename_field(fieldname,self.datasets[i],i)]=i + + def hasField(self,*fieldnames): + for fieldname in fieldnames: + if not fieldname in self.fieldname2dataset: + return False + return True + + def fieldNames(self): + return self.fieldname2dataset.keys() + + def minibatches(self, + fieldnames = minibatches_fieldnames, + minibatch_size = minibatches_minibatch_size, + n_batches = minibatches_n_batches): + + class Iterator(object): + def __init__(self,hsds,iterators): + self.hsds=hsds + self.iterators=iterators + def __iter__(self): + return self + def next_index(self): + return self.iterators[0].next_index() + def next(self): + # concatenate all the fields of the minibatches + minibatch = reduce(LookupList.__add__,[iterator.next() for iterator in self.iterators]) + # and return a DataSetFields whose dataset is the transpose (=examples()) of this minibatch + return DataSetFields(MinibatchDataSet(minibatch,self.hsds.valuesVStack, + self.hsds.valuesHStack), + fieldnames if fieldnames else hsds.fieldNames()) + + assert self.hasfields(fieldnames) + # find out which underlying datasets are necessary to service the required fields + # and construct corresponding minibatch iterators + if fieldnames: + datasets=set([]) + fields_in_dataset=dict([(dataset,[]) for dataset in datasets]) + for fieldname in fieldnames: + dataset=self.datasets[self.fieldnames2dataset[fieldname]] + datasets.add(dataset) + fields_in_dataset[dataset].append(fieldname) + datasets=list(datasets) + iterators=[dataset.minibatches(fields_in_dataset[dataset],minibatch_size,n_batches) + for dataset in datasets] + else: + datasets=self.datasets + iterators=[dataset.minibatches(None,minibatch_size,n_batches) for dataset in datasets] + return Iterator(self,iterators) + + + def valuesVStack(self,fieldname,fieldvalues): + return self.datasets[self.fieldname2dataset[fieldname]].valuesVStack(fieldname,fieldvalues) + + def valuesHStack(self,fieldnames,fieldvalues): + """ + We will use the sub-dataset associated with the first fieldname in the fieldnames list + to do the work, hoping that it can cope with the other values (i.e. won't care + about the incompatible fieldnames). Hence this heuristic will always work if + all the fieldnames are of the same sub-dataset. + """ + return self.datasets[self.fieldname2dataset[fieldnames[0]]].valuesHStack(fieldnames,fieldvalues) + +class VStackedDataSet(DataSet): + """ + A DataSet that wraps several datasets and shows a view that includes all their examples, + in the order provided. This clearly assumes that they all have the same field names + and all (except possibly the last one) are of finite length. + + TODO: automatically detect a chain of stacked datasets due to A + B + C + D ... + """ + def __init__(self,datasets): + self.datasets=datasets + self.length=0 + self.index2dataset={} + # we use this map from row index to dataset index for constant-time random access of examples, + # to avoid having to search for the appropriate dataset each time and slice is asked for + for dataset,k in enumerate(datasets[0:-1]): + L=len(dataset) + assert L