summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7636.misc1
-rw-r--r--synapse/handlers/presence.py29
-rw-r--r--synapse/handlers/typing.py40
-rw-r--r--synapse/push/pusherpool.py4
-rw-r--r--synapse/replication/tcp/streams/_base.py29
-rw-r--r--synapse/storage/data_stores/main/events_worker.py41
-rw-r--r--synapse/storage/data_stores/main/presence.py41
-rw-r--r--synapse/storage/data_stores/main/push_rule.py56
-rw-r--r--synapse/storage/data_stores/main/receipts.py82
9 files changed, 251 insertions, 72 deletions
diff --git a/changelog.d/7636.misc b/changelog.d/7636.misc
new file mode 100644
index 0000000000..f93149502e
--- /dev/null
+++ b/changelog.d/7636.misc
@@ -0,0 +1 @@
+Refactor getting replication updates from database.
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 2e8914be14..d2f25ae12a 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -25,7 +25,7 @@ The methods that define policy are:
 import abc
 import logging
 from contextlib import contextmanager
-from typing import Dict, Iterable, List, Set
+from typing import Dict, Iterable, List, Set, Tuple
 
 from prometheus_client import Counter
 from typing_extensions import ContextManager
@@ -773,7 +773,9 @@ class PresenceHandler(BasePresenceHandler):
 
         return False
 
-    async def get_all_presence_updates(self, last_id, current_id, limit):
+    async def get_all_presence_updates(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, list]], int, bool]:
         """
         Gets a list of presence update rows from between the given stream ids.
         Each row has:
@@ -785,10 +787,31 @@ class PresenceHandler(BasePresenceHandler):
         - last_user_sync_ts(int)
         - status_msg(int)
         - currently_active(int)
+
+        Args:
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
+        Returns:
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
         """
+
         # TODO(markjh): replicate the unpersisted changes.
         # This could use the in-memory stores for recent changes.
-        rows = await self.store.get_all_presence_updates(last_id, current_id, limit)
+        rows = await self.store.get_all_presence_updates(
+            instance_name, last_id, current_id, limit
+        )
         return rows
 
     def notify_new_event(self):
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index c7bc14c623..4330abb9f7 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,7 +15,7 @@
 
 import logging
 from collections import namedtuple
-from typing import List
+from typing import List, Tuple
 
 from twisted.internet import defer
 
@@ -259,14 +259,31 @@ class TypingHandler(object):
         )
 
     async def get_all_typing_updates(
-        self, last_id: int, current_id: int, limit: int
-    ) -> List[dict]:
-        """Get up to `limit` typing updates between the given tokens, earliest
-        updates first.
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, list]], int, bool]:
+        """Get updates for typing replication stream.
+
+        Args:
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
+        Returns:
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
         """
 
         if last_id == current_id:
-            return []
+            return [], current_id, False
 
         changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
             last_id
@@ -280,9 +297,16 @@ class TypingHandler(object):
             serial = self._room_serials[room_id]
             if last_id < serial <= current_id:
                 typing = self._room_typing[room_id]
-                rows.append((serial, room_id, list(typing)))
+                rows.append((serial, [room_id, list(typing)]))
         rows.sort()
-        return rows[:limit]
+
+        limited = False
+        if len(rows) > limit:
+            rows = rows[:limit]
+            current_id = rows[-1][0]
+            limited = True
+
+        return rows, current_id, limited
 
     def get_current_token(self):
         return self._latest_room_serial
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 88d203aa44..f6a5458681 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -215,11 +215,9 @@ class PusherPool:
         try:
             # Need to subtract 1 from the minimum because the lower bound here
             # is not inclusive
-            updated_receipts = yield self.store.get_all_updated_receipts(
+            users_affected = yield self.store.get_users_sent_receipts_between(
                 min_stream_id - 1, max_stream_id
             )
-            # This returns a tuple, user_id is at index 3
-            users_affected = {r[3] for r in updated_receipts}
 
             for u in users_affected:
                 if u in self.pushers:
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 4acefc8a96..f196eff072 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -264,7 +264,7 @@ class BackfillStream(Stream):
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_current_backfill_token),
-            db_query_to_update_function(store.get_all_new_backfill_event_rows),
+            store.get_all_new_backfill_event_rows,
         )
 
 
@@ -291,9 +291,7 @@ class PresenceStream(Stream):
         if hs.config.worker_app is None:
             # on the master, query the presence handler
             presence_handler = hs.get_presence_handler()
-            update_function = db_query_to_update_function(
-                presence_handler.get_all_presence_updates
-            )
+            update_function = presence_handler.get_all_presence_updates
         else:
             # Query master process
             update_function = make_http_update_function(hs, self.NAME)
@@ -318,9 +316,7 @@ class TypingStream(Stream):
 
         if hs.config.worker_app is None:
             # on the master, query the typing handler
-            update_function = db_query_to_update_function(
-                typing_handler.get_all_typing_updates
-            )
+            update_function = typing_handler.get_all_typing_updates
         else:
             # Query master process
             update_function = make_http_update_function(hs, self.NAME)
@@ -352,7 +348,7 @@ class ReceiptsStream(Stream):
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_max_receipt_stream_id),
-            db_query_to_update_function(store.get_all_updated_receipts),
+            store.get_all_updated_receipts,
         )
 
 
@@ -367,26 +363,17 @@ class PushRulesStream(Stream):
 
     def __init__(self, hs):
         self.store = hs.get_datastore()
+
         super(PushRulesStream, self).__init__(
-            hs.get_instance_name(), self._current_token, self._update_function
+            hs.get_instance_name(),
+            self._current_token,
+            self.store.get_all_push_rule_updates,
         )
 
     def _current_token(self, instance_name: str) -> int:
         push_rules_token, _ = self.store.get_push_rules_stream_token()
         return push_rules_token
 
-    async def _update_function(
-        self, instance_name: str, from_token: Token, to_token: Token, limit: int
-    ):
-        rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
-
-        limited = False
-        if len(rows) == limit:
-            to_token = rows[-1][0]
-            limited = True
-
-        return [(row[0], (row[2],)) for row in rows], to_token, limited
-
 
 class PushersStream(Stream):
     """A user has added/changed/removed a pusher
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 213d69100a..a48c7a96ca 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -1077,9 +1077,32 @@ class EventsWorkerStore(SQLBaseStore):
             "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
         )
 
-    def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
+    async def get_all_new_backfill_event_rows(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, list]], int, bool]:
+        """Get updates for backfill replication stream, including all new
+        backfilled events and events that have gone from being outliers to not.
+
+        Args:
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
+        Returns:
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
+        """
         if last_id == current_id:
-            return defer.succeed([])
+            return [], current_id, False
 
         def get_all_new_backfill_event_rows(txn):
             sql = (
@@ -1094,10 +1117,12 @@ class EventsWorkerStore(SQLBaseStore):
                 " LIMIT ?"
             )
             txn.execute(sql, (-last_id, -current_id, limit))
-            new_event_updates = txn.fetchall()
+            new_event_updates = [(row[0], row[1:]) for row in txn]
 
+            limited = False
             if len(new_event_updates) == limit:
                 upper_bound = new_event_updates[-1][0]
+                limited = True
             else:
                 upper_bound = current_id
 
@@ -1114,11 +1139,15 @@ class EventsWorkerStore(SQLBaseStore):
                 " ORDER BY event_stream_ordering DESC"
             )
             txn.execute(sql, (-last_id, -upper_bound))
-            new_event_updates.extend(txn.fetchall())
+            new_event_updates.extend((row[0], row[1:]) for row in txn)
 
-            return new_event_updates
+            if len(new_event_updates) >= limit:
+                upper_bound = new_event_updates[-1][0]
+                limited = True
 
-        return self.db.runInteraction(
+            return new_event_updates, upper_bound, limited
+
+        return await self.db.runInteraction(
             "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
         )
 
diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py
index dab31e0c2d..7574612619 100644
--- a/synapse/storage/data_stores/main/presence.py
+++ b/synapse/storage/data_stores/main/presence.py
@@ -13,6 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import List, Tuple
+
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
@@ -73,9 +75,32 @@ class PresenceStore(SQLBaseStore):
             )
             txn.execute(sql + clause, [stream_id] + list(args))
 
-    def get_all_presence_updates(self, last_id, current_id, limit):
+    async def get_all_presence_updates(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, list]], int, bool]:
+        """Get updates for presence replication stream.
+
+        Args:
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
+        Returns:
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
+        """
+
         if last_id == current_id:
-            return defer.succeed([])
+            return [], current_id, False
 
         def get_all_presence_updates_txn(txn):
             sql = """
@@ -89,9 +114,17 @@ class PresenceStore(SQLBaseStore):
                 LIMIT ?
             """
             txn.execute(sql, (last_id, current_id, limit))
-            return txn.fetchall()
+            updates = [(row[0], row[1:]) for row in txn]
+
+            upper_bound = current_id
+            limited = False
+            if len(updates) >= limit:
+                upper_bound = updates[-1][0]
+                limited = True
+
+            return updates, upper_bound, limited
 
-        return self.db.runInteraction(
+        return await self.db.runInteraction(
             "get_all_presence_updates", get_all_presence_updates_txn
         )
 
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index ef8f40959f..f6e78ca590 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -16,7 +16,7 @@
 
 import abc
 import logging
-from typing import Union
+from typing import List, Tuple, Union
 
 from canonicaljson import json
 
@@ -348,23 +348,53 @@ class PushRulesWorkerStore(
             results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
         return results
 
-    def get_all_push_rule_updates(self, last_id, current_id, limit):
-        """Get all the push rules changes that have happend on the server"""
+    async def get_all_push_rule_updates(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+        """Get updates for push_rules replication stream.
+
+        Args:
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
+        Returns:
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
+        """
+
         if last_id == current_id:
-            return defer.succeed([])
+            return [], current_id, False
 
         def get_all_push_rule_updates_txn(txn):
-            sql = (
-                "SELECT stream_id, event_stream_ordering, user_id, rule_id,"
-                " op, priority_class, priority, conditions, actions"
-                " FROM push_rules_stream"
-                " WHERE ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC LIMIT ?"
-            )
+            sql = """
+                SELECT stream_id, user_id
+                FROM push_rules_stream
+                WHERE ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC
+                LIMIT ?
+            """
             txn.execute(sql, (last_id, current_id, limit))
-            return txn.fetchall()
+            updates = [(stream_id, (user_id,)) for stream_id, user_id in txn]
+
+            limited = False
+            upper_bound = current_id
+            if len(updates) == limit:
+                limited = True
+                upper_bound = updates[-1][0]
+
+            return updates, upper_bound, limited
 
-        return self.db.runInteraction(
+        return await self.db.runInteraction(
             "get_all_push_rule_updates", get_all_push_rule_updates_txn
         )
 
diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py
index d4a7163049..8f5505bd67 100644
--- a/synapse/storage/data_stores/main/receipts.py
+++ b/synapse/storage/data_stores/main/receipts.py
@@ -16,6 +16,7 @@
 
 import abc
 import logging
+from typing import List, Tuple
 
 from canonicaljson import json
 
@@ -267,26 +268,79 @@ class ReceiptsWorkerStore(SQLBaseStore):
         }
         return results
 
-    def get_all_updated_receipts(self, last_id, current_id, limit=None):
+    def get_users_sent_receipts_between(self, last_id: int, current_id: int):
+        """Get all users who sent receipts between `last_id` exclusive and
+        `current_id` inclusive.
+
+        Returns:
+            Deferred[List[str]]
+        """
+
         if last_id == current_id:
             return defer.succeed([])
 
-        def get_all_updated_receipts_txn(txn):
-            sql = (
-                "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
-                " FROM receipts_linearized"
-                " WHERE ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC"
-            )
-            args = [last_id, current_id]
-            if limit is not None:
-                sql += " LIMIT ?"
-                args.append(limit)
-            txn.execute(sql, args)
+        def _get_users_sent_receipts_between_txn(txn):
+            sql = """
+                SELECT DISTINCT user_id FROM receipts_linearized
+                WHERE ? < stream_id AND stream_id <= ?
+            """
+            txn.execute(sql, (last_id, current_id))
 
-            return [r[0:5] + (json.loads(r[5]),) for r in txn]
+            return [r[0] for r in txn]
 
         return self.db.runInteraction(
+            "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn
+        )
+
+    async def get_all_updated_receipts(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, list]], int, bool]:
+        """Get updates for receipts replication stream.
+
+        Args:
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
+        Returns:
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
+        """
+
+        if last_id == current_id:
+            return [], current_id, False
+
+        def get_all_updated_receipts_txn(txn):
+            sql = """
+                SELECT stream_id, room_id, receipt_type, user_id, event_id, data
+                FROM receipts_linearized
+                WHERE ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC
+                LIMIT ?
+            """
+            txn.execute(sql, (last_id, current_id, limit))
+
+            updates = [(r[0], r[1:5] + (json.loads(r[5]),)) for r in txn]
+
+            limited = False
+            upper_bound = current_id
+
+            if len(updates) == limit:
+                limited = True
+                upper_bound = updates[-1][0]
+
+            return updates, upper_bound, limited
+
+        return await self.db.runInteraction(
             "get_all_updated_receipts", get_all_updated_receipts_txn
         )