Mercurial > pylearn
comparison dataset.py @ 226:3595ba2610f7
merged
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Fri, 23 May 2008 17:12:12 -0400 |
parents | 517364d48ae0 |
children | 17c5d080964b |
comparison
equal
deleted
inserted
replaced
225:8bc16220b29a | 226:3595ba2610f7 |
---|---|
243 assert offset+minibatch_size<=self.L | 243 assert offset+minibatch_size<=self.L |
244 ds_nbatches = (self.L-self.next_row)/self.minibatch_size | 244 ds_nbatches = (self.L-self.next_row)/self.minibatch_size |
245 if n_batches is not None: | 245 if n_batches is not None: |
246 ds_nbatches = min(n_batches,ds_nbatches) | 246 ds_nbatches = min(n_batches,ds_nbatches) |
247 if fieldnames: | 247 if fieldnames: |
248 if not dataset.hasFields(*fieldnames): | 248 assert dataset.hasFields(*fieldnames) |
249 raise ValueError('field not present', fieldnames) | |
250 else: | 249 else: |
251 self.fieldnames=dataset.fieldNames() | 250 self.fieldnames=dataset.fieldNames() |
252 self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size, | 251 self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size, |
253 ds_nbatches,self.next_row) | 252 ds_nbatches,self.next_row) |
254 | 253 |
668 | 667 |
669 self._fields=fields_lookuplist | 668 self._fields=fields_lookuplist |
670 assert len(fields_lookuplist)>0 | 669 assert len(fields_lookuplist)>0 |
671 self.length=len(fields_lookuplist[0]) | 670 self.length=len(fields_lookuplist[0]) |
672 for field in fields_lookuplist[1:]: | 671 for field in fields_lookuplist[1:]: |
672 if self.length != len(field) : | |
673 print 'self.length = ',self.length | |
674 print 'len(field) = ', len(field) | |
675 print 'self._fields.keys() = ', self._fields.keys() | |
676 print 'field=',field | |
673 assert self.length==len(field) | 677 assert self.length==len(field) |
674 self.values_vstack=values_vstack | 678 self.values_vstack=values_vstack |
675 self.values_hstack=values_hstack | 679 self.values_hstack=values_hstack |
676 | 680 |
677 def __len__(self): | 681 def __len__(self): |
696 if fieldname not in self._fields.keys(): | 700 if fieldname not in self._fields.keys(): |
697 return False | 701 return False |
698 return True | 702 return True |
699 | 703 |
700 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 704 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
705 #@TODO bug somewhere here, fieldnames doesnt seem to be well handled | |
701 class Iterator(object): | 706 class Iterator(object): |
702 def __init__(self,ds): | 707 def __init__(self,ds,fieldnames): |
708 # tbm: added two next lines to handle fieldnames | |
709 if fieldnames is None: fieldnames = ds._fields.keys() | |
710 self.fieldnames = fieldnames | |
711 | |
703 self.ds=ds | 712 self.ds=ds |
704 self.next_example=offset | 713 self.next_example=offset |
705 assert minibatch_size > 0 | 714 assert minibatch_size > 0 |
706 if offset+minibatch_size > ds.length: | 715 if offset+minibatch_size > ds.length: |
707 raise NotImplementedError() | 716 raise NotImplementedError() |
708 def __iter__(self): | 717 def __iter__(self): |
709 return self | 718 return self |
710 def next(self): | 719 def next(self): |
711 upper = self.next_example+minibatch_size | 720 upper = self.next_example+minibatch_size |
712 assert upper<=self.ds.length | 721 assert upper<=self.ds.length |
713 minibatch = Example(self.ds._fields.keys(), | 722 #minibatch = Example(self.ds._fields.keys(), |
714 [field[self.next_example:upper] | 723 # [field[self.next_example:upper] |
715 for field in self.ds._fields]) | 724 # for field in self.ds._fields]) |
725 # tbm: modif to use fieldnames | |
726 values = [] | |
727 for f in self.fieldnames : | |
728 #print 'we have field',f,'in fieldnames' | |
729 values.append( self.ds._fields[f][self.next_example:upper] ) | |
730 minibatch = Example(self.fieldnames,values) | |
731 #print minibatch | |
716 self.next_example+=minibatch_size | 732 self.next_example+=minibatch_size |
717 return minibatch | 733 return minibatch |
718 | 734 |
719 return Iterator(self) | 735 # tbm: added fieldnames to handle subset of fieldnames |
736 return Iterator(self,fieldnames) | |
720 | 737 |
721 def valuesVStack(self,fieldname,fieldvalues): | 738 def valuesVStack(self,fieldname,fieldvalues): |
722 return self.values_vstack(fieldname,fieldvalues) | 739 return self.values_vstack(fieldname,fieldvalues) |
723 | 740 |
724 def valuesHStack(self,fieldnames,fieldvalues): | 741 def valuesHStack(self,fieldnames,fieldvalues): |
968 | 985 |
969 # check consistency and complete slices definitions | 986 # check consistency and complete slices definitions |
970 for fieldname, fieldcolumns in self.fields_columns.items(): | 987 for fieldname, fieldcolumns in self.fields_columns.items(): |
971 if type(fieldcolumns) is int: | 988 if type(fieldcolumns) is int: |
972 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1] | 989 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1] |
973 | 990 self.fields_columns[fieldname]=[fieldcolumns] |
974 if 0: | |
975 #I changed this because it didn't make sense to me, | |
976 # and it made it more difficult to write my learner. | |
977 # If it breaks stuff, let's talk about it. | |
978 # - James 22/05/2008 | |
979 self.fields_columns[fieldname]=[fieldcolumns] | |
980 else: | |
981 self.fields_columns[fieldname]=fieldcolumns | |
982 | |
983 elif type(fieldcolumns) is slice: | 991 elif type(fieldcolumns) is slice: |
984 start,step=None,None | 992 start,step=None,None |
985 if not fieldcolumns.start: | 993 if not fieldcolumns.start: |
986 start=0 | 994 start=0 |
987 if not fieldcolumns.step: | 995 if not fieldcolumns.step: |