summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/events.py20
-rw-r--r--synapse/handlers/initial_sync.py10
-rw-r--r--synapse/handlers/presence.py210
3 files changed, 135 insertions, 105 deletions
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index ec18a42a68..71a89f09c7 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -19,6 +19,7 @@ import random
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import AuthError, SynapseError
 from synapse.events import EventBase
+from synapse.handlers.presence import format_user_presence_state
 from synapse.logging.utils import log_function
 from synapse.types import UserID
 from synapse.visibility import filter_events_for_client
@@ -97,6 +98,8 @@ class EventStreamHandler(BaseHandler):
                 explicit_room_id=room_id,
             )
 
+            time_now = self.clock.time_msec()
+
             # When the user joins a new room, or another user joins a currently
             # joined room, we need to send down presence for those users.
             to_add = []
@@ -112,19 +115,20 @@ class EventStreamHandler(BaseHandler):
                         users = await self.state.get_current_users_in_room(
                             event.room_id
                         )
-                        states = await presence_handler.get_states(users, as_event=True)
-                        to_add.extend(states)
                     else:
+                        users = [event.state_key]
 
-                        ev = await presence_handler.get_state(
-                            UserID.from_string(event.state_key), as_event=True
-                        )
-                        to_add.append(ev)
+                    states = await presence_handler.get_states(users)
+                    to_add.extend(
+                        {
+                            "type": EventTypes.Presence,
+                            "content": format_user_presence_state(state, time_now),
+                        }
+                        for state in states
+                    )
 
             events.extend(to_add)
 
-            time_now = self.clock.time_msec()
-
             chunks = await self._event_serializer.serialize_events(
                 events,
                 time_now,
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index b116500c7d..f88bad5f25 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -381,10 +381,16 @@ class InitialSyncHandler(BaseHandler):
                 return []
 
             states = await presence_handler.get_states(
-                [m.user_id for m in room_members], as_event=True
+                [m.user_id for m in room_members]
             )
 
-            return states
+            return [
+                {
+                    "type": EventTypes.Presence,
+                    "content": format_user_presence_state(s, time_now),
+                }
+                for s in states
+            ]
 
         async def get_receipts():
             receipts = await self.store.get_linearized_receipts_for_room(
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 6912165622..5cbefae177 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -21,10 +22,10 @@ The methods that define policy are:
     - PresenceHandler._handle_timeouts
     - should_notify
 """
-
+import abc
 import logging
 from contextlib import contextmanager
-from typing import Dict, List, Set
+from typing import Dict, Iterable, List, Set
 
 from six import iteritems, itervalues
 
@@ -41,7 +42,7 @@ from synapse.logging.utils import log_function
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.presence import UserPresenceState
-from synapse.types import UserID, get_domain_from_id
+from synapse.types import JsonDict, UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.descriptors import cached
 from synapse.util.metrics import Measure
@@ -99,13 +100,106 @@ EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000
 assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
 
 
-class PresenceHandler(object):
+class BasePresenceHandler(abc.ABC):
+    """Parts of the PresenceHandler that are shared between workers and master"""
+
+    def __init__(self, hs: "synapse.server.HomeServer"):
+        self.clock = hs.get_clock()
+        self.store = hs.get_datastore()
+
+        active_presence = self.store.take_presence_startup_info()
+        self.user_to_current_state = {state.user_id: state for state in active_presence}
+
+    @abc.abstractmethod
+    async def user_syncing(
+        self, user_id: str, affect_presence: bool
+    ) -> ContextManager[None]:
+        """Returns a context manager that should surround any stream requests
+        from the user.
+
+        This allows us to keep track of who is currently streaming and who isn't
+        without having to have timers outside of this module to avoid flickering
+        when users disconnect/reconnect.
+
+        Args:
+            user_id: the user that is starting a sync
+            affect_presence: If false this function will be a no-op.
+                Useful for streams that are not associated with an actual
+                client that is being used by a user.
+        """
+
+    @abc.abstractmethod
+    def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
+        """Get an iterable of syncing users on this worker, to send to the presence handler
+
+        This is called when a replication connection is established. It should return
+        a list of user ids, which are then sent as USER_SYNC commands to inform the
+        process handling presence about those users.
+
+        Returns:
+            An iterable of user_id strings.
+        """
+
+    async def get_state(self, target_user: UserID) -> UserPresenceState:
+        results = await self.get_states([target_user.to_string()])
+        return results[0]
+
+    async def get_states(
+        self, target_user_ids: Iterable[str]
+    ) -> List[UserPresenceState]:
+        """Get the presence state for users."""
+
+        updates_d = await self.current_state_for_users(target_user_ids)
+        updates = list(updates_d.values())
+
+        for user_id in set(target_user_ids) - {u.user_id for u in updates}:
+            updates.append(UserPresenceState.default(user_id))
+
+        return updates
+
+    async def current_state_for_users(
+        self, user_ids: Iterable[str]
+    ) -> Dict[str, UserPresenceState]:
+        """Get the current presence state for multiple users.
+
+        Returns:
+            dict: `user_id` -> `UserPresenceState`
+        """
+        states = {
+            user_id: self.user_to_current_state.get(user_id, None)
+            for user_id in user_ids
+        }
+
+        missing = [user_id for user_id, state in iteritems(states) if not state]
+        if missing:
+            # There are things not in our in memory cache. Lets pull them out of
+            # the database.
+            res = await self.store.get_presence_for_users(missing)
+            states.update(res)
+
+            missing = [user_id for user_id, state in iteritems(states) if not state]
+            if missing:
+                new = {
+                    user_id: UserPresenceState.default(user_id) for user_id in missing
+                }
+                states.update(new)
+                self.user_to_current_state.update(new)
+
+        return states
+
+    @abc.abstractmethod
+    async def set_state(
+        self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False
+    ) -> None:
+        """Set the presence state of the user. """
+
+
+class PresenceHandler(BasePresenceHandler):
     def __init__(self, hs: "synapse.server.HomeServer"):
+        super().__init__(hs)
         self.hs = hs
         self.is_mine_id = hs.is_mine_id
         self.server_name = hs.hostname
-        self.clock = hs.get_clock()
-        self.store = hs.get_datastore()
         self.wheel_timer = WheelTimer()
         self.notifier = hs.get_notifier()
         self.federation = hs.get_federation_sender()
@@ -115,13 +209,6 @@ class PresenceHandler(object):
 
         federation_registry.register_edu_handler("m.presence", self.incoming_presence)
 
-        active_presence = self.store.take_presence_startup_info()
-
-        # A dictionary of the current state of users. This is prefilled with
-        # non-offline presence from the DB. We should fetch from the DB if
-        # we can't find a users presence in here.
-        self.user_to_current_state = {state.user_id: state for state in active_presence}
-
         LaterGauge(
             "synapse_handlers_presence_user_to_current_state_size",
             "",
@@ -130,7 +217,7 @@ class PresenceHandler(object):
         )
 
         now = self.clock.time_msec()
-        for state in active_presence:
+        for state in self.user_to_current_state.values():
             self.wheel_timer.insert(
                 now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER
             )
@@ -361,10 +448,18 @@ class PresenceHandler(object):
 
         timers_fired_counter.inc(len(states))
 
+        syncing_user_ids = {
+            user_id
+            for user_id, count in self.user_to_num_current_syncs.items()
+            if count
+        }
+        for user_ids in self.external_process_to_current_syncs.values():
+            syncing_user_ids.update(user_ids)
+
         changes = handle_timeouts(
             states,
             is_mine_fn=self.is_mine_id,
-            syncing_user_ids=self.get_currently_syncing_users(),
+            syncing_user_ids=syncing_user_ids,
             now=now,
         )
 
@@ -462,22 +557,9 @@ class PresenceHandler(object):
 
         return _user_syncing()
 
-    def get_currently_syncing_users(self):
-        """Get the set of user ids that are currently syncing on this HS.
-        Returns:
-            set(str): A set of user_id strings.
-        """
-        if self.hs.config.use_presence:
-            syncing_user_ids = {
-                user_id
-                for user_id, count in self.user_to_num_current_syncs.items()
-                if count
-            }
-            for user_ids in self.external_process_to_current_syncs.values():
-                syncing_user_ids.update(user_ids)
-            return syncing_user_ids
-        else:
-            return set()
+    def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
+        # since we are the process handling presence, there is nothing to do here.
+        return []
 
     async def update_external_syncs_row(
         self, process_id, user_id, is_syncing, sync_time_msec
@@ -554,34 +636,6 @@ class PresenceHandler(object):
         res = await self.current_state_for_users([user_id])
         return res[user_id]
 
-    async def current_state_for_users(self, user_ids):
-        """Get the current presence state for multiple users.
-
-        Returns:
-            dict: `user_id` -> `UserPresenceState`
-        """
-        states = {
-            user_id: self.user_to_current_state.get(user_id, None)
-            for user_id in user_ids
-        }
-
-        missing = [user_id for user_id, state in iteritems(states) if not state]
-        if missing:
-            # There are things not in our in memory cache. Lets pull them out of
-            # the database.
-            res = await self.store.get_presence_for_users(missing)
-            states.update(res)
-
-            missing = [user_id for user_id, state in iteritems(states) if not state]
-            if missing:
-                new = {
-                    user_id: UserPresenceState.default(user_id) for user_id in missing
-                }
-                states.update(new)
-                self.user_to_current_state.update(new)
-
-        return states
-
     async def _persist_and_notify(self, states):
         """Persist states in the database, poke the notifier and send to
         interested remote servers
@@ -669,40 +723,6 @@ class PresenceHandler(object):
             federation_presence_counter.inc(len(updates))
             await self._update_states(updates)
 
-    async def get_state(self, target_user, as_event=False):
-        results = await self.get_states([target_user.to_string()], as_event=as_event)
-
-        return results[0]
-
-    async def get_states(self, target_user_ids, as_event=False):
-        """Get the presence state for users.
-
-        Args:
-            target_user_ids (list)
-            as_event (bool): Whether to format it as a client event or not.
-
-        Returns:
-            list
-        """
-
-        updates = await self.current_state_for_users(target_user_ids)
-        updates = list(updates.values())
-
-        for user_id in set(target_user_ids) - {u.user_id for u in updates}:
-            updates.append(UserPresenceState.default(user_id))
-
-        now = self.clock.time_msec()
-        if as_event:
-            return [
-                {
-                    "type": "m.presence",
-                    "content": format_user_presence_state(state, now),
-                }
-                for state in updates
-            ]
-        else:
-            return updates
-
     async def set_state(self, target_user, state, ignore_status_msg=False):
         """Set the presence state of the user.
         """
@@ -889,7 +909,7 @@ class PresenceHandler(object):
             user_ids = await self.state.get_current_users_in_room(room_id)
             user_ids = list(filter(self.is_mine_id, user_ids))
 
-            states = await self.current_state_for_users(user_ids)
+            states_d = await self.current_state_for_users(user_ids)
 
             # Filter out old presence, i.e. offline presence states where
             # the user hasn't been active for a week. We can change this
@@ -899,7 +919,7 @@ class PresenceHandler(object):
             now = self.clock.time_msec()
             states = [
                 state
-                for state in states.values()
+                for state in states_d.values()
                 if state.state != PresenceState.OFFLINE
                 or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000
                 or state.status_msg is not None