diff options
Diffstat (limited to 'synapse/storage/database.py')
-rw-r--r-- | synapse/storage/database.py | 66 |
1 files changed, 47 insertions, 19 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 3ef2bdd74b..12750d9b89 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -241,9 +241,17 @@ class LoggingTransaction: self.exception_callbacks = exception_callbacks def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any): - """Call the given callback on the main twisted thread after the - transaction has finished. Used to invalidate the caches on the - correct thread. + """Call the given callback on the main twisted thread after the transaction has + finished. + + Mostly used to invalidate the caches on the correct thread. + + Note that transactions may be retried a few times if they encounter database + errors such as serialization failures. Callbacks given to `call_after` + will accumulate across transaction attempts and will _all_ be called once a + transaction attempt succeeds, regardless of whether previous transaction + attempts failed. Otherwise, if all transaction attempts fail, all + `call_on_exception` callbacks will be run instead. """ # if self.after_callbacks is None, that means that whatever constructed the # LoggingTransaction isn't expecting there to be any callbacks; assert that @@ -254,6 +262,15 @@ class LoggingTransaction: def call_on_exception( self, callback: Callable[..., object], *args: Any, **kwargs: Any ): + """Call the given callback on the main twisted thread after the transaction has + failed. + + Note that transactions may be retried a few times if they encounter database + errors such as serialization failures. Callbacks given to `call_on_exception` + will accumulate across transaction attempts and will _all_ be called once the + final transaction attempt fails. No `call_on_exception` callbacks will be run + if any transaction attempt succeeds. + """ # if self.exception_callbacks is None, that means that whatever constructed the # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. @@ -2013,29 +2030,40 @@ class DatabasePool: max_value: int, limit: int = 100000, ) -> Tuple[Dict[Any, int], int]: - # Fetch a mapping of room_id -> max stream position for "recent" rooms. - # It doesn't really matter how many we get, the StreamChangeCache will - # do the right thing to ensure it respects the max size of cache. - sql = ( - "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" - " WHERE %(stream)s > ? - %(limit)s" - " GROUP BY %(entity)s" - ) % { - "table": table, - "entity": entity_column, - "stream": stream_column, - "limit": limit, - } + """Gets roughly the last N changes in the given stream table as a + map from entity to the stream ID of the most recent change. + + Also returns the minimum stream ID. + """ + + # This may return many rows for the same entity, but the `limit` is only + # a suggestion so we don't care that much. + # + # Note: Some stream tables can have multiple rows with the same stream + # ID. Instead of handling this with complicated SQL, we instead simply + # add one to the returned minimum stream ID to ensure correctness. + sql = f""" + SELECT {entity_column}, {stream_column} + FROM {table} + ORDER BY {stream_column} DESC + LIMIT ? + """ txn = db_conn.cursor(txn_name="get_cache_dict") - txn.execute(sql, (int(max_value),)) + txn.execute(sql, (limit,)) - cache = {row[0]: int(row[1]) for row in txn} + # The rows come out in reverse stream ID order, so we want to keep the + # stream ID of the first row for each entity. + cache: Dict[Any, int] = {} + for row in txn: + cache.setdefault(row[0], int(row[1])) txn.close() if cache: - min_val = min(cache.values()) + # We add one here as we don't know if we have all rows for the + # minimum stream ID. + min_val = min(cache.values()) + 1 else: min_val = max_value |