Mercurial > pylearn
comparison dataset.py @ 4:f7dcfb5f9d5b
Added test for dataset.
author | bengioy@bengiomac.local |
---|---|
date | Sun, 23 Mar 2008 22:14:10 -0400 |
parents | 378b68d5c4ad |
children | 8039918516fe |
comparison
equal
deleted
inserted
replaced
3:378b68d5c4ad | 4:f7dcfb5f9d5b |
---|---|
56 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" | 56 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" |
57 raise NotImplementedError | 57 raise NotImplementedError |
58 | 58 |
59 # we may want ArrayDataSet defined in another python file | 59 # we may want ArrayDataSet defined in another python file |
60 | 60 |
61 from numpy import * | 61 import numpy |
62 | 62 |
63 class ArrayDataSet(FiniteDataSet): | 63 class ArrayDataSet(FiniteDataSet): |
64 """ | 64 """ |
65 A fixed-length and fixed-width dataset in which each element is a numpy.array | 65 A fixed-length and fixed-width dataset in which each element is a numpy.array |
66 or a number, hence the whole dataset corresponds to a numpy.array. Fields | 66 or a number, hence the whole dataset corresponds to a numpy.array. Fields |
68 each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array. | 68 each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array. |
69 Any dataset can also be converted to a numpy.array (losing the notion of fields) | 69 Any dataset can also be converted to a numpy.array (losing the notion of fields) |
70 by the asarray(dataset) call. | 70 by the asarray(dataset) call. |
71 """ | 71 """ |
72 | 72 |
73 def __self__(self,dataset=None,data=None,fields={}): | 73 def __init__(self,dataset=None,data=None,fields={}): |
74 """ | 74 """ |
75 Construct an ArrayDataSet, either from a DataSet, or from | 75 Construct an ArrayDataSet, either from a DataSet, or from |
76 a numpy.array plus an optional specification of fields (by | 76 a numpy.array plus an optional specification of fields (by |
77 a dictionary of column slices indexed by field names). | 77 a dictionary of column slices indexed by field names). |
78 """ | 78 """ |
79 self.current_row=-1 # used for view of this dataset as an iterator | 79 self.current_row=-1 # used for view of this dataset as an iterator |
80 if dataset: | 80 if dataset!=None: |
81 assert data==None and fields=={} | 81 assert data==None and fields=={} |
82 # convert dataset to an ArrayDataSet | 82 # convert dataset to an ArrayDataSet |
83 raise NotImplementedError | 83 raise NotImplementedError |
84 if data: | 84 if data!=None: |
85 assert dataset==None | 85 assert dataset==None |
86 self.data=data | 86 self.data=data |
87 self.fields=fields | 87 self.fields=fields |
88 self.width = data.shape[1] | 88 self.width = data.shape[1] |
89 for fieldname in fields: | 89 for fieldname in fields: |
90 fieldslice=fields[fieldname] | 90 fieldslice=fields[fieldname] |
91 assert fieldslice.start>=0 and fieldslice.stop<=width) | 91 # make sure fieldslice.start and fieldslice.step are defined |
92 start=fieldslice.start | |
93 step=fieldslice.step | |
94 if not start: | |
95 start=0 | |
96 if not step: | |
97 step=1 | |
98 if not fieldslice.start or not fieldslice.step: | |
99 fieldslice = slice(start,fieldslice.stop,step) | |
100 # and coherent with the data array | |
101 assert fieldslice.start>=0 and fieldslice.stop<=self.width | |
92 | 102 |
93 def next(self): | 103 def next(self): |
94 """Return the next example in the dataset. If the dataset has fields, | 104 """ |
95 the 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array.""" | 105 Return the next example in the dataset. If the dataset has fields, |
96 if fields: | 106 the 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array. |
107 """ | |
108 if self.fields: | |
97 self.current_row+=1 | 109 self.current_row+=1 |
98 if self.current_row==len(self.data): | 110 if self.current_row==len(self.data): |
99 self.current_row=0 | 111 self.current_row=-1 |
112 raise StopIteration | |
100 return self[self.current_row] | 113 return self[self.current_row] |
101 else: | 114 else: |
102 return self.data[self.current_row] | 115 return self.data[self.current_row] |
103 | 116 |
104 def __getattr__(self,fieldname): | 117 def __getattr__(self,fieldname): |
105 """Return a sub-dataset containing only the given fieldname as field.""" | 118 """Return a sub-dataset containing only the given fieldname as field.""" |
106 data = self.fields[fieldname] | 119 data=self.data[self.fields[fieldname]] |
107 return ArrayDataSet(data=data) | 120 if len(data)==1: |
121 return data | |
122 else: | |
123 return ArrayDataSet(data=data) | |
108 | 124 |
109 def __call__(self,*fieldnames): | 125 def __call__(self,*fieldnames): |
110 """Return a sub-dataset containing only the given fieldnames as fields.""" | 126 """Return a sub-dataset containing only the given fieldnames as fields.""" |
111 min_col=self.data.shape[1] | 127 min_col=self.data.shape[1] |
112 max_col=0 | 128 max_col=0 |
142 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" | 158 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" |
143 return ArrayDataSet(data=self.data[slice(slice_args)],fields=self.fields) | 159 return ArrayDataSet(data=self.data[slice(slice_args)],fields=self.fields) |
144 | 160 |
145 def asarray(self): | 161 def asarray(self): |
146 if self.fields: | 162 if self.fields: |
147 columns_used = zeros((self.data.shape[1]),dtype=bool) | 163 columns_used = numpy.zeros((self.data.shape[1]),dtype=bool) |
148 for field_slice in self.fields.values(): | 164 for field_slice in self.fields.values(): |
149 for c in xrange(field_slice.start,field_slice.stop,field_slice.step): | 165 for c in xrange(field_slice.start,field_slice.stop,field_slice.step): |
150 columns_used[c]=True | 166 columns_used[c]=True |
151 # try to figure out if we can map all the slices into one slice: | 167 # try to figure out if we can map all the slices into one slice: |
152 mappable_to_one_slice = True | 168 mappable_to_one_slice = True |
173 # else make contiguous copy | 189 # else make contiguous copy |
174 n_columns = sum(columns_used) | 190 n_columns = sum(columns_used) |
175 result = zeros((len(self.data),n_columns)+self.data.shape[2:],self.data.dtype) | 191 result = zeros((len(self.data),n_columns)+self.data.shape[2:],self.data.dtype) |
176 c=0 | 192 c=0 |
177 for field_slice in self.fields.values(): | 193 for field_slice in self.fields.values(): |
178 slice_width=field_slice.stop-field_slice.start | 194 slice_width=field_slice.stop-field_slice.start/field_slice.step |
179 if field_slice.step: | |
180 slice_width /= field_slice.step | |
181 # copy the field here | 195 # copy the field here |
182 result[:,slice(c,slice_width)]=self.data[field_slice] | 196 result[:,slice(c,slice_width)]=self.data[field_slice] |
183 c+=slice_width | 197 c+=slice_width |
184 return result | 198 return result |
185 return self.data | 199 return self.data |