diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 6460eb9952..969c73c1e7 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -24,9 +24,11 @@ The methods that define policy are:
import abc
import contextlib
import logging
+from bisect import bisect
from contextlib import contextmanager
from typing import (
TYPE_CHECKING,
+ Collection,
Dict,
FrozenSet,
Iterable,
@@ -53,10 +55,11 @@ from synapse.replication.http.presence import (
ReplicationBumpPresenceActiveTime,
ReplicationPresenceSetState,
)
+from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.replication.tcp.commands import ClearUserSyncsCommand
-from synapse.state import StateHandler
+from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
from synapse.storage.databases.main import DataStore
-from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
+from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.metrics import Measure
@@ -118,19 +121,21 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
class BasePresenceHandler(abc.ABC):
- """Parts of the PresenceHandler that are shared between workers and master"""
+ """Parts of the PresenceHandler that are shared between workers and presence
+ writer"""
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.presence_router = hs.get_presence_router()
self.state = hs.get_state_handler()
+ self.is_mine_id = hs.is_mine_id
self._federation = None
- if hs.should_send_federation() or not hs.config.worker_app:
+ if hs.should_send_federation():
self._federation = hs.get_federation_sender()
- self._send_federation = hs.should_send_federation()
+ self._federation_queue = PresenceFederationQueue(hs, self)
self._busy_presence_enabled = hs.config.experimental.msc3026_enabled
@@ -253,28 +258,38 @@ class BasePresenceHandler(abc.ABC):
"""
pass
- async def process_replication_rows(self, token, rows):
- """Process presence stream rows received over replication."""
- pass
+ async def process_replication_rows(
+ self, stream_name: str, instance_name: str, token: int, rows: list
+ ):
+ """Process streams received over replication."""
+ await self._federation_queue.process_replication_rows(
+ stream_name, instance_name, token, rows
+ )
+
+ def get_federation_queue(self) -> "PresenceFederationQueue":
+ """Get the presence federation queue."""
+ return self._federation_queue
async def maybe_send_presence_to_interested_destinations(
self, states: List[UserPresenceState]
):
"""If this instance is a federation sender, send the states to all
- destinations that are interested.
+ destinations that are interested. Filters out any states for remote
+ users.
"""
- if not self._send_federation:
+ if not self._federation:
return
- # If this worker sends federation we must have a FederationSender.
- assert self._federation
+ states = [s for s in states if self.is_mine_id(s.user_id)]
+
+ if not states:
+ return
hosts_and_states = await get_interested_remotes(
self.store,
self.presence_router,
states,
- self.state,
)
for destinations, states in hosts_and_states:
@@ -292,10 +307,17 @@ class WorkerPresenceHandler(BasePresenceHandler):
def __init__(self, hs):
super().__init__(hs)
self.hs = hs
- self.is_mine_id = hs.is_mine_id
+
+ self._presence_writer_instance = hs.config.worker.writers.presence[0]
self._presence_enabled = hs.config.use_presence
+ # Route presence EDUs to the right worker
+ hs.get_federation_registry().register_instances_for_edu(
+ "m.presence",
+ hs.config.worker.writers.presence,
+ )
+
# The number of ongoing syncs on this process, by user id.
# Empty if _presence_enabled is false.
self._user_to_num_current_syncs = {} # type: Dict[str, int]
@@ -303,8 +325,8 @@ class WorkerPresenceHandler(BasePresenceHandler):
self.notifier = hs.get_notifier()
self.instance_id = hs.get_instance_id()
- # user_id -> last_sync_ms. Lists the users that have stopped syncing
- # but we haven't notified the master of that yet
+ # user_id -> last_sync_ms. Lists the users that have stopped syncing but
+ # we haven't notified the presence writer of that yet
self.users_going_offline = {}
self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs)
@@ -337,22 +359,23 @@ class WorkerPresenceHandler(BasePresenceHandler):
)
def mark_as_coming_online(self, user_id):
- """A user has started syncing. Send a UserSync to the master, unless they
- had recently stopped syncing.
+ """A user has started syncing. Send a UserSync to the presence writer,
+ unless they had recently stopped syncing.
Args:
user_id (str)
"""
going_offline = self.users_going_offline.pop(user_id, None)
if not going_offline:
- # Safe to skip because we haven't yet told the master they were offline
+ # Safe to skip because we haven't yet told the presence writer they
+ # were offline
self.send_user_sync(user_id, True, self.clock.time_msec())
def mark_as_going_offline(self, user_id):
- """A user has stopped syncing. We wait before notifying the master as
- its likely they'll come back soon. This allows us to avoid sending
- a stopped syncing immediately followed by a started syncing notification
- to the master
+ """A user has stopped syncing. We wait before notifying the presence
+ writer as its likely they'll come back soon. This allows us to avoid
+ sending a stopped syncing immediately followed by a started syncing
+ notification to the presence writer
Args:
user_id (str)
@@ -360,8 +383,8 @@ class WorkerPresenceHandler(BasePresenceHandler):
self.users_going_offline[user_id] = self.clock.time_msec()
def send_stop_syncing(self):
- """Check if there are any users who have stopped syncing a while ago
- and haven't come back yet. If there are poke the master about them.
+ """Check if there are any users who have stopped syncing a while ago and
+ haven't come back yet. If there are poke the presence writer about them.
"""
now = self.clock.time_msec()
for user_id, last_sync_ms in list(self.users_going_offline.items()):
@@ -421,7 +444,14 @@ class WorkerPresenceHandler(BasePresenceHandler):
# If this is a federation sender, notify about presence updates.
await self.maybe_send_presence_to_interested_destinations(states)
- async def process_replication_rows(self, token, rows):
+ async def process_replication_rows(
+ self, stream_name: str, instance_name: str, token: int, rows: list
+ ):
+ await super().process_replication_rows(stream_name, instance_name, token, rows)
+
+ if stream_name != PresenceStream.NAME:
+ return
+
states = [
UserPresenceState(
row.user_id,
@@ -470,9 +500,12 @@ class WorkerPresenceHandler(BasePresenceHandler):
if not self.hs.config.use_presence:
return
- # Proxy request to master
+ # Proxy request to instance that writes presence
await self._set_state_client(
- user_id=user_id, state=state, ignore_status_msg=ignore_status_msg
+ instance_name=self._presence_writer_instance,
+ user_id=user_id,
+ state=state,
+ ignore_status_msg=ignore_status_msg,
)
async def bump_presence_active_time(self, user):
@@ -483,16 +516,17 @@ class WorkerPresenceHandler(BasePresenceHandler):
if not self.hs.config.use_presence:
return
- # Proxy request to master
+ # Proxy request to instance that writes presence
user_id = user.to_string()
- await self._bump_active_client(user_id=user_id)
+ await self._bump_active_client(
+ instance_name=self._presence_writer_instance, user_id=user_id
+ )
class PresenceHandler(BasePresenceHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
- self.is_mine_id = hs.is_mine_id
self.server_name = hs.hostname
self.wheel_timer = WheelTimer()
self.notifier = hs.get_notifier()
@@ -721,15 +755,12 @@ class PresenceHandler(BasePresenceHandler):
self.store,
self.presence_router,
list(to_federation_ping.values()),
- self.state,
)
- # Since this is master we know that we have a federation sender or
- # queue, and so this will be defined.
- assert self._federation
-
for destinations, states in hosts_and_states:
- self._federation.send_presence_to_destinations(states, destinations)
+ self._federation_queue.send_presence_to_destinations(
+ states, destinations
+ )
async def _handle_timeouts(self):
"""Checks the presence of users that have timed out and updates as
@@ -1208,13 +1239,9 @@ class PresenceHandler(BasePresenceHandler):
user_presence_states
)
- # Since this is master we know that we have a federation sender or
- # queue, and so this will be defined.
- assert self._federation
-
# Send out user presence updates for each destination
for destination, user_state_set in presence_destinations.items():
- self._federation.send_presence_to_destinations(
+ self._federation_queue.send_presence_to_destinations(
destinations=[destination], states=user_state_set
)
@@ -1354,7 +1381,6 @@ class PresenceEventSource:
self.get_presence_router = hs.get_presence_router
self.clock = hs.get_clock()
self.store = hs.get_datastore()
- self.state = hs.get_state_handler()
@log_function
async def get_new_events(
@@ -1823,7 +1849,6 @@ async def get_interested_remotes(
store: DataStore,
presence_router: PresenceRouter,
states: List[UserPresenceState],
- state_handler: StateHandler,
) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
"""Given a list of presence states figure out which remote servers
should be sent which.
@@ -1834,7 +1859,6 @@ async def get_interested_remotes(
store: The homeserver's data store.
presence_router: A module for augmenting the destinations for presence updates.
states: A list of incoming user presence updates.
- state_handler:
Returns:
A list of 2-tuples of destinations and states, where for
@@ -1851,7 +1875,8 @@ async def get_interested_remotes(
)
for room_id, states in room_ids_to_states.items():
- hosts = await state_handler.get_current_hosts_in_room(room_id)
+ user_ids = await store.get_users_in_room(room_id)
+ hosts = {get_domain_from_id(user_id) for user_id in user_ids}
hosts_and_states.append((hosts, states))
for user_id, states in users_to_states.items():
@@ -1859,3 +1884,198 @@ async def get_interested_remotes(
hosts_and_states.append(([host], states))
return hosts_and_states
+
+
+class PresenceFederationQueue:
+ """Handles sending ad hoc presence updates over federation, which are *not*
+ due to state updates (that get handled via the presence stream), e.g.
+ federation pings and sending existing present states to newly joined hosts.
+
+ Only the last N minutes will be queued, so if a federation sender instance
+ is down for longer then some updates will be dropped. This is OK as presence
+ is ephemeral, and so it will self correct eventually.
+
+ On workers the class tracks the last received position of the stream from
+ replication, and handles querying for missed updates over HTTP replication,
+ c.f. `get_current_token` and `get_replication_rows`.
+ """
+
+ # How long to keep entries in the queue for. Workers that are down for
+ # longer than this duration will miss out on older updates.
+ _KEEP_ITEMS_IN_QUEUE_FOR_MS = 5 * 60 * 1000
+
+ # How often to check if we can expire entries from the queue.
+ _CLEAR_ITEMS_EVERY_MS = 60 * 1000
+
+ def __init__(self, hs: "HomeServer", presence_handler: BasePresenceHandler):
+ self._clock = hs.get_clock()
+ self._notifier = hs.get_notifier()
+ self._instance_name = hs.get_instance_name()
+ self._presence_handler = presence_handler
+ self._repl_client = ReplicationGetStreamUpdates.make_client(hs)
+
+ # Should we keep a queue of recent presence updates? We only bother if
+ # another process may be handling federation sending.
+ self._queue_presence_updates = True
+
+ # Whether this instance is a presence writer.
+ self._presence_writer = self._instance_name in hs.config.worker.writers.presence
+
+ # The FederationSender instance, if this process sends federation traffic directly.
+ self._federation = None
+
+ if hs.should_send_federation():
+ self._federation = hs.get_federation_sender()
+
+ # We don't bother queuing up presence states if only this instance
+ # is sending federation.
+ if hs.config.worker.federation_shard_config.instances == [
+ self._instance_name
+ ]:
+ self._queue_presence_updates = False
+
+ # The queue of recently queued updates as tuples of: `(timestamp,
+ # stream_id, destinations, user_ids)`. We don't store the full states
+ # for efficiency, and remote workers will already have the full states
+ # cached.
+ self._queue = [] # type: List[Tuple[int, int, Collection[str], Set[str]]]
+
+ self._next_id = 1
+
+ # Map from instance name to current token
+ self._current_tokens = {} # type: Dict[str, int]
+
+ if self._queue_presence_updates:
+ self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)
+
+ def _clear_queue(self):
+ """Clear out older entries from the queue."""
+ clear_before = self._clock.time_msec() - self._KEEP_ITEMS_IN_QUEUE_FOR_MS
+
+ # The queue is sorted by timestamp, so we can bisect to find the right
+ # place to purge before. Note that we are searching using a 1-tuple with
+ # the time, which does The Right Thing since the queue is a tuple where
+ # the first item is a timestamp.
+ index = bisect(self._queue, (clear_before,))
+ self._queue = self._queue[index:]
+
+ def send_presence_to_destinations(
+ self, states: Collection[UserPresenceState], destinations: Collection[str]
+ ) -> None:
+ """Send the presence states to the given destinations.
+
+ Will forward to the local federation sender (if there is one) and queue
+ to send over replication (if there are other federation sender instances.).
+
+ Must only be called on the presence writer process.
+ """
+
+ # This should only be called on a presence writer.
+ assert self._presence_writer
+
+ if self._federation:
+ self._federation.send_presence_to_destinations(
+ states=states,
+ destinations=destinations,
+ )
+
+ if not self._queue_presence_updates:
+ return
+
+ now = self._clock.time_msec()
+
+ stream_id = self._next_id
+ self._next_id += 1
+
+ self._queue.append((now, stream_id, destinations, {s.user_id for s in states}))
+
+ self._notifier.notify_replication()
+
+ def get_current_token(self, instance_name: str) -> int:
+ """Get the current position of the stream.
+
+ On workers this returns the last stream ID received from replication.
+ """
+ if instance_name == self._instance_name:
+ return self._next_id - 1
+ else:
+ return self._current_tokens.get(instance_name, 0)
+
+ async def get_replication_rows(
+ self,
+ instance_name: str,
+ from_token: int,
+ upto_token: int,
+ target_row_count: int,
+ ) -> Tuple[List[Tuple[int, Tuple[str, str]]], int, bool]:
+ """Get all the updates between the two tokens.
+
+ We return rows in the form of `(destination, user_id)` to keep the size
+ of each row bounded (rather than returning the sets in a row).
+
+ On workers this will query the presence writer process via HTTP replication.
+ """
+ if instance_name != self._instance_name:
+ # If not local we query over http replication from the presence
+ # writer
+ result = await self._repl_client(
+ instance_name=instance_name,
+ stream_name=PresenceFederationStream.NAME,
+ from_token=from_token,
+ upto_token=upto_token,
+ )
+ return result["updates"], result["upto_token"], result["limited"]
+
+ # We can find the correct position in the queue by noting that there is
+ # exactly one entry per stream ID, and that the last entry has an ID of
+ # `self._next_id - 1`, so we can count backwards from the end.
+ #
+ # Since the start of the queue is periodically truncated we need to
+ # handle the case where `from_token` stream ID has already been dropped.
+ start_idx = max(from_token - self._next_id, -len(self._queue))
+
+ to_send = [] # type: List[Tuple[int, Tuple[str, str]]]
+ limited = False
+ new_id = upto_token
+ for _, stream_id, destinations, user_ids in self._queue[start_idx:]:
+ if stream_id > upto_token:
+ break
+
+ new_id = stream_id
+
+ to_send.extend(
+ (stream_id, (destination, user_id))
+ for destination in destinations
+ for user_id in user_ids
+ )
+
+ if len(to_send) > target_row_count:
+ limited = True
+ break
+
+ return to_send, new_id, limited
+
+ async def process_replication_rows(
+ self, stream_name: str, instance_name: str, token: int, rows: list
+ ):
+ if stream_name != PresenceFederationStream.NAME:
+ return
+
+ # We keep track of the current tokens (so that we can catch up with anything we missed after a disconnect)
+ self._current_tokens[instance_name] = token
+
+ # If we're a federation sender we pull out the presence states to send
+ # and forward them on.
+ if not self._federation:
+ return
+
+ hosts_to_users = {} # type: Dict[str, Set[str]]
+ for row in rows:
+ hosts_to_users.setdefault(row.destination, set()).add(row.user_id)
+
+ for host, user_ids in hosts_to_users.items():
+ states = await self._presence_handler.current_state_for_users(user_ids)
+ self._federation.send_presence_to_destinations(
+ states=states.values(),
+ destinations=[host],
+ )
|