comparison dataset.py @ 188:f01ac276c6fb

added __contains__ to Dataset, added parent constructor call to ArrayDataSet
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 14 May 2008 14:49:08 -0400
parents 895b4b60f5e8
children cb6b945acf5a
comparison
equal deleted inserted replaced
187:ebbb0e749565 188:f01ac276c6fb
205 205
206 The default implementation calls the minibatches iterator and extracts the first example of each field. 206 The default implementation calls the minibatches iterator and extracts the first example of each field.
207 """ 207 """
208 return DataSet.MinibatchToSingleExampleIterator(self.minibatches(None, minibatch_size = 1)) 208 return DataSet.MinibatchToSingleExampleIterator(self.minibatches(None, minibatch_size = 1))
209 209
210 def __contains__(self, fieldname):
211 return (fieldname in self.fieldNames()) \
212 or (fieldname in self.attributeNames())
210 213
211 class MinibatchWrapAroundIterator(object): 214 class MinibatchWrapAroundIterator(object):
212 """ 215 """
213 An iterator for minibatches that handles the case where we need to wrap around the 216 An iterator for minibatches that handles the case where we need to wrap around the
214 dataset because n_batches*minibatch_size > len(dataset). It is constructed from 217 dataset because n_batches*minibatch_size > len(dataset). It is constructed from
935 whose first axis iterates over examples, second axis determines fields. 938 whose first axis iterates over examples, second axis determines fields.
936 If the underlying array is N-dimensional (has N axes), then the field 939 If the underlying array is N-dimensional (has N axes), then the field
937 values are (N-2)-dimensional objects (i.e. ordinary numbers if N=2). 940 values are (N-2)-dimensional objects (i.e. ordinary numbers if N=2).
938 """ 941 """
939 942
940 def __init__(self, data_array, fields_columns): 943 def __init__(self, data_array, fields_columns, **kwargs):
941 """ 944 """
942 Construct an ArrayDataSet from the underlying numpy array (data) and 945 Construct an ArrayDataSet from the underlying numpy array (data) and
943 a map (fields_columns) from fieldnames to field columns. The columns of a field are specified 946 a map (fields_columns) from fieldnames to field columns. The columns of a field are specified
944 using the standard arguments for indexing/slicing: integer for a column index, 947 using the standard arguments for indexing/slicing: integer for a column index,
945 slice for an interval of columns (with possible stride), or iterable of column indices. 948 slice for an interval of columns (with possible stride), or iterable of column indices.
946 """ 949 """
950 ArrayFieldsDataSet.__init__(self, **kwargs)
947 self.data=data_array 951 self.data=data_array
948 self.fields_columns=fields_columns 952 self.fields_columns=fields_columns
949 953
950 # check consistency and complete slices definitions 954 # check consistency and complete slices definitions
951 for fieldname, fieldcolumns in self.fields_columns.items(): 955 for fieldname, fieldcolumns in self.fields_columns.items():