diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 2d7e6df6e4..a8d3fbc6de 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
class AccountDataEventSource(object):
def __init__(self, hs):
@@ -23,15 +21,14 @@ class AccountDataEventSource(object):
def get_current_key(self, direction="f"):
return self.store.get_max_account_data_stream_id()
- @defer.inlineCallbacks
- def get_new_events(self, user, from_key, **kwargs):
+ async def get_new_events(self, user, from_key, **kwargs):
user_id = user.to_string()
last_stream_id = from_key
- current_stream_id = yield self.store.get_max_account_data_stream_id()
+ current_stream_id = self.store.get_max_account_data_stream_id()
results = []
- tags = yield self.store.get_updated_tags(user_id, last_stream_id)
+ tags = await self.store.get_updated_tags(user_id, last_stream_id)
for room_id, room_tags in tags.items():
results.append(
@@ -41,7 +38,7 @@ class AccountDataEventSource(object):
(
account_data,
room_account_data,
- ) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
+ ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
for account_data_type, content in account_data.items():
results.append({"type": account_data_type, "content": content})
@@ -53,7 +50,3 @@ class AccountDataEventSource(object):
)
return results, current_stream_id
-
- @defer.inlineCallbacks
- def get_pagination_rows(self, user, config, key):
- return [], config.to_id
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index d04e0fe576..829f52eca1 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -18,8 +18,7 @@ import email.utils
import logging
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
-
-from twisted.internet import defer
+from typing import List
from synapse.api.errors import StoreError
from synapse.logging.context import make_deferred_yieldable
@@ -78,42 +77,39 @@ class AccountValidityHandler(object):
# run as a background process to make sure that the database transactions
# have a logcontext to report to
return run_as_background_process(
- "send_renewals", self.send_renewal_emails
+ "send_renewals", self._send_renewal_emails
)
self.clock.looping_call(send_emails, 30 * 60 * 1000)
- @defer.inlineCallbacks
- def send_renewal_emails(self):
+ async def _send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time
configured in the ``renew_at`` parameter from the ``account_validity``
configuration, and sends renewal emails to all of these users as long as they
have an email 3PID attached to their account.
"""
- expiring_users = yield self.store.get_users_expiring_soon()
+ expiring_users = await self.store.get_users_expiring_soon()
if expiring_users:
for user in expiring_users:
- yield self._send_renewal_email(
+ await self._send_renewal_email(
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
)
- @defer.inlineCallbacks
- def send_renewal_email_to_user(self, user_id):
- expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
- yield self._send_renewal_email(user_id, expiration_ts)
+ async def send_renewal_email_to_user(self, user_id: str):
+ expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
+ await self._send_renewal_email(user_id, expiration_ts)
- @defer.inlineCallbacks
- def _send_renewal_email(self, user_id, expiration_ts):
+ async def _send_renewal_email(self, user_id: str, expiration_ts: int):
"""Sends out a renewal email to every email address attached to the given user
with a unique link allowing them to renew their account.
Args:
- user_id (str): ID of the user to send email(s) to.
- expiration_ts (int): Timestamp in milliseconds for the expiration date of
+ user_id: ID of the user to send email(s) to.
+ expiration_ts: Timestamp in milliseconds for the expiration date of
this user's account (used in the email templates).
"""
- addresses = yield self._get_email_addresses_for_user(user_id)
+ addresses = await self._get_email_addresses_for_user(user_id)
# Stop right here if the user doesn't have at least one email address.
# In this case, they will have to ask their server admin to renew their
@@ -125,7 +121,7 @@ class AccountValidityHandler(object):
return
try:
- user_display_name = yield self.store.get_profile_displayname(
+ user_display_name = await self.store.get_profile_displayname(
UserID.from_string(user_id).localpart
)
if user_display_name is None:
@@ -133,7 +129,7 @@ class AccountValidityHandler(object):
except StoreError:
user_display_name = user_id
- renewal_token = yield self._get_renewal_token(user_id)
+ renewal_token = await self._get_renewal_token(user_id)
url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
self.hs.config.public_baseurl,
renewal_token,
@@ -165,7 +161,7 @@ class AccountValidityHandler(object):
logger.info("Sending renewal email to %s", address)
- yield make_deferred_yieldable(
+ await make_deferred_yieldable(
self.sendmail(
self.hs.config.email_smtp_host,
self._raw_from,
@@ -180,19 +176,18 @@ class AccountValidityHandler(object):
)
)
- yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
+ await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
- @defer.inlineCallbacks
- def _get_email_addresses_for_user(self, user_id):
+ async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:
"""Retrieve the list of email addresses attached to a user's account.
Args:
- user_id (str): ID of the user to lookup email addresses for.
+ user_id: ID of the user to lookup email addresses for.
Returns:
- defer.Deferred[list[str]]: Email addresses for this account.
+ Email addresses for this account.
"""
- threepids = yield self.store.user_get_threepids(user_id)
+ threepids = await self.store.user_get_threepids(user_id)
addresses = []
for threepid in threepids:
@@ -201,16 +196,15 @@ class AccountValidityHandler(object):
return addresses
- @defer.inlineCallbacks
- def _get_renewal_token(self, user_id):
+ async def _get_renewal_token(self, user_id: str) -> str:
"""Generates a 32-byte long random string that will be inserted into the
user's renewal email's unique link, then saves it into the database.
Args:
- user_id (str): ID of the user to generate a string for.
+ user_id: ID of the user to generate a string for.
Returns:
- defer.Deferred[str]: The generated string.
+ The generated string.
Raises:
StoreError(500): Couldn't generate a unique string after 5 attempts.
@@ -219,52 +213,52 @@ class AccountValidityHandler(object):
while attempts < 5:
try:
renewal_token = stringutils.random_string(32)
- yield self.store.set_renewal_token_for_user(user_id, renewal_token)
+ await self.store.set_renewal_token_for_user(user_id, renewal_token)
return renewal_token
except StoreError:
attempts += 1
raise StoreError(500, "Couldn't generate a unique string as refresh string.")
- @defer.inlineCallbacks
- def renew_account(self, renewal_token):
+ async def renew_account(self, renewal_token: str) -> bool:
"""Renews the account attached to a given renewal token by pushing back the
expiration date by the current validity period in the server's configuration.
Args:
- renewal_token (str): Token sent with the renewal request.
+ renewal_token: Token sent with the renewal request.
Returns:
- bool: Whether the provided token is valid.
+ Whether the provided token is valid.
"""
try:
- user_id = yield self.store.get_user_from_renewal_token(renewal_token)
+ user_id = await self.store.get_user_from_renewal_token(renewal_token)
except StoreError:
- defer.returnValue(False)
+ return False
logger.debug("Renewing an account for user %s", user_id)
- yield self.renew_account_for_user(user_id)
+ await self.renew_account_for_user(user_id)
- defer.returnValue(True)
+ return True
- @defer.inlineCallbacks
- def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False):
+ async def renew_account_for_user(
+ self, user_id: str, expiration_ts: int = None, email_sent: bool = False
+ ) -> int:
"""Renews the account attached to a given user by pushing back the
expiration date by the current validity period in the server's
configuration.
Args:
- renewal_token (str): Token sent with the renewal request.
- expiration_ts (int): New expiration date. Defaults to now + validity period.
- email_sent (bool): Whether an email has been sent for this validity period.
+ renewal_token: Token sent with the renewal request.
+ expiration_ts: New expiration date. Defaults to now + validity period.
+ email_sen: Whether an email has been sent for this validity period.
Defaults to False.
Returns:
- defer.Deferred[int]: New expiration date for this account, as a timestamp
- in milliseconds since epoch.
+ New expiration date for this account, as a timestamp in
+ milliseconds since epoch.
"""
if expiration_ts is None:
expiration_ts = self.clock.time_msec() + self._account_validity.period
- yield self.store.set_account_validity_for_user(
+ await self.store.set_account_validity_for_user(
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 57a10daefd..2d889364d4 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -264,6 +264,7 @@ class E2eKeysHandler(object):
return ret
+ @defer.inlineCallbacks
def get_cross_signing_keys_from_cache(self, query, from_user_id):
"""Get cross-signing keys for users from the database
@@ -283,14 +284,32 @@ class E2eKeysHandler(object):
self_signing_keys = {}
user_signing_keys = {}
- # Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486
- return defer.succeed(
- {
- "master_keys": master_keys,
- "self_signing_keys": self_signing_keys,
- "user_signing_keys": user_signing_keys,
- }
- )
+ user_ids = list(query)
+
+ keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
+
+ for user_id, user_info in keys.items():
+ if user_info is None:
+ continue
+ if "master" in user_info:
+ master_keys[user_id] = user_info["master"]
+ if "self_signing" in user_info:
+ self_signing_keys[user_id] = user_info["self_signing"]
+
+ if (
+ from_user_id in keys
+ and keys[from_user_id] is not None
+ and "user_signing" in keys[from_user_id]
+ ):
+ # users can see other users' master and self-signing keys, but can
+ # only see their own user-signing keys
+ user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
+
+ return {
+ "master_keys": master_keys,
+ "self_signing_keys": self_signing_keys,
+ "user_signing_keys": user_signing_keys,
+ }
@trace
@defer.inlineCallbacks
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index abe02907b9..60bb00fc6a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,7 +19,7 @@
import itertools
import logging
-from typing import Dict, Iterable, Optional, Sequence, Tuple
+from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import six
from six import iteritems, itervalues
@@ -63,6 +63,7 @@ from synapse.replication.http.federation import (
)
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.types import UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room
@@ -163,8 +164,7 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
- @defer.inlineCallbacks
- def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
+ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
""" Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
@@ -174,17 +174,15 @@ class FederationHandler(BaseHandler):
pdu (FrozenEvent): received PDU
sent_to_us_directly (bool): True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event.
-
- Returns (Deferred): completes with None
"""
room_id = pdu.room_id
event_id = pdu.event_id
- logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
+ logger.info("handling received PDU: %s", pdu)
# We reprocess pdus when we have seen them only as outliers
- existing = yield self.store.get_event(
+ existing = await self.store.get_event(
event_id, allow_none=True, allow_rejected=True
)
@@ -228,7 +226,7 @@ class FederationHandler(BaseHandler):
#
# Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version.
- is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
+ is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
"[%s %s] Ignoring PDU from %s as we're not in the room",
@@ -243,12 +241,12 @@ class FederationHandler(BaseHandler):
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
- min_depth = yield self.get_min_depth_for_context(pdu.room_id)
+ min_depth = await self.get_min_depth_for_context(pdu.room_id)
logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
prevs = set(pdu.prev_event_ids())
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this
@@ -268,7 +266,7 @@ class FederationHandler(BaseHandler):
len(missing_prevs),
shortstr(missing_prevs),
)
- with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
+ with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
"[%s %s] Acquired room lock to fetch %d missing prev_events",
room_id,
@@ -276,13 +274,19 @@ class FederationHandler(BaseHandler):
len(missing_prevs),
)
- yield self._get_missing_events_for_pdu(
- origin, pdu, prevs, min_depth
- )
+ try:
+ await self._get_missing_events_for_pdu(
+ origin, pdu, prevs, min_depth
+ )
+ except Exception as e:
+ raise Exception(
+ "Error fetching missing prev_events for %s: %s"
+ % (event_id, e)
+ )
# Update the set of things we've seen after trying to
# fetch the missing stuff
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if not prevs - seen:
logger.info(
@@ -290,14 +294,6 @@ class FederationHandler(BaseHandler):
room_id,
event_id,
)
- elif missing_prevs:
- logger.info(
- "[%s %s] Not recursively fetching %d missing prev_events: %s",
- room_id,
- event_id,
- len(missing_prevs),
- shortstr(missing_prevs),
- )
if prevs - seen:
# We've still not been able to get all of the prev_events for this event.
@@ -342,12 +338,18 @@ class FederationHandler(BaseHandler):
affected=pdu.event_id,
)
+ logger.info(
+ "Event %s is missing prev_events: calculating state for a "
+ "backwards extremity",
+ event_id,
+ )
+
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
event_map = {event_id: pdu}
try:
# Get the state of the events we know about
- ours = yield self.state_store.get_state_groups_ids(room_id, seen)
+ ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list(
@@ -361,17 +363,14 @@ class FederationHandler(BaseHandler):
# know about
for p in prevs - seen:
logger.info(
- "[%s %s] Requesting state at missing prev_event %s",
- room_id,
- event_id,
- p,
+ "Requesting state at missing prev_event %s", event_id,
)
with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
- (remote_state, _,) = yield self._get_state_for_room(
+ (remote_state, _,) = await self._get_state_for_room(
origin, room_id, p, include_event_in_state=True
)
@@ -383,8 +382,8 @@ class FederationHandler(BaseHandler):
for x in remote_state:
event_map[x.event_id] = x
- room_version = yield self.store.get_room_version(room_id)
- state_map = yield resolve_events_with_store(
+ room_version = await self.store.get_room_version(room_id)
+ state_map = await resolve_events_with_store(
room_id,
room_version,
state_maps,
@@ -397,10 +396,10 @@ class FederationHandler(BaseHandler):
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
- evs = yield self.store.get_events(
+ evs = await self.store.get_events(
list(state_map.values()),
get_prev_content=False,
- check_redacted=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
)
event_map.update(evs)
@@ -420,10 +419,9 @@ class FederationHandler(BaseHandler):
affected=event_id,
)
- yield self._process_received_pdu(origin, pdu, state=state)
+ await self._process_received_pdu(origin, pdu, state=state)
- @defer.inlineCallbacks
- def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+ async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
"""
Args:
origin (str): Origin of the pdu. Will be called to get the missing events
@@ -435,12 +433,12 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id
event_id = pdu.event_id
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if not prevs - seen:
return
- latest = yield self.store.get_latest_event_ids_in_room(room_id)
+ latest = await self.store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
@@ -504,7 +502,7 @@ class FederationHandler(BaseHandler):
# All that said: Let's try increasing the timout to 60s and see what happens.
try:
- missing_events = yield self.federation_client.get_missing_events(
+ missing_events = await self.federation_client.get_missing_events(
origin,
room_id,
earliest_events_ids=list(latest),
@@ -543,7 +541,7 @@ class FederationHandler(BaseHandler):
)
with nested_logging_context(ev.event_id):
try:
- yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
+ await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
except FederationError as e:
if e.code == 403:
logger.warning(
@@ -555,29 +553,30 @@ class FederationHandler(BaseHandler):
else:
raise
- @defer.inlineCallbacks
- @log_function
- def _get_state_for_room(
- self, destination, room_id, event_id, include_event_in_state
- ):
+ async def _get_state_for_room(
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ include_event_in_state: bool = False,
+ ) -> Tuple[List[EventBase], List[EventBase]]:
"""Requests all of the room state at a given event from a remote homeserver.
Args:
- destination (str): The remote homeserver to query for the state.
- room_id (str): The id of the room we're interested in.
- event_id (str): The id of the event we want the state at.
+ destination: The remote homeserver to query for the state.
+ room_id: The id of the room we're interested in.
+ event_id: The id of the event we want the state at.
include_event_in_state: if true, the event itself will be included in the
returned state event list.
Returns:
- Deferred[Tuple[List[EventBase], List[EventBase]]]:
- A list of events in the state, and a list of events in the auth chain
- for the given event.
+ A list of events in the state, possibly including the event itself, and
+ a list of events in the auth chain for the given event.
"""
(
state_event_ids,
auth_event_ids,
- ) = yield self.federation_client.get_room_state_ids(
+ ) = await self.federation_client.get_room_state_ids(
destination, room_id, event_id=event_id
)
@@ -586,15 +585,15 @@ class FederationHandler(BaseHandler):
if include_event_in_state:
desired_events.add(event_id)
- event_map = yield self._get_events_from_store_or_dest(
+ event_map = await self._get_events_from_store_or_dest(
destination, room_id, desired_events
)
failed_to_fetch = desired_events - event_map.keys()
if failed_to_fetch:
logger.warning(
- "Failed to fetch missing state/auth events for %s: %s",
- room_id,
+ "Failed to fetch missing state/auth events for %s %s",
+ event_id,
failed_to_fetch,
)
@@ -614,15 +613,11 @@ class FederationHandler(BaseHandler):
return remote_state, auth_chain
- @defer.inlineCallbacks
- def _get_events_from_store_or_dest(self, destination, room_id, event_ids):
+ async def _get_events_from_store_or_dest(
+ self, destination: str, room_id: str, event_ids: Iterable[str]
+ ) -> Dict[str, EventBase]:
"""Fetch events from a remote destination, checking if we already have them.
- Args:
- destination (str)
- room_id (str)
- event_ids (Iterable[str])
-
Persists any events we don't already have as outliers.
If we fail to fetch any of the events, a warning will be logged, and the event
@@ -630,10 +625,9 @@ class FederationHandler(BaseHandler):
be in the given room.
Returns:
- Deferred[dict[str, EventBase]]: A deferred resolving to a map
- from event_id to event
+ map from event_id to event
"""
- fetched_events = yield self.store.get_events(event_ids, allow_rejected=True)
+ fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
missing_events = set(event_ids) - fetched_events.keys()
@@ -644,14 +638,14 @@ class FederationHandler(BaseHandler):
room_id,
)
- yield self._get_events_and_persist(
+ await self._get_events_and_persist(
destination=destination, room_id=room_id, events=missing_events
)
# we need to make sure we re-load from the database to get the rejected
# state correct.
fetched_events.update(
- (yield self.store.get_events(missing_events, allow_rejected=True))
+ (await self.store.get_events(missing_events, allow_rejected=True))
)
# check for events which were in the wrong room.
@@ -677,12 +671,14 @@ class FederationHandler(BaseHandler):
bad_room_id,
room_id,
)
+
del fetched_events[bad_event_id]
return fetched_events
- @defer.inlineCallbacks
- def _process_received_pdu(self, origin, event, state):
+ async def _process_received_pdu(
+ self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]],
+ ):
""" Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
@@ -701,15 +697,15 @@ class FederationHandler(BaseHandler):
logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
try:
- context = yield self._handle_new_event(origin, event, state=state)
+ context = await self._handle_new_event(origin, event, state=state)
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
- room = yield self.store.get_room(room_id)
+ room = await self.store.get_room(room_id)
if not room:
try:
- yield self.store.store_room(
+ await self.store.store_room(
room_id=room_id, room_creator_user_id="", is_public=False
)
except StoreError:
@@ -722,11 +718,11 @@ class FederationHandler(BaseHandler):
# changing their profile info.
newly_joined = True
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = await context.get_prev_state_ids(self.store)
prev_state_id = prev_state_ids.get((event.type, event.state_key))
if prev_state_id:
- prev_state = yield self.store.get_event(
+ prev_state = await self.store.get_event(
prev_state_id, allow_none=True
)
if prev_state and prev_state.membership == Membership.JOIN:
@@ -734,11 +730,10 @@ class FederationHandler(BaseHandler):
if newly_joined:
user = UserID.from_string(event.state_key)
- yield self.user_joined_room(user, room_id)
+ await self.user_joined_room(user, room_id)
@log_function
- @defer.inlineCallbacks
- def backfill(self, dest, room_id, limit, extremities):
+ async def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side
@@ -755,7 +750,7 @@ class FederationHandler(BaseHandler):
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
- events = yield self.federation_client.backfill(
+ events = await self.federation_client.backfill(
dest, room_id, limit=limit, extremities=extremities
)
@@ -770,7 +765,7 @@ class FederationHandler(BaseHandler):
# self._sanity_check_event(ev)
# Don't bother processing events we already have.
- seen_events = yield self.store.have_events_in_timeline(
+ seen_events = await self.store.have_events_in_timeline(
set(e.event_id for e in events)
)
@@ -796,7 +791,7 @@ class FederationHandler(BaseHandler):
state_events = {}
events_to_state = {}
for e_id in edges:
- state, auth = yield self._get_state_for_room(
+ state, auth = await self._get_state_for_room(
destination=dest, room_id=room_id, event_id=e_id
)
auth_events.update({a.event_id: a for a in auth})
@@ -840,7 +835,7 @@ class FederationHandler(BaseHandler):
)
)
- yield self._handle_new_events(dest, ev_infos, backfilled=True)
+ await self._handle_new_events(dest, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
@@ -856,16 +851,15 @@ class FederationHandler(BaseHandler):
# We store these one at a time since each event depends on the
# previous to work out the state.
# TODO: We can probably do something more clever here.
- yield self._handle_new_event(dest, event, backfilled=True)
+ await self._handle_new_event(dest, event, backfilled=True)
return events
- @defer.inlineCallbacks
- def maybe_backfill(self, room_id, current_depth):
+ async def maybe_backfill(self, room_id, current_depth):
"""Checks the database to see if we should backfill before paginating,
and if so do.
"""
- extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id)
+ extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
if not extremities:
logger.debug("Not backfilling as no extremeties found.")
@@ -897,15 +891,17 @@ class FederationHandler(BaseHandler):
# state *before* the event, ignoring the special casing certain event
# types have.
- forward_events = yield self.store.get_successor_events(list(extremities))
+ forward_events = await self.store.get_successor_events(list(extremities))
- extremities_events = yield self.store.get_events(
- forward_events, check_redacted=False, get_prev_content=False
+ extremities_events = await self.store.get_events(
+ forward_events,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ get_prev_content=False,
)
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
- filtered_extremities = yield filter_events_for_server(
+ filtered_extremities = await filter_events_for_server(
self.storage,
self.server_name,
list(extremities_events.values()),
@@ -935,7 +931,7 @@ class FederationHandler(BaseHandler):
# First we try hosts that are already in the room
# TODO: HEURISTIC ALERT.
- curr_state = yield self.state_handler.get_current_state(room_id)
+ curr_state = await self.state_handler.get_current_state(room_id)
def get_domains_from_state(state):
"""Get joined domains from state
@@ -974,12 +970,11 @@ class FederationHandler(BaseHandler):
domain for domain, depth in curr_domains if domain != self.server_name
]
- @defer.inlineCallbacks
- def try_backfill(domains):
+ async def try_backfill(domains):
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
- yield self.backfill(
+ await self.backfill(
dom, room_id, limit=100, extremities=extremities
)
# If this succeeded then we probably already have the
@@ -1010,7 +1005,7 @@ class FederationHandler(BaseHandler):
return False
- success = yield try_backfill(likely_domains)
+ success = await try_backfill(likely_domains)
if success:
return True
@@ -1024,7 +1019,7 @@ class FederationHandler(BaseHandler):
logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
- states = yield make_deferred_yieldable(
+ states = await make_deferred_yieldable(
defer.gatherResults(
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True
)
@@ -1034,7 +1029,7 @@ class FederationHandler(BaseHandler):
# event_ids.
states = dict(zip(event_ids, [s.state for s in states]))
- state_map = yield self.store.get_events(
+ state_map = await self.store.get_events(
[e_id for ids in itervalues(states) for e_id in itervalues(ids)],
get_prev_content=False,
)
@@ -1050,7 +1045,7 @@ class FederationHandler(BaseHandler):
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
- success = yield try_backfill(
+ success = await try_backfill(
[dom for dom, _ in likely_domains if dom not in tried_domains]
)
if success:
@@ -1060,8 +1055,7 @@ class FederationHandler(BaseHandler):
return False
- @defer.inlineCallbacks
- def _get_events_and_persist(
+ async def _get_events_and_persist(
self, destination: str, room_id: str, events: Iterable[str]
):
"""Fetch the given events from a server, and persist them as outliers.
@@ -1069,7 +1063,7 @@ class FederationHandler(BaseHandler):
Logs a warning if we can't find the given event.
"""
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
event_infos = []
@@ -1105,9 +1099,9 @@ class FederationHandler(BaseHandler):
e,
)
- yield concurrently_execute(get_event, events, 5)
+ await concurrently_execute(get_event, events, 5)
- yield self._handle_new_events(
+ await self._handle_new_events(
destination, event_infos,
)
@@ -1250,7 +1244,7 @@ class FederationHandler(BaseHandler):
# Check whether this room is the result of an upgrade of a room we already know
# about. If so, migrate over user information
predecessor = yield self.store.get_room_predecessor(room_id)
- if not predecessor:
+ if not predecessor or not isinstance(predecessor.get("room_id"), str):
return
old_room_id = predecessor["room_id"]
logger.debug(
@@ -1278,8 +1272,7 @@ class FederationHandler(BaseHandler):
return True
- @defer.inlineCallbacks
- def _handle_queued_pdus(self, room_queue):
+ async def _handle_queued_pdus(self, room_queue):
"""Process PDUs which got queued up while we were busy send_joining.
Args:
@@ -1295,7 +1288,7 @@ class FederationHandler(BaseHandler):
p.room_id,
)
with nested_logging_context(p.event_id):
- yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
+ await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
except Exception as e:
logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
@@ -1493,7 +1486,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content):
origin, event, event_format_version = yield self._make_and_verify_event(
- target_hosts, room_id, user_id, "leave", content=content,
+ target_hosts, room_id, user_id, "leave", content=content
)
# Mark as outlier as we don't have any state for this event; we're not
# even in the room.
@@ -2854,7 +2847,7 @@ class FederationHandler(BaseHandler):
room_id=room_id, user_id=user.to_string(), change="joined"
)
else:
- return user_joined_room(self.distributor, user, room_id)
+ return defer.succeed(user_joined_room(self.distributor, user, room_id))
@defer.inlineCallbacks
def get_room_complexity(self, remote_room_hosts, room_id):
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 81dce96f4b..44ec3e66ae 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -26,7 +26,7 @@ from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken, UserID
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import concurrently_execute
-from synapse.util.caches.snapshot_cache import SnapshotCache
+from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -41,7 +41,7 @@ class InitialSyncHandler(BaseHandler):
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
- self.snapshot_cache = SnapshotCache()
+ self.snapshot_cache = ResponseCache(hs, "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
@@ -79,21 +79,17 @@ class InitialSyncHandler(BaseHandler):
as_client_event,
include_archived,
)
- now_ms = self.clock.time_msec()
- result = self.snapshot_cache.get(now_ms, key)
- if result is not None:
- return result
- return self.snapshot_cache.set(
- now_ms,
+ return self.snapshot_cache.wrap(
key,
- self._snapshot_all_rooms(
- user_id, pagin_config, as_client_event, include_archived
- ),
+ self._snapshot_all_rooms,
+ user_id,
+ pagin_config,
+ as_client_event,
+ include_archived,
)
- @defer.inlineCallbacks
- def _snapshot_all_rooms(
+ async def _snapshot_all_rooms(
self,
user_id=None,
pagin_config=None,
@@ -105,7 +101,7 @@ class InitialSyncHandler(BaseHandler):
if include_archived:
memberships.append(Membership.LEAVE)
- room_list = yield self.store.get_rooms_for_user_where_membership_is(
+ room_list = await self.store.get_rooms_for_user_where_membership_is(
user_id=user_id, membership_list=memberships
)
@@ -113,33 +109,32 @@ class InitialSyncHandler(BaseHandler):
rooms_ret = []
- now_token = yield self.hs.get_event_sources().get_current_token()
+ now_token = await self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"]
pagination_config = PaginationConfig(from_token=now_token)
- presence, _ = yield presence_stream.get_pagination_rows(
+ presence, _ = await presence_stream.get_pagination_rows(
user, pagination_config.get_source_config("presence"), None
)
receipt_stream = self.hs.get_event_sources().sources["receipt"]
- receipt, _ = yield receipt_stream.get_pagination_rows(
+ receipt, _ = await receipt_stream.get_pagination_rows(
user, pagination_config.get_source_config("receipt"), None
)
- tags_by_room = yield self.store.get_tags_for_user(user_id)
+ tags_by_room = await self.store.get_tags_for_user(user_id)
- account_data, account_data_by_room = yield self.store.get_account_data_for_user(
+ account_data, account_data_by_room = await self.store.get_account_data_for_user(
user_id
)
- public_room_ids = yield self.store.get_public_room_ids()
+ public_room_ids = await self.store.get_public_room_ids()
limit = pagin_config.limit
if limit is None:
limit = 10
- @defer.inlineCallbacks
- def handle_room(event):
+ async def handle_room(event):
d = {
"room_id": event.room_id,
"membership": event.membership,
@@ -152,8 +147,8 @@ class InitialSyncHandler(BaseHandler):
time_now = self.clock.time_msec()
d["inviter"] = event.sender
- invite_event = yield self.store.get_event(event.event_id)
- d["invite"] = yield self._event_serializer.serialize_event(
+ invite_event = await self.store.get_event(event.event_id)
+ d["invite"] = await self._event_serializer.serialize_event(
invite_event, time_now, as_client_event
)
@@ -177,7 +172,7 @@ class InitialSyncHandler(BaseHandler):
lambda states: states[event.event_id]
)
- (messages, token), current_state = yield make_deferred_yieldable(
+ (messages, token), current_state = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
@@ -191,7 +186,7 @@ class InitialSyncHandler(BaseHandler):
)
).addErrback(unwrapFirstError)
- messages = yield filter_events_for_client(
+ messages = await filter_events_for_client(
self.storage, user_id, messages
)
@@ -201,7 +196,7 @@ class InitialSyncHandler(BaseHandler):
d["messages"] = {
"chunk": (
- yield self._event_serializer.serialize_events(
+ await self._event_serializer.serialize_events(
messages, time_now=time_now, as_client_event=as_client_event
)
),
@@ -209,7 +204,7 @@ class InitialSyncHandler(BaseHandler):
"end": end_token.to_string(),
}
- d["state"] = yield self._event_serializer.serialize_events(
+ d["state"] = await self._event_serializer.serialize_events(
current_state.values(),
time_now=time_now,
as_client_event=as_client_event,
@@ -232,7 +227,7 @@ class InitialSyncHandler(BaseHandler):
except Exception:
logger.exception("Failed to get snapshot")
- yield concurrently_execute(handle_room, room_list, 10)
+ await concurrently_execute(handle_room, room_list, 10)
account_data_events = []
for account_data_type, content in account_data.items():
@@ -256,8 +251,7 @@ class InitialSyncHandler(BaseHandler):
return ret
- @defer.inlineCallbacks
- def room_initial_sync(self, requester, room_id, pagin_config=None):
+ async def room_initial_sync(self, requester, room_id, pagin_config=None):
"""Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left.
@@ -274,32 +268,32 @@ class InitialSyncHandler(BaseHandler):
A JSON serialisable dict with the snapshot of the room.
"""
- blocked = yield self.store.is_room_blocked(room_id)
+ blocked = await self.store.is_room_blocked(room_id)
if blocked:
raise SynapseError(403, "This room has been blocked on this server")
user_id = requester.user.to_string()
- membership, member_event_id = yield self._check_in_room_or_world_readable(
+ membership, member_event_id = await self._check_in_room_or_world_readable(
room_id, user_id
)
is_peeking = member_event_id is None
if membership == Membership.JOIN:
- result = yield self._room_initial_sync_joined(
+ result = await self._room_initial_sync_joined(
user_id, room_id, pagin_config, membership, is_peeking
)
elif membership == Membership.LEAVE:
- result = yield self._room_initial_sync_parted(
+ result = await self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
)
account_data_events = []
- tags = yield self.store.get_tags_for_room(user_id, room_id)
+ tags = await self.store.get_tags_for_room(user_id, room_id)
if tags:
account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
- account_data = yield self.store.get_account_data_for_room(user_id, room_id)
+ account_data = await self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items():
account_data_events.append({"type": account_data_type, "content": content})
@@ -307,11 +301,10 @@ class InitialSyncHandler(BaseHandler):
return result
- @defer.inlineCallbacks
- def _room_initial_sync_parted(
+ async def _room_initial_sync_parted(
self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
):
- room_state = yield self.state_store.get_state_for_events([member_event_id])
+ room_state = await self.state_store.get_state_for_events([member_event_id])
room_state = room_state[member_event_id]
@@ -319,13 +312,13 @@ class InitialSyncHandler(BaseHandler):
if limit is None:
limit = 10
- stream_token = yield self.store.get_stream_token_for_event(member_event_id)
+ stream_token = await self.store.get_stream_token_for_event(member_event_id)
- messages, token = yield self.store.get_recent_events_for_room(
+ messages, token = await self.store.get_recent_events_for_room(
room_id, limit=limit, end_token=stream_token
)
- messages = yield filter_events_for_client(
+ messages = await filter_events_for_client(
self.storage, user_id, messages, is_peeking=is_peeking
)
@@ -339,13 +332,13 @@ class InitialSyncHandler(BaseHandler):
"room_id": room_id,
"messages": {
"chunk": (
- yield self._event_serializer.serialize_events(messages, time_now)
+ await self._event_serializer.serialize_events(messages, time_now)
),
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": (
- yield self._event_serializer.serialize_events(
+ await self._event_serializer.serialize_events(
room_state.values(), time_now
)
),
@@ -353,19 +346,18 @@ class InitialSyncHandler(BaseHandler):
"receipts": [],
}
- @defer.inlineCallbacks
- def _room_initial_sync_joined(
+ async def _room_initial_sync_joined(
self, user_id, room_id, pagin_config, membership, is_peeking
):
- current_state = yield self.state.get_current_state(room_id=room_id)
+ current_state = await self.state.get_current_state(room_id=room_id)
# TODO: These concurrently
time_now = self.clock.time_msec()
- state = yield self._event_serializer.serialize_events(
+ state = await self._event_serializer.serialize_events(
current_state.values(), time_now
)
- now_token = yield self.hs.get_event_sources().get_current_token()
+ now_token = await self.hs.get_event_sources().get_current_token()
limit = pagin_config.limit if pagin_config else None
if limit is None:
@@ -380,28 +372,26 @@ class InitialSyncHandler(BaseHandler):
presence_handler = self.hs.get_presence_handler()
- @defer.inlineCallbacks
- def get_presence():
+ async def get_presence():
# If presence is disabled, return an empty list
if not self.hs.config.use_presence:
return []
- states = yield presence_handler.get_states(
+ states = await presence_handler.get_states(
[m.user_id for m in room_members], as_event=True
)
return states
- @defer.inlineCallbacks
- def get_receipts():
- receipts = yield self.store.get_linearized_receipts_for_room(
+ async def get_receipts():
+ receipts = await self.store.get_linearized_receipts_for_room(
room_id, to_key=now_token.receipt_key
)
if not receipts:
receipts = []
return receipts
- presence, receipts, (messages, token) = yield make_deferred_yieldable(
+ presence, receipts, (messages, token) = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(get_presence),
@@ -417,7 +407,7 @@ class InitialSyncHandler(BaseHandler):
).addErrback(unwrapFirstError)
)
- messages = yield filter_events_for_client(
+ messages = await filter_events_for_client(
self.storage, user_id, messages, is_peeking=is_peeking
)
@@ -430,7 +420,7 @@ class InitialSyncHandler(BaseHandler):
"room_id": room_id,
"messages": {
"chunk": (
- yield self._event_serializer.serialize_events(messages, time_now)
+ await self._event_serializer.serialize_events(messages, time_now)
),
"start": start_token.to_string(),
"end": end_token.to_string(),
@@ -444,18 +434,17 @@ class InitialSyncHandler(BaseHandler):
return ret
- @defer.inlineCallbacks
- def _check_in_room_or_world_readable(self, room_id, user_id):
+ async def _check_in_room_or_world_readable(self, room_id, user_id):
try:
# check_user_was_in_room will return the most recent membership
# event for the user if:
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
- member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
+ member_event = await self.auth.check_user_was_in_room(room_id, user_id)
return member_event.membership, member_event.event_id
except AuthError:
- visibility = yield self.state_handler.get_current_state(
+ visibility = await self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 54fa216d83..bf9add7fe2 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -46,6 +46,7 @@ from synapse.events.validator import EventValidator
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester
from synapse.util.async_helpers import Linearizer
@@ -875,7 +876,7 @@ class EventCreationHandler(object):
if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event(
event.redacts,
- check_redacted=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
allow_rejected=False,
allow_none=True,
@@ -952,7 +953,7 @@ class EventCreationHandler(object):
if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event(
event.redacts,
- check_redacted=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
allow_rejected=False,
allow_none=True,
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 8514ddc600..00a6afc963 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -280,8 +280,7 @@ class PaginationHandler(object):
await self.storage.purge_events.purge_room(room_id)
- @defer.inlineCallbacks
- def get_messages(
+ async def get_messages(
self,
requester,
room_id=None,
@@ -307,7 +306,7 @@ class PaginationHandler(object):
room_token = pagin_config.from_token.room_key
else:
pagin_config.from_token = (
- yield self.hs.get_event_sources().get_current_token_for_pagination()
+ await self.hs.get_event_sources().get_current_token_for_pagination()
)
room_token = pagin_config.from_token.room_key
@@ -319,11 +318,11 @@ class PaginationHandler(object):
source_config = pagin_config.get_source_config("room")
- with (yield self.pagination_lock.read(room_id)):
+ with (await self.pagination_lock.read(room_id)):
(
membership,
member_event_id,
- ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
+ ) = await self.auth.check_in_room_or_world_readable(room_id, user_id)
if source_config.direction == "b":
# if we're going backwards, we might need to backfill. This
@@ -331,7 +330,7 @@ class PaginationHandler(object):
if room_token.topological:
max_topo = room_token.topological
else:
- max_topo = yield self.store.get_max_topological_token(
+ max_topo = await self.store.get_max_topological_token(
room_id, room_token.stream
)
@@ -339,18 +338,18 @@ class PaginationHandler(object):
# If they have left the room then clamp the token to be before
# they left the room, to save the effort of loading from the
# database.
- leave_token = yield self.store.get_topological_token_for_event(
+ leave_token = await self.store.get_topological_token_for_event(
member_event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < max_topo:
source_config.from_key = str(leave_token)
- yield self.hs.get_handlers().federation_handler.maybe_backfill(
+ await self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, max_topo
)
- events, next_key = yield self.store.paginate_room_events(
+ events, next_key = await self.store.paginate_room_events(
room_id=room_id,
from_key=source_config.from_key,
to_key=source_config.to_key,
@@ -365,7 +364,7 @@ class PaginationHandler(object):
if event_filter:
events = event_filter.filter(events)
- events = yield filter_events_for_client(
+ events = await filter_events_for_client(
self.storage, user_id, events, is_peeking=(member_event_id is None)
)
@@ -385,19 +384,19 @@ class PaginationHandler(object):
(EventTypes.Member, event.sender) for event in events
)
- state_ids = yield self.state_store.get_state_ids_for_event(
+ state_ids = await self.state_store.get_state_ids_for_event(
events[0].event_id, state_filter=state_filter
)
if state_ids:
- state = yield self.store.get_events(list(state_ids.values()))
+ state = await self.store.get_events(list(state_ids.values()))
state = state.values()
time_now = self.clock.time_msec()
chunk = {
"chunk": (
- yield self._event_serializer.serialize_events(
+ await self._event_serializer.serialize_events(
events, time_now, as_client_event=as_client_event
)
),
@@ -406,7 +405,7 @@ class PaginationHandler(object):
}
if state:
- chunk["state"] = yield self._event_serializer.serialize_events(
+ chunk["state"] = await self._event_serializer.serialize_events(
state, time_now, as_client_event=as_client_event
)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 1e5a4613c9..f9579d69ee 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -295,12 +295,16 @@ class BaseProfileHandler(BaseHandler):
be found to be in any room the server is in, and therefore the query
is denied.
"""
+
# Implementation of MSC1301: don't allow looking up profiles if the
# requester isn't in the same room as the target. We expect requester to
# be None when this function is called outside of a profile query, e.g.
# when building a membership event. In this case, we must allow the
# lookup.
- if not self.hs.config.require_auth_for_profile_requests or not requester:
+ if (
+ not self.hs.config.limit_profile_requests_to_users_who_share_rooms
+ or not requester
+ ):
return
# Always allow the user to query their own profile.
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 60b8bbc7a5..d3a1a7b4a6 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1011,15 +1011,3 @@ class RoomEventSource(object):
def get_current_key_for_room(self, room_id):
return self.store.get_room_events_max_id(room_id)
-
- @defer.inlineCallbacks
- def get_pagination_rows(self, user, config, key):
- events, next_key = yield self.store.paginate_room_events(
- room_id=key,
- from_key=config.from_key,
- to_key=config.to_key,
- direction=config.direction,
- limit=config.limit,
- )
-
- return (events, next_key)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index cc9e6b9bd0..0082f85c26 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -13,20 +13,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import re
+from typing import Tuple
import attr
import saml2
+import saml2.response
from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
+from synapse.config import ConfigError
from synapse.http.servlet import parse_string
from synapse.rest.client.v1.login import SSOAuthHandler
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import (
+ UserID,
+ map_username_to_mxid_localpart,
+ mxid_localpart_allowed_characters,
+)
from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
+@attr.s
+class Saml2SessionData:
+ """Data we track about SAML2 sessions"""
+
+ # time the session was created, in milliseconds
+ creation_time = attr.ib()
+
+
class SamlHandler:
def __init__(self, hs):
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
@@ -37,11 +53,14 @@ class SamlHandler:
self._datastore = hs.get_datastore()
self._hostname = hs.hostname
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
- self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute
)
- self._mxid_mapper = hs.config.saml2_mxid_mapper
+
+ # plugin to do custom mapping from saml response to mxid
+ self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
+ hs.config.saml2_user_mapping_provider_config
+ )
# identifier for the external_ids table
self._auth_provider_id = "saml"
@@ -118,22 +137,10 @@ class SamlHandler:
remote_user_id = saml2_auth.ava["uid"][0]
except KeyError:
logger.warning("SAML2 response lacks a 'uid' attestation")
- raise SynapseError(400, "uid not in SAML2 response")
-
- try:
- mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
- except KeyError:
- logger.warning(
- "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
- )
- raise SynapseError(
- 400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
- )
+ raise SynapseError(400, "'uid' not in SAML2 response")
self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
- displayName = saml2_auth.ava.get("displayName", [None])[0]
-
with (await self._mapping_lock.queue(self._auth_provider_id)):
# first of all, check if we already have a mapping for this user
logger.info(
@@ -173,22 +180,46 @@ class SamlHandler:
)
return registered_user_id
- # figure out a new mxid for this user
- base_mxid_localpart = self._mxid_mapper(mxid_source)
+ # Map saml response to user attributes using the configured mapping provider
+ for i in range(1000):
+ attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
+ saml2_auth, i
+ )
+
+ logger.debug(
+ "Retrieved SAML attributes from user mapping provider: %s "
+ "(attempt %d)",
+ attribute_dict,
+ i,
+ )
+
+ localpart = attribute_dict.get("mxid_localpart")
+ if not localpart:
+ logger.error(
+ "SAML mapping provider plugin did not return a "
+ "mxid_localpart object"
+ )
+ raise SynapseError(500, "Error parsing SAML2 response")
- suffix = 0
- while True:
- localpart = base_mxid_localpart + (str(suffix) if suffix else "")
+ displayname = attribute_dict.get("displayname")
+
+ # Check if this mxid already exists
if not await self._datastore.get_users_by_id_case_insensitive(
UserID(localpart, self._hostname).to_string()
):
+ # This mxid is free
break
- suffix += 1
- logger.info("Allocating mxid for new user with localpart %s", localpart)
+ else:
+ # Unable to generate a username in 1000 iterations
+ # Break and return error to the user
+ raise SynapseError(
+ 500, "Unable to generate a Matrix ID from the SAML response"
+ )
registered_user_id = await self._registration_handler.register_user(
- localpart=localpart, default_display_name=displayName
+ localpart=localpart, default_display_name=displayname
)
+
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
@@ -205,9 +236,120 @@ class SamlHandler:
del self._outstanding_requests_dict[reqid]
+DOT_REPLACE_PATTERN = re.compile(
+ ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
+)
+
+
+def dot_replace_for_mxid(username: str) -> str:
+ username = username.lower()
+ username = DOT_REPLACE_PATTERN.sub(".", username)
+
+ # regular mxids aren't allowed to start with an underscore either
+ username = re.sub("^_", "", username)
+ return username
+
+
+MXID_MAPPER_MAP = {
+ "hexencode": map_username_to_mxid_localpart,
+ "dotreplace": dot_replace_for_mxid,
+}
+
+
@attr.s
-class Saml2SessionData:
- """Data we track about SAML2 sessions"""
+class SamlConfig(object):
+ mxid_source_attribute = attr.ib()
+ mxid_mapper = attr.ib()
- # time the session was created, in milliseconds
- creation_time = attr.ib()
+
+class DefaultSamlMappingProvider(object):
+ __version__ = "0.0.1"
+
+ def __init__(self, parsed_config: SamlConfig):
+ """The default SAML user mapping provider
+
+ Args:
+ parsed_config: Module configuration
+ """
+ self._mxid_source_attribute = parsed_config.mxid_source_attribute
+ self._mxid_mapper = parsed_config.mxid_mapper
+
+ def saml_response_to_user_attributes(
+ self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
+ ) -> dict:
+ """Maps some text from a SAML response to attributes of a new user
+
+ Args:
+ saml_response: A SAML auth response object
+
+ failures: How many times a call to this function with this
+ saml_response has resulted in a failure
+
+ Returns:
+ dict: A dict containing new user attributes. Possible keys:
+ * mxid_localpart (str): Required. The localpart of the user's mxid
+ * displayname (str): The displayname of the user
+ """
+ try:
+ mxid_source = saml_response.ava[self._mxid_source_attribute][0]
+ except KeyError:
+ logger.warning(
+ "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
+ )
+ raise SynapseError(
+ 400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+ )
+
+ # Use the configured mapper for this mxid_source
+ base_mxid_localpart = self._mxid_mapper(mxid_source)
+
+ # Append suffix integer if last call to this function failed to produce
+ # a usable mxid
+ localpart = base_mxid_localpart + (str(failures) if failures else "")
+
+ # Retrieve the display name from the saml response
+ # If displayname is None, the mxid_localpart will be used instead
+ displayname = saml_response.ava.get("displayName", [None])[0]
+
+ return {
+ "mxid_localpart": localpart,
+ "displayname": displayname,
+ }
+
+ @staticmethod
+ def parse_config(config: dict) -> SamlConfig:
+ """Parse the dict provided by the homeserver's config
+ Args:
+ config: A dictionary containing configuration options for this provider
+ Returns:
+ SamlConfig: A custom config object for this module
+ """
+ # Parse config options and use defaults where necessary
+ mxid_source_attribute = config.get("mxid_source_attribute", "uid")
+ mapping_type = config.get("mxid_mapping", "hexencode")
+
+ # Retrieve the associating mapping function
+ try:
+ mxid_mapper = MXID_MAPPER_MAP[mapping_type]
+ except KeyError:
+ raise ConfigError(
+ "saml2_config.user_mapping_provider.config: '%s' is not a valid "
+ "mxid_mapping value" % (mapping_type,)
+ )
+
+ return SamlConfig(mxid_source_attribute, mxid_mapper)
+
+ @staticmethod
+ def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
+ """Returns the required attributes of a SAML
+
+ Args:
+ config: A SamlConfig object containing configuration params for this provider
+
+ Returns:
+ tuple[set,set]: The first set equates to the saml auth response
+ attributes that are required for the module to function, whereas the
+ second set consists of those attributes which can be used if
+ available, but are not necessary
+ """
+ return {"uid", config.mxid_source_attribute}, {"displayName"}
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 56ed262a1f..ef750d1497 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -21,7 +21,7 @@ from unpaddedbase64 import decode_base64, encode_base64
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import SynapseError
+from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.storage.state import StateFilter
from synapse.visibility import filter_events_for_client
@@ -37,6 +37,7 @@ class SearchHandler(BaseHandler):
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def get_old_rooms_from_upgraded_room(self, room_id):
@@ -53,23 +54,38 @@ class SearchHandler(BaseHandler):
room_id (str): id of the room to search through.
Returns:
- Deferred[iterable[unicode]]: predecessor room ids
+ Deferred[iterable[str]]: predecessor room ids
"""
historical_room_ids = []
- while True:
- predecessor = yield self.store.get_room_predecessor(room_id)
+ # The initial room must have been known for us to get this far
+ predecessor = yield self.store.get_room_predecessor(room_id)
- # If no predecessor, assume we've hit a dead end
+ while True:
if not predecessor:
+ # We have reached the end of the chain of predecessors
+ break
+
+ if not isinstance(predecessor.get("room_id"), str):
+ # This predecessor object is malformed. Exit here
+ break
+
+ predecessor_room_id = predecessor["room_id"]
+
+ # Don't add it to the list until we have checked that we are in the room
+ try:
+ next_predecessor_room = yield self.store.get_room_predecessor(
+ predecessor_room_id
+ )
+ except NotFoundError:
+ # The predecessor is not a known room, so we are done here
break
- # Add predecessor's room ID
- historical_room_ids.append(predecessor["room_id"])
+ historical_room_ids.append(predecessor_room_id)
- # Scan through the old room for further predecessors
- room_id = predecessor["room_id"]
+ # And repeat
+ predecessor = next_predecessor_room
return historical_room_ids
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 6f78454322..b635c339ed 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -317,6 +317,3 @@ class TypingNotificationEventSource(object):
def get_current_key(self):
return self.get_typing_handler()._latest_room_serial
-
- def get_pagination_rows(self, user, pagination_config, key):
- return [], pagination_config.from_key
|