From f45bc628f1cb94b488c6ab14faeca1fba00fc9c4 Mon Sep 17 00:00:00 2001 From: Tom Keffer Date: Tue, 15 Jan 2013 18:00:36 +0000 Subject: [PATCH] Now bad SELECT statements raise weedb.OperationalError --- bin/weedb/mysql.py | 2 +- bin/weedb/sqlite.py | 20 ++++++++++++++++---- bin/weedb/test/test_weedb.py | 36 +++++++++++++++++++++++++++++------- 3 files changed, 46 insertions(+), 12 deletions(-) diff --git a/bin/weedb/mysql.py b/bin/weedb/mysql.py index 07701746..8a422549 100644 --- a/bin/weedb/mysql.py +++ b/bin/weedb/mysql.py @@ -165,7 +165,7 @@ class Cursor(object): try: self.cursor.execute(mysql_string, sql_tuple) - except _mysql_exceptions.OperationalError, e: + except (_mysql_exceptions.OperationalError, _mysql_exceptions.ProgrammingError), e: raise weedb.OperationalError(e) return self diff --git a/bin/weedb/sqlite.py b/bin/weedb/sqlite.py index 2857c7cf..ed077159 100644 --- a/bin/weedb/sqlite.py +++ b/bin/weedb/sqlite.py @@ -75,9 +75,7 @@ class Connection(weedb.Connection): def cursor(self): """Return a cursor object.""" - # The sqlite3 cursor object is very full featured. We can simply return - # it (with no wrapper) - return self.connection.cursor() + return Cursor(self.connection) def tables(self): """Returns a list of tables in the database.""" @@ -105,4 +103,18 @@ class Connection(weedb.Connection): def begin(self): self.connection.execute("BEGIN TRANSACTION") - \ No newline at end of file + +class Cursor(sqlite3.Cursor): + """A wrapper around the sqlite cursor object""" + + # The sqlite3 cursor object is very full featured. We need only turn + # the sqlite exceptions into weedb exceptions. + def __init__(self, *args, **kwargs): + sqlite3.Cursor.__init__(self, *args, **kwargs) + + def execute(self, *args, **kwargs): + try: + sqlite3.Cursor.execute(self, *args, **kwargs) + except sqlite3.OperationalError, e: + # Convert to a weedb exception + raise weedb.OperationalError(e) \ No newline at end of file diff --git a/bin/weedb/test/test_weedb.py b/bin/weedb/test/test_weedb.py index f960cce5..380ed325 100644 --- a/bin/weedb/test/test_weedb.py +++ b/bin/weedb/test/test_weedb.py @@ -9,6 +9,7 @@ # """Test the weedb package""" +from __future__ import with_statement import unittest import weedb @@ -34,12 +35,13 @@ class Common(unittest.TestCase): weedb.create(self.db_dict) self.assertRaises(weedb.DatabaseExists, weedb.create, self.db_dict) _connect = weedb.connect(self.db_dict) - _cursor = _connect.cursor() - _cursor.execute("""CREATE TABLE test1 ( dateTime INTEGER NOT NULL UNIQUE PRIMARY KEY, """\ - """min REAL, mintime INTEGER, max REAL, maxtime INTEGER, sum REAL, count INTEGER);""") - _cursor.execute("""CREATE TABLE test2 ( dateTime INTEGER NOT NULL UNIQUE PRIMARY KEY, """\ - """min REAL, mintime INTEGER, max REAL, maxtime INTEGER, sum REAL, count INTEGER);""") - _cursor.close() + with weedb.Transaction(_connect) as _cursor: + _cursor.execute("""CREATE TABLE test1 ( dateTime INTEGER NOT NULL UNIQUE PRIMARY KEY, """\ + """min REAL, mintime INTEGER, max REAL, maxtime INTEGER, sum REAL, count INTEGER);""") + _cursor.execute("""CREATE TABLE test2 ( dateTime INTEGER NOT NULL UNIQUE PRIMARY KEY, """\ + """min REAL, mintime INTEGER, max REAL, maxtime INTEGER, sum REAL, count INTEGER);""") + for irec in range(20): + _cursor.execute("INSERT INTO test1 (dateTime, min, mintime) VALUES (?, ?, ?)", (irec, 10*irec, irec)) _connect.close() def test_drop(self): @@ -74,6 +76,25 @@ class Common(unittest.TestCase): self.assertRaises(weedb.OperationalError, _connect.columnsOf, 'foo') _connect.close() + def test_select(self): + self.populate_db() + _connect = weedb.connect(self.db_dict) + _cursor = _connect.cursor() + _cursor.execute("SELECT dateTime, min FROM test1") + for i, _row in enumerate(_cursor): + self.assertEqual(_row[0], i) + _cursor.close() + _connect.close() + + def test_bad_select(self): + self.populate_db() + _connect = weedb.connect(self.db_dict) + _cursor = _connect.cursor() + with self.assertRaises(weedb.OperationalError): + _cursor.execute("SELECT dateTime, min FROM foo") + _cursor.close() + _connect.close() + class TestSqlite(Common): def __init__(self, *args, **kwargs): @@ -88,7 +109,8 @@ class TestMySQL(Common): def suite(): - tests = ['test_drop', 'test_double_create', 'test_no_db', 'test_no_tables', 'test_create', 'test_bad_table'] + tests = ['test_drop', 'test_double_create', 'test_no_db', 'test_no_tables', + 'test_create', 'test_bad_table', 'test_select', 'test_bad_select'] return unittest.TestSuite(map(TestSqlite, tests) + map(TestMySQL, tests)) if __name__ == '__main__':