comparison dataset.py @ 19:57f4015e2e09

Iterators extend LookupList
author bergstrj@iro.umontreal.ca
date Thu, 27 Mar 2008 01:59:44 -0400
parents 759d17112b23
children 266c68cb6136
comparison
equal deleted inserted replaced
18:60b164a0d84a 19:57f4015e2e09
38 all the fields of DataSet self. Every field of "i" will give access to 38 all the fields of DataSet self. Every field of "i" will give access to
39 a the field of a single example. Fields should be accessible via 39 a the field of a single example. Fields should be accessible via
40 i[identifier], but the derived class is free to accept any type of 40 i[identifier], but the derived class is free to accept any type of
41 identifier, and add extra functionality to the iterator. 41 identifier, and add extra functionality to the iterator.
42 """ 42 """
43 for i in self.minibatches( minibatch_size = 1): 43 raise AbstractFunction()
44 yield Example(i.keys(), [v[0] for v in i.values()])
45 44
46 def zip(self, *fieldnames): 45 def zip(self, *fieldnames):
47 """ 46 """
48 Supports two forms of syntax: 47 Supports two forms of syntax:
49 48
59 f1, f2, and f3 fields of a single example on each loop iteration. 58 f1, f2, and f3 fields of a single example on each loop iteration.
60 59
61 The derived class may accept fieldname arguments of any type. 60 The derived class may accept fieldname arguments of any type.
62 61
63 """ 62 """
64 for i in self.minibatches(fieldnames, minibatch_size = 1): 63 class Iter(LookupList):
65 yield [f[0] for f in i] 64 def __init__(self, ll):
65 LookupList.__init__(self, ll.keys(), ll.values())
66 self.ll = ll
67 def __iter__(self): #makes for loop work
68 return self
69 def next(self):
70 self.ll.next()
71 self._values = [v[0] for v in self.ll._values]
72 return self
73 return Iter(self.minibatches(fieldnames, minibatch_size = 1))
66 74
67 minibatches_fieldnames = None 75 minibatches_fieldnames = None
68 minibatches_minibatch_size = 1 76 minibatches_minibatch_size = 1
69 minibatches_n_batches = None 77 minibatches_n_batches = None
70 def minibatches(self, 78 def minibatches(self,
175 self.dataset=dataset 183 self.dataset=dataset
176 self.minibatch_size=minibatch_size 184 self.minibatch_size=minibatch_size
177 assert minibatch_size>=1 and minibatch_size<=len(dataset) 185 assert minibatch_size>=1 and minibatch_size<=len(dataset)
178 self.current = -self.minibatch_size 186 self.current = -self.minibatch_size
179 self.fieldnames = fieldnames 187 self.fieldnames = fieldnames
188 if len(dataset) % minibatch_size:
189 raise NotImplementedError()
180 190
181 def __iter__(self): 191 def __iter__(self):
182 return self 192 return self
183 193
184 def next(self): 194 def next(self):
285 each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy array. 295 each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy array.
286 Any dataset can also be converted to a numpy array (losing the notion of fields 296 Any dataset can also be converted to a numpy array (losing the notion of fields
287 by the numpy.array(dataset) call. 297 by the numpy.array(dataset) call.
288 """ 298 """
289 299
290 class Iterator(object): 300 class Iterator(LookupList):
291 """An iterator over a finite dataset that implements wrap-around""" 301 """An iterator over a finite dataset that implements wrap-around"""
292 def __init__(self, dataset, fieldnames, minibatch_size, next_max): 302 def __init__(self, dataset, fieldnames, minibatch_size, next_max):
303 LookupList.__init__(self, fieldnames, [0] * len(fieldnames))
293 self.dataset=dataset 304 self.dataset=dataset
294 self.fieldnames = fieldnames
295 self.minibatch_size=minibatch_size 305 self.minibatch_size=minibatch_size
296 self.next_count = 0 306 self.next_count = 0
297 self.next_max = next_max 307 self.next_max = next_max
298 self.current = -self.minibatch_size 308 self.current = -self.minibatch_size
299 assert minibatch_size > 0 309 assert minibatch_size > 0
300 if minibatch_size >= len(dataset): 310 if minibatch_size >= len(dataset):
301 raise NotImplementedError() 311 raise NotImplementedError()
302 312
303 def __iter__(self): 313 def __iter__(self): #makes for loop work
304 #Why do we do this? -JB
305 return self 314 return self
306 315
307 @staticmethod 316 @staticmethod
308 def matcat(a, b): 317 def matcat(a, b):
309 a0, a1 = a.shape 318 a0, a1 = a.shape
321 self.next_count += 1 330 self.next_count += 1
322 if self.next_count == self.next_max: 331 if self.next_count == self.next_max:
323 raise StopIteration 332 raise StopIteration
324 333
325 #determine the first and last elements of the slice we'll return 334 #determine the first and last elements of the slice we'll return
335 rows = self.dataset.data.shape[0]
326 self.current += self.minibatch_size 336 self.current += self.minibatch_size
327 if self.current >= len(self.dataset): 337 if self.current >= rows:
328 self.current -= len(self.dataset) 338 self.current -= rows
329 upper = self.current + self.minibatch_size 339 upper = self.current + self.minibatch_size
330 340
331 if upper <= len(self.dataset): 341 data = self.dataset.data
342
343 if upper <= rows:
332 #this is the easy case, we only need once slice 344 #this is the easy case, we only need once slice
333 dataview = self.dataset.data[self.current:upper] 345 dataview = data[self.current:upper]
334 else: 346 else:
335 # the minibatch wraps around the end of the dataset 347 # the minibatch wraps around the end of the dataset
336 dataview = self.dataset.data[self.current:] 348 dataview = data[self.current:]
337 upper -= len(self.dataset) 349 upper -= rows
338 assert upper > 0 350 assert upper > 0
339 dataview = self.matcat(dataview, self.dataset.data[:upper]) 351 dataview = self.matcat(dataview, data[:upper])
340 352
341 353
342 rval = [dataview[:, self.dataset.fields[f]] for f in self.fieldnames] 354 self._values = [dataview[:, self.dataset.fields[f]]\
343 355 for f in self._names]
344 if self.fieldnames: 356
345 rval = Example(self.fieldnames, rval) 357 return self
346
347 return rval
348 358
349 359
350 def __init__(self, data, fields=None): 360 def __init__(self, data, fields=None):
351 """ 361 """
352 There are two ways to construct an ArrayDataSet: (1) from an 362 There are two ways to construct an ArrayDataSet: (1) from an
369 step=1 379 step=1
370 if not fieldslice.start or not fieldslice.step: 380 if not fieldslice.start or not fieldslice.step:
371 fields[fieldname] = fieldslice = slice(start,fieldslice.stop,step) 381 fields[fieldname] = fieldslice = slice(start,fieldslice.stop,step)
372 # and coherent with the data array 382 # and coherent with the data array
373 assert fieldslice.start >= 0 and fieldslice.stop <= cols 383 assert fieldslice.start >= 0 and fieldslice.stop <= cols
384
385 def __iter__(self):
386 return self.zip(*self.fieldNames())
374 387
375 def minibatches(self, 388 def minibatches(self,
376 fieldnames = DataSet.minibatches_fieldnames, 389 fieldnames = DataSet.minibatches_fieldnames,
377 minibatch_size = DataSet.minibatches_minibatch_size, 390 minibatch_size = DataSet.minibatches_minibatch_size,
378 n_batches = DataSet.minibatches_n_batches): 391 n_batches = DataSet.minibatches_n_batches):