comparison dataset.py @ 61:a8b70a9117ad

bugfix: in MinibatchDataSet renamed the class variable fields to _fields as parent class have a function called field. bugfix: next_example is a class variable...
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Fri, 02 May 2008 09:55:38 -0400
parents 9165d86855ab
children 23bf2c9eb7b3
comparison
equal deleted inserted replaced
60:9165d86855ab 61:a8b70a9117ad
586 """ 586 """
587 The user can (and generally should) also provide values_vstack(fieldname,fieldvalues) 587 The user can (and generally should) also provide values_vstack(fieldname,fieldvalues)
588 and a values_hstack(fieldnames,fieldvalues) functions behaving with the same 588 and a values_hstack(fieldnames,fieldvalues) functions behaving with the same
589 semantics as the DataSet methods of the same name (but without the self argument). 589 semantics as the DataSet methods of the same name (but without the self argument).
590 """ 590 """
591 self.fields=fields_lookuplist 591 self._fields=fields_lookuplist
592 assert len(fields_lookuplist)>0 592 assert len(fields_lookuplist)>0
593 self.length=len(fields_lookuplist[0]) 593 self.length=len(fields_lookuplist[0])
594 for field in fields_lookuplist[1:]: 594 for field in fields_lookuplist[1:]:
595 assert self.length==len(field) 595 assert self.length==len(field)
596 self.values_vstack=values_vstack 596 self.values_vstack=values_vstack
600 return self.length 600 return self.length
601 601
602 def __getitem__(self,i): 602 def __getitem__(self,i):
603 if type(i) in (int,slice,list): 603 if type(i) in (int,slice,list):
604 return DataSetFields(MinibatchDataSet( 604 return DataSetFields(MinibatchDataSet(
605 Example(self.fields.keys(),[field[i] for field in self.fields])),self.fields) 605 Example(self._fields.keys(),[field[i] for field in self._fields])),self._fields)
606 if self.hasFields(i): 606 if self.hasFields(i):
607 return self.fields[i] 607 return self._fields[i]
608 assert i in self.__dict__ # else it means we are trying to access a non-existing property 608 assert i in self.__dict__ # else it means we are trying to access a non-existing property
609 return self.__dict__[i] 609 return self.__dict__[i]
610 610
611 def fieldNames(self): 611 def fieldNames(self):
612 return self.fields.keys() 612 return self._fields.keys()
613 613
614 def hasFields(self,*fieldnames): 614 def hasFields(self,*fieldnames):
615 for fieldname in fieldnames: 615 for fieldname in fieldnames:
616 if fieldname not in self.fields.keys(): 616 if fieldname not in self._fields.keys():
617 return False 617 return False
618 return True 618 return True
619 619
620 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 620 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
621 class Iterator(object): 621 class Iterator(object):
626 if offset+minibatch_size > ds.length: 626 if offset+minibatch_size > ds.length:
627 raise NotImplementedError() 627 raise NotImplementedError()
628 def __iter__(self): 628 def __iter__(self):
629 return self 629 return self
630 def next(self): 630 def next(self):
631 upper = next_example+minibatch_size 631 upper = self.next_example+minibatch_size
632 assert upper<=self.ds.length 632 assert upper<=self.ds.length
633 minibatch = Example(self.ds.fields.keys(), 633 minibatch = Example(self.ds._fields.keys(),
634 [field[next_example:upper] 634 [field[self.next_example:upper]
635 for field in self.ds.fields]) 635 for field in self.ds._fields])
636 self.next_example+=minibatch_size 636 self.next_example+=minibatch_size
637 return DataSetFields(MinibatchDataSet(minibatch),fieldnames) 637 return DataSetFields(MinibatchDataSet(minibatch),fieldnames)
638 638
639 return Iterator(self) 639 return Iterator(self)
640 640