Mercurial > pylearn
comparison dataset.py @ 61:a8b70a9117ad
bugfix: in MinibatchDataSet renamed the class variable fields to _fields as parent class have a function called field.
bugfix: next_example is a class variable...
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Fri, 02 May 2008 09:55:38 -0400 |
parents | 9165d86855ab |
children | 23bf2c9eb7b3 |
comparison
equal
deleted
inserted
replaced
60:9165d86855ab | 61:a8b70a9117ad |
---|---|
586 """ | 586 """ |
587 The user can (and generally should) also provide values_vstack(fieldname,fieldvalues) | 587 The user can (and generally should) also provide values_vstack(fieldname,fieldvalues) |
588 and a values_hstack(fieldnames,fieldvalues) functions behaving with the same | 588 and a values_hstack(fieldnames,fieldvalues) functions behaving with the same |
589 semantics as the DataSet methods of the same name (but without the self argument). | 589 semantics as the DataSet methods of the same name (but without the self argument). |
590 """ | 590 """ |
591 self.fields=fields_lookuplist | 591 self._fields=fields_lookuplist |
592 assert len(fields_lookuplist)>0 | 592 assert len(fields_lookuplist)>0 |
593 self.length=len(fields_lookuplist[0]) | 593 self.length=len(fields_lookuplist[0]) |
594 for field in fields_lookuplist[1:]: | 594 for field in fields_lookuplist[1:]: |
595 assert self.length==len(field) | 595 assert self.length==len(field) |
596 self.values_vstack=values_vstack | 596 self.values_vstack=values_vstack |
600 return self.length | 600 return self.length |
601 | 601 |
602 def __getitem__(self,i): | 602 def __getitem__(self,i): |
603 if type(i) in (int,slice,list): | 603 if type(i) in (int,slice,list): |
604 return DataSetFields(MinibatchDataSet( | 604 return DataSetFields(MinibatchDataSet( |
605 Example(self.fields.keys(),[field[i] for field in self.fields])),self.fields) | 605 Example(self._fields.keys(),[field[i] for field in self._fields])),self._fields) |
606 if self.hasFields(i): | 606 if self.hasFields(i): |
607 return self.fields[i] | 607 return self._fields[i] |
608 assert i in self.__dict__ # else it means we are trying to access a non-existing property | 608 assert i in self.__dict__ # else it means we are trying to access a non-existing property |
609 return self.__dict__[i] | 609 return self.__dict__[i] |
610 | 610 |
611 def fieldNames(self): | 611 def fieldNames(self): |
612 return self.fields.keys() | 612 return self._fields.keys() |
613 | 613 |
614 def hasFields(self,*fieldnames): | 614 def hasFields(self,*fieldnames): |
615 for fieldname in fieldnames: | 615 for fieldname in fieldnames: |
616 if fieldname not in self.fields.keys(): | 616 if fieldname not in self._fields.keys(): |
617 return False | 617 return False |
618 return True | 618 return True |
619 | 619 |
620 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 620 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
621 class Iterator(object): | 621 class Iterator(object): |
626 if offset+minibatch_size > ds.length: | 626 if offset+minibatch_size > ds.length: |
627 raise NotImplementedError() | 627 raise NotImplementedError() |
628 def __iter__(self): | 628 def __iter__(self): |
629 return self | 629 return self |
630 def next(self): | 630 def next(self): |
631 upper = next_example+minibatch_size | 631 upper = self.next_example+minibatch_size |
632 assert upper<=self.ds.length | 632 assert upper<=self.ds.length |
633 minibatch = Example(self.ds.fields.keys(), | 633 minibatch = Example(self.ds._fields.keys(), |
634 [field[next_example:upper] | 634 [field[self.next_example:upper] |
635 for field in self.ds.fields]) | 635 for field in self.ds._fields]) |
636 self.next_example+=minibatch_size | 636 self.next_example+=minibatch_size |
637 return DataSetFields(MinibatchDataSet(minibatch),fieldnames) | 637 return DataSetFields(MinibatchDataSet(minibatch),fieldnames) |
638 | 638 |
639 return Iterator(self) | 639 return Iterator(self) |
640 | 640 |