diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 3ef2bdd74b..12750d9b89 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -241,9 +241,17 @@ class LoggingTransaction:
self.exception_callbacks = exception_callbacks
def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any):
- """Call the given callback on the main twisted thread after the
- transaction has finished. Used to invalidate the caches on the
- correct thread.
+ """Call the given callback on the main twisted thread after the transaction has
+ finished.
+
+ Mostly used to invalidate the caches on the correct thread.
+
+ Note that transactions may be retried a few times if they encounter database
+ errors such as serialization failures. Callbacks given to `call_after`
+ will accumulate across transaction attempts and will _all_ be called once a
+ transaction attempt succeeds, regardless of whether previous transaction
+ attempts failed. Otherwise, if all transaction attempts fail, all
+ `call_on_exception` callbacks will be run instead.
"""
# if self.after_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
@@ -254,6 +262,15 @@ class LoggingTransaction:
def call_on_exception(
self, callback: Callable[..., object], *args: Any, **kwargs: Any
):
+ """Call the given callback on the main twisted thread after the transaction has
+ failed.
+
+ Note that transactions may be retried a few times if they encounter database
+ errors such as serialization failures. Callbacks given to `call_on_exception`
+ will accumulate across transaction attempts and will _all_ be called once the
+ final transaction attempt fails. No `call_on_exception` callbacks will be run
+ if any transaction attempt succeeds.
+ """
# if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
@@ -2013,29 +2030,40 @@ class DatabasePool:
max_value: int,
limit: int = 100000,
) -> Tuple[Dict[Any, int], int]:
- # Fetch a mapping of room_id -> max stream position for "recent" rooms.
- # It doesn't really matter how many we get, the StreamChangeCache will
- # do the right thing to ensure it respects the max size of cache.
- sql = (
- "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
- " WHERE %(stream)s > ? - %(limit)s"
- " GROUP BY %(entity)s"
- ) % {
- "table": table,
- "entity": entity_column,
- "stream": stream_column,
- "limit": limit,
- }
+ """Gets roughly the last N changes in the given stream table as a
+ map from entity to the stream ID of the most recent change.
+
+ Also returns the minimum stream ID.
+ """
+
+ # This may return many rows for the same entity, but the `limit` is only
+ # a suggestion so we don't care that much.
+ #
+ # Note: Some stream tables can have multiple rows with the same stream
+ # ID. Instead of handling this with complicated SQL, we instead simply
+ # add one to the returned minimum stream ID to ensure correctness.
+ sql = f"""
+ SELECT {entity_column}, {stream_column}
+ FROM {table}
+ ORDER BY {stream_column} DESC
+ LIMIT ?
+ """
txn = db_conn.cursor(txn_name="get_cache_dict")
- txn.execute(sql, (int(max_value),))
+ txn.execute(sql, (limit,))
- cache = {row[0]: int(row[1]) for row in txn}
+ # The rows come out in reverse stream ID order, so we want to keep the
+ # stream ID of the first row for each entity.
+ cache: Dict[Any, int] = {}
+ for row in txn:
+ cache.setdefault(row[0], int(row[1]))
txn.close()
if cache:
- min_val = min(cache.values())
+ # We add one here as we don't know if we have all rows for the
+ # minimum stream ID.
+ min_val = min(cache.values()) + 1
else:
min_val = max_value
|