Skip to content
Open
174 changes: 174 additions & 0 deletions Lib/test/test_sqlite3/test_userfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,34 @@ def value(self): return 1 << 65
self.assertRaisesRegex(sqlite.DataError, "string or blob too big",
self.cur.execute, self.query % "err_val_ret")

def test_close_conn_in_window_func_value(self):
# gh-145040: closing connection in window function value() callback.
con = sqlite.connect(":memory:", autocommit=True)
con.execute("CREATE TABLE t(x INTEGER)")
con.executemany("INSERT INTO t VALUES(?)",
[(i,) for i in range(20)])

class CloseConnWindow:
def step(self, value):
pass
def finalize(self):
return 0
def value(self):
con.close()
return 0
def inverse(self, value):
pass

con.create_window_function("evil_win", 1, CloseConnWindow)
msg = "from within a callback"
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
cursor = con.execute(
"SELECT evil_win(x) OVER "
"(ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM t"
)
list(cursor)
con.close()


class AggregateTests(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -723,6 +751,152 @@ def test_agg_keyword_args(self):
'takes exactly 3 positional arguments'):
self.con.create_aggregate("test", 1, aggregate_class=AggrText)

def test_aggr_close_conn_in_step(self):
# Connection.close() in an aggregate step callback must not crash.
con = sqlite.connect(":memory:", autocommit=True)
cur = con.cursor()
cur.execute("CREATE TABLE t(x INTEGER)")
for i in range(50):
cur.execute("INSERT INTO t VALUES (?)", (i,))

class CloseConnAgg:
def __init__(self):
self.total = 0

def step(self, value):
self.total += value
con.close()

def finalize(self):
return self.total

con.create_aggregate("agg_close", 1, CloseConnAgg)
msg = "from within a callback"
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
con.execute("SELECT agg_close(x) FROM t")
con.close()

def test_close_conn_in_nested_callback(self):
# gh-145040: close() must be prevented even in nested callbacks.
con = sqlite.connect(":memory:", autocommit=True)
con.execute("CREATE TABLE t(x INTEGER)")
for i in range(5):
con.execute("INSERT INTO t VALUES(?)", (i,))

def outer_func(x):
con.close()
return x

def inner_func(x):
return x * 10

con.create_function("outer_func", 1, outer_func)
con.create_function("inner_func", 1, inner_func)
msg = "from within a callback"
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
con.execute("SELECT outer_func(inner_func(x)) FROM t")
# Connection must still be usable after the failed close attempt.
self.assertEqual(con.execute("SELECT 1").fetchone(), (1,))
con.close()

def test_close_conn_in_nested_callback_caught(self):
# gh-145040: close attempt must propagate even if the exception
# is caught inside the callback and a nested execute consumes
# the flag.
con = sqlite.connect(":memory:", autocommit=True)
con.execute("CREATE TABLE t(x INTEGER)")
con.execute("INSERT INTO t VALUES(1)")

def swallow_close(x):
try:
con.close()
except sqlite.ProgrammingError:
pass
try:
con.execute("SELECT 1")
except sqlite.ProgrammingError:
pass
return x

con.create_function("swallow_close", 1, swallow_close)
msg = "from within a callback"
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
con.execute("SELECT swallow_close(x) FROM t")
# Connection must still be usable.
self.assertEqual(con.execute("SELECT 1").fetchone(), (1,))
con.close()

def test_close_conn_in_udf_during_executemany(self):
# gh-145040: closing connection in UDF during executemany.
con = sqlite.connect(":memory:", autocommit=True)
con.execute("CREATE TABLE t(x)")

def close_conn(x):
con.close()
return x

con.create_function("close_conn", 1, close_conn)
msg = "from within a callback"
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
con.executemany("INSERT INTO t VALUES(close_conn(?))",
[(i,) for i in range(10)])
con.close()

def test_close_conn_in_progress_handler_during_iternext(self):
# gh-145040: closing connection in progress handler during iteration.
con = sqlite.connect(":memory:", autocommit=True)
con.execute("CREATE TABLE t(x)")
con.executemany("INSERT INTO t VALUES(?)",
[(i,) for i in range(100)])

count = 0
def close_progress():
nonlocal count
count += 1
if count >= 5:
con.close()
return 1
return 0

cursor = con.execute("SELECT * FROM t")
con.set_progress_handler(close_progress, 1)
msg = "from within a callback"
import test.support
with test.support.catch_unraisable_exception():
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
for row in cursor:
pass
del cursor
gc_collect()
con.close()

def test_close_conn_in_collation_callback(self):
# gh-145040: closing connection in collation callback.
con = sqlite.connect(":memory:", autocommit=True)
con.execute("CREATE TABLE t(x TEXT)")
con.executemany("INSERT INTO t VALUES(?)",
[(f"item_{i}",) for i in range(50)])

count = 0
def evil_collation(a, b):
nonlocal count
count += 1
if count == 10:
con.close()
if a < b:
return -1
elif a > b:
return 1
return 0

con.create_collation("evil_coll", evil_collation)
msg = "from within a callback"
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
con.execute(
"SELECT * FROM t ORDER BY x COLLATE evil_coll"
)
con.close()


class AuthorizerTests(unittest.TestCase):
@staticmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Fixed a crash in the :mod:`sqlite3` module caused by closing the database
connection from within a callback function invoked during
``sqlite3_step()`` (e.g., an aggregate ``step``, a user-defined function
via :meth:`~sqlite3.Connection.create_function`, a progress handler, or a
collation callback). Raise :exc:`~sqlite3.ProgrammingError` instead of
crashing.
40 changes: 39 additions & 1 deletion Modules/_sqlite/connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "pycore_pyerrors.h" // _PyErr_ChainExceptions1()
#include "pycore_pylifecycle.h" // _Py_IsInterpreterFinalizing()
#include "pycore_unicodeobject.h" // _PyUnicode_AsUTF8NoNUL
#include "pycore_weakref.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need pycore_weakref.h; we should be fine with the public API functions here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See blob.c for how to iterate through the weak refs using the public API.


#include <stdbool.h>

Expand Down Expand Up @@ -283,10 +284,17 @@ pysqlite_connection_init_impl(pysqlite_Connection *self, PyObject *database,
goto error;
}

/* Create lists of weak references to blobs */
/* Create lists of weak references to cursors and blobs */
PyObject *cursors = PyList_New(0);
if (cursors == NULL) {
Py_DECREF(statement_cache);
goto error;
}

PyObject *blobs = PyList_New(0);
if (blobs == NULL) {
Py_DECREF(statement_cache);
Py_DECREF(cursors);
goto error;
}

Expand All @@ -299,7 +307,9 @@ pysqlite_connection_init_impl(pysqlite_Connection *self, PyObject *database,
self->check_same_thread = check_same_thread;
self->thread_ident = PyThread_get_thread_ident();
self->statement_cache = statement_cache;
self->cursors = cursors;
self->blobs = blobs;
self->close_attempted_in_callback = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding this flag, did you try to just raise in pysqlite_connection_close_impl?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(The exception should propagate through the existing (slightly weird) callback exception handling in connection.c)

self->row_factory = Py_NewRef(Py_None);
self->text_factory = Py_NewRef(&PyUnicode_Type);
self->trace_ctx = NULL;
Expand Down Expand Up @@ -381,6 +391,7 @@ connection_traverse(PyObject *op, visitproc visit, void *arg)
pysqlite_Connection *self = _pysqlite_Connection_CAST(op);
Py_VISIT(Py_TYPE(self));
Py_VISIT(self->statement_cache);
Py_VISIT(self->cursors);
Py_VISIT(self->blobs);
Py_VISIT(self->row_factory);
Py_VISIT(self->text_factory);
Expand All @@ -405,6 +416,7 @@ connection_clear(PyObject *op)
{
pysqlite_Connection *self = _pysqlite_Connection_CAST(op);
Py_CLEAR(self->statement_cache);
Py_CLEAR(self->cursors);
Py_CLEAR(self->blobs);
Py_CLEAR(self->row_factory);
Py_CLEAR(self->text_factory);
Expand Down Expand Up @@ -655,6 +667,32 @@ pysqlite_connection_close_impl(pysqlite_Connection *self)
return NULL;
}

/* Check if any cursor is locked (actively executing a query);
* closing during a callback is illegal per the SQLite C API docs. */
assert(PyList_CheckExact(self->cursors));
Py_ssize_t n = PyList_GET_SIZE(self->cursors);
for (Py_ssize_t i = 0; i < n; i++) {
PyObject *weakref = PyList_GET_ITEM(self->cursors, i);
if (_PyWeakref_IsDead(weakref)) {
continue;
}
PyObject *obj;
if (!PyWeakref_GetRef(weakref, &obj)) {
continue;
}
int locked = ((pysqlite_Cursor *)obj)->locked;
Py_DECREF(obj);
if (locked) {
self->close_attempted_in_callback = 1;
PyTypeObject *tp = Py_TYPE(self);
pysqlite_state *state = pysqlite_get_state_by_type(tp);
PyErr_SetString(state->ProgrammingError,
"Cannot close the database connection "
"from within a callback function.");
return NULL;
}
}

pysqlite_close_all_blobs(self);
Py_CLEAR(self->statement_cache);
if (connection_close(self) < 0) {
Expand Down
8 changes: 7 additions & 1 deletion Modules/_sqlite/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,18 @@ typedef struct

int initialized;

/* set to 1 when close() is attempted while a cursor is locked (actively
* executing); checked after stmt_step() returns to raise the appropriate
* ProgrammingError */
int close_attempted_in_callback;

/* thread identification of the thread the connection was created in */
unsigned long thread_ident;

PyObject *statement_cache;

/* Lists of weak references to blobs used within this connection */
/* Lists of weak references to cursors and blobs used within this connection */
PyObject *cursors;
PyObject *blobs;

PyObject* row_factory;
Expand Down
Loading
Loading