comparison dataset.py @ 226:3595ba2610f7

merged
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 23 May 2008 17:12:12 -0400
parents 517364d48ae0
children 17c5d080964b
comparison
equal deleted inserted replaced
225:8bc16220b29a 226:3595ba2610f7
243 assert offset+minibatch_size<=self.L 243 assert offset+minibatch_size<=self.L
244 ds_nbatches = (self.L-self.next_row)/self.minibatch_size 244 ds_nbatches = (self.L-self.next_row)/self.minibatch_size
245 if n_batches is not None: 245 if n_batches is not None:
246 ds_nbatches = min(n_batches,ds_nbatches) 246 ds_nbatches = min(n_batches,ds_nbatches)
247 if fieldnames: 247 if fieldnames:
248 if not dataset.hasFields(*fieldnames): 248 assert dataset.hasFields(*fieldnames)
249 raise ValueError('field not present', fieldnames)
250 else: 249 else:
251 self.fieldnames=dataset.fieldNames() 250 self.fieldnames=dataset.fieldNames()
252 self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size, 251 self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size,
253 ds_nbatches,self.next_row) 252 ds_nbatches,self.next_row)
254 253
668 667
669 self._fields=fields_lookuplist 668 self._fields=fields_lookuplist
670 assert len(fields_lookuplist)>0 669 assert len(fields_lookuplist)>0
671 self.length=len(fields_lookuplist[0]) 670 self.length=len(fields_lookuplist[0])
672 for field in fields_lookuplist[1:]: 671 for field in fields_lookuplist[1:]:
672 if self.length != len(field) :
673 print 'self.length = ',self.length
674 print 'len(field) = ', len(field)
675 print 'self._fields.keys() = ', self._fields.keys()
676 print 'field=',field
673 assert self.length==len(field) 677 assert self.length==len(field)
674 self.values_vstack=values_vstack 678 self.values_vstack=values_vstack
675 self.values_hstack=values_hstack 679 self.values_hstack=values_hstack
676 680
677 def __len__(self): 681 def __len__(self):
696 if fieldname not in self._fields.keys(): 700 if fieldname not in self._fields.keys():
697 return False 701 return False
698 return True 702 return True
699 703
700 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 704 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
705 #@TODO bug somewhere here, fieldnames doesnt seem to be well handled
701 class Iterator(object): 706 class Iterator(object):
702 def __init__(self,ds): 707 def __init__(self,ds,fieldnames):
708 # tbm: added two next lines to handle fieldnames
709 if fieldnames is None: fieldnames = ds._fields.keys()
710 self.fieldnames = fieldnames
711
703 self.ds=ds 712 self.ds=ds
704 self.next_example=offset 713 self.next_example=offset
705 assert minibatch_size > 0 714 assert minibatch_size > 0
706 if offset+minibatch_size > ds.length: 715 if offset+minibatch_size > ds.length:
707 raise NotImplementedError() 716 raise NotImplementedError()
708 def __iter__(self): 717 def __iter__(self):
709 return self 718 return self
710 def next(self): 719 def next(self):
711 upper = self.next_example+minibatch_size 720 upper = self.next_example+minibatch_size
712 assert upper<=self.ds.length 721 assert upper<=self.ds.length
713 minibatch = Example(self.ds._fields.keys(), 722 #minibatch = Example(self.ds._fields.keys(),
714 [field[self.next_example:upper] 723 # [field[self.next_example:upper]
715 for field in self.ds._fields]) 724 # for field in self.ds._fields])
725 # tbm: modif to use fieldnames
726 values = []
727 for f in self.fieldnames :
728 #print 'we have field',f,'in fieldnames'
729 values.append( self.ds._fields[f][self.next_example:upper] )
730 minibatch = Example(self.fieldnames,values)
731 #print minibatch
716 self.next_example+=minibatch_size 732 self.next_example+=minibatch_size
717 return minibatch 733 return minibatch
718 734
719 return Iterator(self) 735 # tbm: added fieldnames to handle subset of fieldnames
736 return Iterator(self,fieldnames)
720 737
721 def valuesVStack(self,fieldname,fieldvalues): 738 def valuesVStack(self,fieldname,fieldvalues):
722 return self.values_vstack(fieldname,fieldvalues) 739 return self.values_vstack(fieldname,fieldvalues)
723 740
724 def valuesHStack(self,fieldnames,fieldvalues): 741 def valuesHStack(self,fieldnames,fieldvalues):
968 985
969 # check consistency and complete slices definitions 986 # check consistency and complete slices definitions
970 for fieldname, fieldcolumns in self.fields_columns.items(): 987 for fieldname, fieldcolumns in self.fields_columns.items():
971 if type(fieldcolumns) is int: 988 if type(fieldcolumns) is int:
972 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1] 989 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1]
973 990 self.fields_columns[fieldname]=[fieldcolumns]
974 if 0:
975 #I changed this because it didn't make sense to me,
976 # and it made it more difficult to write my learner.
977 # If it breaks stuff, let's talk about it.
978 # - James 22/05/2008
979 self.fields_columns[fieldname]=[fieldcolumns]
980 else:
981 self.fields_columns[fieldname]=fieldcolumns
982
983 elif type(fieldcolumns) is slice: 991 elif type(fieldcolumns) is slice:
984 start,step=None,None 992 start,step=None,None
985 if not fieldcolumns.start: 993 if not fieldcolumns.start:
986 start=0 994 start=0
987 if not fieldcolumns.step: 995 if not fieldcolumns.step: