summary refs log tree commit diff
path: root/synapse/storage/database.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/database.py')
-rw-r--r--synapse/storage/database.py43
1 files changed, 27 insertions, 16 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 0264dea61d..12750d9b89 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -2030,29 +2030,40 @@ class DatabasePool:
         max_value: int,
         limit: int = 100000,
     ) -> Tuple[Dict[Any, int], int]:
-        # Fetch a mapping of room_id -> max stream position for "recent" rooms.
-        # It doesn't really matter how many we get, the StreamChangeCache will
-        # do the right thing to ensure it respects the max size of cache.
-        sql = (
-            "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
-            " WHERE %(stream)s > ? - %(limit)s"
-            " GROUP BY %(entity)s"
-        ) % {
-            "table": table,
-            "entity": entity_column,
-            "stream": stream_column,
-            "limit": limit,
-        }
+        """Gets roughly the last N changes in the given stream table as a
+        map from entity to the stream ID of the most recent change.
+
+        Also returns the minimum stream ID.
+        """
+
+        # This may return many rows for the same entity, but the `limit` is only
+        # a suggestion so we don't care that much.
+        #
+        # Note: Some stream tables can have multiple rows with the same stream
+        # ID. Instead of handling this with complicated SQL, we instead simply
+        # add one to the returned minimum stream ID to ensure correctness.
+        sql = f"""
+            SELECT {entity_column}, {stream_column}
+            FROM {table}
+            ORDER BY {stream_column} DESC
+            LIMIT ?
+        """
 
         txn = db_conn.cursor(txn_name="get_cache_dict")
-        txn.execute(sql, (int(max_value),))
+        txn.execute(sql, (limit,))
 
-        cache = {row[0]: int(row[1]) for row in txn}
+        # The rows come out in reverse stream ID order, so we want to keep the
+        # stream ID of the first row for each entity.
+        cache: Dict[Any, int] = {}
+        for row in txn:
+            cache.setdefault(row[0], int(row[1]))
 
         txn.close()
 
         if cache:
-            min_val = min(cache.values())
+            # We add one here as we don't know if we have all rows for the
+            # minimum stream ID.
+            min_val = min(cache.values()) + 1
         else:
             min_val = max_value