diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 54293d0b9c..7e76db3e2a 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -631,7 +631,7 @@ class DeviceListUpdater:
max_len=10000,
expiry_ms=30 * 60 * 1000,
iterable=True,
- )
+ ) # type: ExpiringCache[str, Set[str]]
# Attempt to resync out of sync device lists every 30s.
self._resync_retry_in_progress = False
@@ -760,7 +760,7 @@ class DeviceListUpdater:
"""Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta.
"""
- seen_updates = self._seen_updates.get(user_id, set())
+ seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str]
extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 739653a3fa..92b18378fc 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -38,7 +38,6 @@ from synapse.types import (
)
from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer
-from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING:
@@ -1292,17 +1291,6 @@ class SigningKeyEduUpdater:
# user_id -> list of updates waiting to be handled.
self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
- # Recently seen stream ids. We don't bother keeping these in the DB,
- # but they're useful to have them about to reduce the number of spurious
- # resyncs.
- self._seen_updates = ExpiringCache(
- cache_name="signing_key_update_edu",
- clock=self.clock,
- max_len=10000,
- expiry_ms=30 * 60 * 1000,
- iterable=True,
- )
-
async def incoming_signing_key_update(
self, origin: str, edu_content: JsonDict
) -> None:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 3ebee38ebe..5ea8a7b603 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -21,7 +21,17 @@ import itertools
import logging
from collections.abc import Container
from http import HTTPStatus
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Union,
+)
import attr
from signedjson.key import decode_verify_key_bytes
@@ -171,15 +181,17 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
- async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
+ async def on_receive_pdu(
+ self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False
+ ) -> None:
"""Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
Args:
- origin (str): server which initiated the /send/ transaction. Will
+ origin: server which initiated the /send/ transaction. Will
be used to fetch missing events or state.
- pdu (FrozenEvent): received PDU
- sent_to_us_directly (bool): True if this event was pushed to us; False if
+ pdu: received PDU
+ sent_to_us_directly: True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event.
"""
@@ -411,13 +423,15 @@ class FederationHandler(BaseHandler):
await self._process_received_pdu(origin, pdu, state=state)
- async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+ async def _get_missing_events_for_pdu(
+ self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
+ ) -> None:
"""
Args:
- origin (str): Origin of the pdu. Will be called to get the missing events
+ origin: Origin of the pdu. Will be called to get the missing events
pdu: received pdu
- prevs (set(str)): List of event ids which we are missing
- min_depth (int): Minimum depth of events to return.
+ prevs: List of event ids which we are missing
+ min_depth: Minimum depth of events to return.
"""
room_id = pdu.room_id
@@ -778,7 +792,7 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
- ):
+ ) -> None:
"""Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
@@ -887,7 +901,9 @@ class FederationHandler(BaseHandler):
logger.exception("Failed to resync device for %s", sender)
@log_function
- async def backfill(self, dest, room_id, limit, extremities):
+ async def backfill(
+ self, dest: str, room_id: str, limit: int, extremities: List[str]
+ ) -> List[EventBase]:
"""Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side
@@ -1142,16 +1158,15 @@ class FederationHandler(BaseHandler):
curr_state = await self.state_handler.get_current_state(room_id)
- def get_domains_from_state(state):
+ def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
"""Get joined domains from state
Args:
- state (dict[tuple, FrozenEvent]): State map from type/state
- key to event.
+ state: State map from type/state key to event.
Returns:
- list[tuple[str, int]]: Returns a list of servers with the
- lowest depth of their joins. Sorted by lowest depth first.
+ Returns a list of servers with the lowest depth of their joins.
+ Sorted by lowest depth first.
"""
joined_users = [
(state_key, int(event.depth))
@@ -1179,7 +1194,7 @@ class FederationHandler(BaseHandler):
domain for domain, depth in curr_domains if domain != self.server_name
]
- async def try_backfill(domains):
+ async def try_backfill(domains: List[str]) -> bool:
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
@@ -1258,21 +1273,25 @@ class FederationHandler(BaseHandler):
}
for e_id, _ in sorted_extremeties_tuple:
- likely_domains = get_domains_from_state(states[e_id])
+ likely_extremeties_domains = get_domains_from_state(states[e_id])
success = await try_backfill(
- [dom for dom, _ in likely_domains if dom not in tried_domains]
+ [
+ dom
+ for dom, _ in likely_extremeties_domains
+ if dom not in tried_domains
+ ]
)
if success:
return True
- tried_domains.update(dom for dom, _ in likely_domains)
+ tried_domains.update(dom for dom, _ in likely_extremeties_domains)
return False
async def _get_events_and_persist(
self, destination: str, room_id: str, events: Iterable[str]
- ):
+ ) -> None:
"""Fetch the given events from a server, and persist them as outliers.
This function *does not* recursively get missing auth events of the
@@ -1348,7 +1367,7 @@ class FederationHandler(BaseHandler):
event_infos,
)
- def _sanity_check_event(self, ev):
+ def _sanity_check_event(self, ev: EventBase) -> None:
"""
Do some early sanity checks of a received event
@@ -1357,9 +1376,7 @@ class FederationHandler(BaseHandler):
or cascade of event fetches.
Args:
- ev (synapse.events.EventBase): event to be checked
-
- Returns: None
+ ev: event to be checked
Raises:
SynapseError if the event does not pass muster
@@ -1380,7 +1397,7 @@ class FederationHandler(BaseHandler):
)
raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
- async def send_invite(self, target_host, event):
+ async def send_invite(self, target_host: str, event: EventBase) -> EventBase:
"""Sends the invite to the remote server for signing.
Invites must be signed by the invitee's server before distribution.
@@ -1528,12 +1545,13 @@ class FederationHandler(BaseHandler):
run_in_background(self._handle_queued_pdus, room_queue)
- async def _handle_queued_pdus(self, room_queue):
+ async def _handle_queued_pdus(
+ self, room_queue: List[Tuple[EventBase, str]]
+ ) -> None:
"""Process PDUs which got queued up while we were busy send_joining.
Args:
- room_queue (list[FrozenEvent, str]): list of PDUs to be processed
- and the servers that sent them
+ room_queue: list of PDUs to be processed and the servers that sent them
"""
for p, origin in room_queue:
try:
@@ -1612,7 +1630,7 @@ class FederationHandler(BaseHandler):
return event
- async def on_send_join_request(self, origin, pdu):
+ async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict:
"""We have received a join event for a room. Fully process it and
respond with the current state and auth chains.
"""
@@ -1668,7 +1686,7 @@ class FederationHandler(BaseHandler):
async def on_invite_request(
self, origin: str, event: EventBase, room_version: RoomVersion
- ):
+ ) -> EventBase:
"""We've got an invite event. Process and persist it. Sign it.
Respond with the now signed event.
@@ -1841,7 +1859,7 @@ class FederationHandler(BaseHandler):
return event
- async def on_send_leave_request(self, origin, pdu):
+ async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None:
""" We have received a leave event for a room. Fully process it."""
event = pdu
@@ -1969,12 +1987,17 @@ class FederationHandler(BaseHandler):
else:
return None
- async def get_min_depth_for_context(self, context):
+ async def get_min_depth_for_context(self, context: str) -> int:
return await self.store.get_min_depth(context)
async def _handle_new_event(
- self, origin, event, state=None, auth_events=None, backfilled=False
- ):
+ self,
+ origin: str,
+ event: EventBase,
+ state: Optional[Iterable[EventBase]] = None,
+ auth_events: Optional[MutableStateMap[EventBase]] = None,
+ backfilled: bool = False,
+ ) -> EventContext:
context = await self._prep_event(
origin, event, state=state, auth_events=auth_events, backfilled=backfilled
)
@@ -2280,40 +2303,14 @@ class FederationHandler(BaseHandler):
logger.warning("Soft-failing %r because %s", event, e)
event.internal_metadata.soft_failed = True
- async def on_query_auth(
- self, origin, event_id, room_id, remote_auth_chain, rejects, missing
- ):
- in_room = await self.auth.check_host_in_room(room_id, origin)
- if not in_room:
- raise AuthError(403, "Host not in room.")
-
- event = await self.store.get_event(event_id, check_room_id=room_id)
-
- # Just go through and process each event in `remote_auth_chain`. We
- # don't want to fall into the trap of `missing` being wrong.
- for e in remote_auth_chain:
- try:
- await self._handle_new_event(origin, e)
- except AuthError:
- pass
-
- # Now get the current auth_chain for the event.
- local_auth_chain = await self.store.get_auth_chain(
- room_id, list(event.auth_event_ids()), include_given=True
- )
-
- # TODO: Check if we would now reject event_id. If so we need to tell
- # everyone.
-
- ret = await self.construct_auth_difference(local_auth_chain, remote_auth_chain)
-
- logger.debug("on_query_auth returning: %s", ret)
-
- return ret
-
async def on_get_missing_events(
- self, origin, room_id, earliest_events, latest_events, limit
- ):
+ self,
+ origin: str,
+ room_id: str,
+ earliest_events: List[str],
+ latest_events: List[str],
+ limit: int,
+ ) -> List[EventBase]:
in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -2617,8 +2614,8 @@ class FederationHandler(BaseHandler):
assumes that we have already processed all events in remote_auth
Params:
- local_auth (list)
- remote_auth (list)
+ local_auth
+ remote_auth
Returns:
dict
@@ -2742,8 +2739,8 @@ class FederationHandler(BaseHandler):
@log_function
async def exchange_third_party_invite(
- self, sender_user_id, target_user_id, room_id, signed
- ):
+ self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict
+ ) -> None:
third_party_invite = {"signed": signed}
event_dict = {
@@ -2835,8 +2832,12 @@ class FederationHandler(BaseHandler):
await member_handler.send_membership_event(None, event, context)
async def add_display_name_to_third_party_invite(
- self, room_version, event_dict, event, context
- ):
+ self,
+ room_version: str,
+ event_dict: JsonDict,
+ event: EventBase,
+ context: EventContext,
+ ) -> Tuple[EventBase, EventContext]:
key = (
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"],
@@ -2872,13 +2873,13 @@ class FederationHandler(BaseHandler):
EventValidator().validate_new(event, self.config)
return (event, context)
- async def _check_signature(self, event, context):
+ async def _check_signature(self, event: EventBase, context: EventContext) -> None:
"""
Checks that the signature in the event is consistent with its invite.
Args:
- event (Event): The m.room.member event to check
- context (EventContext):
+ event: The m.room.member event to check
+ context:
Raises:
AuthError: if signature didn't match any keys, or key has been
@@ -2964,13 +2965,13 @@ class FederationHandler(BaseHandler):
raise last_exception
- async def _check_key_revocation(self, public_key, url):
+ async def _check_key_revocation(self, public_key: str, url: str) -> None:
"""
Checks whether public_key has been revoked.
Args:
- public_key (str): base-64 encoded public key.
- url (str): Key revocation URL.
+ public_key: base-64 encoded public key.
+ url: Key revocation URL.
Raises:
AuthError: if they key has been revoked.
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index da92feacc9..c817f2952d 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -25,7 +25,17 @@ The methods that define policy are:
import abc
import logging
from contextlib import contextmanager
-from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Dict,
+ FrozenSet,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ Union,
+)
from prometheus_client import Counter
from typing_extensions import ContextManager
@@ -34,6 +44,7 @@ import synapse.metrics
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
+from synapse.events.presence_router import PresenceRouter
from synapse.logging.context import run_in_background
from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
@@ -42,7 +53,7 @@ from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
@@ -209,6 +220,7 @@ class PresenceHandler(BasePresenceHandler):
self.notifier = hs.get_notifier()
self.federation = hs.get_federation_sender()
self.state = hs.get_state_handler()
+ self.presence_router = hs.get_presence_router()
self._presence_enabled = hs.config.use_presence
federation_registry = hs.get_federation_registry()
@@ -653,7 +665,7 @@ class PresenceHandler(BasePresenceHandler):
"""
stream_id, max_token = await self.store.update_presence(states)
- parties = await get_interested_parties(self.store, states)
+ parties = await get_interested_parties(self.store, self.presence_router, states)
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
@@ -1041,7 +1053,12 @@ class PresenceEventSource:
#
# Presence -> Notifier -> PresenceEventSource -> Presence
#
+ # Same with get_module_api, get_presence_router
+ #
+ # AuthHandler -> Notifier -> PresenceEventSource -> ModuleApi -> AuthHandler
self.get_presence_handler = hs.get_presence_handler
+ self.get_module_api = hs.get_module_api
+ self.get_presence_router = hs.get_presence_router
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
@@ -1055,7 +1072,7 @@ class PresenceEventSource:
include_offline=True,
explicit_room_id=None,
**kwargs
- ):
+ ) -> Tuple[List[UserPresenceState], int]:
# The process for getting presence events are:
# 1. Get the rooms the user is in.
# 2. Get the list of user in the rooms.
@@ -1068,7 +1085,17 @@ class PresenceEventSource:
# We don't try and limit the presence updates by the current token, as
# sending down the rare duplicate is not a concern.
+ user_id = user.to_string()
+ stream_change_cache = self.store.presence_stream_cache
+
with Measure(self.clock, "presence.get_new_events"):
+ if user_id in self.get_module_api()._send_full_presence_to_local_users:
+ # This user has been specified by a module to receive all current, online
+ # user presence. Removing from_key and setting include_offline to false
+ # will do effectively this.
+ from_key = None
+ include_offline = False
+
if from_key is not None:
from_key = int(from_key)
@@ -1091,59 +1118,209 @@ class PresenceEventSource:
# doesn't return. C.f. #5503.
return [], max_token
- presence = self.get_presence_handler()
- stream_change_cache = self.store.presence_stream_cache
-
+ # Figure out which other users this user should receive updates for
users_interested_in = await self._get_interested_in(user, explicit_room_id)
- user_ids_changed = set() # type: Collection[str]
- changed = None
- if from_key:
- changed = stream_change_cache.get_all_entities_changed(from_key)
+ # We have a set of users that we're interested in the presence of. We want to
+ # cross-reference that with the users that have actually changed their presence.
- if changed is not None and len(changed) < 500:
- assert isinstance(user_ids_changed, set)
+ # Check whether this user should see all user updates
- # For small deltas, its quicker to get all changes and then
- # work out if we share a room or they're in our presence list
- get_updates_counter.labels("stream").inc()
- for other_user_id in changed:
- if other_user_id in users_interested_in:
- user_ids_changed.add(other_user_id)
- else:
- # Too many possible updates. Find all users we can see and check
- # if any of them have changed.
- get_updates_counter.labels("full").inc()
+ if users_interested_in == PresenceRouter.ALL_USERS:
+ # Provide presence state for all users
+ presence_updates = await self._filter_all_presence_updates_for_user(
+ user_id, include_offline, from_key
+ )
- if from_key:
- user_ids_changed = stream_change_cache.get_entities_changed(
- users_interested_in, from_key
+ # Remove the user from the list of users to receive all presence
+ if user_id in self.get_module_api()._send_full_presence_to_local_users:
+ self.get_module_api()._send_full_presence_to_local_users.remove(
+ user_id
)
+
+ return presence_updates, max_token
+
+ # Make mypy happy. users_interested_in should now be a set
+ assert not isinstance(users_interested_in, str)
+
+ # The set of users that we're interested in and that have had a presence update.
+ # We'll actually pull the presence updates for these users at the end.
+ interested_and_updated_users = (
+ set()
+ ) # type: Union[Set[str], FrozenSet[str]]
+
+ if from_key:
+ # First get all users that have had a presence update
+ updated_users = stream_change_cache.get_all_entities_changed(from_key)
+
+ # Cross-reference users we're interested in with those that have had updates.
+ # Use a slightly-optimised method for processing smaller sets of updates.
+ if updated_users is not None and len(updated_users) < 500:
+ # For small deltas, it's quicker to get all changes and then
+ # cross-reference with the users we're interested in
+ get_updates_counter.labels("stream").inc()
+ for other_user_id in updated_users:
+ if other_user_id in users_interested_in:
+ # mypy thinks this variable could be a FrozenSet as it's possibly set
+ # to one in the `get_entities_changed` call below, and `add()` is not
+ # method on a FrozenSet. That doesn't affect us here though, as
+ # `interested_and_updated_users` is clearly a set() above.
+ interested_and_updated_users.add(other_user_id) # type: ignore
else:
- user_ids_changed = users_interested_in
+ # Too many possible updates. Find all users we can see and check
+ # if any of them have changed.
+ get_updates_counter.labels("full").inc()
- updates = await presence.current_state_for_users(user_ids_changed)
+ interested_and_updated_users = (
+ stream_change_cache.get_entities_changed(
+ users_interested_in, from_key
+ )
+ )
+ else:
+ # No from_key has been specified. Return the presence for all users
+ # this user is interested in
+ interested_and_updated_users = users_interested_in
+
+ # Retrieve the current presence state for each user
+ users_to_state = await self.get_presence_handler().current_state_for_users(
+ interested_and_updated_users
+ )
+ presence_updates = list(users_to_state.values())
- if include_offline:
- return (list(updates.values()), max_token)
+ # Remove the user from the list of users to receive all presence
+ if user_id in self.get_module_api()._send_full_presence_to_local_users:
+ self.get_module_api()._send_full_presence_to_local_users.remove(user_id)
+
+ if not include_offline:
+ # Filter out offline presence states
+ presence_updates = self._filter_offline_presence_state(presence_updates)
+
+ return presence_updates, max_token
+
+ async def _filter_all_presence_updates_for_user(
+ self,
+ user_id: str,
+ include_offline: bool,
+ from_key: Optional[int] = None,
+ ) -> List[UserPresenceState]:
+ """
+ Computes the presence updates a user should receive.
+
+ First pulls presence updates from the database. Then consults PresenceRouter
+ for whether any updates should be excluded by user ID.
+
+ Args:
+ user_id: The User ID of the user to compute presence updates for.
+ include_offline: Whether to include offline presence states from the results.
+ from_key: The minimum stream ID of updates to pull from the database
+ before filtering.
+
+ Returns:
+ A list of presence states for the given user to receive.
+ """
+ if from_key:
+ # Only return updates since the last sync
+ updated_users = self.store.presence_stream_cache.get_all_entities_changed(
+ from_key
+ )
+ if not updated_users:
+ updated_users = []
+
+ # Get the actual presence update for each change
+ users_to_state = await self.get_presence_handler().current_state_for_users(
+ updated_users
+ )
+ presence_updates = list(users_to_state.values())
+
+ if not include_offline:
+ # Filter out offline states
+ presence_updates = self._filter_offline_presence_state(presence_updates)
else:
- return (
- [s for s in updates.values() if s.state != PresenceState.OFFLINE],
- max_token,
+ users_to_state = await self.store.get_presence_for_all_users(
+ include_offline=include_offline
)
+ presence_updates = list(users_to_state.values())
+
+ # TODO: This feels wildly inefficient, and it's unfortunate we need to ask the
+ # module for information on a number of users when we then only take the info
+ # for a single user
+
+ # Filter through the presence router
+ users_to_state_set = await self.get_presence_router().get_users_for_states(
+ presence_updates
+ )
+
+ # We only want the mapping for the syncing user
+ presence_updates = list(users_to_state_set[user_id])
+
+ # Return presence information for all users
+ return presence_updates
+
+ def _filter_offline_presence_state(
+ self, presence_updates: Iterable[UserPresenceState]
+ ) -> List[UserPresenceState]:
+ """Given an iterable containing user presence updates, return a list with any offline
+ presence states removed.
+
+ Args:
+ presence_updates: Presence states to filter
+
+ Returns:
+ A new list with any offline presence states removed.
+ """
+ return [
+ update
+ for update in presence_updates
+ if update.state != PresenceState.OFFLINE
+ ]
+
def get_current_key(self):
return self.store.get_current_presence_token()
@cached(num_args=2, cache_context=True)
- async def _get_interested_in(self, user, explicit_room_id, cache_context):
+ async def _get_interested_in(
+ self,
+ user: UserID,
+ explicit_room_id: Optional[str] = None,
+ cache_context: Optional[_CacheContext] = None,
+ ) -> Union[Set[str], str]:
"""Returns the set of users that the given user should see presence
- updates for
+ updates for.
+
+ Args:
+ user: The user to retrieve presence updates for.
+ explicit_room_id: The users that are in the room will be returned.
+
+ Returns:
+ A set of user IDs to return presence updates for, or "ALL" to return all
+ known updates.
"""
user_id = user.to_string()
users_interested_in = set()
users_interested_in.add(user_id) # So that we receive our own presence
+ # cache_context isn't likely to ever be None due to the @cached decorator,
+ # but we can't have a non-optional argument after the optional argument
+ # explicit_room_id either. Assert cache_context is not None so we can use it
+ # without mypy complaining.
+ assert cache_context
+
+ # Check with the presence router whether we should poll additional users for
+ # their presence information
+ additional_users = await self.get_presence_router().get_interested_users(
+ user.to_string()
+ )
+ if additional_users == PresenceRouter.ALL_USERS:
+ # If the module requested that this user see the presence updates of *all*
+ # users, then simply return that instead of calculating what rooms this
+ # user shares
+ return PresenceRouter.ALL_USERS
+
+ # Add the additional users from the router
+ users_interested_in.update(additional_users)
+
+ # Find the users who share a room with this user
users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id, on_invalidate=cache_context.invalidate
)
@@ -1314,14 +1491,15 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now):
async def get_interested_parties(
- store: DataStore, states: List[UserPresenceState]
+ store: DataStore, presence_router: PresenceRouter, states: List[UserPresenceState]
) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]:
"""Given a list of states return which entities (rooms, users)
are interested in the given states.
Args:
- store
- states
+ store: The homeserver's data store.
+ presence_router: A module for augmenting the destinations for presence updates.
+ states: A list of incoming user presence updates.
Returns:
A 2-tuple of `(room_ids_to_states, users_to_states)`,
@@ -1337,11 +1515,22 @@ async def get_interested_parties(
# Always notify self
users_to_states.setdefault(state.user_id, []).append(state)
+ # Ask a presence routing module for any additional parties if one
+ # is loaded.
+ router_users_to_states = await presence_router.get_users_for_states(states)
+
+ # Update the dictionaries with additional destinations and state to send
+ for user_id, user_states in router_users_to_states.items():
+ users_to_states.setdefault(user_id, []).extend(user_states)
+
return room_ids_to_states, users_to_states
async def get_interested_remotes(
- store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
+ store: DataStore,
+ presence_router: PresenceRouter,
+ states: List[UserPresenceState],
+ state_handler: StateHandler,
) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
"""Given a list of presence states figure out which remote servers
should be sent which.
@@ -1349,9 +1538,10 @@ async def get_interested_remotes(
All the presence states should be for local users only.
Args:
- store
- states
- state_handler
+ store: The homeserver's data store.
+ presence_router: A module for augmenting the destinations for presence updates.
+ states: A list of incoming user presence updates.
+ state_handler:
Returns:
A list of 2-tuples of destinations and states, where for
@@ -1363,7 +1553,9 @@ async def get_interested_remotes(
# First we look up the rooms each user is in (as well as any explicit
# subscriptions), then for each distinct room we look up the remote
# hosts in those rooms.
- room_ids_to_states, users_to_states = await get_interested_parties(store, states)
+ room_ids_to_states, users_to_states = await get_interested_parties(
+ store, presence_router, states
+ )
for room_id, states in room_ids_to_states.items():
hosts = await state_handler.get_current_hosts_in_room(room_id)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 1cf12f3255..894ef859f4 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -20,7 +20,7 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from synapse import types
-from synapse.api.constants import AccountDataTypes, EventTypes, Membership
+from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules, Membership
from synapse.api.errors import (
AuthError,
Codes,
@@ -29,6 +29,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.api.ratelimiting import Ratelimiter
+from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
@@ -178,6 +179,62 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id)
+ async def _can_join_without_invite(
+ self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str
+ ) -> bool:
+ """
+ Check whether a user can join a room without an invite.
+
+ When joining a room with restricted joined rules (as defined in MSC3083),
+ the membership of spaces must be checked during join.
+
+ Args:
+ state_ids: The state of the room as it currently is.
+ room_version: The room version of the room being joined.
+ user_id: The user joining the room.
+
+ Returns:
+ True if the user can join the room, false otherwise.
+ """
+ # This only applies to room versions which support the new join rule.
+ if not room_version.msc3083_join_rules:
+ return True
+
+ # If there's no join rule, then it defaults to public (so this doesn't apply).
+ join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
+ if not join_rules_event_id:
+ return True
+
+ # If the join rule is not restricted, this doesn't apply.
+ join_rules_event = await self.store.get_event(join_rules_event_id)
+ if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED:
+ return True
+
+ # If allowed is of the wrong form, then only allow invited users.
+ allowed_spaces = join_rules_event.content.get("allow", [])
+ if not isinstance(allowed_spaces, list):
+ return False
+
+ # Get the list of joined rooms and see if there's an overlap.
+ joined_rooms = await self.store.get_rooms_for_user(user_id)
+
+ # Pull out the other room IDs, invalid data gets filtered.
+ for space in allowed_spaces:
+ if not isinstance(space, dict):
+ continue
+
+ space_id = space.get("space")
+ if not isinstance(space_id, str):
+ continue
+
+ # The user was joined to one of the spaces specified, they can join
+ # this room!
+ if space_id in joined_rooms:
+ return True
+
+ # The user was not in any of the required spaces.
+ return False
+
async def _local_membership_update(
self,
requester: Requester,
@@ -235,9 +292,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if event.membership == Membership.JOIN:
newly_joined = True
+ user_is_invited = False
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
+ user_is_invited = prev_member_event.membership == Membership.INVITE
+
+ # If the member is not already in the room and is not accepting an invite,
+ # check if they should be allowed access via membership in a space.
+ if (
+ newly_joined
+ and not user_is_invited
+ and not await self._can_join_without_invite(
+ prev_state_ids, event.room_version, user_id
+ )
+ ):
+ raise AuthError(
+ 403,
+ "You do not belong to any of the required spaces to join this room.",
+ )
# Only rate-limit if the user actually joined the room, otherwise we'll end
# up blocking profile updates.
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 7b356ba7e5..ff11266c67 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -252,13 +252,13 @@ class SyncHandler:
self.storage = hs.get_storage()
self.state_store = self.storage.state
- # ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
+ # ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
self.lazy_loaded_members_cache = ExpiringCache(
"lazy_loaded_members_cache",
self.clock,
max_len=0,
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
- )
+ ) # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]]
async def wait_for_sync_for_user(
self,
@@ -733,8 +733,10 @@ class SyncHandler:
def get_lazy_loaded_members_cache(
self, cache_key: Tuple[str, Optional[str]]
- ) -> LruCache:
- cache = self.lazy_loaded_members_cache.get(cache_key)
+ ) -> LruCache[str, str]:
+ cache = self.lazy_loaded_members_cache.get(
+ cache_key
+ ) # type: Optional[LruCache[str, str]]
if cache is None:
logger.debug("creating LruCache for %r", cache_key)
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
|