diff --git a/changelog.d/11564.misc b/changelog.d/11564.misc
new file mode 100644
index 0000000000..2c48e22de0
--- /dev/null
+++ b/changelog.d/11564.misc
@@ -0,0 +1 @@
+Add some safety checks that storage functions are used correctly.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 3b44e6469c..a219999f15 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -13,8 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import logging
import time
+import types
from collections import defaultdict
from sys import intern
from time import monotonic as monotonic_time
@@ -526,6 +528,12 @@ class DatabasePool:
the function will correctly handle being aborted and retried half way
through its execution.
+ Similarly, the arguments to `func` (`args`, `kwargs`) should not be generators,
+ since they could be evaluated multiple times (which would produce an empty
+ result on the second or subsequent evaluation). Likewise, the closure of `func`
+ must not reference any generators. This method attempts to detect such usage
+ and will log an error.
+
Args:
conn
desc
@@ -536,6 +544,39 @@ class DatabasePool:
**kwargs
"""
+ # Robustness check: ensure that none of the arguments are generators, since that
+ # will fail if we have to repeat the transaction.
+ # For now, we just log an error, and hope that it works on the first attempt.
+ # TODO: raise an exception.
+ for i, arg in enumerate(args):
+ if inspect.isgenerator(arg):
+ logger.error(
+ "Programming error: generator passed to new_transaction as "
+ "argument %i to function %s",
+ i,
+ func,
+ )
+ for name, val in kwargs.items():
+ if inspect.isgenerator(val):
+ logger.error(
+ "Programming error: generator passed to new_transaction as "
+ "argument %s to function %s",
+ name,
+ func,
+ )
+ # also check variables referenced in func's closure
+ if inspect.isfunction(func):
+ f = cast(types.FunctionType, func)
+ if f.__closure__:
+ for i, cell in enumerate(f.__closure__):
+ if inspect.isgenerator(cell.cell_contents):
+ logger.error(
+ "Programming error: function %s references generator %s "
+ "via its closure",
+ f,
+ f.__code__.co_freevars[i],
+ )
+
start = monotonic_time()
txn_id = self._TXN_ID
@@ -1226,9 +1267,9 @@ class DatabasePool:
self,
table: str,
key_names: Collection[str],
- key_values: Collection[Iterable[Any]],
+ key_values: Collection[Collection[Any]],
value_names: Collection[str],
- value_values: Iterable[Iterable[Any]],
+ value_values: Collection[Collection[Any]],
desc: str,
) -> None:
"""
@@ -1920,7 +1961,7 @@ class DatabasePool:
self,
table: str,
column: str,
- iterable: Iterable[Any],
+ iterable: Collection[Any],
keyvalues: Dict[str, Any],
desc: str,
) -> int:
@@ -1931,7 +1972,8 @@ class DatabasePool:
Args:
table: string giving the table name
column: column name to test for inclusion against `iterable`
- iterable: list
+ iterable: list of values to match against `column`. NB cannot be a generator
+ as it may be evaluated multiple times.
keyvalues: dict of column names and values to select the rows with
desc: description of the transaction, for logging and metrics
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 02d534ae45..cbf9ec38f7 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -269,6 +269,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
"""
# Add user entries to the table, updating the presence_stream_id column if the user already
# exists in the table.
+ presence_stream_id = self._presence_id_gen.get_current_token()
await self.db_pool.simple_upsert_many(
table="users_to_send_full_presence_to",
key_names=("user_id",),
@@ -279,9 +280,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
# devices at different times, each device will receive full presence once - when
# the presence stream ID in their sync token is less than the one in the table
# for their user ID.
- value_values=(
- (self._presence_id_gen.get_current_token(),) for _ in user_ids
- ),
+ value_values=[(presence_stream_id,) for _ in user_ids],
desc="add_users_to_send_full_presence_to",
)
|