Mercurial > pylearn
comparison dataset.py @ 188:f01ac276c6fb
added __contains__ to Dataset, added parent constructor call to ArrayDataSet
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 14 May 2008 14:49:08 -0400 |
parents | 895b4b60f5e8 |
children | cb6b945acf5a |
comparison
equal
deleted
inserted
replaced
187:ebbb0e749565 | 188:f01ac276c6fb |
---|---|
205 | 205 |
206 The default implementation calls the minibatches iterator and extracts the first example of each field. | 206 The default implementation calls the minibatches iterator and extracts the first example of each field. |
207 """ | 207 """ |
208 return DataSet.MinibatchToSingleExampleIterator(self.minibatches(None, minibatch_size = 1)) | 208 return DataSet.MinibatchToSingleExampleIterator(self.minibatches(None, minibatch_size = 1)) |
209 | 209 |
210 def __contains__(self, fieldname): | |
211 return (fieldname in self.fieldNames()) \ | |
212 or (fieldname in self.attributeNames()) | |
210 | 213 |
211 class MinibatchWrapAroundIterator(object): | 214 class MinibatchWrapAroundIterator(object): |
212 """ | 215 """ |
213 An iterator for minibatches that handles the case where we need to wrap around the | 216 An iterator for minibatches that handles the case where we need to wrap around the |
214 dataset because n_batches*minibatch_size > len(dataset). It is constructed from | 217 dataset because n_batches*minibatch_size > len(dataset). It is constructed from |
935 whose first axis iterates over examples, second axis determines fields. | 938 whose first axis iterates over examples, second axis determines fields. |
936 If the underlying array is N-dimensional (has N axes), then the field | 939 If the underlying array is N-dimensional (has N axes), then the field |
937 values are (N-2)-dimensional objects (i.e. ordinary numbers if N=2). | 940 values are (N-2)-dimensional objects (i.e. ordinary numbers if N=2). |
938 """ | 941 """ |
939 | 942 |
940 def __init__(self, data_array, fields_columns): | 943 def __init__(self, data_array, fields_columns, **kwargs): |
941 """ | 944 """ |
942 Construct an ArrayDataSet from the underlying numpy array (data) and | 945 Construct an ArrayDataSet from the underlying numpy array (data) and |
943 a map (fields_columns) from fieldnames to field columns. The columns of a field are specified | 946 a map (fields_columns) from fieldnames to field columns. The columns of a field are specified |
944 using the standard arguments for indexing/slicing: integer for a column index, | 947 using the standard arguments for indexing/slicing: integer for a column index, |
945 slice for an interval of columns (with possible stride), or iterable of column indices. | 948 slice for an interval of columns (with possible stride), or iterable of column indices. |
946 """ | 949 """ |
950 ArrayFieldsDataSet.__init__(self, **kwargs) | |
947 self.data=data_array | 951 self.data=data_array |
948 self.fields_columns=fields_columns | 952 self.fields_columns=fields_columns |
949 | 953 |
950 # check consistency and complete slices definitions | 954 # check consistency and complete slices definitions |
951 for fieldname, fieldcolumns in self.fields_columns.items(): | 955 for fieldname, fieldcolumns in self.fields_columns.items(): |