diff options
Diffstat (limited to 'synapse/handlers/presence.py')
-rw-r--r-- | synapse/handlers/presence.py | 98 |
1 files changed, 58 insertions, 40 deletions
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 3594f3b00f..1846068150 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -25,24 +25,22 @@ The methods that define policy are: import abc import logging from contextlib import contextmanager -from typing import Dict, Iterable, List, Set - -from six import iteritems, itervalues +from typing import Dict, Iterable, List, Set, Tuple from prometheus_client import Counter from typing_extensions import ContextManager -from twisted.internet import defer - import synapse.metrics from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.errors import SynapseError +from synapse.api.presence import UserPresenceState from synapse.logging.context import run_in_background from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.presence import UserPresenceState -from synapse.types import JsonDict, UserID, get_domain_from_id +from synapse.state import StateHandler +from synapse.storage.databases.main import DataStore +from synapse.types import Collection, JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches.descriptors import cached from synapse.util.metrics import Measure @@ -170,14 +168,14 @@ class BasePresenceHandler(abc.ABC): for user_id in user_ids } - missing = [user_id for user_id, state in iteritems(states) if not state] + missing = [user_id for user_id, state in states.items() if not state] if missing: # There are things not in our in memory cache. Lets pull them out of # the database. res = await self.store.get_presence_for_users(missing) states.update(res) - missing = [user_id for user_id, state in iteritems(states) if not state] + missing = [user_id for user_id, state in states.items() if not state] if missing: new = { user_id: UserPresenceState.default(user_id) for user_id in missing @@ -321,7 +319,7 @@ class PresenceHandler(BasePresenceHandler): is some spurious presence changes that will self-correct. """ # If the DB pool has already terminated, don't try updating - if not self.store.db.is_running(): + if not self.store.db_pool.is_running(): return logger.info( @@ -632,7 +630,7 @@ class PresenceHandler(BasePresenceHandler): await self._update_states( [ prev_state.copy_and_replace(last_user_sync_ts=time_now_ms) - for prev_state in itervalues(prev_states) + for prev_state in prev_states.values() ] ) self.external_process_last_updated_ms.pop(process_id, None) @@ -775,7 +773,9 @@ class PresenceHandler(BasePresenceHandler): return False - async def get_all_presence_updates(self, last_id, current_id, limit): + async def get_all_presence_updates( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: """ Gets a list of presence update rows from between the given stream ids. Each row has: @@ -787,10 +787,31 @@ class PresenceHandler(BasePresenceHandler): - last_user_sync_ts(int) - status_msg(int) - currently_active(int) + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data """ + # TODO(markjh): replicate the unpersisted changes. # This could use the in-memory stores for recent changes. - rows = await self.store.get_all_presence_updates(last_id, current_id, limit) + rows = await self.store.get_all_presence_updates( + instance_name, last_id, current_id, limit + ) return rows def notify_new_event(self): @@ -874,16 +895,9 @@ class PresenceHandler(BasePresenceHandler): await self._on_user_joined_room(room_id, state_key) - async def _on_user_joined_room(self, room_id, user_id): + async def _on_user_joined_room(self, room_id: str, user_id: str) -> None: """Called when we detect a user joining the room via the current state delta stream. - - Args: - room_id (str) - user_id (str) - - Returns: - Deferred """ if self.is_mine_id(user_id): @@ -914,8 +928,8 @@ class PresenceHandler(BasePresenceHandler): # TODO: Check that this is actually a new server joining the # room. - user_ids = await self.state.get_current_users_in_room(room_id) - user_ids = list(filter(self.is_mine_id, user_ids)) + users = await self.state.get_current_users_in_room(room_id) + user_ids = list(filter(self.is_mine_id, users)) states_d = await self.current_state_for_users(user_ids) @@ -1087,7 +1101,7 @@ class PresenceEventSource(object): return (list(updates.values()), max_token) else: return ( - [s for s in itervalues(updates) if s.state != PresenceState.OFFLINE], + [s for s in updates.values() if s.state != PresenceState.OFFLINE], max_token, ) @@ -1275,22 +1289,24 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now): return new_state, persist_and_notify, federation_ping -@defer.inlineCallbacks -def get_interested_parties(store, states): +async def get_interested_parties( + store: DataStore, states: List[UserPresenceState] +) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]: """Given a list of states return which entities (rooms, users) are interested in the given states. Args: - states (list(UserPresenceState)) + store + states Returns: - 2-tuple: `(room_ids_to_states, users_to_states)`, + A 2-tuple of `(room_ids_to_states, users_to_states)`, with each item being a dict of `entity_name` -> `[UserPresenceState]` """ room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]] users_to_states = {} # type: Dict[str, List[UserPresenceState]] for state in states: - room_ids = yield store.get_rooms_for_user(state.user_id) + room_ids = await store.get_rooms_for_user(state.user_id) for room_id in room_ids: room_ids_to_states.setdefault(room_id, []).append(state) @@ -1300,34 +1316,36 @@ def get_interested_parties(store, states): return room_ids_to_states, users_to_states -@defer.inlineCallbacks -def get_interested_remotes(store, states, state_handler): +async def get_interested_remotes( + store: DataStore, states: List[UserPresenceState], state_handler: StateHandler +) -> List[Tuple[Collection[str], List[UserPresenceState]]]: """Given a list of presence states figure out which remote servers should be sent which. All the presence states should be for local users only. Args: - store (DataStore) - states (list(UserPresenceState)) + store + states + state_handler Returns: - Deferred list of ([destinations], [UserPresenceState]), where for - each row the list of UserPresenceState should be sent to each + A list of 2-tuples of destinations and states, where for + each tuple the list of UserPresenceState should be sent to each destination """ - hosts_and_states = [] + hosts_and_states = [] # type: List[Tuple[Collection[str], List[UserPresenceState]]] # First we look up the rooms each user is in (as well as any explicit # subscriptions), then for each distinct room we look up the remote # hosts in those rooms. - room_ids_to_states, users_to_states = yield get_interested_parties(store, states) + room_ids_to_states, users_to_states = await get_interested_parties(store, states) - for room_id, states in iteritems(room_ids_to_states): - hosts = yield state_handler.get_current_hosts_in_room(room_id) + for room_id, states in room_ids_to_states.items(): + hosts = await state_handler.get_current_hosts_in_room(room_id) hosts_and_states.append((hosts, states)) - for user_id, states in iteritems(users_to_states): + for user_id, states in users_to_states.items(): host = get_domain_from_id(user_id) hosts_and_states.append(([host], states)) |