summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11564.misc1
-rw-r--r--synapse/storage/database.py50
-rw-r--r--synapse/storage/databases/main/presence.py5
3 files changed, 49 insertions, 7 deletions
diff --git a/changelog.d/11564.misc b/changelog.d/11564.misc
new file mode 100644
index 0000000000..2c48e22de0
--- /dev/null
+++ b/changelog.d/11564.misc
@@ -0,0 +1 @@
+Add some safety checks that storage functions are used correctly.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 3b44e6469c..a219999f15 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -13,8 +13,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import inspect
 import logging
 import time
+import types
 from collections import defaultdict
 from sys import intern
 from time import monotonic as monotonic_time
@@ -526,6 +528,12 @@ class DatabasePool:
         the function will correctly handle being aborted and retried half way
         through its execution.
 
+        Similarly, the arguments to `func` (`args`, `kwargs`) should not be generators,
+        since they could be evaluated multiple times (which would produce an empty
+        result on the second or subsequent evaluation). Likewise, the closure of `func`
+        must not reference any generators.  This method attempts to detect such usage
+        and will log an error.
+
         Args:
             conn
             desc
@@ -536,6 +544,39 @@ class DatabasePool:
             **kwargs
         """
 
+        # Robustness check: ensure that none of the arguments are generators, since that
+        # will fail if we have to repeat the transaction.
+        # For now, we just log an error, and hope that it works on the first attempt.
+        # TODO: raise an exception.
+        for i, arg in enumerate(args):
+            if inspect.isgenerator(arg):
+                logger.error(
+                    "Programming error: generator passed to new_transaction as "
+                    "argument %i to function %s",
+                    i,
+                    func,
+                )
+        for name, val in kwargs.items():
+            if inspect.isgenerator(val):
+                logger.error(
+                    "Programming error: generator passed to new_transaction as "
+                    "argument %s to function %s",
+                    name,
+                    func,
+                )
+        # also check variables referenced in func's closure
+        if inspect.isfunction(func):
+            f = cast(types.FunctionType, func)
+            if f.__closure__:
+                for i, cell in enumerate(f.__closure__):
+                    if inspect.isgenerator(cell.cell_contents):
+                        logger.error(
+                            "Programming error: function %s references generator %s "
+                            "via its closure",
+                            f,
+                            f.__code__.co_freevars[i],
+                        )
+
         start = monotonic_time()
         txn_id = self._TXN_ID
 
@@ -1226,9 +1267,9 @@ class DatabasePool:
         self,
         table: str,
         key_names: Collection[str],
-        key_values: Collection[Iterable[Any]],
+        key_values: Collection[Collection[Any]],
         value_names: Collection[str],
-        value_values: Iterable[Iterable[Any]],
+        value_values: Collection[Collection[Any]],
         desc: str,
     ) -> None:
         """
@@ -1920,7 +1961,7 @@ class DatabasePool:
         self,
         table: str,
         column: str,
-        iterable: Iterable[Any],
+        iterable: Collection[Any],
         keyvalues: Dict[str, Any],
         desc: str,
     ) -> int:
@@ -1931,7 +1972,8 @@ class DatabasePool:
         Args:
             table: string giving the table name
             column: column name to test for inclusion against `iterable`
-            iterable: list
+            iterable: list of values to match against `column`. NB cannot be a generator
+                as it may be evaluated multiple times.
             keyvalues: dict of column names and values to select the rows with
             desc: description of the transaction, for logging and metrics
 
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 02d534ae45..cbf9ec38f7 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -269,6 +269,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         """
         # Add user entries to the table, updating the presence_stream_id column if the user already
         # exists in the table.
+        presence_stream_id = self._presence_id_gen.get_current_token()
         await self.db_pool.simple_upsert_many(
             table="users_to_send_full_presence_to",
             key_names=("user_id",),
@@ -279,9 +280,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
             # devices at different times, each device will receive full presence once - when
             # the presence stream ID in their sync token is less than the one in the table
             # for their user ID.
-            value_values=(
-                (self._presence_id_gen.get_current_token(),) for _ in user_ids
-            ),
+            value_values=[(presence_stream_id,) for _ in user_ids],
             desc="add_users_to_send_full_presence_to",
         )