Mercurial > pylearn
comparison dataset.py @ 3:378b68d5c4ad
Added first (untested) version of ArrayDataSet
author | bengioy@bengiomac.local |
---|---|
date | Sun, 23 Mar 2008 14:41:22 -0400 |
parents | 3fddb1c8f955 |
children | f7dcfb5f9d5b |
comparison
equal
deleted
inserted
replaced
2:3fddb1c8f955 | 3:378b68d5c4ad |
---|---|
13 """ | 13 """ |
14 | 14 |
15 def __init__(self): | 15 def __init__(self): |
16 pass | 16 pass |
17 | 17 |
18 def __iter__(): | 18 def __iter__(self): |
19 return self | 19 return self |
20 | 20 |
21 def next(): | 21 def next(self): |
22 """Return the next example in the dataset.""" | 22 """Return the next example in the dataset.""" |
23 raise NotImplementedError | 23 raise NotImplementedError |
24 | 24 |
25 def __getattr__(fieldname): | 25 def __getattr__(self,fieldname): |
26 """Return a sub-dataset containing only the given fieldname as field.""" | 26 """Return a sub-dataset containing only the given fieldname as field.""" |
27 return self(fieldname) | 27 return self(fieldname) |
28 | 28 |
29 def __call__(*fieldnames): | 29 def __call__(self,*fieldnames): |
30 """Return a sub-dataset containing only the given fieldnames as fields.""" | 30 """Return a sub-dataset containing only the given fieldnames as fields.""" |
31 raise NotImplementedError | 31 raise NotImplementedError |
32 | 32 |
33 fieldNames(self): | 33 def fieldNames(self): |
34 """Return the list of field names that are supported by getattr and getFields.""" | 34 """Return the list of field names that are supported by getattr and getFields.""" |
35 raise NotImplementedError | 35 raise NotImplementedError |
36 | 36 |
37 class FiniteDataSet(DataSet): | 37 class FiniteDataSet(DataSet): |
38 """ | 38 """ |
53 raise NotImplementedError | 53 raise NotImplementedError |
54 | 54 |
55 def __getslice__(self,*slice_args): | 55 def __getslice__(self,*slice_args): |
56 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" | 56 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" |
57 raise NotImplementedError | 57 raise NotImplementedError |
58 | |
59 # we may want ArrayDataSet defined in another python file | |
60 | |
61 from numpy import * | |
62 | |
63 class ArrayDataSet(FiniteDataSet): | |
64 """ | |
65 A fixed-length and fixed-width dataset in which each element is a numpy.array | |
66 or a number, hence the whole dataset corresponds to a numpy.array. Fields | |
67 must correspond to a slice of columns. If the dataset has fields, | |
68 each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array. | |
69 Any dataset can also be converted to a numpy.array (losing the notion of fields) | |
70 by the asarray(dataset) call. | |
71 """ | |
72 | |
73 def __self__(self,dataset=None,data=None,fields={}): | |
74 """ | |
75 Construct an ArrayDataSet, either from a DataSet, or from | |
76 a numpy.array plus an optional specification of fields (by | |
77 a dictionary of column slices indexed by field names). | |
78 """ | |
79 self.current_row=-1 # used for view of this dataset as an iterator | |
80 if dataset: | |
81 assert data==None and fields=={} | |
82 # convert dataset to an ArrayDataSet | |
83 raise NotImplementedError | |
84 if data: | |
85 assert dataset==None | |
86 self.data=data | |
87 self.fields=fields | |
88 self.width = data.shape[1] | |
89 for fieldname in fields: | |
90 fieldslice=fields[fieldname] | |
91 assert fieldslice.start>=0 and fieldslice.stop<=width) | |
92 | |
93 def next(self): | |
94 """Return the next example in the dataset. If the dataset has fields, | |
95 the 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array.""" | |
96 if fields: | |
97 self.current_row+=1 | |
98 if self.current_row==len(self.data): | |
99 self.current_row=0 | |
100 return self[self.current_row] | |
101 else: | |
102 return self.data[self.current_row] | |
103 | |
104 def __getattr__(self,fieldname): | |
105 """Return a sub-dataset containing only the given fieldname as field.""" | |
106 data = self.fields[fieldname] | |
107 return ArrayDataSet(data=data) | |
108 | |
109 def __call__(self,*fieldnames): | |
110 """Return a sub-dataset containing only the given fieldnames as fields.""" | |
111 min_col=self.data.shape[1] | |
112 max_col=0 | |
113 for field_slice in self.fields.values(): | |
114 min_col=min(min_col,field_slice.start) | |
115 max_col=max(max_col,field_slice.stop) | |
116 new_fields={} | |
117 for field in self.fields: | |
118 new_fields[field[0]]=slice(field[1].start-min_col,field[1].stop-min_col,field[1].step) | |
119 return ArrayDataSet(data=self.data[:,min_col:max_col],fields=new_fields) | |
120 | |
121 def fieldNames(self): | |
122 """Return the list of field names that are supported by getattr and getFields.""" | |
123 return self.fields.keys() | |
124 | |
125 def __len__(self): | |
126 """len(dataset) returns the number of examples in the dataset.""" | |
127 return len(self.data) | |
58 | 128 |
129 def __getitem__(self,i): | |
130 """ | |
131 dataset[i] returns the (i+1)-th example of the dataset. If the dataset has fields | |
132 then a one-example dataset is returned (to be able to handle example.field accesses). | |
133 """ | |
134 if self.fields: | |
135 if isinstance(i,slice): | |
136 return ArrayDataSet(data=data[slice],fields=self.fields) | |
137 return ArrayDataSet(data=self.data[i:i+1],fields=self.fields) | |
138 else: | |
139 return data[i] | |
140 | |
141 def __getslice__(self,*slice_args): | |
142 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" | |
143 return ArrayDataSet(data=self.data[slice(slice_args)],fields=self.fields) | |
144 | |
145 def asarray(self): | |
146 if self.fields: | |
147 columns_used = zeros((self.data.shape[1]),dtype=bool) | |
148 for field_slice in self.fields.values(): | |
149 for c in xrange(field_slice.start,field_slice.stop,field_slice.step): | |
150 columns_used[c]=True | |
151 # try to figure out if we can map all the slices into one slice: | |
152 mappable_to_one_slice = True | |
153 start=0 | |
154 while start<len(columns_used) and not columns_used[start]: | |
155 start+=1 | |
156 stop=len(columns_used) | |
157 while stop>0 and not columns_used[stop-1]: | |
158 stop-=1 | |
159 step=0 | |
160 i=start | |
161 while i<stop: | |
162 j=i+1 | |
163 while not columns_used[j] and j<stop: | |
164 j+=1 | |
165 if step: | |
166 if step!=j-i: | |
167 mappable_to_one_slice = False | |
168 break | |
169 else: | |
170 step = j-i | |
171 if mappable_to_one_slice: | |
172 return data[slice(start,stop,step)] | |
173 # else make contiguous copy | |
174 n_columns = sum(columns_used) | |
175 result = zeros((len(self.data),n_columns)+self.data.shape[2:],self.data.dtype) | |
176 c=0 | |
177 for field_slice in self.fields.values(): | |
178 slice_width=field_slice.stop-field_slice.start | |
179 if field_slice.step: | |
180 slice_width /= field_slice.step | |
181 # copy the field here | |
182 result[:,slice(c,slice_width)]=self.data[field_slice] | |
183 c+=slice_width | |
184 return result | |
185 return self.data | |
186 |