diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 161b3c933c..98d3d2d97f 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -627,23 +627,28 @@ class AuthHandler(BaseHandler):
async def add_oob_auth(
self, stagetype: str, authdict: Dict[str, Any], clientip: str
- ) -> bool:
+ ) -> None:
"""
Adds the result of out-of-band authentication into an existing auth
session. Currently used for adding the result of fallback auth.
+
+ Raises:
+ LoginError if the stagetype is unknown or the session is missing.
+ LoginError is raised by check_auth if authentication fails.
"""
if stagetype not in self.checkers:
- raise LoginError(400, "", Codes.MISSING_PARAM)
+ raise LoginError(
+ 400, f"Unknown UIA stage type: {stagetype}", Codes.INVALID_PARAM
+ )
if "session" not in authdict:
- raise LoginError(400, "", Codes.MISSING_PARAM)
+ raise LoginError(400, "Missing session ID", Codes.MISSING_PARAM)
+ # If authentication fails a LoginError is raised. Otherwise, store
+ # the successful result.
result = await self.checkers[stagetype].check_auth(authdict, clientip)
- if result:
- await self.store.mark_ui_auth_stage_complete(
- authdict["session"], stagetype, result
- )
- return True
- return False
+ await self.store.mark_ui_auth_stage_complete(
+ authdict["session"], stagetype, result
+ )
def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]:
"""
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c0e13bdaac..246df43501 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -203,18 +203,13 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
- async def on_receive_pdu(
- self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False
- ) -> None:
- """Process a PDU received via a federation /send/ transaction, or
- via backfill of missing prev_events
+ async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None:
+ """Process a PDU received via a federation /send/ transaction
Args:
origin: server which initiated the /send/ transaction. Will
be used to fetch missing events or state.
pdu: received PDU
- sent_to_us_directly: True if this event was pushed to us; False if
- we pulled it as the result of a missing prev_event.
"""
room_id = pdu.room_id
@@ -276,8 +271,6 @@ class FederationHandler(BaseHandler):
)
return None
- state = None
-
# Check that the event passes auth based on the state at the event. This is
# done for events that are to be added to the timeline (non-outliers).
#
@@ -285,23 +278,16 @@ class FederationHandler(BaseHandler):
# - Fetching any missing prev events to fill in gaps in the graph
# - Fetching state if we have a hole in the graph
if not pdu.internal_metadata.is_outlier():
- # We only backfill backwards to the min depth.
- min_depth = await self.get_min_depth_for_context(pdu.room_id)
-
- logger.debug("min_depth: %d", min_depth)
-
prevs = set(pdu.prev_event_ids())
seen = await self.store.have_events_in_timeline(prevs)
+ missing_prevs = prevs - seen
+
+ if missing_prevs:
+ # We only backfill backwards to the min depth.
+ min_depth = await self.get_min_depth_for_context(pdu.room_id)
+ logger.debug("min_depth: %d", min_depth)
- if min_depth is not None and pdu.depth < min_depth:
- # This is so that we don't notify the user about this
- # message, to work around the fact that some events will
- # reference really really old events we really don't want to
- # send to the clients.
- pdu.internal_metadata.outlier = True
- elif min_depth is not None and pdu.depth > min_depth:
- missing_prevs = prevs - seen
- if sent_to_us_directly and missing_prevs:
+ if min_depth is not None and pdu.depth > min_depth:
# If we're missing stuff, ensure we only fetch stuff one
# at a time.
logger.info(
@@ -325,42 +311,23 @@ class FederationHandler(BaseHandler):
% (event_id, e)
) from e
- # Update the set of things we've seen after trying to
- # fetch the missing stuff
- seen = await self.store.have_events_in_timeline(prevs)
-
- if not prevs - seen:
- logger.info(
- "Found all missing prev_events",
- )
-
- missing_prevs = prevs - seen
- if missing_prevs:
- # We've still not been able to get all of the prev_events for this event.
- #
- # In this case, we need to fall back to asking another server in the
- # federation for the state at this event. That's ok provided we then
- # resolve the state against other bits of the DAG before using it (which
- # will ensure that you can't just take over a room by sending an event,
- # withholding its prev_events, and declaring yourself to be an admin in
- # the subsequent state request).
- #
- # Now, if we're pulling this event as a missing prev_event, then clearly
- # this event is not going to become the only forward-extremity and we are
- # guaranteed to resolve its state against our existing forward
- # extremities, so that should be fine.
- #
- # On the other hand, if this event was pushed to us, it is possible for
- # it to become the only forward-extremity in the room, and we would then
- # trust its state to be the state for the whole room. This is very bad.
- # Further, if the event was pushed to us, there is no excuse for us not to
- # have all the prev_events. We therefore reject any such events.
- #
- # XXX this really feels like it could/should be merged with the above,
- # but there is an interaction with min_depth that I'm not really
- # following.
-
- if sent_to_us_directly:
+ # Update the set of things we've seen after trying to
+ # fetch the missing stuff
+ seen = await self.store.have_events_in_timeline(prevs)
+ missing_prevs = prevs - seen
+
+ if not missing_prevs:
+ logger.info("Found all missing prev_events")
+
+ if missing_prevs:
+ # since this event was pushed to us, it is possible for it to
+ # become the only forward-extremity in the room, and we would then
+ # trust its state to be the state for the whole room. This is very
+ # bad. Further, if the event was pushed to us, there is no excuse
+ # for us not to have all the prev_events. (XXX: apart from
+ # min_depth?)
+ #
+ # We therefore reject any such events.
logger.warning(
"Rejecting: failed to fetch %d prev events: %s",
len(missing_prevs),
@@ -376,93 +343,7 @@ class FederationHandler(BaseHandler):
affected=pdu.event_id,
)
- logger.info(
- "Event %s is missing prev_events %s: calculating state for a "
- "backwards extremity",
- event_id,
- shortstr(missing_prevs),
- )
-
- # Calculate the state after each of the previous events, and
- # resolve them to find the correct state at the current event.
- event_map = {event_id: pdu}
- try:
- # Get the state of the events we know about
- ours = await self.state_store.get_state_groups_ids(room_id, seen)
-
- # state_maps is a list of mappings from (type, state_key) to event_id
- state_maps: List[StateMap[str]] = list(ours.values())
-
- # we don't need this any more, let's delete it.
- del ours
-
- # Ask the remote server for the states we don't
- # know about
- for p in missing_prevs:
- logger.info("Requesting state after missing prev_event %s", p)
-
- with nested_logging_context(p):
- # note that if any of the missing prevs share missing state or
- # auth events, the requests to fetch those events are deduped
- # by the get_pdu_cache in federation_client.
- remote_state = (
- await self._get_state_after_missing_prev_event(
- origin, room_id, p
- )
- )
-
- remote_state_map = {
- (x.type, x.state_key): x.event_id for x in remote_state
- }
- state_maps.append(remote_state_map)
-
- for x in remote_state:
- event_map[x.event_id] = x
-
- room_version = await self.store.get_room_version_id(room_id)
- state_map = (
- await self._state_resolution_handler.resolve_events_with_store(
- room_id,
- room_version,
- state_maps,
- event_map,
- state_res_store=StateResolutionStore(self.store),
- )
- )
-
- # We need to give _process_received_pdu the actual state events
- # rather than event ids, so generate that now.
-
- # First though we need to fetch all the events that are in
- # state_map, so we can build up the state below.
- evs = await self.store.get_events(
- list(state_map.values()),
- get_prev_content=False,
- redact_behaviour=EventRedactBehaviour.AS_IS,
- )
- event_map.update(evs)
-
- state = [event_map[e] for e in state_map.values()]
- except Exception:
- logger.warning(
- "Error attempting to resolve state at missing " "prev_events",
- exc_info=True,
- )
- raise FederationError(
- "ERROR",
- 403,
- "We can't get valid state history.",
- affected=event_id,
- )
-
- # A second round of checks for all events. Check that the event passes auth
- # based on `auth_events`, this allows us to assert that the event would
- # have been allowed at some point. If an event passes this check its OK
- # for it to be used as part of a returned `/state` request, as either
- # a) we received the event as part of the original join and so trust it, or
- # b) we'll do a state resolution with existing state before it becomes
- # part of the "current state", which adds more protection.
- await self._process_received_pdu(origin, pdu, state=state)
+ await self._process_received_pdu(origin, pdu, state=None)
async def _get_missing_events_for_pdu(
self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
@@ -562,24 +443,7 @@ class FederationHandler(BaseHandler):
return
logger.info("Got %d prev_events", len(missing_events))
-
- # We want to sort these by depth so we process them and
- # tell clients about them in order.
- missing_events.sort(key=lambda x: x.depth)
-
- for ev in missing_events:
- logger.info("Handling received prev_event %s", ev)
- with nested_logging_context(ev.event_id):
- try:
- await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
- except FederationError as e:
- if e.code == 403:
- logger.warning(
- "Received prev_event %s failed history check.",
- ev.event_id,
- )
- else:
- raise
+ await self._process_pulled_events(origin, missing_events)
async def _get_state_for_room(
self,
@@ -1496,6 +1360,198 @@ class FederationHandler(BaseHandler):
event_infos,
)
+ async def _process_pulled_events(
+ self, origin: str, events: Iterable[EventBase]
+ ) -> None:
+ """Process a batch of events we have pulled from a remote server
+
+ Pulls in any events required to auth the events, persists the received events,
+ and notifies clients, if appropriate.
+
+ Assumes the events have already had their signatures and hashes checked.
+
+ Params:
+ origin: The server we received these events from
+ events: The received events.
+ """
+
+ # We want to sort these by depth so we process them and
+ # tell clients about them in order.
+ sorted_events = sorted(events, key=lambda x: x.depth)
+
+ for ev in sorted_events:
+ with nested_logging_context(ev.event_id):
+ await self._process_pulled_event(origin, ev)
+
+ async def _process_pulled_event(self, origin: str, event: EventBase) -> None:
+ """Process a single event that we have pulled from a remote server
+
+ Pulls in any events required to auth the event, persists the received event,
+ and notifies clients, if appropriate.
+
+ Assumes the event has already had its signatures and hashes checked.
+
+ This is somewhat equivalent to on_receive_pdu, but applies somewhat different
+ logic in the case that we are missing prev_events (in particular, it just
+ requests the state at that point, rather than triggering a get_missing_events) -
+ so is appropriate when we have pulled the event from a remote server, rather
+ than having it pushed to us.
+
+ Params:
+ origin: The server we received this event from
+ events: The received event
+ """
+ logger.info("Processing pulled event %s", event)
+
+ # these should not be outliers.
+ assert not event.internal_metadata.is_outlier()
+
+ event_id = event.event_id
+
+ existing = await self.store.get_event(
+ event_id, allow_none=True, allow_rejected=True
+ )
+ if existing:
+ if not existing.internal_metadata.is_outlier():
+ logger.info(
+ "Ignoring received event %s which we have already seen",
+ event_id,
+ )
+ return
+ logger.info("De-outliering event %s", event_id)
+
+ try:
+ self._sanity_check_event(event)
+ except SynapseError as err:
+ logger.warning("Event %s failed sanity check: %s", event_id, err)
+ return
+
+ try:
+ state = await self._resolve_state_at_missing_prevs(origin, event)
+ await self._process_received_pdu(origin, event, state=state)
+ except FederationError as e:
+ if e.code == 403:
+ logger.warning("Pulled event %s failed history check.", event_id)
+ else:
+ raise
+
+ async def _resolve_state_at_missing_prevs(
+ self, dest: str, event: EventBase
+ ) -> Optional[Iterable[EventBase]]:
+ """Calculate the state at an event with missing prev_events.
+
+ This is used when we have pulled a batch of events from a remote server, and
+ still don't have all the prev_events.
+
+ If we already have all the prev_events for `event`, this method does nothing.
+
+ Otherwise, the missing prevs become new backwards extremities, and we fall back
+ to asking the remote server for the state after each missing `prev_event`,
+ and resolving across them.
+
+ That's ok provided we then resolve the state against other bits of the DAG
+ before using it - in other words, that the received event `event` is not going
+ to become the only forwards_extremity in the room (which will ensure that you
+ can't just take over a room by sending an event, withholding its prev_events,
+ and declaring yourself to be an admin in the subsequent state request).
+
+ In other words: we should only call this method if `event` has been *pulled*
+ as part of a batch of missing prev events, or similar.
+
+ Params:
+ dest: the remote server to ask for state at the missing prevs. Typically,
+ this will be the server we got `event` from.
+ event: an event to check for missing prevs.
+
+ Returns:
+ if we already had all the prev events, `None`. Otherwise, returns a list of
+ the events in the state at `event`.
+ """
+ room_id = event.room_id
+ event_id = event.event_id
+
+ prevs = set(event.prev_event_ids())
+ seen = await self.store.have_events_in_timeline(prevs)
+ missing_prevs = prevs - seen
+
+ if not missing_prevs:
+ return None
+
+ logger.info(
+ "Event %s is missing prev_events %s: calculating state for a "
+ "backwards extremity",
+ event_id,
+ shortstr(missing_prevs),
+ )
+ # Calculate the state after each of the previous events, and
+ # resolve them to find the correct state at the current event.
+ event_map = {event_id: event}
+ try:
+ # Get the state of the events we know about
+ ours = await self.state_store.get_state_groups_ids(room_id, seen)
+
+ # state_maps is a list of mappings from (type, state_key) to event_id
+ state_maps: List[StateMap[str]] = list(ours.values())
+
+ # we don't need this any more, let's delete it.
+ del ours
+
+ # Ask the remote server for the states we don't
+ # know about
+ for p in missing_prevs:
+ logger.info("Requesting state after missing prev_event %s", p)
+
+ with nested_logging_context(p):
+ # note that if any of the missing prevs share missing state or
+ # auth events, the requests to fetch those events are deduped
+ # by the get_pdu_cache in federation_client.
+ remote_state = await self._get_state_after_missing_prev_event(
+ dest, room_id, p
+ )
+
+ remote_state_map = {
+ (x.type, x.state_key): x.event_id for x in remote_state
+ }
+ state_maps.append(remote_state_map)
+
+ for x in remote_state:
+ event_map[x.event_id] = x
+
+ room_version = await self.store.get_room_version_id(room_id)
+ state_map = await self._state_resolution_handler.resolve_events_with_store(
+ room_id,
+ room_version,
+ state_maps,
+ event_map,
+ state_res_store=StateResolutionStore(self.store),
+ )
+
+ # We need to give _process_received_pdu the actual state events
+ # rather than event ids, so generate that now.
+
+ # First though we need to fetch all the events that are in
+ # state_map, so we can build up the state below.
+ evs = await self.store.get_events(
+ list(state_map.values()),
+ get_prev_content=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ )
+ event_map.update(evs)
+
+ state = [event_map[e] for e in state_map.values()]
+ except Exception:
+ logger.warning(
+ "Error attempting to resolve state at missing prev_events",
+ exc_info=True,
+ )
+ raise FederationError(
+ "ERROR",
+ 403,
+ "We can't get valid state history.",
+ affected=event_id,
+ )
+ return state
+
def _sanity_check_event(self, ev: EventBase) -> None:
"""
Do some early sanity checks of a received event
@@ -1764,7 +1820,7 @@ class FederationHandler(BaseHandler):
p,
)
with nested_logging_context(p.event_id):
- await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
+ await self.on_receive_pdu(origin, p)
except Exception as e:
logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
@@ -2375,6 +2431,7 @@ class FederationHandler(BaseHandler):
not event.internal_metadata.is_outlier()
and not backfilled
and not context.rejected
+ and (await self.store.get_min_depth(event.room_id)) <= event.depth
):
await self.action_generator.handle_push_actions_for_event(
event, context
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index e1c544a3c9..4e8f7f1d85 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -151,7 +151,7 @@ class InitialSyncHandler(BaseHandler):
limit = 10
async def handle_room(event: RoomsForUser):
- d = {
+ d: JsonDict = {
"room_id": event.room_id,
"membership": event.membership,
"visibility": (
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 8cf614136e..0ed59d757b 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -56,6 +56,22 @@ login_counter = Counter(
)
+def init_counters_for_auth_provider(auth_provider_id: str) -> None:
+ """Ensure the prometheus counters for the given auth provider are initialised
+
+ This fixes a problem where the counters are not reported for a given auth provider
+ until the user first logs in/registers.
+ """
+ for is_guest in (True, False):
+ login_counter.labels(guest=is_guest, auth_provider=auth_provider_id)
+ for shadow_banned in (True, False):
+ registration_counter.labels(
+ guest=is_guest,
+ shadow_banned=shadow_banned,
+ auth_provider=auth_provider_id,
+ )
+
+
class LoginDict(TypedDict):
device_id: str
access_token: str
@@ -96,6 +112,8 @@ class RegistrationHandler(BaseHandler):
self.session_lifetime = hs.config.session_lifetime
self.access_token_lifetime = hs.config.access_token_lifetime
+ init_counters_for_auth_provider("")
+
async def check_username(
self,
localpart: str,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index ba13196218..401b84aad1 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -36,6 +36,7 @@ from synapse.api.ratelimiting import Ratelimiter
from synapse.event_auth import get_named_level, get_power_level_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
+from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.types import (
JsonDict,
Requester,
@@ -79,7 +80,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.account_data_handler = hs.get_account_data_handler()
self.event_auth_handler = hs.get_event_auth_handler()
- self.member_linearizer = Linearizer(name="member")
+ self.member_linearizer: Linearizer = Linearizer(name="member")
self.clock = hs.get_clock()
self.spam_checker = hs.get_spam_checker()
@@ -556,6 +557,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
content.pop("displayname", None)
content.pop("avatar_url", None)
+ if len(content.get("displayname") or "") > MAX_DISPLAYNAME_LEN:
+ raise SynapseError(
+ 400,
+ f"Displayname is too long (max {MAX_DISPLAYNAME_LEN})",
+ errcode=Codes.BAD_JSON,
+ )
+
+ if len(content.get("avatar_url") or "") > MAX_AVATAR_URL_LEN:
+ raise SynapseError(
+ 400,
+ f"Avatar URL is too long (max {MAX_AVATAR_URL_LEN})",
+ errcode=Codes.BAD_JSON,
+ )
+
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index ac6cfc0da9..906985c754 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -28,12 +28,11 @@ from synapse.api.constants import (
Membership,
RoomTypes,
)
-from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.events import EventBase
from synapse.events.utils import format_event_for_client_v2
from synapse.types import JsonDict
from synapse.util.caches.response_cache import ResponseCache
-from synapse.util.stringutils import random_string
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -76,6 +75,9 @@ class _PaginationSession:
class RoomSummaryHandler:
+ # A unique key used for pagination sessions for the room hierarchy endpoint.
+ _PAGINATION_SESSION_TYPE = "room_hierarchy_pagination"
+
# The time a pagination session remains valid for.
_PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000
@@ -87,12 +89,6 @@ class RoomSummaryHandler:
self._server_name = hs.hostname
self._federation_client = hs.get_federation_client()
- # A map of query information to the current pagination state.
- #
- # TODO Allow for multiple workers to share this data.
- # TODO Expire pagination tokens.
- self._pagination_sessions: Dict[_PaginationKey, _PaginationSession] = {}
-
# If a user tries to fetch the same page multiple times in quick succession,
# only process the first attempt and return its result to subsequent requests.
self._pagination_response_cache: ResponseCache[
@@ -102,21 +98,6 @@ class RoomSummaryHandler:
"get_room_hierarchy",
)
- def _expire_pagination_sessions(self):
- """Expire pagination session which are old."""
- expire_before = (
- self._clock.time_msec() - self._PAGINATION_SESSION_VALIDITY_PERIOD_MS
- )
- to_expire = []
-
- for key, value in self._pagination_sessions.items():
- if value.creation_time_ms < expire_before:
- to_expire.append(key)
-
- for key in to_expire:
- logger.debug("Expiring pagination session id %s", key)
- del self._pagination_sessions[key]
-
async def get_space_summary(
self,
requester: str,
@@ -327,18 +308,29 @@ class RoomSummaryHandler:
# If this is continuing a previous session, pull the persisted data.
if from_token:
- self._expire_pagination_sessions()
+ try:
+ pagination_session = await self._store.get_session(
+ session_type=self._PAGINATION_SESSION_TYPE,
+ session_id=from_token,
+ )
+ except StoreError:
+ raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM)
- pagination_key = _PaginationKey(
- requested_room_id, suggested_only, max_depth, from_token
- )
- if pagination_key not in self._pagination_sessions:
+ # If the requester, room ID, suggested-only, or max depth were modified
+ # the session is invalid.
+ if (
+ requester != pagination_session["requester"]
+ or requested_room_id != pagination_session["room_id"]
+ or suggested_only != pagination_session["suggested_only"]
+ or max_depth != pagination_session["max_depth"]
+ ):
raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM)
# Load the previous state.
- pagination_session = self._pagination_sessions[pagination_key]
- room_queue = pagination_session.room_queue
- processed_rooms = pagination_session.processed_rooms
+ room_queue = [
+ _RoomQueueEntry(*fields) for fields in pagination_session["room_queue"]
+ ]
+ processed_rooms = set(pagination_session["processed_rooms"])
else:
# The queue of rooms to process, the next room is last on the stack.
room_queue = [_RoomQueueEntry(requested_room_id, ())]
@@ -456,13 +448,21 @@ class RoomSummaryHandler:
# If there's additional data, generate a pagination token (and persist state).
if room_queue:
- next_batch = random_string(24)
- result["next_batch"] = next_batch
- pagination_key = _PaginationKey(
- requested_room_id, suggested_only, max_depth, next_batch
- )
- self._pagination_sessions[pagination_key] = _PaginationSession(
- self._clock.time_msec(), room_queue, processed_rooms
+ result["next_batch"] = await self._store.create_session(
+ session_type=self._PAGINATION_SESSION_TYPE,
+ value={
+ # Information which must be identical across pagination.
+ "requester": requester,
+ "room_id": requested_room_id,
+ "suggested_only": suggested_only,
+ "max_depth": max_depth,
+ # The stored state.
+ "room_queue": [
+ attr.astuple(room_entry) for room_entry in room_queue
+ ],
+ "processed_rooms": list(processed_rooms),
+ },
+ expiry_ms=self._PAGINATION_SESSION_VALIDITY_PERIOD_MS,
)
return result
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 1b855a685c..0e6ebb574e 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -37,6 +37,7 @@ from twisted.web.server import Request
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.config.sso import SsoAttributeRequirement
+from synapse.handlers.register import init_counters_for_auth_provider
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html, respond_with_redirect
@@ -213,6 +214,7 @@ class SsoHandler:
p_id = p.idp_id
assert p_id not in self._identity_providers
self._identity_providers[p_id] = p
+ init_counters_for_auth_provider(p_id)
def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]:
"""Get the configured identity providers"""
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 590642f510..86c3c7f0df 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1,5 +1,4 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-# Copyright 2018, 2019 New Vector Ltd
+# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -31,6 +30,8 @@ from prometheus_client import Counter
from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.filtering import FilterCollection
+from synapse.api.presence import UserPresenceState
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
@@ -86,20 +87,20 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100
SyncRequestKey = Tuple[Any, ...]
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class SyncConfig:
- user = attr.ib(type=UserID)
- filter_collection = attr.ib(type=FilterCollection)
- is_guest = attr.ib(type=bool)
- request_key = attr.ib(type=SyncRequestKey)
- device_id = attr.ib(type=Optional[str])
+ user: UserID
+ filter_collection: FilterCollection
+ is_guest: bool
+ request_key: SyncRequestKey
+ device_id: Optional[str]
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class TimelineBatch:
- prev_batch = attr.ib(type=StreamToken)
- events = attr.ib(type=List[EventBase])
- limited = attr.ib(type=bool)
+ prev_batch: StreamToken
+ events: List[EventBase]
+ limited: bool
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -113,16 +114,16 @@ class TimelineBatch:
# if there are updates for it, which we check after the instance has been created.
# This should not be a big deal because we update the notification counts afterwards as
# well anyway.
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class JoinedSyncResult:
- room_id = attr.ib(type=str)
- timeline = attr.ib(type=TimelineBatch)
- state = attr.ib(type=StateMap[EventBase])
- ephemeral = attr.ib(type=List[JsonDict])
- account_data = attr.ib(type=List[JsonDict])
- unread_notifications = attr.ib(type=JsonDict)
- summary = attr.ib(type=Optional[JsonDict])
- unread_count = attr.ib(type=int)
+ room_id: str
+ timeline: TimelineBatch
+ state: StateMap[EventBase]
+ ephemeral: List[JsonDict]
+ account_data: List[JsonDict]
+ unread_notifications: JsonDict
+ summary: Optional[JsonDict]
+ unread_count: int
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -138,12 +139,12 @@ class JoinedSyncResult:
)
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class ArchivedSyncResult:
- room_id = attr.ib(type=str)
- timeline = attr.ib(type=TimelineBatch)
- state = attr.ib(type=StateMap[EventBase])
- account_data = attr.ib(type=List[JsonDict])
+ room_id: str
+ timeline: TimelineBatch
+ state: StateMap[EventBase]
+ account_data: List[JsonDict]
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -152,37 +153,37 @@ class ArchivedSyncResult:
return bool(self.timeline or self.state or self.account_data)
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class InvitedSyncResult:
- room_id = attr.ib(type=str)
- invite = attr.ib(type=EventBase)
+ room_id: str
+ invite: EventBase
def __bool__(self) -> bool:
"""Invited rooms should always be reported to the client"""
return True
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class KnockedSyncResult:
- room_id = attr.ib(type=str)
- knock = attr.ib(type=EventBase)
+ room_id: str
+ knock: EventBase
def __bool__(self) -> bool:
"""Knocked rooms should always be reported to the client"""
return True
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class GroupsSyncResult:
- join = attr.ib(type=JsonDict)
- invite = attr.ib(type=JsonDict)
- leave = attr.ib(type=JsonDict)
+ join: JsonDict
+ invite: JsonDict
+ leave: JsonDict
def __bool__(self) -> bool:
return bool(self.join or self.invite or self.leave)
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceLists:
"""
Attributes:
@@ -190,27 +191,27 @@ class DeviceLists:
left: List of user_ids whose devices we no longer track
"""
- changed = attr.ib(type=Collection[str])
- left = attr.ib(type=Collection[str])
+ changed: Collection[str]
+ left: Collection[str]
def __bool__(self) -> bool:
return bool(self.changed or self.left)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _RoomChanges:
"""The set of room entries to include in the sync, plus the set of joined
and left room IDs since last sync.
"""
- room_entries = attr.ib(type=List["RoomSyncResultBuilder"])
- invited = attr.ib(type=List[InvitedSyncResult])
- knocked = attr.ib(type=List[KnockedSyncResult])
- newly_joined_rooms = attr.ib(type=List[str])
- newly_left_rooms = attr.ib(type=List[str])
+ room_entries: List["RoomSyncResultBuilder"]
+ invited: List[InvitedSyncResult]
+ knocked: List[KnockedSyncResult]
+ newly_joined_rooms: List[str]
+ newly_left_rooms: List[str]
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class SyncResult:
"""
Attributes:
@@ -230,18 +231,18 @@ class SyncResult:
groups: Group updates, if any
"""
- next_batch = attr.ib(type=StreamToken)
- presence = attr.ib(type=List[JsonDict])
- account_data = attr.ib(type=List[JsonDict])
- joined = attr.ib(type=List[JoinedSyncResult])
- invited = attr.ib(type=List[InvitedSyncResult])
- knocked = attr.ib(type=List[KnockedSyncResult])
- archived = attr.ib(type=List[ArchivedSyncResult])
- to_device = attr.ib(type=List[JsonDict])
- device_lists = attr.ib(type=DeviceLists)
- device_one_time_keys_count = attr.ib(type=JsonDict)
- device_unused_fallback_key_types = attr.ib(type=List[str])
- groups = attr.ib(type=Optional[GroupsSyncResult])
+ next_batch: StreamToken
+ presence: List[UserPresenceState]
+ account_data: List[JsonDict]
+ joined: List[JoinedSyncResult]
+ invited: List[InvitedSyncResult]
+ knocked: List[KnockedSyncResult]
+ archived: List[ArchivedSyncResult]
+ to_device: List[JsonDict]
+ device_lists: DeviceLists
+ device_one_time_keys_count: JsonDict
+ device_unused_fallback_key_types: List[str]
+ groups: Optional[GroupsSyncResult]
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -701,7 +702,7 @@ class SyncHandler:
name_id = state_ids.get((EventTypes.Name, ""))
canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ""))
- summary = {}
+ summary: JsonDict = {}
empty_ms = MemberSummary([], 0)
# TODO: only send these when they change.
@@ -1843,6 +1844,9 @@ class SyncHandler:
knocked = []
for event in room_list:
+ if event.room_version_id not in KNOWN_ROOM_VERSIONS:
+ continue
+
if event.membership == Membership.JOIN:
room_entries.append(
RoomSyncResultBuilder(
@@ -2076,21 +2080,23 @@ class SyncHandler:
# If the membership's stream ordering is after the given stream
# ordering, we need to go and work out if the user was in the room
# before.
- for room_id, event_pos in joined_rooms:
- if not event_pos.persisted_after(room_key):
- joined_room_ids.add(room_id)
+ for joined_room in joined_rooms:
+ if not joined_room.event_pos.persisted_after(room_key):
+ joined_room_ids.add(joined_room.room_id)
continue
- logger.info("User joined room after current token: %s", room_id)
+ logger.info("User joined room after current token: %s", joined_room.room_id)
extrems = (
await self.store.get_forward_extremities_for_room_at_stream_ordering(
- room_id, event_pos.stream
+ joined_room.room_id, joined_room.event_pos.stream
)
)
- users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
+ users_in_room = await self.state.get_current_users_in_room(
+ joined_room.room_id, extrems
+ )
if user_id in users_in_room:
- joined_room_ids.add(room_id)
+ joined_room_ids.add(joined_room.room_id)
return frozenset(joined_room_ids)
@@ -2160,7 +2166,7 @@ def _calculate_state(
return {event_id_to_key[e]: e for e in state_ids}
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class SyncResultBuilder:
"""Used to help build up a new SyncResult for a user
@@ -2172,33 +2178,33 @@ class SyncResultBuilder:
joined_room_ids: List of rooms the user is joined to
# The following mirror the fields in a sync response
- presence (list)
- account_data (list)
- joined (list[JoinedSyncResult])
- invited (list[InvitedSyncResult])
- knocked (list[KnockedSyncResult])
- archived (list[ArchivedSyncResult])
- groups (GroupsSyncResult|None)
- to_device (list)
+ presence
+ account_data
+ joined
+ invited
+ knocked
+ archived
+ groups
+ to_device
"""
- sync_config = attr.ib(type=SyncConfig)
- full_state = attr.ib(type=bool)
- since_token = attr.ib(type=Optional[StreamToken])
- now_token = attr.ib(type=StreamToken)
- joined_room_ids = attr.ib(type=FrozenSet[str])
+ sync_config: SyncConfig
+ full_state: bool
+ since_token: Optional[StreamToken]
+ now_token: StreamToken
+ joined_room_ids: FrozenSet[str]
- presence = attr.ib(type=List[JsonDict], default=attr.Factory(list))
- account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list))
- joined = attr.ib(type=List[JoinedSyncResult], default=attr.Factory(list))
- invited = attr.ib(type=List[InvitedSyncResult], default=attr.Factory(list))
- knocked = attr.ib(type=List[KnockedSyncResult], default=attr.Factory(list))
- archived = attr.ib(type=List[ArchivedSyncResult], default=attr.Factory(list))
- groups = attr.ib(type=Optional[GroupsSyncResult], default=None)
- to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list))
+ presence: List[UserPresenceState] = attr.Factory(list)
+ account_data: List[JsonDict] = attr.Factory(list)
+ joined: List[JoinedSyncResult] = attr.Factory(list)
+ invited: List[InvitedSyncResult] = attr.Factory(list)
+ knocked: List[KnockedSyncResult] = attr.Factory(list)
+ archived: List[ArchivedSyncResult] = attr.Factory(list)
+ groups: Optional[GroupsSyncResult] = None
+ to_device: List[JsonDict] = attr.Factory(list)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class RoomSyncResultBuilder:
"""Stores information needed to create either a `JoinedSyncResult` or
`ArchivedSyncResult`.
@@ -2214,10 +2220,10 @@ class RoomSyncResultBuilder:
upto_token: Latest point to return events from.
"""
- room_id = attr.ib(type=str)
- rtype = attr.ib(type=str)
- events = attr.ib(type=Optional[List[EventBase]])
- newly_joined = attr.ib(type=bool)
- full_state = attr.ib(type=bool)
- since_token = attr.ib(type=Optional[StreamToken])
- upto_token = attr.ib(type=StreamToken)
+ room_id: str
+ rtype: str
+ events: Optional[List[EventBase]]
+ newly_joined: bool
+ full_state: bool
+ since_token: Optional[StreamToken]
+ upto_token: StreamToken
diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py
index 4c3b669fae..13b0c61d2e 100644
--- a/synapse/handlers/ui_auth/__init__.py
+++ b/synapse/handlers/ui_auth/__init__.py
@@ -34,3 +34,8 @@ class UIAuthSessionDataConstants:
# used by validate_user_via_ui_auth to store the mxid of the user we are validating
# for.
REQUEST_USER_ID = "request_user_id"
+
+ # used during registration to store the registration token used (if required) so that:
+ # - we can prevent a token being used twice by one session
+ # - we can 'use up' the token after registration has successfully completed
+ REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 5414ce77d8..d3828dec6b 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -49,7 +49,7 @@ class UserInteractiveAuthChecker:
clientip: The IP address of the client.
Raises:
- SynapseError if authentication failed
+ LoginError if authentication failed.
Returns:
The result of authentication (to pass back to the client?)
@@ -131,7 +131,9 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
)
if resp_body["success"]:
return True
- raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
+ raise LoginError(
+ 401, "Captcha authentication failed", errcode=Codes.UNAUTHORIZED
+ )
class _BaseThreepidAuthChecker:
@@ -191,7 +193,9 @@ class _BaseThreepidAuthChecker:
raise AssertionError("Unrecognized threepid medium: %s" % (medium,))
if not threepid:
- raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
+ raise LoginError(
+ 401, "Unable to get validated threepid", errcode=Codes.UNAUTHORIZED
+ )
if threepid["medium"] != medium:
raise LoginError(
@@ -237,11 +241,76 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
return await self._check_threepid("msisdn", authdict)
+class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
+ AUTH_TYPE = LoginType.REGISTRATION_TOKEN
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+ self.hs = hs
+ self._enabled = bool(hs.config.registration_requires_token)
+ self.store = hs.get_datastore()
+
+ def is_enabled(self) -> bool:
+ return self._enabled
+
+ async def check_auth(self, authdict: dict, clientip: str) -> Any:
+ if "token" not in authdict:
+ raise LoginError(400, "Missing registration token", Codes.MISSING_PARAM)
+ if not isinstance(authdict["token"], str):
+ raise LoginError(
+ 400, "Registration token must be a string", Codes.INVALID_PARAM
+ )
+ if "session" not in authdict:
+ raise LoginError(400, "Missing UIA session", Codes.MISSING_PARAM)
+
+ # Get these here to avoid cyclic dependencies
+ from synapse.handlers.ui_auth import UIAuthSessionDataConstants
+
+ auth_handler = self.hs.get_auth_handler()
+
+ session = authdict["session"]
+ token = authdict["token"]
+
+ # If the LoginType.REGISTRATION_TOKEN stage has already been completed,
+ # return early to avoid incrementing `pending` again.
+ stored_token = await auth_handler.get_session_data(
+ session, UIAuthSessionDataConstants.REGISTRATION_TOKEN
+ )
+ if stored_token:
+ if token != stored_token:
+ raise LoginError(
+ 400, "Registration token has changed", Codes.INVALID_PARAM
+ )
+ else:
+ return token
+
+ if await self.store.registration_token_is_valid(token):
+ # Increment pending counter, so that if token has limited uses it
+ # can't be used up by someone else in the meantime.
+ await self.store.set_registration_token_pending(token)
+ # Store the token in the UIA session, so that once registration
+ # is complete `completed` can be incremented.
+ await auth_handler.set_session_data(
+ session,
+ UIAuthSessionDataConstants.REGISTRATION_TOKEN,
+ token,
+ )
+ # The token will be stored as the result of the authentication stage
+ # in ui_auth_sessions_credentials. This allows the pending counter
+ # for tokens to be decremented when expired sessions are deleted.
+ return token
+ else:
+ raise LoginError(
+ 401, "Invalid registration token", errcode=Codes.UNAUTHORIZED
+ )
+
+
INTERACTIVE_AUTH_CHECKERS = [
DummyAuthChecker,
TermsAuthChecker,
RecaptchaAuthChecker,
EmailIdentityAuthChecker,
MsisdnAuthChecker,
+ RegistrationTokenAuthChecker,
]
"""A list of UserInteractiveAuthChecker classes"""
|