Mercurial > pylearn
comparison dataset.py @ 45:a5c70dc42972
Test functions for dataset.py
author | bengioy@grenat.iro.umontreal.ca |
---|---|
date | Tue, 29 Apr 2008 11:25:36 -0400 |
parents | 5a85fda9b19b |
children | c5b07e87b0cb |
comparison
equal
deleted
inserted
replaced
44:5a85fda9b19b | 45:a5c70dc42972 |
---|---|
2 from lookup_list import LookupList | 2 from lookup_list import LookupList |
3 Example = LookupList | 3 Example = LookupList |
4 from misc import unique_elements_list_intersection | 4 from misc import unique_elements_list_intersection |
5 from string import join | 5 from string import join |
6 from sys import maxint | 6 from sys import maxint |
7 import numpy | |
7 | 8 |
8 class AbstractFunction (Exception): """Derived class must override this function""" | 9 class AbstractFunction (Exception): """Derived class must override this function""" |
9 class NotImplementedYet (NotImplementedError): """Work in progress, this should eventually be implemented""" | 10 class NotImplementedYet (NotImplementedError): """Work in progress, this should eventually be implemented""" |
10 #class UnboundedDataSet (Exception): """Trying to obtain length of unbounded dataset (a stream)""" | 11 #class UnboundedDataSet (Exception): """Trying to obtain length of unbounded dataset (a stream)""" |
11 | 12 |
374 # or a list of indices | 375 # or a list of indices |
375 elif type(i) is list: | 376 elif type(i) is list: |
376 rows = i | 377 rows = i |
377 if rows is not None: | 378 if rows is not None: |
378 fields_values = zip(*[self[row] for row in rows]) | 379 fields_values = zip(*[self[row] for row in rows]) |
379 return DataSet.MinibatchDataSet( | 380 return MinibatchDataSet( |
380 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) | 381 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) |
381 for fieldname,field_values | 382 for fieldname,field_values |
382 in zip(self.fieldNames(),fields_values)])) | 383 in zip(self.fieldNames(),fields_values)])) |
383 # else check for a fieldname | 384 # else check for a fieldname |
384 if self.hasFields(i): | 385 if self.hasFields(i): |