| #-*- coding: ISO-8859-1 -*- |
| # pysqlite2/test/userfunctions.py: tests for user-defined functions and |
| # aggregates. |
| # |
| # Copyright (C) 2005-2007 Gerhard Häring <[email protected]> |
| # |
| # This file is part of pysqlite. |
| # |
| # This software is provided 'as-is', without any express or implied |
| # warranty. In no event will the authors be held liable for any damages |
| # arising from the use of this software. |
| # |
| # Permission is granted to anyone to use this software for any purpose, |
| # including commercial applications, and to alter it and redistribute it |
| # freely, subject to the following restrictions: |
| # |
| # 1. The origin of this software must not be misrepresented; you must not |
| # claim that you wrote the original software. If you use this software |
| # in a product, an acknowledgment in the product documentation would be |
| # appreciated but is not required. |
| # 2. Altered source versions must be plainly marked as such, and must not be |
| # misrepresented as being the original software. |
| # 3. This notice may not be removed or altered from any source distribution. |
| |
| import unittest |
| import sqlite3 as sqlite |
| |
| def func_returntext(): |
| return "foo" |
| def func_returnunicode(): |
| return u"bar" |
| def func_returnint(): |
| return 42 |
| def func_returnfloat(): |
| return 3.14 |
| def func_returnnull(): |
| return None |
| def func_returnblob(): |
| return buffer("blob") |
| def func_returnlonglong(): |
| return 1<<31 |
| def func_raiseexception(): |
| 5 // 0 |
| |
| def func_isstring(v): |
| return type(v) is unicode |
| def func_isint(v): |
| return type(v) is int |
| def func_isfloat(v): |
| return type(v) is float |
| def func_isnone(v): |
| return type(v) is type(None) |
| def func_isblob(v): |
| return type(v) is buffer |
| def func_islonglong(v): |
| return isinstance(v, (int, long)) and v >= 1<<31 |
| |
| class AggrNoStep: |
| def __init__(self): |
| pass |
| |
| def finalize(self): |
| return 1 |
| |
| class AggrNoFinalize: |
| def __init__(self): |
| pass |
| |
| def step(self, x): |
| pass |
| |
| class AggrExceptionInInit: |
| def __init__(self): |
| 5 // 0 |
| |
| def step(self, x): |
| pass |
| |
| def finalize(self): |
| pass |
| |
| class AggrExceptionInStep: |
| def __init__(self): |
| pass |
| |
| def step(self, x): |
| 5 // 0 |
| |
| def finalize(self): |
| return 42 |
| |
| class AggrExceptionInFinalize: |
| def __init__(self): |
| pass |
| |
| def step(self, x): |
| pass |
| |
| def finalize(self): |
| 5 // 0 |
| |
| class AggrCheckType: |
| def __init__(self): |
| self.val = None |
| |
| def step(self, whichType, val): |
| theType = {"str": unicode, "int": int, "float": float, "None": type(None), "blob": buffer} |
| self.val = int(theType[whichType] is type(val)) |
| |
| def finalize(self): |
| return self.val |
| |
| class AggrSum: |
| def __init__(self): |
| self.val = 0.0 |
| |
| def step(self, val): |
| self.val += val |
| |
| def finalize(self): |
| return self.val |
| |
| class FunctionTests(unittest.TestCase): |
| def setUp(self): |
| self.con = sqlite.connect(":memory:") |
| |
| self.con.create_function("returntext", 0, func_returntext) |
| self.con.create_function("returnunicode", 0, func_returnunicode) |
| self.con.create_function("returnint", 0, func_returnint) |
| self.con.create_function("returnfloat", 0, func_returnfloat) |
| self.con.create_function("returnnull", 0, func_returnnull) |
| self.con.create_function("returnblob", 0, func_returnblob) |
| self.con.create_function("returnlonglong", 0, func_returnlonglong) |
| self.con.create_function("raiseexception", 0, func_raiseexception) |
| |
| self.con.create_function("isstring", 1, func_isstring) |
| self.con.create_function("isint", 1, func_isint) |
| self.con.create_function("isfloat", 1, func_isfloat) |
| self.con.create_function("isnone", 1, func_isnone) |
| self.con.create_function("isblob", 1, func_isblob) |
| self.con.create_function("islonglong", 1, func_islonglong) |
| |
| def tearDown(self): |
| self.con.close() |
| |
| def CheckFuncErrorOnCreate(self): |
| try: |
| self.con.create_function("bla", -100, lambda x: 2*x) |
| self.fail("should have raised an OperationalError") |
| except sqlite.OperationalError: |
| pass |
| |
| def CheckFuncRefCount(self): |
| def getfunc(): |
| def f(): |
| return 1 |
| return f |
| f = getfunc() |
| globals()["foo"] = f |
| # self.con.create_function("reftest", 0, getfunc()) |
| self.con.create_function("reftest", 0, f) |
| cur = self.con.cursor() |
| cur.execute("select reftest()") |
| |
| def CheckFuncReturnText(self): |
| cur = self.con.cursor() |
| cur.execute("select returntext()") |
| val = cur.fetchone()[0] |
| self.assertEqual(type(val), unicode) |
| self.assertEqual(val, "foo") |
| |
| def CheckFuncReturnUnicode(self): |
| cur = self.con.cursor() |
| cur.execute("select returnunicode()") |
| val = cur.fetchone()[0] |
| self.assertEqual(type(val), unicode) |
| self.assertEqual(val, u"bar") |
| |
| def CheckFuncReturnInt(self): |
| cur = self.con.cursor() |
| cur.execute("select returnint()") |
| val = cur.fetchone()[0] |
| self.assertEqual(type(val), int) |
| self.assertEqual(val, 42) |
| |
| def CheckFuncReturnFloat(self): |
| cur = self.con.cursor() |
| cur.execute("select returnfloat()") |
| val = cur.fetchone()[0] |
| self.assertEqual(type(val), float) |
| if val < 3.139 or val > 3.141: |
| self.fail("wrong value") |
| |
| def CheckFuncReturnNull(self): |
| cur = self.con.cursor() |
| cur.execute("select returnnull()") |
| val = cur.fetchone()[0] |
| self.assertEqual(type(val), type(None)) |
| self.assertEqual(val, None) |
| |
| def CheckFuncReturnBlob(self): |
| cur = self.con.cursor() |
| cur.execute("select returnblob()") |
| val = cur.fetchone()[0] |
| self.assertEqual(type(val), buffer) |
| self.assertEqual(val, buffer("blob")) |
| |
| def CheckFuncReturnLongLong(self): |
| cur = self.con.cursor() |
| cur.execute("select returnlonglong()") |
| val = cur.fetchone()[0] |
| self.assertEqual(val, 1<<31) |
| |
| def CheckFuncException(self): |
| cur = self.con.cursor() |
| try: |
| cur.execute("select raiseexception()") |
| cur.fetchone() |
| self.fail("should have raised OperationalError") |
| except sqlite.OperationalError, e: |
| self.assertEqual(e.args[0], 'user-defined function raised exception') |
| |
| def CheckParamString(self): |
| cur = self.con.cursor() |
| cur.execute("select isstring(?)", ("foo",)) |
| val = cur.fetchone()[0] |
| self.assertEqual(val, 1) |
| |
| def CheckParamInt(self): |
| cur = self.con.cursor() |
| cur.execute("select isint(?)", (42,)) |
| val = cur.fetchone()[0] |
| self.assertEqual(val, 1) |
| |
| def CheckParamFloat(self): |
| cur = self.con.cursor() |
| cur.execute("select isfloat(?)", (3.14,)) |
| val = cur.fetchone()[0] |
| self.assertEqual(val, 1) |
| |
| def CheckParamNone(self): |
| cur = self.con.cursor() |
| cur.execute("select isnone(?)", (None,)) |
| val = cur.fetchone()[0] |
| self.assertEqual(val, 1) |
| |
| def CheckParamBlob(self): |
| cur = self.con.cursor() |
| cur.execute("select isblob(?)", (buffer("blob"),)) |
| val = cur.fetchone()[0] |
| self.assertEqual(val, 1) |
| |
| def CheckParamLongLong(self): |
| cur = self.con.cursor() |
| cur.execute("select islonglong(?)", (1<<42,)) |
| val = cur.fetchone()[0] |
| self.assertEqual(val, 1) |
| |
| class AggregateTests(unittest.TestCase): |
| def setUp(self): |
| self.con = sqlite.connect(":memory:") |
| cur = self.con.cursor() |
| cur.execute(""" |
| create table test( |
| t text, |
| i integer, |
| f float, |
| n, |
| b blob |
| ) |
| """) |
| cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)", |
| ("foo", 5, 3.14, None, buffer("blob"),)) |
| |
| self.con.create_aggregate("nostep", 1, AggrNoStep) |
| self.con.create_aggregate("nofinalize", 1, AggrNoFinalize) |
| self.con.create_aggregate("excInit", 1, AggrExceptionInInit) |
| self.con.create_aggregate("excStep", 1, AggrExceptionInStep) |
| self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize) |
| self.con.create_aggregate("checkType", 2, AggrCheckType) |
| self.con.create_aggregate("mysum", 1, AggrSum) |
| |
| def tearDown(self): |
| #self.cur.close() |
| #self.con.close() |
| pass |
| |
| def CheckAggrErrorOnCreate(self): |
| try: |
| self.con.create_function("bla", -100, AggrSum) |
| self.fail("should have raised an OperationalError") |
| except sqlite.OperationalError: |
| pass |
| |
| def CheckAggrNoStep(self): |
| cur = self.con.cursor() |
| try: |
| cur.execute("select nostep(t) from test") |
| self.fail("should have raised an AttributeError") |
| except AttributeError, e: |
| self.assertEqual(e.args[0], "AggrNoStep instance has no attribute 'step'") |
| |
| def CheckAggrNoFinalize(self): |
| cur = self.con.cursor() |
| try: |
| cur.execute("select nofinalize(t) from test") |
| val = cur.fetchone()[0] |
| self.fail("should have raised an OperationalError") |
| except sqlite.OperationalError, e: |
| self.assertEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error") |
| |
| def CheckAggrExceptionInInit(self): |
| cur = self.con.cursor() |
| try: |
| cur.execute("select excInit(t) from test") |
| val = cur.fetchone()[0] |
| self.fail("should have raised an OperationalError") |
| except sqlite.OperationalError, e: |
| self.assertEqual(e.args[0], "user-defined aggregate's '__init__' method raised error") |
| |
| def CheckAggrExceptionInStep(self): |
| cur = self.con.cursor() |
| try: |
| cur.execute("select excStep(t) from test") |
| val = cur.fetchone()[0] |
| self.fail("should have raised an OperationalError") |
| except sqlite.OperationalError, e: |
| self.assertEqual(e.args[0], "user-defined aggregate's 'step' method raised error") |
| |
| def CheckAggrExceptionInFinalize(self): |
| cur = self.con.cursor() |
| try: |
| cur.execute("select excFinalize(t) from test") |
| val = cur.fetchone()[0] |
| self.fail("should have raised an OperationalError") |
| except sqlite.OperationalError, e: |
| self.assertEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error") |
| |
| def CheckAggrCheckParamStr(self): |
| cur = self.con.cursor() |
| cur.execute("select checkType('str', ?)", ("foo",)) |
| val = cur.fetchone()[0] |
| self.assertEqual(val, 1) |
| |
| def CheckAggrCheckParamInt(self): |
| cur = self.con.cursor() |
| cur.execute("select checkType('int', ?)", (42,)) |
| val = cur.fetchone()[0] |
| self.assertEqual(val, 1) |
| |
| def CheckAggrCheckParamFloat(self): |
| cur = self.con.cursor() |
| cur.execute("select checkType('float', ?)", (3.14,)) |
| val = cur.fetchone()[0] |
| self.assertEqual(val, 1) |
| |
| def CheckAggrCheckParamNone(self): |
| cur = self.con.cursor() |
| cur.execute("select checkType('None', ?)", (None,)) |
| val = cur.fetchone()[0] |
| self.assertEqual(val, 1) |
| |
| def CheckAggrCheckParamBlob(self): |
| cur = self.con.cursor() |
| cur.execute("select checkType('blob', ?)", (buffer("blob"),)) |
| val = cur.fetchone()[0] |
| self.assertEqual(val, 1) |
| |
| def CheckAggrCheckAggrSum(self): |
| cur = self.con.cursor() |
| cur.execute("delete from test") |
| cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)]) |
| cur.execute("select mysum(i) from test") |
| val = cur.fetchone()[0] |
| self.assertEqual(val, 60) |
| |
| class AuthorizerTests(unittest.TestCase): |
| @staticmethod |
| def authorizer_cb(action, arg1, arg2, dbname, source): |
| if action != sqlite.SQLITE_SELECT: |
| return sqlite.SQLITE_DENY |
| if arg2 == 'c2' or arg1 == 't2': |
| return sqlite.SQLITE_DENY |
| return sqlite.SQLITE_OK |
| |
| def setUp(self): |
| self.con = sqlite.connect(":memory:") |
| self.con.executescript(""" |
| create table t1 (c1, c2); |
| create table t2 (c1, c2); |
| insert into t1 (c1, c2) values (1, 2); |
| insert into t2 (c1, c2) values (4, 5); |
| """) |
| |
| # For our security test: |
| self.con.execute("select c2 from t2") |
| |
| self.con.set_authorizer(self.authorizer_cb) |
| |
| def tearDown(self): |
| pass |
| |
| def test_table_access(self): |
| try: |
| self.con.execute("select * from t2") |
| except sqlite.DatabaseError, e: |
| if not e.args[0].endswith("prohibited"): |
| self.fail("wrong exception text: %s" % e.args[0]) |
| return |
| self.fail("should have raised an exception due to missing privileges") |
| |
| def test_column_access(self): |
| try: |
| self.con.execute("select c2 from t1") |
| except sqlite.DatabaseError, e: |
| if not e.args[0].endswith("prohibited"): |
| self.fail("wrong exception text: %s" % e.args[0]) |
| return |
| self.fail("should have raised an exception due to missing privileges") |
| |
| class AuthorizerRaiseExceptionTests(AuthorizerTests): |
| @staticmethod |
| def authorizer_cb(action, arg1, arg2, dbname, source): |
| if action != sqlite.SQLITE_SELECT: |
| raise ValueError |
| if arg2 == 'c2' or arg1 == 't2': |
| raise ValueError |
| return sqlite.SQLITE_OK |
| |
| class AuthorizerIllegalTypeTests(AuthorizerTests): |
| @staticmethod |
| def authorizer_cb(action, arg1, arg2, dbname, source): |
| if action != sqlite.SQLITE_SELECT: |
| return 0.0 |
| if arg2 == 'c2' or arg1 == 't2': |
| return 0.0 |
| return sqlite.SQLITE_OK |
| |
| class AuthorizerLargeIntegerTests(AuthorizerTests): |
| @staticmethod |
| def authorizer_cb(action, arg1, arg2, dbname, source): |
| if action != sqlite.SQLITE_SELECT: |
| return 2**32 |
| if arg2 == 'c2' or arg1 == 't2': |
| return 2**32 |
| return sqlite.SQLITE_OK |
| |
| |
| def suite(): |
| function_suite = unittest.makeSuite(FunctionTests, "Check") |
| aggregate_suite = unittest.makeSuite(AggregateTests, "Check") |
| authorizer_suite = unittest.makeSuite(AuthorizerTests) |
| return unittest.TestSuite(( |
| function_suite, |
| aggregate_suite, |
| authorizer_suite, |
| unittest.makeSuite(AuthorizerRaiseExceptionTests), |
| unittest.makeSuite(AuthorizerIllegalTypeTests), |
| unittest.makeSuite(AuthorizerLargeIntegerTests), |
| )) |
| |
| def test(): |
| runner = unittest.TextTestRunner() |
| runner.run(suite()) |
| |
| if __name__ == "__main__": |
| test() |