summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/data_stores/main/account_data.py62
1 files changed, 43 insertions, 19 deletions
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index 46b494b334..f9eef1b78e 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -16,6 +16,7 @@
 
 import abc
 import logging
+from typing import List, Tuple
 
 from canonicaljson import json
 
@@ -175,41 +176,64 @@ class AccountDataWorkerStore(SQLBaseStore):
             "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
         )
 
-    def get_all_updated_account_data(
-        self, last_global_id, last_room_id, current_id, limit
-    ):
-        """Get all the client account_data that has changed on the server
+    async def get_updated_global_account_data(
+        self, last_id: int, current_id: int, limit: int
+    ) -> List[Tuple[int, str, str]]:
+        """Get the global account_data that has changed, for the account_data stream
+
         Args:
-            last_global_id(int): The position to fetch from for top level data
-            last_room_id(int): The position to fetch from for per room data
-            current_id(int): The position to fetch up to.
+            last_id: the last stream_id from the previous batch.
+            current_id: the maximum stream_id to return up to
+            limit: the maximum number of rows to return
+
         Returns:
-            A deferred pair of lists of tuples of stream_id int, user_id string,
-            room_id string, and type string.
+            A list of tuples of stream_id int, user_id string,
+            and type string.
         """
-        if last_room_id == current_id and last_global_id == current_id:
-            return defer.succeed(([], []))
+        if last_id == current_id:
+            return []
 
-        def get_updated_account_data_txn(txn):
+        def get_updated_global_account_data_txn(txn):
             sql = (
                 "SELECT stream_id, user_id, account_data_type"
                 " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
-            txn.execute(sql, (last_global_id, current_id, limit))
-            global_results = txn.fetchall()
+            txn.execute(sql, (last_id, current_id, limit))
+            return txn.fetchall()
+
+        return await self.db.runInteraction(
+            "get_updated_global_account_data", get_updated_global_account_data_txn
+        )
+
+    async def get_updated_room_account_data(
+        self, last_id: int, current_id: int, limit: int
+    ) -> List[Tuple[int, str, str, str]]:
+        """Get the global account_data that has changed, for the account_data stream
 
+        Args:
+            last_id: the last stream_id from the previous batch.
+            current_id: the maximum stream_id to return up to
+            limit: the maximum number of rows to return
+
+        Returns:
+            A list of tuples of stream_id int, user_id string,
+            room_id string and type string.
+        """
+        if last_id == current_id:
+            return []
+
+        def get_updated_room_account_data_txn(txn):
             sql = (
                 "SELECT stream_id, user_id, room_id, account_data_type"
                 " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
-            txn.execute(sql, (last_room_id, current_id, limit))
-            room_results = txn.fetchall()
-            return global_results, room_results
+            txn.execute(sql, (last_id, current_id, limit))
+            return txn.fetchall()
 
-        return self.db.runInteraction(
-            "get_all_updated_account_data_txn", get_updated_account_data_txn
+        return await self.db.runInteraction(
+            "get_updated_room_account_data", get_updated_room_account_data_txn
         )
 
     def get_updated_account_data_for_user(self, user_id, stream_id):