diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 2d7e6df6e4..20ec1ca01b 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
class AccountDataEventSource(object):
def __init__(self, hs):
@@ -23,15 +21,14 @@ class AccountDataEventSource(object):
def get_current_key(self, direction="f"):
return self.store.get_max_account_data_stream_id()
- @defer.inlineCallbacks
- def get_new_events(self, user, from_key, **kwargs):
+ async def get_new_events(self, user, from_key, **kwargs):
user_id = user.to_string()
last_stream_id = from_key
- current_stream_id = yield self.store.get_max_account_data_stream_id()
+ current_stream_id = self.store.get_max_account_data_stream_id()
results = []
- tags = yield self.store.get_updated_tags(user_id, last_stream_id)
+ tags = await self.store.get_updated_tags(user_id, last_stream_id)
for room_id, room_tags in tags.items():
results.append(
@@ -41,7 +38,7 @@ class AccountDataEventSource(object):
(
account_data,
room_account_data,
- ) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
+ ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
for account_data_type, content in account_data.items():
results.append({"type": account_data_type, "content": content})
@@ -54,6 +51,5 @@ class AccountDataEventSource(object):
return results, current_stream_id
- @defer.inlineCallbacks
- def get_pagination_rows(self, user, config, key):
+ async def get_pagination_rows(self, user, config, key):
return [], config.to_id
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index d04e0fe576..829f52eca1 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -18,8 +18,7 @@ import email.utils
import logging
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
-
-from twisted.internet import defer
+from typing import List
from synapse.api.errors import StoreError
from synapse.logging.context import make_deferred_yieldable
@@ -78,42 +77,39 @@ class AccountValidityHandler(object):
# run as a background process to make sure that the database transactions
# have a logcontext to report to
return run_as_background_process(
- "send_renewals", self.send_renewal_emails
+ "send_renewals", self._send_renewal_emails
)
self.clock.looping_call(send_emails, 30 * 60 * 1000)
- @defer.inlineCallbacks
- def send_renewal_emails(self):
+ async def _send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time
configured in the ``renew_at`` parameter from the ``account_validity``
configuration, and sends renewal emails to all of these users as long as they
have an email 3PID attached to their account.
"""
- expiring_users = yield self.store.get_users_expiring_soon()
+ expiring_users = await self.store.get_users_expiring_soon()
if expiring_users:
for user in expiring_users:
- yield self._send_renewal_email(
+ await self._send_renewal_email(
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
)
- @defer.inlineCallbacks
- def send_renewal_email_to_user(self, user_id):
- expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
- yield self._send_renewal_email(user_id, expiration_ts)
+ async def send_renewal_email_to_user(self, user_id: str):
+ expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
+ await self._send_renewal_email(user_id, expiration_ts)
- @defer.inlineCallbacks
- def _send_renewal_email(self, user_id, expiration_ts):
+ async def _send_renewal_email(self, user_id: str, expiration_ts: int):
"""Sends out a renewal email to every email address attached to the given user
with a unique link allowing them to renew their account.
Args:
- user_id (str): ID of the user to send email(s) to.
- expiration_ts (int): Timestamp in milliseconds for the expiration date of
+ user_id: ID of the user to send email(s) to.
+ expiration_ts: Timestamp in milliseconds for the expiration date of
this user's account (used in the email templates).
"""
- addresses = yield self._get_email_addresses_for_user(user_id)
+ addresses = await self._get_email_addresses_for_user(user_id)
# Stop right here if the user doesn't have at least one email address.
# In this case, they will have to ask their server admin to renew their
@@ -125,7 +121,7 @@ class AccountValidityHandler(object):
return
try:
- user_display_name = yield self.store.get_profile_displayname(
+ user_display_name = await self.store.get_profile_displayname(
UserID.from_string(user_id).localpart
)
if user_display_name is None:
@@ -133,7 +129,7 @@ class AccountValidityHandler(object):
except StoreError:
user_display_name = user_id
- renewal_token = yield self._get_renewal_token(user_id)
+ renewal_token = await self._get_renewal_token(user_id)
url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
self.hs.config.public_baseurl,
renewal_token,
@@ -165,7 +161,7 @@ class AccountValidityHandler(object):
logger.info("Sending renewal email to %s", address)
- yield make_deferred_yieldable(
+ await make_deferred_yieldable(
self.sendmail(
self.hs.config.email_smtp_host,
self._raw_from,
@@ -180,19 +176,18 @@ class AccountValidityHandler(object):
)
)
- yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
+ await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
- @defer.inlineCallbacks
- def _get_email_addresses_for_user(self, user_id):
+ async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:
"""Retrieve the list of email addresses attached to a user's account.
Args:
- user_id (str): ID of the user to lookup email addresses for.
+ user_id: ID of the user to lookup email addresses for.
Returns:
- defer.Deferred[list[str]]: Email addresses for this account.
+ Email addresses for this account.
"""
- threepids = yield self.store.user_get_threepids(user_id)
+ threepids = await self.store.user_get_threepids(user_id)
addresses = []
for threepid in threepids:
@@ -201,16 +196,15 @@ class AccountValidityHandler(object):
return addresses
- @defer.inlineCallbacks
- def _get_renewal_token(self, user_id):
+ async def _get_renewal_token(self, user_id: str) -> str:
"""Generates a 32-byte long random string that will be inserted into the
user's renewal email's unique link, then saves it into the database.
Args:
- user_id (str): ID of the user to generate a string for.
+ user_id: ID of the user to generate a string for.
Returns:
- defer.Deferred[str]: The generated string.
+ The generated string.
Raises:
StoreError(500): Couldn't generate a unique string after 5 attempts.
@@ -219,52 +213,52 @@ class AccountValidityHandler(object):
while attempts < 5:
try:
renewal_token = stringutils.random_string(32)
- yield self.store.set_renewal_token_for_user(user_id, renewal_token)
+ await self.store.set_renewal_token_for_user(user_id, renewal_token)
return renewal_token
except StoreError:
attempts += 1
raise StoreError(500, "Couldn't generate a unique string as refresh string.")
- @defer.inlineCallbacks
- def renew_account(self, renewal_token):
+ async def renew_account(self, renewal_token: str) -> bool:
"""Renews the account attached to a given renewal token by pushing back the
expiration date by the current validity period in the server's configuration.
Args:
- renewal_token (str): Token sent with the renewal request.
+ renewal_token: Token sent with the renewal request.
Returns:
- bool: Whether the provided token is valid.
+ Whether the provided token is valid.
"""
try:
- user_id = yield self.store.get_user_from_renewal_token(renewal_token)
+ user_id = await self.store.get_user_from_renewal_token(renewal_token)
except StoreError:
- defer.returnValue(False)
+ return False
logger.debug("Renewing an account for user %s", user_id)
- yield self.renew_account_for_user(user_id)
+ await self.renew_account_for_user(user_id)
- defer.returnValue(True)
+ return True
- @defer.inlineCallbacks
- def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False):
+ async def renew_account_for_user(
+ self, user_id: str, expiration_ts: int = None, email_sent: bool = False
+ ) -> int:
"""Renews the account attached to a given user by pushing back the
expiration date by the current validity period in the server's
configuration.
Args:
- renewal_token (str): Token sent with the renewal request.
- expiration_ts (int): New expiration date. Defaults to now + validity period.
- email_sent (bool): Whether an email has been sent for this validity period.
+ renewal_token: Token sent with the renewal request.
+ expiration_ts: New expiration date. Defaults to now + validity period.
+ email_sen: Whether an email has been sent for this validity period.
Defaults to False.
Returns:
- defer.Deferred[int]: New expiration date for this account, as a timestamp
- in milliseconds since epoch.
+ New expiration date for this account, as a timestamp in
+ milliseconds since epoch.
"""
if expiration_ts is None:
expiration_ts = self.clock.time_msec() + self._account_validity.period
- yield self.store.set_account_validity_for_user(
+ await self.store.set_account_validity_for_user(
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 28c12753c1..57a10daefd 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -264,7 +264,6 @@ class E2eKeysHandler(object):
return ret
- @defer.inlineCallbacks
def get_cross_signing_keys_from_cache(self, query, from_user_id):
"""Get cross-signing keys for users from the database
@@ -284,35 +283,14 @@ class E2eKeysHandler(object):
self_signing_keys = {}
user_signing_keys = {}
- for user_id in query:
- # XXX: consider changing the store functions to allow querying
- # multiple users simultaneously.
- key = yield self.store.get_e2e_cross_signing_key(
- user_id, "master", from_user_id
- )
- if key:
- master_keys[user_id] = key
-
- key = yield self.store.get_e2e_cross_signing_key(
- user_id, "self_signing", from_user_id
- )
- if key:
- self_signing_keys[user_id] = key
-
- # users can see other users' master and self-signing keys, but can
- # only see their own user-signing keys
- if from_user_id == user_id:
- key = yield self.store.get_e2e_cross_signing_key(
- user_id, "user_signing", from_user_id
- )
- if key:
- user_signing_keys[user_id] = key
-
- return {
- "master_keys": master_keys,
- "self_signing_keys": self_signing_keys,
- "user_signing_keys": user_signing_keys,
- }
+ # Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486
+ return defer.succeed(
+ {
+ "master_keys": master_keys,
+ "self_signing_keys": self_signing_keys,
+ "user_signing_keys": user_signing_keys,
+ }
+ )
@trace
@defer.inlineCallbacks
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bc26921768..62985bab9f 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,7 +19,7 @@
import itertools
import logging
-from typing import Dict, Iterable, Optional, Sequence, Tuple
+from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import six
from six import iteritems, itervalues
@@ -63,8 +63,9 @@ from synapse.replication.http.federation import (
)
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.types import UserID, get_domain_from_id
-from synapse.util import unwrapFirstError
+from synapse.util import batch_iter, unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
@@ -164,8 +165,7 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
- @defer.inlineCallbacks
- def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
+ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
""" Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
@@ -175,17 +175,15 @@ class FederationHandler(BaseHandler):
pdu (FrozenEvent): received PDU
sent_to_us_directly (bool): True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event.
-
- Returns (Deferred): completes with None
"""
room_id = pdu.room_id
event_id = pdu.event_id
- logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
+ logger.info("handling received PDU: %s", pdu)
# We reprocess pdus when we have seen them only as outliers
- existing = yield self.store.get_event(
+ existing = await self.store.get_event(
event_id, allow_none=True, allow_rejected=True
)
@@ -229,7 +227,7 @@ class FederationHandler(BaseHandler):
#
# Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version.
- is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
+ is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
"[%s %s] Ignoring PDU from %s as we're not in the room",
@@ -245,12 +243,12 @@ class FederationHandler(BaseHandler):
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
- min_depth = yield self.get_min_depth_for_context(pdu.room_id)
+ min_depth = await self.get_min_depth_for_context(pdu.room_id)
logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
prevs = set(pdu.prev_event_ids())
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this
@@ -270,7 +268,7 @@ class FederationHandler(BaseHandler):
len(missing_prevs),
shortstr(missing_prevs),
)
- with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
+ with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
"[%s %s] Acquired room lock to fetch %d missing prev_events",
room_id,
@@ -278,13 +276,19 @@ class FederationHandler(BaseHandler):
len(missing_prevs),
)
- yield self._get_missing_events_for_pdu(
- origin, pdu, prevs, min_depth
- )
+ try:
+ await self._get_missing_events_for_pdu(
+ origin, pdu, prevs, min_depth
+ )
+ except Exception as e:
+ raise Exception(
+ "Error fetching missing prev_events for %s: %s"
+ % (event_id, e)
+ )
# Update the set of things we've seen after trying to
# fetch the missing stuff
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if not prevs - seen:
logger.info(
@@ -292,14 +296,6 @@ class FederationHandler(BaseHandler):
room_id,
event_id,
)
- elif missing_prevs:
- logger.info(
- "[%s %s] Not recursively fetching %d missing prev_events: %s",
- room_id,
- event_id,
- len(missing_prevs),
- shortstr(missing_prevs),
- )
if prevs - seen:
# We've still not been able to get all of the prev_events for this event.
@@ -344,13 +340,19 @@ class FederationHandler(BaseHandler):
affected=pdu.event_id,
)
+ logger.info(
+ "Event %s is missing prev_events: calculating state for a "
+ "backwards extremity",
+ event_id,
+ )
+
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
auth_chains = set()
event_map = {event_id: pdu}
try:
# Get the state of the events we know about
- ours = yield self.state_store.get_state_groups_ids(room_id, seen)
+ ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list(
@@ -364,13 +366,10 @@ class FederationHandler(BaseHandler):
# know about
for p in prevs - seen:
logger.info(
- "[%s %s] Requesting state at missing prev_event %s",
- room_id,
- event_id,
- p,
+ "Requesting state at missing prev_event %s", event_id,
)
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
@@ -379,24 +378,10 @@ class FederationHandler(BaseHandler):
(
remote_state,
got_auth_chain,
- ) = yield self.federation_client.get_state_for_room(
- origin, room_id, p
- )
-
- # we want the state *after* p; get_state_for_room returns the
- # state *before* p.
- remote_event = yield self.federation_client.get_pdu(
- [origin], p, room_version, outlier=True
+ ) = await self._get_state_for_room(
+ origin, room_id, p, include_event_in_state=True
)
- if remote_event is None:
- raise Exception(
- "Unable to get missing prev_event %s" % (p,)
- )
-
- if remote_event.is_state():
- remote_state.append(remote_event)
-
# XXX hrm I'm not convinced that duplicate events will compare
# for equality, so I'm not sure this does what the author
# hoped.
@@ -410,7 +395,7 @@ class FederationHandler(BaseHandler):
for x in remote_state:
event_map[x.event_id] = x
- state_map = yield resolve_events_with_store(
+ state_map = await resolve_events_with_store(
room_version,
state_maps,
event_map,
@@ -422,10 +407,10 @@ class FederationHandler(BaseHandler):
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
- evs = yield self.store.get_events(
+ evs = await self.store.get_events(
list(state_map.values()),
get_prev_content=False,
- check_redacted=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
)
event_map.update(evs)
@@ -446,12 +431,11 @@ class FederationHandler(BaseHandler):
affected=event_id,
)
- yield self._process_received_pdu(
+ await self._process_received_pdu(
origin, pdu, state=state, auth_chain=auth_chain
)
- @defer.inlineCallbacks
- def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+ async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
"""
Args:
origin (str): Origin of the pdu. Will be called to get the missing events
@@ -463,12 +447,12 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id
event_id = pdu.event_id
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if not prevs - seen:
return
- latest = yield self.store.get_latest_event_ids_in_room(room_id)
+ latest = await self.store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
@@ -532,7 +516,7 @@ class FederationHandler(BaseHandler):
# All that said: Let's try increasing the timout to 60s and see what happens.
try:
- missing_events = yield self.federation_client.get_missing_events(
+ missing_events = await self.federation_client.get_missing_events(
origin,
room_id,
earliest_events_ids=list(latest),
@@ -571,7 +555,7 @@ class FederationHandler(BaseHandler):
)
with nested_logging_context(ev.event_id):
try:
- yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
+ await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
except FederationError as e:
if e.code == 403:
logger.warning(
@@ -583,8 +567,116 @@ class FederationHandler(BaseHandler):
else:
raise
- @defer.inlineCallbacks
- def _process_received_pdu(self, origin, event, state, auth_chain):
+ async def _get_state_for_room(
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ include_event_in_state: bool = False,
+ ) -> Tuple[List[EventBase], List[EventBase]]:
+ """Requests all of the room state at a given event from a remote homeserver.
+
+ Args:
+ destination: The remote homeserver to query for the state.
+ room_id: The id of the room we're interested in.
+ event_id: The id of the event we want the state at.
+ include_event_in_state: if true, the event itself will be included in the
+ returned state event list.
+
+ Returns:
+ A list of events in the state, possibly including the event itself, and
+ a list of events in the auth chain for the given event.
+ """
+ (
+ state_event_ids,
+ auth_event_ids,
+ ) = await self.federation_client.get_room_state_ids(
+ destination, room_id, event_id=event_id
+ )
+
+ desired_events = set(state_event_ids + auth_event_ids)
+
+ if include_event_in_state:
+ desired_events.add(event_id)
+
+ event_map = await self._get_events_from_store_or_dest(
+ destination, room_id, desired_events
+ )
+
+ failed_to_fetch = desired_events - event_map.keys()
+ if failed_to_fetch:
+ logger.warning(
+ "Failed to fetch missing state/auth events for %s %s",
+ event_id,
+ failed_to_fetch,
+ )
+
+ remote_state = [
+ event_map[e_id] for e_id in state_event_ids if e_id in event_map
+ ]
+
+ if include_event_in_state:
+ remote_event = event_map.get(event_id)
+ if not remote_event:
+ raise Exception("Unable to get missing prev_event %s" % (event_id,))
+ if remote_event.is_state():
+ remote_state.append(remote_event)
+
+ auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
+ auth_chain.sort(key=lambda e: e.depth)
+
+ return remote_state, auth_chain
+
+ async def _get_events_from_store_or_dest(
+ self, destination: str, room_id: str, event_ids: Iterable[str]
+ ) -> Dict[str, EventBase]:
+ """Fetch events from a remote destination, checking if we already have them.
+
+ Args:
+ destination
+ room_id
+ event_ids
+
+ Returns:
+ map from event_id to event
+ """
+ fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
+
+ missing_events = set(event_ids) - fetched_events.keys()
+
+ if not missing_events:
+ return fetched_events
+
+ logger.debug(
+ "Fetching unknown state/auth events %s for room %s",
+ missing_events,
+ event_ids,
+ )
+
+ room_version = await self.store.get_room_version(room_id)
+
+ # XXX 20 requests at once? really?
+ for batch in batch_iter(missing_events, 20):
+ deferreds = [
+ run_in_background(
+ self.federation_client.get_pdu,
+ destinations=[destination],
+ event_id=e_id,
+ room_version=room_version,
+ )
+ for e_id in batch
+ ]
+
+ res = await make_deferred_yieldable(
+ defer.DeferredList(deferreds, consumeErrors=True)
+ )
+ for success, result in res:
+ if success and result:
+ fetched_events[result.event_id] = result
+
+ return fetched_events
+
+ async def _process_received_pdu(self, origin, event, state, auth_chain):
""" Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
"""
@@ -599,7 +691,7 @@ class FederationHandler(BaseHandler):
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}
- seen_ids = yield self.store.have_seen_events(event_ids)
+ seen_ids = await self.store.have_seen_events(event_ids)
if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
@@ -626,18 +718,18 @@ class FederationHandler(BaseHandler):
event_id,
[e.event.event_id for e in event_infos],
)
- yield self._handle_new_events(origin, event_infos)
+ await self._handle_new_events(origin, event_infos)
try:
- context = yield self._handle_new_event(origin, event, state=state)
+ context = await self._handle_new_event(origin, event, state=state)
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
- room = yield self.store.get_room(room_id)
+ room = await self.store.get_room(room_id)
if not room:
try:
- yield self.store.store_room(
+ await self.store.store_room(
room_id=room_id, room_creator_user_id="", is_public=False
)
except StoreError:
@@ -650,11 +742,11 @@ class FederationHandler(BaseHandler):
# changing their profile info.
newly_joined = True
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = await context.get_prev_state_ids(self.store)
prev_state_id = prev_state_ids.get((event.type, event.state_key))
if prev_state_id:
- prev_state = yield self.store.get_event(
+ prev_state = await self.store.get_event(
prev_state_id, allow_none=True
)
if prev_state and prev_state.membership == Membership.JOIN:
@@ -662,11 +754,10 @@ class FederationHandler(BaseHandler):
if newly_joined:
user = UserID.from_string(event.state_key)
- yield self.user_joined_room(user, room_id)
+ await self.user_joined_room(user, room_id)
@log_function
- @defer.inlineCallbacks
- def backfill(self, dest, room_id, limit, extremities):
+ async def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side
@@ -683,9 +774,9 @@ class FederationHandler(BaseHandler):
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
- events = yield self.federation_client.backfill(
+ events = await self.federation_client.backfill(
dest, room_id, limit=limit, extremities=extremities
)
@@ -700,7 +791,7 @@ class FederationHandler(BaseHandler):
# self._sanity_check_event(ev)
# Don't bother processing events we already have.
- seen_events = yield self.store.have_events_in_timeline(
+ seen_events = await self.store.have_events_in_timeline(
set(e.event_id for e in events)
)
@@ -723,7 +814,7 @@ class FederationHandler(BaseHandler):
state_events = {}
events_to_state = {}
for e_id in edges:
- state, auth = yield self.federation_client.get_state_for_room(
+ state, auth = await self._get_state_for_room(
destination=dest, room_id=room_id, event_id=e_id
)
auth_events.update({a.event_id: a for a in auth})
@@ -748,7 +839,7 @@ class FederationHandler(BaseHandler):
# We repeatedly do this until we stop finding new auth events.
while missing_auth - failed_to_fetch:
logger.info("Missing auth for backfill: %r", missing_auth)
- ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
+ ret_events = await self.store.get_events(missing_auth - failed_to_fetch)
auth_events.update(ret_events)
required_auth.update(
@@ -762,7 +853,7 @@ class FederationHandler(BaseHandler):
missing_auth - failed_to_fetch,
)
- results = yield make_deferred_yieldable(
+ results = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
@@ -789,7 +880,7 @@ class FederationHandler(BaseHandler):
failed_to_fetch = missing_auth - set(auth_events)
- seen_events = yield self.store.have_seen_events(
+ seen_events = await self.store.have_seen_events(
set(auth_events.keys()) | set(state_events.keys())
)
@@ -851,7 +942,7 @@ class FederationHandler(BaseHandler):
)
)
- yield self._handle_new_events(dest, ev_infos, backfilled=True)
+ await self._handle_new_events(dest, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
@@ -867,16 +958,15 @@ class FederationHandler(BaseHandler):
# We store these one at a time since each event depends on the
# previous to work out the state.
# TODO: We can probably do something more clever here.
- yield self._handle_new_event(dest, event, backfilled=True)
+ await self._handle_new_event(dest, event, backfilled=True)
return events
- @defer.inlineCallbacks
- def maybe_backfill(self, room_id, current_depth):
+ async def maybe_backfill(self, room_id, current_depth):
"""Checks the database to see if we should backfill before paginating,
and if so do.
"""
- extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id)
+ extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
if not extremities:
logger.debug("Not backfilling as no extremeties found.")
@@ -908,15 +998,17 @@ class FederationHandler(BaseHandler):
# state *before* the event, ignoring the special casing certain event
# types have.
- forward_events = yield self.store.get_successor_events(list(extremities))
+ forward_events = await self.store.get_successor_events(list(extremities))
- extremities_events = yield self.store.get_events(
- forward_events, check_redacted=False, get_prev_content=False
+ extremities_events = await self.store.get_events(
+ forward_events,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ get_prev_content=False,
)
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
- filtered_extremities = yield filter_events_for_server(
+ filtered_extremities = await filter_events_for_server(
self.storage,
self.server_name,
list(extremities_events.values()),
@@ -946,7 +1038,7 @@ class FederationHandler(BaseHandler):
# First we try hosts that are already in the room
# TODO: HEURISTIC ALERT.
- curr_state = yield self.state_handler.get_current_state(room_id)
+ curr_state = await self.state_handler.get_current_state(room_id)
def get_domains_from_state(state):
"""Get joined domains from state
@@ -985,12 +1077,11 @@ class FederationHandler(BaseHandler):
domain for domain, depth in curr_domains if domain != self.server_name
]
- @defer.inlineCallbacks
- def try_backfill(domains):
+ async def try_backfill(domains):
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
- yield self.backfill(
+ await self.backfill(
dom, room_id, limit=100, extremities=extremities
)
# If this succeeded then we probably already have the
@@ -1021,7 +1112,7 @@ class FederationHandler(BaseHandler):
return False
- success = yield try_backfill(likely_domains)
+ success = await try_backfill(likely_domains)
if success:
return True
@@ -1035,7 +1126,7 @@ class FederationHandler(BaseHandler):
logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
- states = yield make_deferred_yieldable(
+ states = await make_deferred_yieldable(
defer.gatherResults(
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True
)
@@ -1045,7 +1136,7 @@ class FederationHandler(BaseHandler):
# event_ids.
states = dict(zip(event_ids, [s.state for s in states]))
- state_map = yield self.store.get_events(
+ state_map = await self.store.get_events(
[e_id for ids in itervalues(states) for e_id in itervalues(ids)],
get_prev_content=False,
)
@@ -1061,7 +1152,7 @@ class FederationHandler(BaseHandler):
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
- success = yield try_backfill(
+ success = await try_backfill(
[dom for dom, _ in likely_domains if dom not in tried_domains]
)
if success:
@@ -1210,7 +1301,7 @@ class FederationHandler(BaseHandler):
# Check whether this room is the result of an upgrade of a room we already know
# about. If so, migrate over user information
predecessor = yield self.store.get_room_predecessor(room_id)
- if not predecessor:
+ if not predecessor or not isinstance(predecessor.get("room_id"), str):
return
old_room_id = predecessor["room_id"]
logger.debug(
@@ -1238,8 +1329,7 @@ class FederationHandler(BaseHandler):
return True
- @defer.inlineCallbacks
- def _handle_queued_pdus(self, room_queue):
+ async def _handle_queued_pdus(self, room_queue):
"""Process PDUs which got queued up while we were busy send_joining.
Args:
@@ -1255,7 +1345,7 @@ class FederationHandler(BaseHandler):
p.room_id,
)
with nested_logging_context(p.event_id):
- yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
+ await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
except Exception as e:
logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
@@ -1453,7 +1543,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content):
origin, event, event_format_version = yield self._make_and_verify_event(
- target_hosts, room_id, user_id, "leave", content=content,
+ target_hosts, room_id, user_id, "leave", content=content
)
# Mark as outlier as we don't have any state for this event; we're not
# even in the room.
@@ -2814,7 +2904,7 @@ class FederationHandler(BaseHandler):
room_id=room_id, user_id=user.to_string(), change="joined"
)
else:
- return user_joined_room(self.distributor, user, room_id)
+ return defer.succeed(user_joined_room(self.distributor, user, room_id))
@defer.inlineCallbacks
def get_room_complexity(self, remote_room_hosts, room_id):
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 54fa216d83..bf9add7fe2 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -46,6 +46,7 @@ from synapse.events.validator import EventValidator
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester
from synapse.util.async_helpers import Linearizer
@@ -875,7 +876,7 @@ class EventCreationHandler(object):
if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event(
event.redacts,
- check_redacted=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
allow_rejected=False,
allow_none=True,
@@ -952,7 +953,7 @@ class EventCreationHandler(object):
if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event(
event.redacts,
- check_redacted=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
allow_rejected=False,
allow_none=True,
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 8514ddc600..00a6afc963 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -280,8 +280,7 @@ class PaginationHandler(object):
await self.storage.purge_events.purge_room(room_id)
- @defer.inlineCallbacks
- def get_messages(
+ async def get_messages(
self,
requester,
room_id=None,
@@ -307,7 +306,7 @@ class PaginationHandler(object):
room_token = pagin_config.from_token.room_key
else:
pagin_config.from_token = (
- yield self.hs.get_event_sources().get_current_token_for_pagination()
+ await self.hs.get_event_sources().get_current_token_for_pagination()
)
room_token = pagin_config.from_token.room_key
@@ -319,11 +318,11 @@ class PaginationHandler(object):
source_config = pagin_config.get_source_config("room")
- with (yield self.pagination_lock.read(room_id)):
+ with (await self.pagination_lock.read(room_id)):
(
membership,
member_event_id,
- ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
+ ) = await self.auth.check_in_room_or_world_readable(room_id, user_id)
if source_config.direction == "b":
# if we're going backwards, we might need to backfill. This
@@ -331,7 +330,7 @@ class PaginationHandler(object):
if room_token.topological:
max_topo = room_token.topological
else:
- max_topo = yield self.store.get_max_topological_token(
+ max_topo = await self.store.get_max_topological_token(
room_id, room_token.stream
)
@@ -339,18 +338,18 @@ class PaginationHandler(object):
# If they have left the room then clamp the token to be before
# they left the room, to save the effort of loading from the
# database.
- leave_token = yield self.store.get_topological_token_for_event(
+ leave_token = await self.store.get_topological_token_for_event(
member_event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < max_topo:
source_config.from_key = str(leave_token)
- yield self.hs.get_handlers().federation_handler.maybe_backfill(
+ await self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, max_topo
)
- events, next_key = yield self.store.paginate_room_events(
+ events, next_key = await self.store.paginate_room_events(
room_id=room_id,
from_key=source_config.from_key,
to_key=source_config.to_key,
@@ -365,7 +364,7 @@ class PaginationHandler(object):
if event_filter:
events = event_filter.filter(events)
- events = yield filter_events_for_client(
+ events = await filter_events_for_client(
self.storage, user_id, events, is_peeking=(member_event_id is None)
)
@@ -385,19 +384,19 @@ class PaginationHandler(object):
(EventTypes.Member, event.sender) for event in events
)
- state_ids = yield self.state_store.get_state_ids_for_event(
+ state_ids = await self.state_store.get_state_ids_for_event(
events[0].event_id, state_filter=state_filter
)
if state_ids:
- state = yield self.store.get_events(list(state_ids.values()))
+ state = await self.store.get_events(list(state_ids.values()))
state = state.values()
time_now = self.clock.time_msec()
chunk = {
"chunk": (
- yield self._event_serializer.serialize_events(
+ await self._event_serializer.serialize_events(
events, time_now, as_client_event=as_client_event
)
),
@@ -406,7 +405,7 @@ class PaginationHandler(object):
}
if state:
- chunk["state"] = yield self._event_serializer.serialize_events(
+ chunk["state"] = await self._event_serializer.serialize_events(
state, time_now, as_client_event=as_client_event
)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index cc9e6b9bd0..0082f85c26 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -13,20 +13,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import re
+from typing import Tuple
import attr
import saml2
+import saml2.response
from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
+from synapse.config import ConfigError
from synapse.http.servlet import parse_string
from synapse.rest.client.v1.login import SSOAuthHandler
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import (
+ UserID,
+ map_username_to_mxid_localpart,
+ mxid_localpart_allowed_characters,
+)
from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
+@attr.s
+class Saml2SessionData:
+ """Data we track about SAML2 sessions"""
+
+ # time the session was created, in milliseconds
+ creation_time = attr.ib()
+
+
class SamlHandler:
def __init__(self, hs):
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
@@ -37,11 +53,14 @@ class SamlHandler:
self._datastore = hs.get_datastore()
self._hostname = hs.hostname
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
- self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute
)
- self._mxid_mapper = hs.config.saml2_mxid_mapper
+
+ # plugin to do custom mapping from saml response to mxid
+ self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
+ hs.config.saml2_user_mapping_provider_config
+ )
# identifier for the external_ids table
self._auth_provider_id = "saml"
@@ -118,22 +137,10 @@ class SamlHandler:
remote_user_id = saml2_auth.ava["uid"][0]
except KeyError:
logger.warning("SAML2 response lacks a 'uid' attestation")
- raise SynapseError(400, "uid not in SAML2 response")
-
- try:
- mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
- except KeyError:
- logger.warning(
- "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
- )
- raise SynapseError(
- 400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
- )
+ raise SynapseError(400, "'uid' not in SAML2 response")
self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
- displayName = saml2_auth.ava.get("displayName", [None])[0]
-
with (await self._mapping_lock.queue(self._auth_provider_id)):
# first of all, check if we already have a mapping for this user
logger.info(
@@ -173,22 +180,46 @@ class SamlHandler:
)
return registered_user_id
- # figure out a new mxid for this user
- base_mxid_localpart = self._mxid_mapper(mxid_source)
+ # Map saml response to user attributes using the configured mapping provider
+ for i in range(1000):
+ attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
+ saml2_auth, i
+ )
+
+ logger.debug(
+ "Retrieved SAML attributes from user mapping provider: %s "
+ "(attempt %d)",
+ attribute_dict,
+ i,
+ )
+
+ localpart = attribute_dict.get("mxid_localpart")
+ if not localpart:
+ logger.error(
+ "SAML mapping provider plugin did not return a "
+ "mxid_localpart object"
+ )
+ raise SynapseError(500, "Error parsing SAML2 response")
- suffix = 0
- while True:
- localpart = base_mxid_localpart + (str(suffix) if suffix else "")
+ displayname = attribute_dict.get("displayname")
+
+ # Check if this mxid already exists
if not await self._datastore.get_users_by_id_case_insensitive(
UserID(localpart, self._hostname).to_string()
):
+ # This mxid is free
break
- suffix += 1
- logger.info("Allocating mxid for new user with localpart %s", localpart)
+ else:
+ # Unable to generate a username in 1000 iterations
+ # Break and return error to the user
+ raise SynapseError(
+ 500, "Unable to generate a Matrix ID from the SAML response"
+ )
registered_user_id = await self._registration_handler.register_user(
- localpart=localpart, default_display_name=displayName
+ localpart=localpart, default_display_name=displayname
)
+
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
@@ -205,9 +236,120 @@ class SamlHandler:
del self._outstanding_requests_dict[reqid]
+DOT_REPLACE_PATTERN = re.compile(
+ ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
+)
+
+
+def dot_replace_for_mxid(username: str) -> str:
+ username = username.lower()
+ username = DOT_REPLACE_PATTERN.sub(".", username)
+
+ # regular mxids aren't allowed to start with an underscore either
+ username = re.sub("^_", "", username)
+ return username
+
+
+MXID_MAPPER_MAP = {
+ "hexencode": map_username_to_mxid_localpart,
+ "dotreplace": dot_replace_for_mxid,
+}
+
+
@attr.s
-class Saml2SessionData:
- """Data we track about SAML2 sessions"""
+class SamlConfig(object):
+ mxid_source_attribute = attr.ib()
+ mxid_mapper = attr.ib()
- # time the session was created, in milliseconds
- creation_time = attr.ib()
+
+class DefaultSamlMappingProvider(object):
+ __version__ = "0.0.1"
+
+ def __init__(self, parsed_config: SamlConfig):
+ """The default SAML user mapping provider
+
+ Args:
+ parsed_config: Module configuration
+ """
+ self._mxid_source_attribute = parsed_config.mxid_source_attribute
+ self._mxid_mapper = parsed_config.mxid_mapper
+
+ def saml_response_to_user_attributes(
+ self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
+ ) -> dict:
+ """Maps some text from a SAML response to attributes of a new user
+
+ Args:
+ saml_response: A SAML auth response object
+
+ failures: How many times a call to this function with this
+ saml_response has resulted in a failure
+
+ Returns:
+ dict: A dict containing new user attributes. Possible keys:
+ * mxid_localpart (str): Required. The localpart of the user's mxid
+ * displayname (str): The displayname of the user
+ """
+ try:
+ mxid_source = saml_response.ava[self._mxid_source_attribute][0]
+ except KeyError:
+ logger.warning(
+ "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
+ )
+ raise SynapseError(
+ 400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+ )
+
+ # Use the configured mapper for this mxid_source
+ base_mxid_localpart = self._mxid_mapper(mxid_source)
+
+ # Append suffix integer if last call to this function failed to produce
+ # a usable mxid
+ localpart = base_mxid_localpart + (str(failures) if failures else "")
+
+ # Retrieve the display name from the saml response
+ # If displayname is None, the mxid_localpart will be used instead
+ displayname = saml_response.ava.get("displayName", [None])[0]
+
+ return {
+ "mxid_localpart": localpart,
+ "displayname": displayname,
+ }
+
+ @staticmethod
+ def parse_config(config: dict) -> SamlConfig:
+ """Parse the dict provided by the homeserver's config
+ Args:
+ config: A dictionary containing configuration options for this provider
+ Returns:
+ SamlConfig: A custom config object for this module
+ """
+ # Parse config options and use defaults where necessary
+ mxid_source_attribute = config.get("mxid_source_attribute", "uid")
+ mapping_type = config.get("mxid_mapping", "hexencode")
+
+ # Retrieve the associating mapping function
+ try:
+ mxid_mapper = MXID_MAPPER_MAP[mapping_type]
+ except KeyError:
+ raise ConfigError(
+ "saml2_config.user_mapping_provider.config: '%s' is not a valid "
+ "mxid_mapping value" % (mapping_type,)
+ )
+
+ return SamlConfig(mxid_source_attribute, mxid_mapper)
+
+ @staticmethod
+ def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
+ """Returns the required attributes of a SAML
+
+ Args:
+ config: A SamlConfig object containing configuration params for this provider
+
+ Returns:
+ tuple[set,set]: The first set equates to the saml auth response
+ attributes that are required for the module to function, whereas the
+ second set consists of those attributes which can be used if
+ available, but are not necessary
+ """
+ return {"uid", config.mxid_source_attribute}, {"displayName"}
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 56ed262a1f..ef750d1497 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -21,7 +21,7 @@ from unpaddedbase64 import decode_base64, encode_base64
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import SynapseError
+from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.storage.state import StateFilter
from synapse.visibility import filter_events_for_client
@@ -37,6 +37,7 @@ class SearchHandler(BaseHandler):
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def get_old_rooms_from_upgraded_room(self, room_id):
@@ -53,23 +54,38 @@ class SearchHandler(BaseHandler):
room_id (str): id of the room to search through.
Returns:
- Deferred[iterable[unicode]]: predecessor room ids
+ Deferred[iterable[str]]: predecessor room ids
"""
historical_room_ids = []
- while True:
- predecessor = yield self.store.get_room_predecessor(room_id)
+ # The initial room must have been known for us to get this far
+ predecessor = yield self.store.get_room_predecessor(room_id)
- # If no predecessor, assume we've hit a dead end
+ while True:
if not predecessor:
+ # We have reached the end of the chain of predecessors
+ break
+
+ if not isinstance(predecessor.get("room_id"), str):
+ # This predecessor object is malformed. Exit here
+ break
+
+ predecessor_room_id = predecessor["room_id"]
+
+ # Don't add it to the list until we have checked that we are in the room
+ try:
+ next_predecessor_room = yield self.store.get_room_predecessor(
+ predecessor_room_id
+ )
+ except NotFoundError:
+ # The predecessor is not a known room, so we are done here
break
- # Add predecessor's room ID
- historical_room_ids.append(predecessor["room_id"])
+ historical_room_ids.append(predecessor_room_id)
- # Scan through the old room for further predecessors
- room_id = predecessor["room_id"]
+ # And repeat
+ predecessor = next_predecessor_room
return historical_room_ids
|