summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-06-11 15:37:22 +0100
committerErik Johnston <erik@matrix.org>2021-06-11 15:37:22 +0100
commita4b573ee48e03f3c3afcf1f666469be92dc65878 (patch)
tree3a26159d7763d0704489bca95ffd44dfe535c206 /synapse
parentMerge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes (diff)
parentFixup changelog (diff)
downloadsynapse-a4b573ee48e03f3c3afcf1f666469be92dc65878.tar.xz
Merge branch 'release-v1.36' into matrix-org-hotfixes
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/handlers/presence.py50
-rw-r--r--synapse/storage/databases/main/presence.py2
-rw-r--r--synapse/storage/util/id_generators.py15
4 files changed, 48 insertions, 21 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 445e8a5cad..407ba14a76 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
 except ImportError:
     pass
 
-__version__ = "1.35.1"
+__version__ = "1.36.0rc2"
 
 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     # We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index f5a049d754..44ed7a0712 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -299,14 +299,14 @@ class BasePresenceHandler(abc.ABC):
         if not states:
             return
 
-        hosts_and_states = await get_interested_remotes(
+        hosts_to_states = await get_interested_remotes(
             self.store,
             self.presence_router,
             states,
         )
 
-        for destinations, states in hosts_and_states:
-            self._federation.send_presence_to_destinations(states, destinations)
+        for destination, host_states in hosts_to_states.items():
+            self._federation.send_presence_to_destinations(host_states, [destination])
 
     async def send_full_presence_to_users(self, user_ids: Collection[str]):
         """
@@ -495,9 +495,6 @@ class WorkerPresenceHandler(BasePresenceHandler):
             users=users_to_states.keys(),
         )
 
-        # If this is a federation sender, notify about presence updates.
-        await self.maybe_send_presence_to_interested_destinations(states)
-
     async def process_replication_rows(
         self, stream_name: str, instance_name: str, token: int, rows: list
     ):
@@ -519,11 +516,27 @@ class WorkerPresenceHandler(BasePresenceHandler):
             for row in rows
         ]
 
-        for state in states:
-            self.user_to_current_state[state.user_id] = state
+        # The list of states to notify sync streams and remote servers about.
+        # This is calculated by comparing the old and new states for each user
+        # using `should_notify(..)`.
+        #
+        # Note that this is necessary as the presence writer will periodically
+        # flush presence state changes that should not be notified about to the
+        # DB, and so will be sent over the replication stream.
+        state_to_notify = []
+
+        for new_state in states:
+            old_state = self.user_to_current_state.get(new_state.user_id)
+            self.user_to_current_state[new_state.user_id] = new_state
+
+            if not old_state or should_notify(old_state, new_state):
+                state_to_notify.append(new_state)
 
         stream_id = token
-        await self.notify_from_replication(states, stream_id)
+        await self.notify_from_replication(state_to_notify, stream_id)
+
+        # If this is a federation sender, notify about presence updates.
+        await self.maybe_send_presence_to_interested_destinations(state_to_notify)
 
     def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
         return [
@@ -829,15 +842,15 @@ class PresenceHandler(BasePresenceHandler):
             if to_federation_ping:
                 federation_presence_out_counter.inc(len(to_federation_ping))
 
-                hosts_and_states = await get_interested_remotes(
+                hosts_to_states = await get_interested_remotes(
                     self.store,
                     self.presence_router,
                     list(to_federation_ping.values()),
                 )
 
-                for destinations, states in hosts_and_states:
+                for destination, states in hosts_to_states.items():
                     self._federation_queue.send_presence_to_destinations(
-                        states, destinations
+                        states, [destination]
                     )
 
     async def _handle_timeouts(self) -> None:
@@ -1962,7 +1975,7 @@ async def get_interested_remotes(
     store: DataStore,
     presence_router: PresenceRouter,
     states: List[UserPresenceState],
-) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
+) -> Dict[str, Set[UserPresenceState]]:
     """Given a list of presence states figure out which remote servers
     should be sent which.
 
@@ -1974,11 +1987,9 @@ async def get_interested_remotes(
         states: A list of incoming user presence updates.
 
     Returns:
-        A list of 2-tuples of destinations and states, where for
-        each tuple the list of UserPresenceState should be sent to each
-        destination
+        A map from destinations to presence states to send to that destination.
     """
-    hosts_and_states = []  # type: List[Tuple[Collection[str], List[UserPresenceState]]]
+    hosts_and_states: Dict[str, Set[UserPresenceState]] = {}
 
     # First we look up the rooms each user is in (as well as any explicit
     # subscriptions), then for each distinct room we look up the remote
@@ -1990,11 +2001,12 @@ async def get_interested_remotes(
     for room_id, states in room_ids_to_states.items():
         user_ids = await store.get_users_in_room(room_id)
         hosts = {get_domain_from_id(user_id) for user_id in user_ids}
-        hosts_and_states.append((hosts, states))
+        for host in hosts:
+            hosts_and_states.setdefault(host, set()).update(states)
 
     for user_id, states in users_to_states.items():
         host = get_domain_from_id(user_id)
-        hosts_and_states.append(([host], states))
+        hosts_and_states.setdefault(host, set()).update(states)
 
     return hosts_and_states
 
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 6a2baa7841..1388771c40 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -50,7 +50,7 @@ class PresenceStore(SQLBaseStore):
                 instance_name=self._instance_name,
                 tables=[("presence_stream", "instance_name", "stream_id")],
                 sequence_name="presence_stream_sequence",
-                writers=hs.config.worker.writers.to_device,
+                writers=hs.config.worker.writers.presence,
             )
         else:
             self._presence_id_gen = StreamIdGenerator(
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index b1bd3a52d9..f1e62f9e85 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -397,6 +397,11 @@ class MultiWriterIdGenerator:
                 # ... persist event ...
         """
 
+        # If we have a list of instances that are allowed to write to this
+        # stream, make sure we're in it.
+        if self._writers and self._instance_name not in self._writers:
+            raise Exception("Tried to allocate stream ID on non-writer")
+
         return _MultiWriterCtxManager(self)
 
     def get_next_mult(self, n: int):
@@ -406,6 +411,11 @@ class MultiWriterIdGenerator:
                 # ... persist events ...
         """
 
+        # If we have a list of instances that are allowed to write to this
+        # stream, make sure we're in it.
+        if self._writers and self._instance_name not in self._writers:
+            raise Exception("Tried to allocate stream ID on non-writer")
+
         return _MultiWriterCtxManager(self, n)
 
     def get_next_txn(self, txn: LoggingTransaction):
@@ -416,6 +426,11 @@ class MultiWriterIdGenerator:
             # ... persist event ...
         """
 
+        # If we have a list of instances that are allowed to write to this
+        # stream, make sure we're in it.
+        if self._writers and self._instance_name not in self._writers:
+            raise Exception("Tried to allocate stream ID on non-writer")
+
         next_id = self._load_next_id_txn(txn)
 
         with self._lock: