diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 1971e373ed..e2ac595a62 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -30,6 +30,7 @@ from typing import (
Optional,
Tuple,
Union,
+ cast,
)
import attr
@@ -72,6 +73,7 @@ from synapse.util.stringutils import base62_encode
from synapse.util.threepids import canonicalise_email
if TYPE_CHECKING:
+ from synapse.rest.client.v1.login import LoginResponse
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -777,6 +779,108 @@ class AuthHandler(BaseHandler):
"params": params,
}
+ async def refresh_token(
+ self,
+ refresh_token: str,
+ valid_until_ms: Optional[int],
+ ) -> Tuple[str, str]:
+ """
+ Consumes a refresh token and generate both a new access token and a new refresh token from it.
+
+ The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token.
+
+ Args:
+ refresh_token: The token to consume.
+ valid_until_ms: The expiration timestamp of the new access token.
+
+ Returns:
+ A tuple containing the new access token and refresh token
+ """
+
+ # Verify the token signature first before looking up the token
+ if not self._verify_refresh_token(refresh_token):
+ raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN)
+
+ existing_token = await self.store.lookup_refresh_token(refresh_token)
+ if existing_token is None:
+ raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN)
+
+ if (
+ existing_token.has_next_access_token_been_used
+ or existing_token.has_next_refresh_token_been_refreshed
+ ):
+ raise SynapseError(
+ 403, "refresh token isn't valid anymore", Codes.FORBIDDEN
+ )
+
+ (
+ new_refresh_token,
+ new_refresh_token_id,
+ ) = await self.get_refresh_token_for_user_id(
+ user_id=existing_token.user_id, device_id=existing_token.device_id
+ )
+ access_token = await self.get_access_token_for_user_id(
+ user_id=existing_token.user_id,
+ device_id=existing_token.device_id,
+ valid_until_ms=valid_until_ms,
+ refresh_token_id=new_refresh_token_id,
+ )
+ await self.store.replace_refresh_token(
+ existing_token.token_id, new_refresh_token_id
+ )
+ return access_token, new_refresh_token
+
+ def _verify_refresh_token(self, token: str) -> bool:
+ """
+ Verifies the shape of a refresh token.
+
+ Args:
+ token: The refresh token to verify
+
+ Returns:
+ Whether the token has the right shape
+ """
+ parts = token.split("_", maxsplit=4)
+ if len(parts) != 4:
+ return False
+
+ type, localpart, rand, crc = parts
+
+ # Refresh tokens are prefixed by "syr_", let's check that
+ if type != "syr":
+ return False
+
+ # Check the CRC
+ base = f"{type}_{localpart}_{rand}"
+ expected_crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
+ if crc != expected_crc:
+ return False
+
+ return True
+
+ async def get_refresh_token_for_user_id(
+ self,
+ user_id: str,
+ device_id: str,
+ ) -> Tuple[str, int]:
+ """
+ Creates a new refresh token for the user with the given user ID.
+
+ Args:
+ user_id: canonical user ID
+ device_id: the device ID to associate with the token.
+
+ Returns:
+ The newly created refresh token and its ID in the database
+ """
+ refresh_token = self.generate_refresh_token(UserID.from_string(user_id))
+ refresh_token_id = await self.store.add_refresh_token_to_user(
+ user_id=user_id,
+ token=refresh_token,
+ device_id=device_id,
+ )
+ return refresh_token, refresh_token_id
+
async def get_access_token_for_user_id(
self,
user_id: str,
@@ -784,6 +888,7 @@ class AuthHandler(BaseHandler):
valid_until_ms: Optional[int],
puppets_user_id: Optional[str] = None,
is_appservice_ghost: bool = False,
+ refresh_token_id: Optional[int] = None,
) -> str:
"""
Creates a new access token for the user with the given user ID.
@@ -801,6 +906,8 @@ class AuthHandler(BaseHandler):
valid_until_ms: when the token is valid until. None for
no expiry.
is_appservice_ghost: Whether the user is an application ghost user
+ refresh_token_id: the refresh token ID that will be associated with
+ this access token.
Returns:
The access token for the user's session.
Raises:
@@ -836,6 +943,7 @@ class AuthHandler(BaseHandler):
device_id=device_id,
valid_until_ms=valid_until_ms,
puppets_user_id=puppets_user_id,
+ refresh_token_id=refresh_token_id,
)
# the device *should* have been registered before we got here; however,
@@ -928,7 +1036,7 @@ class AuthHandler(BaseHandler):
self,
login_submission: Dict[str, Any],
ratelimit: bool = False,
- ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
+ ) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate auth types which don't
@@ -1073,7 +1181,7 @@ class AuthHandler(BaseHandler):
self,
username: str,
login_submission: Dict[str, Any],
- ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
+ ) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
"""Helper for validate_login
Handles login, once we've mapped 3pids onto userids
@@ -1151,7 +1259,7 @@ class AuthHandler(BaseHandler):
async def check_password_provider_3pid(
self, medium: str, address: str, password: str
- ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
+ ) -> Tuple[Optional[str], Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
"""Check if a password provider is able to validate a thirdparty login
Args:
@@ -1215,6 +1323,19 @@ class AuthHandler(BaseHandler):
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
return f"{base}_{crc}"
+ def generate_refresh_token(self, for_user: UserID) -> str:
+ """Generates an opaque string, for use as a refresh token"""
+
+ # we use the following format for refresh tokens:
+ # syr_<base64 local part>_<random string>_<base62 crc check>
+
+ b64local = unpaddedbase64.encode_base64(for_user.localpart.encode("utf-8"))
+ random_string = stringutils.random_string(20)
+ base = f"syr_{b64local}_{random_string}"
+
+ crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
+ return f"{base}_{crc}"
+
async def validate_short_term_login_token(
self, login_token: str
) -> LoginTokenAttributes:
@@ -1563,7 +1684,7 @@ class AuthHandler(BaseHandler):
)
respond_with_html(request, 200, html)
- async def _sso_login_callback(self, login_result: JsonDict) -> None:
+ async def _sso_login_callback(self, login_result: "LoginResponse") -> None:
"""
A login callback which might add additional attributes to the login response.
@@ -1577,7 +1698,8 @@ class AuthHandler(BaseHandler):
extra_attributes = self._extra_attributes.get(login_result["user_id"])
if extra_attributes:
- login_result.update(extra_attributes.extra_attributes)
+ login_result_dict = cast(Dict[str, Any], login_result)
+ login_result_dict.update(extra_attributes.extra_attributes)
def _expire_sso_extra_attributes(self) -> None:
"""
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 1b566dbf2d..d929c65131 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1711,80 +1711,6 @@ class FederationHandler(BaseHandler):
return event
- async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict:
- """We have received a join event for a room. Fully process it and
- respond with the current state and auth chains.
- """
- event = pdu
-
- logger.debug(
- "on_send_join_request from %s: Got event: %s, signatures: %s",
- origin,
- event.event_id,
- event.signatures,
- )
-
- if get_domain_from_id(event.sender) != origin:
- logger.info(
- "Got /send_join request for user %r from different origin %s",
- event.sender,
- origin,
- )
- raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
-
- event.internal_metadata.outlier = False
- # Send this event on behalf of the origin server.
- #
- # The reasons we have the destination server rather than the origin
- # server send it are slightly mysterious: the origin server should have
- # all the necessary state once it gets the response to the send_join,
- # so it could send the event itself if it wanted to. It may be that
- # doing it this way reduces failure modes, or avoids certain attacks
- # where a new server selectively tells a subset of the federation that
- # it has joined.
- #
- # The fact is that, as of the current writing, Synapse doesn't send out
- # the join event over federation after joining, and changing it now
- # would introduce the danger of backwards-compatibility problems.
- event.internal_metadata.send_on_behalf_of = origin
-
- # Calculate the event context.
- context = await self.state_handler.compute_event_context(event)
-
- # Get the state before the new event.
- prev_state_ids = await context.get_prev_state_ids()
-
- # Check if the user is already in the room or invited to the room.
- user_id = event.state_key
- prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
- prev_member_event = None
- if prev_member_event_id:
- prev_member_event = await self.store.get_event(prev_member_event_id)
-
- # Check if the member should be allowed access via membership in a space.
- await self._event_auth_handler.check_restricted_join_rules(
- prev_state_ids,
- event.room_version,
- user_id,
- prev_member_event,
- )
-
- # Persist the event.
- await self._auth_and_persist_event(origin, event, context)
-
- logger.debug(
- "on_send_join_request: After _auth_and_persist_event: %s, sigs: %s",
- event.event_id,
- event.signatures,
- )
-
- state_ids = list(prev_state_ids.values())
- auth_chain = await self.store.get_auth_chain(event.room_id, state_ids)
-
- state = await self.store.get_events(list(prev_state_ids.values()))
-
- return {"state": list(state.values()), "auth_chain": auth_chain}
-
async def on_invite_request(
self, origin: str, event: EventBase, room_version: RoomVersion
) -> EventBase:
@@ -1960,37 +1886,6 @@ class FederationHandler(BaseHandler):
return event
- async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None:
- """We have received a leave event for a room. Fully process it."""
- event = pdu
-
- logger.debug(
- "on_send_leave_request: Got event: %s, signatures: %s",
- event.event_id,
- event.signatures,
- )
-
- if get_domain_from_id(event.sender) != origin:
- logger.info(
- "Got /send_leave request for user %r from different origin %s",
- event.sender,
- origin,
- )
- raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
-
- event.internal_metadata.outlier = False
-
- context = await self.state_handler.compute_event_context(event)
- await self._auth_and_persist_event(origin, event, context)
-
- logger.debug(
- "on_send_leave_request: After _auth_and_persist_event: %s, sigs: %s",
- event.event_id,
- event.signatures,
- )
-
- return None
-
@log_function
async def on_make_knock_request(
self, origin: str, room_id: str, user_id: str
@@ -2054,51 +1949,115 @@ class FederationHandler(BaseHandler):
return event
@log_function
- async def on_send_knock_request(
+ async def on_send_membership_event(
self, origin: str, event: EventBase
) -> EventContext:
"""
- We have received a knock event for a room. Verify that event and send it into the room
- on the knocking homeserver's behalf.
+ We have received a join/leave/knock event for a room via send_join/leave/knock.
+
+ Verify that event and send it into the room on the remote homeserver's behalf.
+
+ This is quite similar to on_receive_pdu, with the following principal
+ differences:
+ * only membership events are permitted (and only events with
+ sender==state_key -- ie, no kicks or bans)
+ * *We* send out the event on behalf of the remote server.
+ * We enforce the membership restrictions of restricted rooms.
+ * Rejected events result in an exception rather than being stored.
+
+ There are also other differences, however it is not clear if these are by
+ design or omission. In particular, we do not attempt to backfill any missing
+ prev_events.
Args:
- origin: The remote homeserver of the knocking user.
- event: The knocking member event that has been signed by the remote homeserver.
+ origin: The homeserver of the remote (joining/invited/knocking) user.
+ event: The member event that has been signed by the remote homeserver.
Returns:
The context of the event after inserting it into the room graph.
+
+ Raises:
+ SynapseError if the event is not accepted into the room
"""
logger.debug(
- "on_send_knock_request: Got event: %s, signatures: %s",
+ "on_send_membership_event: Got event: %s, signatures: %s",
event.event_id,
event.signatures,
)
if get_domain_from_id(event.sender) != origin:
logger.info(
- "Got /send_knock request for user %r from different origin %s",
+ "Got send_membership request for user %r from different origin %s",
event.sender,
origin,
)
raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
- event.internal_metadata.outlier = False
+ if event.sender != event.state_key:
+ raise SynapseError(400, "state_key and sender must match", Codes.BAD_JSON)
- context = await self.state_handler.compute_event_context(event)
+ assert not event.internal_metadata.outlier
- event_allowed = await self.third_party_event_rules.check_event_allowed(
- event, context
- )
- if not event_allowed:
- logger.info("Sending of knock %s forbidden by third-party rules", event)
+ # Send this event on behalf of the other server.
+ #
+ # The remote server isn't a full participant in the room at this point, so
+ # may not have an up-to-date list of the other homeservers participating in
+ # the room, so we send it on their behalf.
+ event.internal_metadata.send_on_behalf_of = origin
+
+ context = await self.state_handler.compute_event_context(event)
+ context = await self._check_event_auth(origin, event, context)
+ if context.rejected:
raise SynapseError(
- 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ 403, f"{event.membership} event was rejected", Codes.FORBIDDEN
)
- await self._auth_and_persist_event(origin, event, context)
+ # for joins, we need to check the restrictions of restricted rooms
+ if event.membership == Membership.JOIN:
+ await self._check_join_restrictions(context, event)
+ # for knock events, we run the third-party event rules. It's not entirely clear
+ # why we don't do this for other sorts of membership events.
+ if event.membership == Membership.KNOCK:
+ event_allowed = await self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ logger.info("Sending of knock %s forbidden by third-party rules", event)
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
+ # all looks good, we can persist the event.
+ await self._run_push_actions_and_persist_event(event, context)
return context
+ async def _check_join_restrictions(
+ self, context: EventContext, event: EventBase
+ ) -> None:
+ """Check that restrictions in restricted join rules are matched
+
+ Called when we receive a join event via send_join.
+
+ Raises an auth error if the restrictions are not matched.
+ """
+ prev_state_ids = await context.get_prev_state_ids()
+
+ # Check if the user is already in the room or invited to the room.
+ user_id = event.state_key
+ prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
+ prev_member_event = None
+ if prev_member_event_id:
+ prev_member_event = await self.store.get_event(prev_member_event_id)
+
+ # Check if the member should be allowed access via membership in a space.
+ await self._event_auth_handler.check_restricted_join_rules(
+ prev_state_ids,
+ event.room_version,
+ user_id,
+ prev_member_event,
+ )
+
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
"""Returns the state at the event. i.e. not including said event."""
@@ -2240,6 +2199,18 @@ class FederationHandler(BaseHandler):
backfilled=backfilled,
)
+ await self._run_push_actions_and_persist_event(event, context, backfilled)
+
+ async def _run_push_actions_and_persist_event(
+ self, event: EventBase, context: EventContext, backfilled: bool = False
+ ):
+ """Run the push actions for a received event, and persist it.
+
+ Args:
+ event: The event itself.
+ context: The event context.
+ backfilled: True if the event was backfilled.
+ """
try:
if (
not event.internal_metadata.is_outlier()
@@ -2553,9 +2524,9 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
context: EventContext,
- state: Optional[Iterable[EventBase]],
- auth_events: Optional[MutableStateMap[EventBase]],
- backfilled: bool,
+ state: Optional[Iterable[EventBase]] = None,
+ auth_events: Optional[MutableStateMap[EventBase]] = None,
+ backfilled: bool = False,
) -> EventContext:
"""
Checks whether an event should be rejected (for failing auth checks).
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index ca1ed6a5c0..26ef016179 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -15,9 +15,10 @@
"""Contains functions for registering clients."""
import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from prometheus_client import Counter
+from typing_extensions import TypedDict
from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
@@ -54,6 +55,16 @@ login_counter = Counter(
["guest", "auth_provider"],
)
+LoginDict = TypedDict(
+ "LoginDict",
+ {
+ "device_id": str,
+ "access_token": str,
+ "valid_until_ms": Optional[int],
+ "refresh_token": Optional[str],
+ },
+)
+
class RegistrationHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
@@ -85,6 +96,7 @@ class RegistrationHandler(BaseHandler):
self.pusher_pool = hs.get_pusherpool()
self.session_lifetime = hs.config.session_lifetime
+ self.access_token_lifetime = hs.config.access_token_lifetime
async def check_username(
self,
@@ -386,11 +398,32 @@ class RegistrationHandler(BaseHandler):
room_alias = RoomAlias.from_string(r)
if self.hs.hostname != room_alias.domain:
- logger.warning(
- "Cannot create room alias %s, "
- "it does not match server domain",
+ # If the alias is remote, try to join the room. This might fail
+ # because the room might be invite only, but we don't have any local
+ # user in the room to invite this one with, so at this point that's
+ # the best we can do.
+ logger.info(
+ "Cannot automatically create room with alias %s as it isn't"
+ " local, trying to join the room instead",
r,
)
+
+ (
+ room,
+ remote_room_hosts,
+ ) = await room_member_handler.lookup_room_alias(room_alias)
+ room_id = room.to_string()
+
+ await room_member_handler.update_membership(
+ requester=create_requester(
+ user_id, authenticated_entity=self._server_name
+ ),
+ target=UserID.from_string(user_id),
+ room_id=room_id,
+ remote_room_hosts=remote_room_hosts,
+ action="join",
+ ratelimit=False,
+ )
else:
# A shallow copy is OK here since the only key that is
# modified is room_alias_name.
@@ -448,22 +481,32 @@ class RegistrationHandler(BaseHandler):
)
# Calculate whether the room requires an invite or can be
- # joined directly. Note that unless a join rule of public exists,
- # it is treated as requiring an invite.
- requires_invite = True
-
- state = await self.store.get_filtered_current_state_ids(
- room_id, StateFilter.from_types([(EventTypes.JoinRules, "")])
+ # joined directly. By default, we consider the room as requiring an
+ # invite if the homeserver is in the room (unless told otherwise by the
+ # join rules). Otherwise we consider it as being joinable, at the risk of
+ # failing to join, but in this case there's little more we can do since
+ # we don't have a local user in the room to craft up an invite with.
+ requires_invite = await self.store.is_host_joined(
+ room_id,
+ self.server_name,
)
- event_id = state.get((EventTypes.JoinRules, ""))
- if event_id:
- join_rules_event = await self.store.get_event(
- event_id, allow_none=True
+ if requires_invite:
+ # If the server is in the room, check if the room is public.
+ state = await self.store.get_filtered_current_state_ids(
+ room_id, StateFilter.from_types([(EventTypes.JoinRules, "")])
)
- if join_rules_event:
- join_rule = join_rules_event.content.get("join_rule", None)
- requires_invite = join_rule and join_rule != JoinRules.PUBLIC
+
+ event_id = state.get((EventTypes.JoinRules, ""))
+ if event_id:
+ join_rules_event = await self.store.get_event(
+ event_id, allow_none=True
+ )
+ if join_rules_event:
+ join_rule = join_rules_event.content.get("join_rule", None)
+ requires_invite = (
+ join_rule and join_rule != JoinRules.PUBLIC
+ )
# Send the invite, if necessary.
if requires_invite:
@@ -665,7 +708,8 @@ class RegistrationHandler(BaseHandler):
is_guest: bool = False,
is_appservice_ghost: bool = False,
auth_provider_id: Optional[str] = None,
- ) -> Tuple[str, str]:
+ should_issue_refresh_token: bool = False,
+ ) -> Tuple[str, str, Optional[int], Optional[str]]:
"""Register a device for a user and generate an access token.
The access token will be limited by the homeserver's session_lifetime config.
@@ -677,8 +721,9 @@ class RegistrationHandler(BaseHandler):
is_guest: Whether this is a guest account
auth_provider_id: The SSO IdP the user used, if any (just used for the
prometheus metrics).
+ should_issue_refresh_token: Whether it should also issue a refresh token
Returns:
- Tuple of device ID and access token
+ Tuple of device ID, access token, access token expiration time and refresh token
"""
res = await self._register_device_client(
user_id=user_id,
@@ -686,6 +731,7 @@ class RegistrationHandler(BaseHandler):
initial_display_name=initial_display_name,
is_guest=is_guest,
is_appservice_ghost=is_appservice_ghost,
+ should_issue_refresh_token=should_issue_refresh_token,
)
login_counter.labels(
@@ -693,7 +739,12 @@ class RegistrationHandler(BaseHandler):
auth_provider=(auth_provider_id or ""),
).inc()
- return res["device_id"], res["access_token"]
+ return (
+ res["device_id"],
+ res["access_token"],
+ res["valid_until_ms"],
+ res["refresh_token"],
+ )
async def register_device_inner(
self,
@@ -702,7 +753,8 @@ class RegistrationHandler(BaseHandler):
initial_display_name: Optional[str],
is_guest: bool = False,
is_appservice_ghost: bool = False,
- ) -> Dict[str, str]:
+ should_issue_refresh_token: bool = False,
+ ) -> LoginDict:
"""Helper for register_device
Does the bits that need doing on the main process. Not for use outside this
@@ -717,6 +769,9 @@ class RegistrationHandler(BaseHandler):
)
valid_until_ms = self.clock.time_msec() + self.session_lifetime
+ refresh_token = None
+ refresh_token_id = None
+
registered_device_id = await self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
@@ -724,14 +779,30 @@ class RegistrationHandler(BaseHandler):
assert valid_until_ms is None
access_token = self.macaroon_gen.generate_guest_access_token(user_id)
else:
+ if should_issue_refresh_token:
+ (
+ refresh_token,
+ refresh_token_id,
+ ) = await self._auth_handler.get_refresh_token_for_user_id(
+ user_id,
+ device_id=registered_device_id,
+ )
+ valid_until_ms = self.clock.time_msec() + self.access_token_lifetime
+
access_token = await self._auth_handler.get_access_token_for_user_id(
user_id,
device_id=registered_device_id,
valid_until_ms=valid_until_ms,
is_appservice_ghost=is_appservice_ghost,
+ refresh_token_id=refresh_token_id,
)
- return {"device_id": registered_device_id, "access_token": access_token}
+ return {
+ "device_id": registered_device_id,
+ "access_token": access_token,
+ "valid_until_ms": valid_until_ms,
+ "refresh_token": refresh_token,
+ }
async def post_registration_actions(
self, user_id: str, auth_result: dict, access_token: Optional[str]
|