diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 3594f3b00f..91a3aec1cc 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -25,24 +25,22 @@ The methods that define policy are:
import abc
import logging
from contextlib import contextmanager
-from typing import Dict, Iterable, List, Set
-
-from six import iteritems, itervalues
+from typing import Dict, Iterable, List, Set, Tuple
from prometheus_client import Counter
from typing_extensions import ContextManager
-from twisted.internet import defer
-
import synapse.metrics
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError
+from synapse.api.presence import UserPresenceState
from synapse.logging.context import run_in_background
from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.presence import UserPresenceState
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.state import StateHandler
+from synapse.storage.databases.main import DataStore
+from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cached
from synapse.util.metrics import Measure
@@ -170,14 +168,14 @@ class BasePresenceHandler(abc.ABC):
for user_id in user_ids
}
- missing = [user_id for user_id, state in iteritems(states) if not state]
+ missing = [user_id for user_id, state in states.items() if not state]
if missing:
# There are things not in our in memory cache. Lets pull them out of
# the database.
res = await self.store.get_presence_for_users(missing)
states.update(res)
- missing = [user_id for user_id, state in iteritems(states) if not state]
+ missing = [user_id for user_id, state in states.items() if not state]
if missing:
new = {
user_id: UserPresenceState.default(user_id) for user_id in missing
@@ -321,7 +319,7 @@ class PresenceHandler(BasePresenceHandler):
is some spurious presence changes that will self-correct.
"""
# If the DB pool has already terminated, don't try updating
- if not self.store.db.is_running():
+ if not self.store.db_pool.is_running():
return
logger.info(
@@ -632,7 +630,7 @@ class PresenceHandler(BasePresenceHandler):
await self._update_states(
[
prev_state.copy_and_replace(last_user_sync_ts=time_now_ms)
- for prev_state in itervalues(prev_states)
+ for prev_state in prev_states.values()
]
)
self.external_process_last_updated_ms.pop(process_id, None)
@@ -775,7 +773,9 @@ class PresenceHandler(BasePresenceHandler):
return False
- async def get_all_presence_updates(self, last_id, current_id, limit):
+ async def get_all_presence_updates(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, list]], int, bool]:
"""
Gets a list of presence update rows from between the given stream ids.
Each row has:
@@ -787,10 +787,31 @@ class PresenceHandler(BasePresenceHandler):
- last_user_sync_ts(int)
- status_msg(int)
- currently_active(int)
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
"""
+
# TODO(markjh): replicate the unpersisted changes.
# This could use the in-memory stores for recent changes.
- rows = await self.store.get_all_presence_updates(last_id, current_id, limit)
+ rows = await self.store.get_all_presence_updates(
+ instance_name, last_id, current_id, limit
+ )
return rows
def notify_new_event(self):
@@ -874,16 +895,9 @@ class PresenceHandler(BasePresenceHandler):
await self._on_user_joined_room(room_id, state_key)
- async def _on_user_joined_room(self, room_id, user_id):
+ async def _on_user_joined_room(self, room_id: str, user_id: str) -> None:
"""Called when we detect a user joining the room via the current state
delta stream.
-
- Args:
- room_id (str)
- user_id (str)
-
- Returns:
- Deferred
"""
if self.is_mine_id(user_id):
@@ -914,8 +928,8 @@ class PresenceHandler(BasePresenceHandler):
# TODO: Check that this is actually a new server joining the
# room.
- user_ids = await self.state.get_current_users_in_room(room_id)
- user_ids = list(filter(self.is_mine_id, user_ids))
+ users = await self.state.get_current_users_in_room(room_id)
+ user_ids = list(filter(self.is_mine_id, users))
states_d = await self.current_state_for_users(user_ids)
@@ -996,7 +1010,7 @@ def format_user_presence_state(state, now, include_user_id=True):
return content
-class PresenceEventSource(object):
+class PresenceEventSource:
def __init__(self, hs):
# We can't call get_presence_handler here because there's a cycle:
#
@@ -1087,7 +1101,7 @@ class PresenceEventSource(object):
return (list(updates.values()), max_token)
else:
return (
- [s for s in itervalues(updates) if s.state != PresenceState.OFFLINE],
+ [s for s in updates.values() if s.state != PresenceState.OFFLINE],
max_token,
)
@@ -1275,22 +1289,24 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now):
return new_state, persist_and_notify, federation_ping
-@defer.inlineCallbacks
-def get_interested_parties(store, states):
+async def get_interested_parties(
+ store: DataStore, states: List[UserPresenceState]
+) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]:
"""Given a list of states return which entities (rooms, users)
are interested in the given states.
Args:
- states (list(UserPresenceState))
+ store
+ states
Returns:
- 2-tuple: `(room_ids_to_states, users_to_states)`,
+ A 2-tuple of `(room_ids_to_states, users_to_states)`,
with each item being a dict of `entity_name` -> `[UserPresenceState]`
"""
room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]]
users_to_states = {} # type: Dict[str, List[UserPresenceState]]
for state in states:
- room_ids = yield store.get_rooms_for_user(state.user_id)
+ room_ids = await store.get_rooms_for_user(state.user_id)
for room_id in room_ids:
room_ids_to_states.setdefault(room_id, []).append(state)
@@ -1300,34 +1316,36 @@ def get_interested_parties(store, states):
return room_ids_to_states, users_to_states
-@defer.inlineCallbacks
-def get_interested_remotes(store, states, state_handler):
+async def get_interested_remotes(
+ store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
+) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
"""Given a list of presence states figure out which remote servers
should be sent which.
All the presence states should be for local users only.
Args:
- store (DataStore)
- states (list(UserPresenceState))
+ store
+ states
+ state_handler
Returns:
- Deferred list of ([destinations], [UserPresenceState]), where for
- each row the list of UserPresenceState should be sent to each
+ A list of 2-tuples of destinations and states, where for
+ each tuple the list of UserPresenceState should be sent to each
destination
"""
- hosts_and_states = []
+ hosts_and_states = [] # type: List[Tuple[Collection[str], List[UserPresenceState]]]
# First we look up the rooms each user is in (as well as any explicit
# subscriptions), then for each distinct room we look up the remote
# hosts in those rooms.
- room_ids_to_states, users_to_states = yield get_interested_parties(store, states)
+ room_ids_to_states, users_to_states = await get_interested_parties(store, states)
- for room_id, states in iteritems(room_ids_to_states):
- hosts = yield state_handler.get_current_hosts_in_room(room_id)
+ for room_id, states in room_ids_to_states.items():
+ hosts = await state_handler.get_current_hosts_in_room(room_id)
hosts_and_states.append((hosts, states))
- for user_id, states in iteritems(users_to_states):
+ for user_id, states in users_to_states.items():
host = get_domain_from_id(user_id)
hosts_and_states.append(([host], states))
|