summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/_base.py4
-rw-r--r--synapse/replication/http/devices.py2
-rw-r--r--synapse/replication/http/federation.py20
-rw-r--r--synapse/replication/http/login.py2
-rw-r--r--synapse/replication/http/membership.py16
-rw-r--r--synapse/replication/http/register.py4
-rw-r--r--synapse/replication/http/send_event.py2
-rw-r--r--synapse/replication/slave/storage/_base.py2
-rw-r--r--synapse/replication/slave/storage/account_data.py2
-rw-r--r--synapse/replication/slave/storage/client_ips.py2
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py2
-rw-r--r--synapse/replication/slave/storage/devices.py2
-rw-r--r--synapse/replication/slave/storage/events.py2
-rw-r--r--synapse/replication/slave/storage/filtering.py2
-rw-r--r--synapse/replication/slave/storage/groups.py2
-rw-r--r--synapse/replication/slave/storage/presence.py2
-rw-r--r--synapse/replication/slave/storage/pushers.py2
-rw-r--r--synapse/replication/slave/storage/receipts.py2
-rw-r--r--synapse/replication/slave/storage/room.py2
-rw-r--r--synapse/replication/tcp/client.py8
-rw-r--r--synapse/replication/tcp/handler.py2
-rw-r--r--synapse/replication/tcp/resource.py2
-rw-r--r--synapse/replication/tcp/streams/_base.py6
-rw-r--r--synapse/replication/tcp/streams/events.py4
24 files changed, 48 insertions, 48 deletions
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index ba16f22c91..b448da6710 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -33,7 +33,7 @@ from synapse.util.stringutils import random_string
 logger = logging.getLogger(__name__)
 
 
-class ReplicationEndpoint:
+class ReplicationEndpoint(metaclass=abc.ABCMeta):
     """Helper base class for defining new replication HTTP endpoints.
 
     This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..`
@@ -72,8 +72,6 @@ class ReplicationEndpoint:
             is received.
     """
 
-    __metaclass__ = abc.ABCMeta
-
     NAME = abc.abstractproperty()  # type: str  # type: ignore
     PATH_ARGS = abc.abstractproperty()  # type: Tuple[str, ...]  # type: ignore
     METHOD = "POST"
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index 20f3ba76c0..807b85d2e1 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -53,7 +53,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
     CACHE = False
 
     def __init__(self, hs):
-        super(ReplicationUserDevicesResyncRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.device_list_updater = hs.get_device_handler().device_list_updater
         self.store = hs.get_datastore()
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 6b56315148..5393b9a9e7 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -57,7 +57,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
     PATH_ARGS = ()
 
     def __init__(self, hs):
-        super(ReplicationFederationSendEventsRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.store = hs.get_datastore()
         self.storage = hs.get_storage()
@@ -65,10 +65,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         self.federation_handler = hs.get_handlers().federation_handler
 
     @staticmethod
-    async def _serialize_payload(store, event_and_contexts, backfilled):
+    async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
         """
         Args:
             store
+            room_id (str)
             event_and_contexts (list[tuple[FrozenEvent, EventContext]])
             backfilled (bool): Whether or not the events are the result of
                 backfilling
@@ -88,7 +89,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
                 }
             )
 
-        payload = {"events": event_payloads, "backfilled": backfilled}
+        payload = {
+            "events": event_payloads,
+            "backfilled": backfilled,
+            "room_id": room_id,
+        }
 
         return payload
 
@@ -96,6 +101,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         with Measure(self.clock, "repl_fed_send_events_parse"):
             content = parse_json_object_from_request(request)
 
+            room_id = content["room_id"]
             backfilled = content["backfilled"]
 
             event_payloads = content["events"]
@@ -120,7 +126,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         logger.info("Got %d events from federation", len(event_and_contexts))
 
         max_stream_id = await self.federation_handler.persist_events_and_notify(
-            event_and_contexts, backfilled
+            room_id, event_and_contexts, backfilled
         )
 
         return 200, {"max_stream_id": max_stream_id}
@@ -144,7 +150,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
     PATH_ARGS = ("edu_type",)
 
     def __init__(self, hs):
-        super(ReplicationFederationSendEduRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
@@ -187,7 +193,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
     CACHE = False
 
     def __init__(self, hs):
-        super(ReplicationGetQueryRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
@@ -230,7 +236,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
     PATH_ARGS = ("room_id",)
 
     def __init__(self, hs):
-        super(ReplicationCleanRoomRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.store = hs.get_datastore()
 
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index fb326bb869..4c81e2d784 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -32,7 +32,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
     PATH_ARGS = ("user_id",)
 
     def __init__(self, hs):
-        super(RegisterDeviceReplicationServlet, self).__init__(hs)
+        super().__init__(hs)
         self.registration_handler = hs.get_registration_handler()
 
     @staticmethod
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 741329ab5f..30680baee8 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Optional
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 from synapse.types import JsonDict, Requester, UserID
-from synapse.util.distributor import user_joined_room, user_left_room
+from synapse.util.distributor import user_left_room
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -45,7 +45,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
     PATH_ARGS = ("room_id", "user_id")
 
     def __init__(self, hs):
-        super(ReplicationRemoteJoinRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.federation_handler = hs.get_handlers().federation_handler
         self.store = hs.get_datastore()
@@ -107,7 +107,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
     PATH_ARGS = ("invite_event_id",)
 
     def __init__(self, hs: "HomeServer"):
-        super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
@@ -168,7 +168,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
     CACHE = False  # No point caching as should return instantly.
 
     def __init__(self, hs):
-        super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.registeration_handler = hs.get_registration_handler()
         self.store = hs.get_datastore()
@@ -181,9 +181,9 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
         Args:
             room_id (str)
             user_id (str)
-            change (str): Either "joined" or "left"
+            change (str): "left"
         """
-        assert change in ("joined", "left")
+        assert change == "left"
 
         return {}
 
@@ -192,9 +192,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
 
         user = UserID.from_string(user_id)
 
-        if change == "joined":
-            user_joined_room(self.distributor, user, room_id)
-        elif change == "left":
+        if change == "left":
             user_left_room(self.distributor, user, room_id)
         else:
             raise Exception("Unrecognized change: %r", change)
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index a02b27474d..7b12ec9060 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -29,7 +29,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
     PATH_ARGS = ("user_id",)
 
     def __init__(self, hs):
-        super(ReplicationRegisterServlet, self).__init__(hs)
+        super().__init__(hs)
         self.store = hs.get_datastore()
         self.registration_handler = hs.get_registration_handler()
 
@@ -104,7 +104,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
     PATH_ARGS = ("user_id",)
 
     def __init__(self, hs):
-        super(ReplicationPostRegisterActionsServlet, self).__init__(hs)
+        super().__init__(hs)
         self.store = hs.get_datastore()
         self.registration_handler = hs.get_registration_handler()
 
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index f13d452426..9a3a694d5d 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -52,7 +52,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
     PATH_ARGS = ("event_id",)
 
     def __init__(self, hs):
-        super(ReplicationSendEventRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.event_creation_handler = hs.get_event_creation_handler()
         self.store = hs.get_datastore()
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 60f2e1245f..d25fa49e1a 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
 
 class BaseSlavedStore(CacheInvalidationWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(BaseSlavedStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
             self._cache_id_gen = MultiWriterIdGenerator(
                 db_conn,
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index bb66ba9b80..4268565fc8 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -34,7 +34,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
             ],
         )
 
-        super(SlavedAccountDataStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     def get_max_account_data_stream_id(self):
         return self._account_data_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index a6fdedde63..1f8dafe7ea 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -22,7 +22,7 @@ from ._base import BaseSlavedStore
 
 class SlavedClientIpStore(BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.client_ip_last_seen = Cache(
             name="client_ip_last_seen", keylen=4, max_entries=50000
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 533d927701..5b045bed02 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -24,7 +24,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
         self._device_inbox_id_gen = SlavedIdTracker(
             db_conn, "device_inbox", "stream_id"
         )
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 3b788c9625..e0d86240dd 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -24,7 +24,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedDeviceStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.hs = hs
 
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index da1cc836cf..fbffe6d85c 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -56,7 +56,7 @@ class SlavedEventStore(
     BaseSlavedStore,
 ):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedEventStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         events_max = self._stream_id_gen.get_current_token()
         curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index 2562b6fc38..6a23252861 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -21,7 +21,7 @@ from ._base import BaseSlavedStore
 
 class SlavedFilteringStore(BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedFilteringStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     # Filters are immutable so this cache doesn't need to be expired
     get_user_filter = FilteringStore.__dict__["get_user_filter"]
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 567b4a5cc1..30955bcbfe 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -23,7 +23,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedGroupServerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.hs = hs
 
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 025f6f6be8..55620c03d8 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -25,7 +25,7 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 class SlavedPresenceStore(BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
         self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
 
         self._presence_on_startup = self._get_active_presence(db_conn)  # type: ignore
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 9da218bfe8..c418730ba8 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -24,7 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedPusherStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
         self._pushers_id_gen = SlavedIdTracker(
             db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
         )
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 5c2986e050..6195917376 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -30,7 +30,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
             db_conn, "receipts_linearized", "stream_id"
         )
 
-        super(SlavedReceiptsStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 80ae803ad9..109ac6bea1 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -23,7 +23,7 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 class RoomStore(RoomWorkerStore, BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RoomStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
         self._public_room_id_gen = SlavedIdTracker(
             db_conn, "public_room_list_stream", "stream_id"
         )
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index d6ecf5b327..e82b9e386f 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -29,6 +29,7 @@ from synapse.replication.tcp.streams.events import (
     EventsStreamEventRow,
     EventsStreamRow,
 )
+from synapse.types import UserID
 from synapse.util.async_helpers import timeout_deferred
 from synapse.util.metrics import Measure
 
@@ -98,7 +99,6 @@ class ReplicationDataHandler:
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
-        self.pusher_pool = hs.get_pusherpool()
         self.notifier = hs.get_notifier()
         self._reactor = hs.get_reactor()
         self._clock = hs.get_clock()
@@ -148,14 +148,12 @@ class ReplicationDataHandler:
                 if event.rejected_reason:
                     continue
 
-                extra_users = ()  # type: Tuple[str, ...]
+                extra_users = ()  # type: Tuple[UserID, ...]
                 if event.type == EventTypes.Member:
-                    extra_users = (event.state_key,)
+                    extra_users = (UserID.from_string(event.state_key),)
                 max_token = self.store.get_room_max_stream_ordering()
                 self.notifier.on_new_room_event(event, token, max_token, extra_users)
 
-            await self.pusher_pool.on_new_notifications(token, token)
-
         # Notify any waiting deferreds. The list is ordered by position so we
         # just iterate through the list until we reach a position that is
         # greater than the received row position.
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 1c303f3a46..b323841f73 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -109,7 +109,7 @@ class ReplicationCommandHandler:
             if isinstance(stream, (EventsStream, BackfillStream)):
                 # Only add EventStream and BackfillStream as a source on the
                 # instance in charge of event persistence.
-                if hs.config.worker.writers.events == hs.get_instance_name():
+                if hs.get_instance_name() in hs.config.worker.writers.events:
                     self._streams_to_replicate.append(stream)
 
                 continue
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 04d894fb3d..687984e7a8 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -93,7 +93,7 @@ class ReplicationStreamer:
         """
         if not self.command_handler.connected():
             # Don't bother if nothing is listening. We still need to advance
-            # the stream tokens otherwise they'll fall beihind forever
+            # the stream tokens otherwise they'll fall behind forever
             for stream in self.streams:
                 stream.discard_updates_and_advance()
             return
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 682d47f402..54dccd15a6 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -345,7 +345,7 @@ class PushRulesStream(Stream):
     def __init__(self, hs):
         self.store = hs.get_datastore()
 
-        super(PushRulesStream, self).__init__(
+        super().__init__(
             hs.get_instance_name(),
             self._current_token,
             self.store.get_all_push_rule_updates,
@@ -383,7 +383,7 @@ class CachesStream(Stream):
     the cache on the workers
     """
 
-    @attr.s
+    @attr.s(slots=True)
     class CachesStreamRow:
         """Stream to inform workers they should invalidate their cache.
 
@@ -441,7 +441,7 @@ class DeviceListsStream(Stream):
     told about a device update.
     """
 
-    @attr.s
+    @attr.s(slots=True)
     class DeviceListsStreamRow:
         entity = attr.ib(type=str)
 
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index f929fc3954..ccc7ca30d8 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,7 +19,7 @@ from typing import List, Tuple, Type
 
 import attr
 
-from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
+from ._base import Stream, StreamUpdateResult, Token
 
 """Handling of the 'events' replication stream
 
@@ -117,7 +117,7 @@ class EventsStream(Stream):
         self._store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(self._store.get_current_events_token),
+            self._store._stream_id_gen.get_current_token_for_writer,
             self._update_function,
         )