summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/slave/storage/events.py89
-rw-r--r--synapse/replication/slave/storage/push_rule.py8
-rw-r--r--synapse/replication/tcp/streams/_base.py68
3 files changed, 56 insertions, 109 deletions
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index b313720a4b..1a1a50a24f 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -15,11 +15,6 @@
 # limitations under the License.
 import logging
 
-from synapse.api.constants import EventTypes
-from synapse.replication.tcp.streams.events import (
-    EventsStreamCurrentStateRow,
-    EventsStreamEventRow,
-)
 from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore
 from synapse.storage.data_stores.main.event_push_actions import (
     EventPushActionsWorkerStore,
@@ -35,7 +30,6 @@ from synapse.storage.database import Database
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
 
 logger = logging.getLogger(__name__)
 
@@ -62,11 +56,6 @@ class SlavedEventStore(
     BaseSlavedStore,
 ):
     def __init__(self, database: Database, db_conn, hs):
-        self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
-        self._backfill_id_gen = SlavedIdTracker(
-            db_conn, "events", "stream_ordering", step=-1
-        )
-
         super(SlavedEventStore, self).__init__(database, db_conn, hs)
 
         events_max = self._stream_id_gen.get_current_token()
@@ -92,81 +81,3 @@ class SlavedEventStore(
 
     def get_room_min_stream_ordering(self):
         return self._backfill_id_gen.get_current_token()
-
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "events":
-            self._stream_id_gen.advance(token)
-            for row in rows:
-                self._process_event_stream_row(token, row)
-        elif stream_name == "backfill":
-            self._backfill_id_gen.advance(-token)
-            for row in rows:
-                self.invalidate_caches_for_event(
-                    -token,
-                    row.event_id,
-                    row.room_id,
-                    row.type,
-                    row.state_key,
-                    row.redacts,
-                    row.relates_to,
-                    backfilled=True,
-                )
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
-
-    def _process_event_stream_row(self, token, row):
-        data = row.data
-
-        if row.type == EventsStreamEventRow.TypeId:
-            self.invalidate_caches_for_event(
-                token,
-                data.event_id,
-                data.room_id,
-                data.type,
-                data.state_key,
-                data.redacts,
-                data.relates_to,
-                backfilled=False,
-            )
-        elif row.type == EventsStreamCurrentStateRow.TypeId:
-            self._curr_state_delta_stream_cache.entity_has_changed(
-                row.data.room_id, token
-            )
-
-            if data.type == EventTypes.Member:
-                self.get_rooms_for_user_with_stream_ordering.invalidate(
-                    (data.state_key,)
-                )
-        else:
-            raise Exception("Unknown events stream row type %s" % (row.type,))
-
-    def invalidate_caches_for_event(
-        self,
-        stream_ordering,
-        event_id,
-        room_id,
-        etype,
-        state_key,
-        redacts,
-        relates_to,
-        backfilled,
-    ):
-        self._invalidate_get_event_cache(event_id)
-
-        self.get_latest_event_ids_in_room.invalidate((room_id,))
-
-        self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
-
-        if not backfilled:
-            self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
-
-        if redacts:
-            self._invalidate_get_event_cache(redacts)
-
-        if etype == EventTypes.Member:
-            self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
-            self.get_invited_rooms_for_local_user.invalidate((state_key,))
-
-        if relates_to:
-            self.get_relations_for_event.invalidate_many((relates_to,))
-            self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
-            self.get_applicable_edit.invalidate((relates_to,))
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 5d5816d7eb..6adb19463a 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -15,19 +15,11 @@
 # limitations under the License.
 
 from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
-from synapse.storage.database import Database
 
-from ._slaved_id_tracker import SlavedIdTracker
 from .events import SlavedEventStore
 
 
 class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
-    def __init__(self, database: Database, db_conn, hs):
-        self._push_rules_stream_id_gen = SlavedIdTracker(
-            db_conn, "push_rules_stream", "stream_id"
-        )
-        super(SlavedPushRuleStore, self).__init__(database, db_conn, hs)
-
     def get_push_rules_stream_token(self):
         return (
             self._push_rules_stream_id_gen.get_current_token(),
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b48a6a3e91..d42aaff055 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,14 +14,27 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import heapq
 import logging
 from collections import namedtuple
-from typing import Any, Awaitable, Callable, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    List,
+    Optional,
+    Tuple,
+    TypeVar,
+)
 
 import attr
 
 from synapse.replication.http.streams import ReplicationGetStreamUpdates
 
+if TYPE_CHECKING:
+    import synapse.server
+
 logger = logging.getLogger(__name__)
 
 # the number of rows to request from an update_function.
@@ -37,7 +50,7 @@ Token = int
 # parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
 # just a row from a database query, though this is dependent on the stream in question.
 #
-StreamRow = Tuple
+StreamRow = TypeVar("StreamRow", bound=Tuple)
 
 # The type returned by the update_function of a stream, as well as get_updates(),
 # get_updates_since, etc.
@@ -533,32 +546,63 @@ class AccountDataStream(Stream):
     """
 
     AccountDataStreamRow = namedtuple(
-        "AccountDataStream", ("user_id", "room_id", "data_type")  # str  # str  # str
+        "AccountDataStream",
+        ("user_id", "room_id", "data_type"),  # str  # Optional[str]  # str
     )
 
     NAME = "account_data"
     ROW_TYPE = AccountDataStreamRow
 
-    def __init__(self, hs):
+    def __init__(self, hs: "synapse.server.HomeServer"):
         self.store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(self.store.get_max_account_data_stream_id),
-            db_query_to_update_function(self._update_function),
+            self._update_function,
+        )
+
+    async def _update_function(
+        self, instance_name: str, from_token: int, to_token: int, limit: int
+    ) -> StreamUpdateResult:
+        limited = False
+        global_results = await self.store.get_updated_global_account_data(
+            from_token, to_token, limit
         )
 
-    async def _update_function(self, from_token, to_token, limit):
-        global_results, room_results = await self.store.get_all_updated_account_data(
-            from_token, from_token, to_token, limit
+        # if the global results hit the limit, we'll need to limit the room results to
+        # the same stream token.
+        if len(global_results) >= limit:
+            to_token = global_results[-1][0]
+            limited = True
+
+        room_results = await self.store.get_updated_room_account_data(
+            from_token, to_token, limit
         )
 
-        results = list(room_results)
-        results.extend(
-            (stream_id, user_id, None, account_data_type)
+        # likewise, if the room results hit the limit, limit the global results to
+        # the same stream token.
+        if len(room_results) >= limit:
+            to_token = room_results[-1][0]
+            limited = True
+
+        # convert the global results to the right format, and limit them to the to_token
+        # at the same time
+        global_rows = (
+            (stream_id, (user_id, None, account_data_type))
             for stream_id, user_id, account_data_type in global_results
+            if stream_id <= to_token
+        )
+
+        # we know that the room_results are already limited to `to_token` so no need
+        # for a check on `stream_id` here.
+        room_rows = (
+            (stream_id, (user_id, room_id, account_data_type))
+            for stream_id, user_id, room_id, account_data_type in room_results
         )
 
-        return results
+        # we need to return a sorted list, so merge them together.
+        updates = list(heapq.merge(room_rows, global_rows))
+        return updates, to_token, limited
 
 
 class GroupServerStream(Stream):