summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py53
-rw-r--r--synapse/storage/data_stores/main/account_data.py6
-rw-r--r--synapse/storage/data_stores/main/deviceinbox.py2
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py6
-rw-r--r--synapse/storage/data_stores/main/events.py8
-rw-r--r--synapse/storage/data_stores/main/filtering.py2
-rw-r--r--synapse/storage/data_stores/main/media_repository.py6
-rw-r--r--synapse/storage/data_stores/main/receipts.py2
-rw-r--r--synapse/storage/data_stores/main/registration.py4
-rw-r--r--synapse/storage/data_stores/main/stream.py2
-rw-r--r--synapse/storage/data_stores/main/tags.py4
-rw-r--r--synapse/storage/prepare_database.py2
12 files changed, 62 insertions, 35 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index ab596fa68d..459901ac60 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -409,16 +409,15 @@ class SQLBaseStore(object):
             i = 0
             N = 5
             while True:
+                cursor = LoggingTransaction(
+                    conn.cursor(),
+                    name,
+                    self.database_engine,
+                    after_callbacks,
+                    exception_callbacks,
+                )
                 try:
-                    txn = conn.cursor()
-                    txn = LoggingTransaction(
-                        txn,
-                        name,
-                        self.database_engine,
-                        after_callbacks,
-                        exception_callbacks,
-                    )
-                    r = func(txn, *args, **kwargs)
+                    r = func(cursor, *args, **kwargs)
                     conn.commit()
                     return r
                 except self.database_engine.module.OperationalError as e:
@@ -456,6 +455,40 @@ class SQLBaseStore(object):
                                 )
                             continue
                     raise
+                finally:
+                    # we're either about to retry with a new cursor, or we're about to
+                    # release the connection. Once we release the connection, it could
+                    # get used for another query, which might do a conn.rollback().
+                    #
+                    # In the latter case, even though that probably wouldn't affect the
+                    # results of this transaction, python's sqlite will reset all
+                    # statements on the connection [1], which will make our cursor
+                    # invalid [2].
+                    #
+                    # In any case, continuing to read rows after commit()ing seems
+                    # dubious from the PoV of ACID transactional semantics
+                    # (sqlite explicitly says that once you commit, you may see rows
+                    # from subsequent updates.)
+                    #
+                    # In psycopg2, cursors are essentially a client-side fabrication -
+                    # all the data is transferred to the client side when the statement
+                    # finishes executing - so in theory we could go on streaming results
+                    # from the cursor, but attempting to do so would make us
+                    # incompatible with sqlite, so let's make sure we're not doing that
+                    # by closing the cursor.
+                    #
+                    # (*named* cursors in psycopg2 are different and are proper server-
+                    # side things, but (a) we don't use them and (b) they are implicitly
+                    # closed by ending the transaction anyway.)
+                    #
+                    # In short, if we haven't finished with the cursor yet, that's a
+                    # problem waiting to bite us.
+                    #
+                    # TL;DR: we're done with the cursor, so we can close it.
+                    #
+                    # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465
+                    # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
+                    cursor.close()
         except Exception as e:
             logger.debug("[TXN FAIL] {%s} %s", name, e)
             raise
@@ -851,7 +884,7 @@ class SQLBaseStore(object):
             allvalues.update(values)
             latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
 
-        sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % (
+        sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
             table,
             ", ".join(k for k in allvalues),
             ", ".join("?" for _ in allvalues),
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index 6afbfc0d74..22093484ed 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -184,14 +184,14 @@ class AccountDataWorkerStore(SQLBaseStore):
             current_id(int): The position to fetch up to.
         Returns:
             A deferred pair of lists of tuples of stream_id int, user_id string,
-            room_id string, type string, and content string.
+            room_id string, and type string.
         """
         if last_room_id == current_id and last_global_id == current_id:
             return defer.succeed(([], []))
 
         def get_updated_account_data_txn(txn):
             sql = (
-                "SELECT stream_id, user_id, account_data_type, content"
+                "SELECT stream_id, user_id, account_data_type"
                 " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
@@ -199,7 +199,7 @@ class AccountDataWorkerStore(SQLBaseStore):
             global_results = txn.fetchall()
 
             sql = (
-                "SELECT stream_id, user_id, room_id, account_data_type, content"
+                "SELECT stream_id, user_id, room_id, account_data_type"
                 " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 96cd0fb77a..a23744f11c 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -380,7 +380,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
             devices = list(messages_by_device.keys())
             if len(devices) == 1 and devices[0] == "*":
                 # Handle wildcard device_ids.
-                sql = "SELECT device_id FROM devices" " WHERE user_id = ?"
+                sql = "SELECT device_id FROM devices WHERE user_id = ?"
                 txn.execute(sql, (user_id,))
                 message_json = json.dumps(messages_by_device["*"])
                 for row in txn:
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 073412a78d..d8ad59ad93 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -138,9 +138,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
                 result.setdefault(user_id, {})[device_id] = None
 
         # get signatures on the device
-        signature_sql = (
-            "SELECT * " "  FROM e2e_cross_signing_signatures " " WHERE %s"
-        ) % (" OR ".join("(" + q + ")" for q in signature_query_clauses))
+        signature_sql = ("SELECT *  FROM e2e_cross_signing_signatures WHERE %s") % (
+            " OR ".join("(" + q + ")" for q in signature_query_clauses)
+        )
 
         txn.execute(signature_sql, signature_query_params)
         rows = self.cursor_to_dict(txn)
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 878f7568a6..627c0b67f1 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -713,9 +713,7 @@ class EventsStore(
 
                 metadata_json = encode_json(event.internal_metadata.get_dict())
 
-                sql = (
-                    "UPDATE event_json SET internal_metadata = ?" " WHERE event_id = ?"
-                )
+                sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
                 txn.execute(sql, (metadata_json, event.event_id))
 
                 # Add an entry to the ex_outlier_stream table to replicate the
@@ -732,7 +730,7 @@ class EventsStore(
                     },
                 )
 
-                sql = "UPDATE events SET outlier = ?" " WHERE event_id = ?"
+                sql = "UPDATE events SET outlier = ? WHERE event_id = ?"
                 txn.execute(sql, (False, event.event_id))
 
                 # Update the event_backward_extremities table now that this
@@ -1479,7 +1477,7 @@ class EventsStore(
 
         # We do joins against events_to_purge for e.g. calculating state
         # groups to purge, etc., so lets make an index.
-        txn.execute("CREATE INDEX events_to_purge_id" " ON events_to_purge(event_id)")
+        txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)")
 
         txn.execute("SELECT event_id, should_delete FROM events_to_purge")
         event_rows = txn.fetchall()
diff --git a/synapse/storage/data_stores/main/filtering.py b/synapse/storage/data_stores/main/filtering.py
index a2a2a67927..f05ace299a 100644
--- a/synapse/storage/data_stores/main/filtering.py
+++ b/synapse/storage/data_stores/main/filtering.py
@@ -55,7 +55,7 @@ class FilteringStore(SQLBaseStore):
             if filter_id_response is not None:
                 return filter_id_response[0]
 
-            sql = "SELECT MAX(filter_id) FROM user_filters " "WHERE user_id = ?"
+            sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
             txn.execute(sql, (user_localpart,))
             max_id = txn.fetchone()[0]
             if max_id is None:
diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py
index 84b5f3ad5e..0f2887bdce 100644
--- a/synapse/storage/data_stores/main/media_repository.py
+++ b/synapse/storage/data_stores/main/media_repository.py
@@ -337,7 +337,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         if len(media_ids) == 0:
             return
 
-        sql = "DELETE FROM local_media_repository_url_cache" " WHERE media_id = ?"
+        sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
 
         def _delete_url_cache_txn(txn):
             txn.executemany(sql, [(media_id,) for media_id in media_ids])
@@ -365,11 +365,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             return
 
         def _delete_url_cache_media_txn(txn):
-            sql = "DELETE FROM local_media_repository" " WHERE media_id = ?"
+            sql = "DELETE FROM local_media_repository WHERE media_id = ?"
 
             txn.executemany(sql, [(media_id,) for media_id in media_ids])
 
-            sql = "DELETE FROM local_media_repository_thumbnails" " WHERE media_id = ?"
+            sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
 
             txn.executemany(sql, [(media_id,) for media_id in media_ids])
 
diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py
index 0c24430f28..8b17334ff4 100644
--- a/synapse/storage/data_stores/main/receipts.py
+++ b/synapse/storage/data_stores/main/receipts.py
@@ -280,7 +280,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 args.append(limit)
             txn.execute(sql, args)
 
-            return (r[0:5] + (json.loads(r[5]),) for r in txn)
+            return list(r[0:5] + (json.loads(r[5]),) for r in txn)
 
         return self.runInteraction(
             "get_all_updated_receipts", get_all_updated_receipts_txn
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index ee1b2b2bbf..6a594c160c 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -377,9 +377,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         """
 
         def f(txn):
-            sql = (
-                "SELECT name, password_hash FROM users" " WHERE lower(name) = lower(?)"
-            )
+            sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)"
             txn.execute(sql, (user_id,))
             return dict(txn)
 
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 8780fdd989..9ae4a913a1 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -616,7 +616,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     def _get_max_topological_txn(self, txn, room_id):
         txn.execute(
-            "SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?",
+            "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
             (room_id,),
         )
 
diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py
index 10d1887f75..aa24339717 100644
--- a/synapse/storage/data_stores/main/tags.py
+++ b/synapse/storage/data_stores/main/tags.py
@@ -83,9 +83,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
         )
 
         def get_tag_content(txn, tag_ids):
-            sql = (
-                "SELECT tag, content" " FROM room_tags" " WHERE user_id=? AND room_id=?"
-            )
+            sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
             results = []
             for stream_id, user_id, room_id in tag_ids:
                 txn.execute(sql, (user_id, room_id))
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 2e7753820e..731e1c9d9c 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -447,7 +447,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
         # Mark as done.
         cur.execute(
             database_engine.convert_param_style(
-                "INSERT INTO applied_module_schemas (module_name, file)" " VALUES (?,?)"
+                "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)"
             ),
             (modname, name),
         )