diff options
Diffstat (limited to 'synapse/replication/tcp/client.py')
-rw-r--r-- | synapse/replication/tcp/client.py | 37 |
1 files changed, 33 insertions, 4 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 20cb8a654f..28826302f5 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -16,12 +16,17 @@ """ import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Tuple from twisted.internet.protocol import ReconnectingClientFactory -from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.api.constants import EventTypes from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol +from synapse.replication.tcp.streams.events import ( + EventsStream, + EventsStreamEventRow, + EventsStreamRow, +) if TYPE_CHECKING: from synapse.server import HomeServer @@ -83,8 +88,10 @@ class ReplicationDataHandler: to handle updates in additional ways. """ - def __init__(self, store: BaseSlavedStore): - self.store = store + def __init__(self, hs: "HomeServer"): + self.store = hs.get_datastore() + self.pusher_pool = hs.get_pusherpool() + self.notifier = hs.get_notifier() async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -102,6 +109,28 @@ class ReplicationDataHandler: """ self.store.process_replication_rows(stream_name, instance_name, token, rows) + if stream_name == EventsStream.NAME: + # We shouldn't get multiple rows per token for events stream, so + # we don't need to optimise this for multiple rows. + for row in rows: + if row.type != EventsStreamEventRow.TypeId: + continue + assert isinstance(row, EventsStreamRow) + + event = await self.store.get_event( + row.data.event_id, allow_rejected=True + ) + if event.rejected_reason: + continue + + extra_users = () # type: Tuple[str, ...] + if event.type == EventTypes.Member: + extra_users = (event.state_key,) + max_token = self.store.get_room_max_stream_ordering() + self.notifier.on_new_room_event(event, token, max_token, extra_users) + + await self.pusher_pool.on_new_notifications(token, token) + async def on_position(self, stream_name: str, instance_name: str, token: int): self.store.process_replication_rows(stream_name, instance_name, token, []) |