changeset 161:2761360e1db0

queries with only 1 quant column
author catherine@dellzilla
date Mon, 29 Sep 2008 16:50:51 -0400
parents 8329ab7a3a49
children 8a0490229923
files sqlpython/plothandler.py
diffstat 1 files changed, 22 insertions(+), 12 deletions(-) [+]
line wrap: on
line diff
--- a/sqlpython/plothandler.py	Mon Sep 29 15:41:09 2008 -0400
+++ b/sqlpython/plothandler.py	Mon Sep 29 16:50:51 2008 -0400
@@ -14,17 +14,27 @@
             self.title = sqlSession.tblname
             self.xlabel = sqlSession.curs.description[0][0]
             self.datatypes = [d[1] for d in sqlSession.curs.description]
-            for (colNum, datatype) in enumerate(self.datatypes):
-                if colNum > 0 and datatype in self.plottable_types:
-                    yseries = [row[colNum] for row in sqlSession.rows]
-                    if max(yseries) is not None:
+            plottableSeries = [dt in self.plottable_types for dt in self.datatypes]
+            if plottableSeries.count(True) == 0:
+                raise ValueError, 'At least one quantitative column needed to plot.'
+            elif len(plottableSeries) == 1: # only one column, but it is plottable
+                idx = plottableSeries.index(True)
+                self.yserieslists = [[row[0] for row in sqlSession.rows]]
+                self.legends = [sqlSession.curs.description[0][0]]                
+                self.xvalues = range(len(sqlSession.rows))
+                self.xticks = [row[0] for row in sqlSession.rows]
+            else:
+                for (colNum, plottable) in enumerate(plottableSeries):
+                    if colNum > 0 and plottable:
+                        yseries = [row[colNum] for row in sqlSession.rows]
                         self.yserieslists.append(yseries)
                         self.legends.append(sqlSession.curs.description[colNum][0])
-            if self.datatypes[0] in self.plottable_types:
-                self.xvalues = [r[0] for r in sqlSession.rows]
-            else:
-                self.xvalues = range(sqlSession.curs.rowcount)
-                self.xticks = [r[0] for r in sqlSession.rows]
+                if plottableSeries[0]:
+                    self.xvalues = [r[0] for r in sqlSession.rows]
+                else:
+                    self.xvalues = range(sqlSession.curs.rowcount)
+                    self.xticks = [r[0] for r in sqlSession.rows]
+            
         def shelve(self):
             s = shelve.open(shelvename,'c')
             for k in ('xvalues xticks yserieslists title legends xlabel outformat'.split()):
@@ -38,11 +48,11 @@
             self.draw()
         def bar(self):
             barEdges = pylab.arange(len(self.xvalues))
-            width = 0.5 / len(self.yserieslists)
+            width = 0.6 / len(self.yserieslists)
             colorcycler = itertools.cycle('rgb')
             for (offset, yseries) in enumerate(self.yserieslists):
                 self.yplots.append(pylab.bar(barEdges + (offset*width), yseries, width=width, color=colorcycler.next()))
-            pylab.xticks(barEdges + 0.25, self.xticks or self.xvalues)            
+            pylab.xticks(barEdges + 0.3, self.xticks or self.xvalues)            
         def line(self, markers):
             for yseries in self.yserieslists:
                 self.yplots.append(pylab.plot(self.xvalues, yseries, markers))
@@ -76,4 +86,4 @@
 except ImportError:
     class Plot(object):
         def __init__(self, *args, **kwargs):
-            raise ImportError, 'Must install python-matplotlib to draw plots'
\ No newline at end of file
+            raise ImportError, 'Must install python-matplotlib and pytyon-numpy to draw plots'
\ No newline at end of file