comparison dataset.py @ 3:378b68d5c4ad

Added first (untested) version of ArrayDataSet
author bengioy@bengiomac.local
date Sun, 23 Mar 2008 14:41:22 -0400
parents 3fddb1c8f955
children f7dcfb5f9d5b
comparison
equal deleted inserted replaced
2:3fddb1c8f955 3:378b68d5c4ad
13 """ 13 """
14 14
15 def __init__(self): 15 def __init__(self):
16 pass 16 pass
17 17
18 def __iter__(): 18 def __iter__(self):
19 return self 19 return self
20 20
21 def next(): 21 def next(self):
22 """Return the next example in the dataset.""" 22 """Return the next example in the dataset."""
23 raise NotImplementedError 23 raise NotImplementedError
24 24
25 def __getattr__(fieldname): 25 def __getattr__(self,fieldname):
26 """Return a sub-dataset containing only the given fieldname as field.""" 26 """Return a sub-dataset containing only the given fieldname as field."""
27 return self(fieldname) 27 return self(fieldname)
28 28
29 def __call__(*fieldnames): 29 def __call__(self,*fieldnames):
30 """Return a sub-dataset containing only the given fieldnames as fields.""" 30 """Return a sub-dataset containing only the given fieldnames as fields."""
31 raise NotImplementedError 31 raise NotImplementedError
32 32
33 fieldNames(self): 33 def fieldNames(self):
34 """Return the list of field names that are supported by getattr and getFields.""" 34 """Return the list of field names that are supported by getattr and getFields."""
35 raise NotImplementedError 35 raise NotImplementedError
36 36
37 class FiniteDataSet(DataSet): 37 class FiniteDataSet(DataSet):
38 """ 38 """
53 raise NotImplementedError 53 raise NotImplementedError
54 54
55 def __getslice__(self,*slice_args): 55 def __getslice__(self,*slice_args):
56 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" 56 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1."""
57 raise NotImplementedError 57 raise NotImplementedError
58
59 # we may want ArrayDataSet defined in another python file
60
61 from numpy import *
62
63 class ArrayDataSet(FiniteDataSet):
64 """
65 A fixed-length and fixed-width dataset in which each element is a numpy.array
66 or a number, hence the whole dataset corresponds to a numpy.array. Fields
67 must correspond to a slice of columns. If the dataset has fields,
68 each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array.
69 Any dataset can also be converted to a numpy.array (losing the notion of fields)
70 by the asarray(dataset) call.
71 """
72
73 def __self__(self,dataset=None,data=None,fields={}):
74 """
75 Construct an ArrayDataSet, either from a DataSet, or from
76 a numpy.array plus an optional specification of fields (by
77 a dictionary of column slices indexed by field names).
78 """
79 self.current_row=-1 # used for view of this dataset as an iterator
80 if dataset:
81 assert data==None and fields=={}
82 # convert dataset to an ArrayDataSet
83 raise NotImplementedError
84 if data:
85 assert dataset==None
86 self.data=data
87 self.fields=fields
88 self.width = data.shape[1]
89 for fieldname in fields:
90 fieldslice=fields[fieldname]
91 assert fieldslice.start>=0 and fieldslice.stop<=width)
92
93 def next(self):
94 """Return the next example in the dataset. If the dataset has fields,
95 the 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array."""
96 if fields:
97 self.current_row+=1
98 if self.current_row==len(self.data):
99 self.current_row=0
100 return self[self.current_row]
101 else:
102 return self.data[self.current_row]
103
104 def __getattr__(self,fieldname):
105 """Return a sub-dataset containing only the given fieldname as field."""
106 data = self.fields[fieldname]
107 return ArrayDataSet(data=data)
108
109 def __call__(self,*fieldnames):
110 """Return a sub-dataset containing only the given fieldnames as fields."""
111 min_col=self.data.shape[1]
112 max_col=0
113 for field_slice in self.fields.values():
114 min_col=min(min_col,field_slice.start)
115 max_col=max(max_col,field_slice.stop)
116 new_fields={}
117 for field in self.fields:
118 new_fields[field[0]]=slice(field[1].start-min_col,field[1].stop-min_col,field[1].step)
119 return ArrayDataSet(data=self.data[:,min_col:max_col],fields=new_fields)
120
121 def fieldNames(self):
122 """Return the list of field names that are supported by getattr and getFields."""
123 return self.fields.keys()
124
125 def __len__(self):
126 """len(dataset) returns the number of examples in the dataset."""
127 return len(self.data)
58 128
129 def __getitem__(self,i):
130 """
131 dataset[i] returns the (i+1)-th example of the dataset. If the dataset has fields
132 then a one-example dataset is returned (to be able to handle example.field accesses).
133 """
134 if self.fields:
135 if isinstance(i,slice):
136 return ArrayDataSet(data=data[slice],fields=self.fields)
137 return ArrayDataSet(data=self.data[i:i+1],fields=self.fields)
138 else:
139 return data[i]
140
141 def __getslice__(self,*slice_args):
142 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1."""
143 return ArrayDataSet(data=self.data[slice(slice_args)],fields=self.fields)
144
145 def asarray(self):
146 if self.fields:
147 columns_used = zeros((self.data.shape[1]),dtype=bool)
148 for field_slice in self.fields.values():
149 for c in xrange(field_slice.start,field_slice.stop,field_slice.step):
150 columns_used[c]=True
151 # try to figure out if we can map all the slices into one slice:
152 mappable_to_one_slice = True
153 start=0
154 while start<len(columns_used) and not columns_used[start]:
155 start+=1
156 stop=len(columns_used)
157 while stop>0 and not columns_used[stop-1]:
158 stop-=1
159 step=0
160 i=start
161 while i<stop:
162 j=i+1
163 while not columns_used[j] and j<stop:
164 j+=1
165 if step:
166 if step!=j-i:
167 mappable_to_one_slice = False
168 break
169 else:
170 step = j-i
171 if mappable_to_one_slice:
172 return data[slice(start,stop,step)]
173 # else make contiguous copy
174 n_columns = sum(columns_used)
175 result = zeros((len(self.data),n_columns)+self.data.shape[2:],self.data.dtype)
176 c=0
177 for field_slice in self.fields.values():
178 slice_width=field_slice.stop-field_slice.start
179 if field_slice.step:
180 slice_width /= field_slice.step
181 # copy the field here
182 result[:,slice(c,slice_width)]=self.data[field_slice]
183 c+=slice_width
184 return result
185 return self.data
186