summary refs log tree commit diff
path: root/synapse/handlers/presence.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/presence.py')
-rw-r--r--synapse/handlers/presence.py98
1 files changed, 58 insertions, 40 deletions
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 3594f3b00f..1846068150 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -25,24 +25,22 @@ The methods that define policy are:
 import abc
 import logging
 from contextlib import contextmanager
-from typing import Dict, Iterable, List, Set
-
-from six import iteritems, itervalues
+from typing import Dict, Iterable, List, Set, Tuple
 
 from prometheus_client import Counter
 from typing_extensions import ContextManager
 
-from twisted.internet import defer
-
 import synapse.metrics
 from synapse.api.constants import EventTypes, Membership, PresenceState
 from synapse.api.errors import SynapseError
+from synapse.api.presence import UserPresenceState
 from synapse.logging.context import run_in_background
 from synapse.logging.utils import log_function
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.presence import UserPresenceState
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.state import StateHandler
+from synapse.storage.databases.main import DataStore
+from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.descriptors import cached
 from synapse.util.metrics import Measure
@@ -170,14 +168,14 @@ class BasePresenceHandler(abc.ABC):
             for user_id in user_ids
         }
 
-        missing = [user_id for user_id, state in iteritems(states) if not state]
+        missing = [user_id for user_id, state in states.items() if not state]
         if missing:
             # There are things not in our in memory cache. Lets pull them out of
             # the database.
             res = await self.store.get_presence_for_users(missing)
             states.update(res)
 
-            missing = [user_id for user_id, state in iteritems(states) if not state]
+            missing = [user_id for user_id, state in states.items() if not state]
             if missing:
                 new = {
                     user_id: UserPresenceState.default(user_id) for user_id in missing
@@ -321,7 +319,7 @@ class PresenceHandler(BasePresenceHandler):
         is some spurious presence changes that will self-correct.
         """
         # If the DB pool has already terminated, don't try updating
-        if not self.store.db.is_running():
+        if not self.store.db_pool.is_running():
             return
 
         logger.info(
@@ -632,7 +630,7 @@ class PresenceHandler(BasePresenceHandler):
             await self._update_states(
                 [
                     prev_state.copy_and_replace(last_user_sync_ts=time_now_ms)
-                    for prev_state in itervalues(prev_states)
+                    for prev_state in prev_states.values()
                 ]
             )
             self.external_process_last_updated_ms.pop(process_id, None)
@@ -775,7 +773,9 @@ class PresenceHandler(BasePresenceHandler):
 
         return False
 
-    async def get_all_presence_updates(self, last_id, current_id, limit):
+    async def get_all_presence_updates(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, list]], int, bool]:
         """
         Gets a list of presence update rows from between the given stream ids.
         Each row has:
@@ -787,10 +787,31 @@ class PresenceHandler(BasePresenceHandler):
         - last_user_sync_ts(int)
         - status_msg(int)
         - currently_active(int)
+
+        Args:
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
+        Returns:
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
         """
+
         # TODO(markjh): replicate the unpersisted changes.
         # This could use the in-memory stores for recent changes.
-        rows = await self.store.get_all_presence_updates(last_id, current_id, limit)
+        rows = await self.store.get_all_presence_updates(
+            instance_name, last_id, current_id, limit
+        )
         return rows
 
     def notify_new_event(self):
@@ -874,16 +895,9 @@ class PresenceHandler(BasePresenceHandler):
 
             await self._on_user_joined_room(room_id, state_key)
 
-    async def _on_user_joined_room(self, room_id, user_id):
+    async def _on_user_joined_room(self, room_id: str, user_id: str) -> None:
         """Called when we detect a user joining the room via the current state
         delta stream.
-
-        Args:
-            room_id (str)
-            user_id (str)
-
-        Returns:
-            Deferred
         """
 
         if self.is_mine_id(user_id):
@@ -914,8 +928,8 @@ class PresenceHandler(BasePresenceHandler):
             # TODO: Check that this is actually a new server joining the
             # room.
 
-            user_ids = await self.state.get_current_users_in_room(room_id)
-            user_ids = list(filter(self.is_mine_id, user_ids))
+            users = await self.state.get_current_users_in_room(room_id)
+            user_ids = list(filter(self.is_mine_id, users))
 
             states_d = await self.current_state_for_users(user_ids)
 
@@ -1087,7 +1101,7 @@ class PresenceEventSource(object):
             return (list(updates.values()), max_token)
         else:
             return (
-                [s for s in itervalues(updates) if s.state != PresenceState.OFFLINE],
+                [s for s in updates.values() if s.state != PresenceState.OFFLINE],
                 max_token,
             )
 
@@ -1275,22 +1289,24 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now):
     return new_state, persist_and_notify, federation_ping
 
 
-@defer.inlineCallbacks
-def get_interested_parties(store, states):
+async def get_interested_parties(
+    store: DataStore, states: List[UserPresenceState]
+) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]:
     """Given a list of states return which entities (rooms, users)
     are interested in the given states.
 
     Args:
-        states (list(UserPresenceState))
+        store
+        states
 
     Returns:
-        2-tuple: `(room_ids_to_states, users_to_states)`,
+        A 2-tuple of `(room_ids_to_states, users_to_states)`,
         with each item being a dict of `entity_name` -> `[UserPresenceState]`
     """
     room_ids_to_states = {}  # type: Dict[str, List[UserPresenceState]]
     users_to_states = {}  # type: Dict[str, List[UserPresenceState]]
     for state in states:
-        room_ids = yield store.get_rooms_for_user(state.user_id)
+        room_ids = await store.get_rooms_for_user(state.user_id)
         for room_id in room_ids:
             room_ids_to_states.setdefault(room_id, []).append(state)
 
@@ -1300,34 +1316,36 @@ def get_interested_parties(store, states):
     return room_ids_to_states, users_to_states
 
 
-@defer.inlineCallbacks
-def get_interested_remotes(store, states, state_handler):
+async def get_interested_remotes(
+    store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
+) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
     """Given a list of presence states figure out which remote servers
     should be sent which.
 
     All the presence states should be for local users only.
 
     Args:
-        store (DataStore)
-        states (list(UserPresenceState))
+        store
+        states
+        state_handler
 
     Returns:
-        Deferred list of ([destinations], [UserPresenceState]), where for
-        each row the list of UserPresenceState should be sent to each
+        A list of 2-tuples of destinations and states, where for
+        each tuple the list of UserPresenceState should be sent to each
         destination
     """
-    hosts_and_states = []
+    hosts_and_states = []  # type: List[Tuple[Collection[str], List[UserPresenceState]]]
 
     # First we look up the rooms each user is in (as well as any explicit
     # subscriptions), then for each distinct room we look up the remote
     # hosts in those rooms.
-    room_ids_to_states, users_to_states = yield get_interested_parties(store, states)
+    room_ids_to_states, users_to_states = await get_interested_parties(store, states)
 
-    for room_id, states in iteritems(room_ids_to_states):
-        hosts = yield state_handler.get_current_hosts_in_room(room_id)
+    for room_id, states in room_ids_to_states.items():
+        hosts = await state_handler.get_current_hosts_in_room(room_id)
         hosts_and_states.append((hosts, states))
 
-    for user_id, states in iteritems(users_to_states):
+    for user_id, states in users_to_states.items():
         host = get_domain_from_id(user_id)
         hosts_and_states.append(([host], states))