diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 3940da5c88..8d5b2177d2 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -41,7 +41,7 @@ class Membership:
INVITE = "invite"
JOIN = "join"
- KNOCK = "knock"
+ KNOCK = "xyz.amorgan.knock"
LEAVE = "leave"
BAN = "ban"
LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN)
@@ -58,7 +58,7 @@ class PresenceState:
class JoinRules:
PUBLIC = "public"
- KNOCK = "knock"
+ KNOCK = "xyz.amorgan.knock"
INVITE = "invite"
PRIVATE = "private"
# As defined for MSC3083.
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 0231c79079..4cb8bbaf70 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -449,7 +449,7 @@ class IncompatibleRoomVersionError(SynapseError):
super().__init__(
code=400,
msg="Your homeserver does not support the features required to "
- "join this room",
+ "interact with this room",
errcode=Codes.INCOMPATIBLE_ROOM_VERSION,
)
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 373a4669d0..3349f399ba 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -56,7 +56,7 @@ class RoomVersion:
state_res = attr.ib(type=int) # one of the StateResolutionVersions
enforce_key_validity = attr.ib(type=bool)
- # Before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
+ # Before MSC2432, m.room.aliases had special auth rules and redaction rules
special_case_aliases_auth = attr.ib(type=bool)
# Strictly enforce canonicaljson, do not allow:
# * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1]
@@ -70,6 +70,9 @@ class RoomVersion:
msc2176_redaction_rules = attr.ib(type=bool)
# MSC3083: Support the 'restricted' join_rule.
msc3083_join_rules = attr.ib(type=bool)
+ # MSC2403: Allows join_rules to be set to 'knock', changes auth rules to allow sending
+ # m.room.membership event with membership 'knock'.
+ msc2403_knocking = attr.ib(type=bool)
class RoomVersions:
@@ -84,6 +87,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
+ msc2403_knocking=False,
)
V2 = RoomVersion(
"2",
@@ -96,6 +100,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
+ msc2403_knocking=False,
)
V3 = RoomVersion(
"3",
@@ -108,6 +113,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
+ msc2403_knocking=False,
)
V4 = RoomVersion(
"4",
@@ -120,6 +126,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
+ msc2403_knocking=False,
)
V5 = RoomVersion(
"5",
@@ -132,6 +139,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
+ msc2403_knocking=False,
)
V6 = RoomVersion(
"6",
@@ -144,6 +152,7 @@ class RoomVersions:
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
+ msc2403_knocking=False,
)
MSC2176 = RoomVersion(
"org.matrix.msc2176",
@@ -156,6 +165,7 @@ class RoomVersions:
limit_notifications_power_levels=True,
msc2176_redaction_rules=True,
msc3083_join_rules=False,
+ msc2403_knocking=False,
)
MSC3083 = RoomVersion(
"org.matrix.msc3083",
@@ -168,6 +178,20 @@ class RoomVersions:
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
msc3083_join_rules=True,
+ msc2403_knocking=False,
+ )
+ MSC2403 = RoomVersion(
+ "xyz.amorgan.knock",
+ RoomDisposition.UNSTABLE,
+ EventFormatVersions.V3,
+ StateResolutionVersions.V2,
+ enforce_key_validity=True,
+ special_case_aliases_auth=False,
+ strict_canonicaljson=True,
+ limit_notifications_power_levels=True,
+ msc2176_redaction_rules=False,
+ msc3083_join_rules=False,
+ msc2403_knocking=True,
)
@@ -183,4 +207,5 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.MSC2176,
RoomVersions.MSC3083,
)
+ # Note that we do not include MSC2043 here unless it is enabled in the config.
} # type: Dict[str, RoomVersion]
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 68ae19c977..2878d2c140 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -36,7 +36,6 @@ from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.groups import SlavedGroupServerStore
-from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
@@ -54,7 +53,6 @@ class AdminCmdSlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
SlavedFilteringStore,
- SlavedPresenceStore,
SlavedGroupServerStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index fe04d7a672..61152b2c46 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
from prometheus_client import Counter
-from synapse.api.constants import EventTypes, ThirdPartyEntityKind
+from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
from synapse.events import EventBase
from synapse.events.utils import serialize_event
@@ -247,9 +247,14 @@ class ApplicationServiceApi(SimpleHttpClient):
e,
time_now,
as_client_event=True,
- is_invite=(
+ # If this is an invite or a knock membership event, and we're interested
+ # in this user, then include any stripped state alongside the event.
+ include_stripped_room_state=(
e.type == EventTypes.Member
- and e.membership == "invite"
+ and (
+ e.membership == Membership.INVITE
+ or e.membership == Membership.KNOCK
+ )
and service.is_interested_in_user(e.state_key)
),
)
diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py
index c58a7d95a7..957de7f3a6 100644
--- a/synapse/config/account_validity.py
+++ b/synapse/config/account_validity.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 6ebce4b2f7..37668079e7 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.config._base import Config
from synapse.types import JsonDict
@@ -29,3 +30,9 @@ class ExperimentalConfig(Config):
# MSC3026 (busy presence state)
self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool
+
+ # MSC2403 (room knocking)
+ self.msc2403_enabled = experimental.get("msc2403_enabled", False) # type: bool
+ if self.msc2403_enabled:
+ # Enable the MSC2403 unstable room version
+ KNOWN_ROOM_VERSIONS[RoomVersions.MSC2403.identifier] = RoomVersions.MSC2403
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index c78a83abe1..2f77d6703d 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -248,6 +248,10 @@ class ContentRepositoryConfig(Config):
# The largest allowed upload size in bytes
#
+ # If you are using a reverse proxy you may also need to set this value in
+ # your reverse proxy's config. Notably Nginx has a small max body size by default.
+ # See https://matrix-org.github.io/synapse/develop/reverse_proxy.html.
+ #
#max_upload_size: 50M
# Maximum number of pixels that will be thumbnailed
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 70c556566e..33d7c60241 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -160,6 +160,7 @@ def check(
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()])
+ # 5. If type is m.room.membership
if event.type == EventTypes.Member:
_is_membership_change_allowed(room_version_obj, event, auth_events)
logger.debug("Allowing! %s", event)
@@ -257,6 +258,11 @@ def _is_membership_change_allowed(
caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE
+ caller_knocked = (
+ caller
+ and room_version.msc2403_knocking
+ and caller.membership == Membership.KNOCK
+ )
# get info about the target
key = (EventTypes.Member, target_user_id)
@@ -283,6 +289,7 @@ def _is_membership_change_allowed(
{
"caller_in_room": caller_in_room,
"caller_invited": caller_invited,
+ "caller_knocked": caller_knocked,
"target_banned": target_banned,
"target_in_room": target_in_room,
"membership": membership,
@@ -299,9 +306,14 @@ def _is_membership_change_allowed(
raise AuthError(403, "%s is banned from the room" % (target_user_id,))
return
- if Membership.JOIN != membership:
+ # Require the user to be in the room for membership changes other than join/knock.
+ if Membership.JOIN != membership and (
+ RoomVersion.msc2403_knocking and Membership.KNOCK != membership
+ ):
+ # If the user has been invited or has knocked, they are allowed to change their
+ # membership event to leave
if (
- caller_invited
+ (caller_invited or caller_knocked)
and Membership.LEAVE == membership
and target_user_id == event.user_id
):
@@ -339,7 +351,9 @@ def _is_membership_change_allowed(
and join_rule == JoinRules.MSC3083_RESTRICTED
):
pass
- elif join_rule == JoinRules.INVITE:
+ elif join_rule == JoinRules.INVITE or (
+ room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
+ ):
if not caller_in_room and not caller_invited:
raise AuthError(403, "You are not invited to this room.")
else:
@@ -358,6 +372,17 @@ def _is_membership_change_allowed(
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban")
+ elif room_version.msc2403_knocking and Membership.KNOCK == membership:
+ if join_rule != JoinRules.KNOCK:
+ raise AuthError(403, "You don't have permission to knock")
+ elif target_user_id != event.user_id:
+ raise AuthError(403, "You cannot knock for other users")
+ elif target_in_room:
+ raise AuthError(403, "You cannot knock on a room you are already in")
+ elif caller_invited:
+ raise AuthError(403, "You are already invited to this room")
+ elif target_banned:
+ raise AuthError(403, "You are banned from this room")
else:
raise AuthError(500, "Unknown membership %s" % membership)
@@ -718,7 +743,7 @@ def auth_types_for_event(event: EventBase) -> Set[Tuple[str, str]]:
if event.type == EventTypes.Member:
membership = event.content["membership"]
- if membership in [Membership.JOIN, Membership.INVITE]:
+ if membership in [Membership.JOIN, Membership.INVITE, Membership.KNOCK]:
auth_types.add((EventTypes.JoinRules, ""))
auth_types.add((EventTypes.Member, event.state_key))
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 7d7cd9aaee..ec96999e4e 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -242,6 +242,7 @@ def format_event_for_client_v1(d):
"replaces_state",
"prev_content",
"invite_room_state",
+ "knock_room_state",
)
for key in copy_keys:
if key in d["unsigned"]:
@@ -278,7 +279,7 @@ def serialize_event(
event_format=format_event_for_client_v1,
token_id=None,
only_event_fields=None,
- is_invite=False,
+ include_stripped_room_state=False,
):
"""Serialize event for clients
@@ -289,8 +290,10 @@ def serialize_event(
event_format
token_id
only_event_fields
- is_invite (bool): Whether this is an invite that is being sent to the
- invitee
+ include_stripped_room_state (bool): Some events can have stripped room state
+ stored in the `unsigned` field. This is required for invite and knock
+ functionality. If this option is False, that state will be removed from the
+ event before it is returned. Otherwise, it will be kept.
Returns:
dict
@@ -322,11 +325,13 @@ def serialize_event(
if txn_id is not None:
d["unsigned"]["transaction_id"] = txn_id
- # If this is an invite for somebody else, then we don't care about the
- # invite_room_state as that's meant solely for the invitee. Other clients
- # will already have the state since they're in the room.
- if not is_invite:
+ # invite_room_state and knock_room_state are a list of stripped room state events
+ # that are meant to provide metadata about a room to an invitee/knocker. They are
+ # intended to only be included in specific circumstances, such as down sync, and
+ # should not be included in any other case.
+ if not include_stripped_room_state:
d["unsigned"].pop("invite_room_state", None)
+ d["unsigned"].pop("knock_room_state", None)
if as_client_event:
d = event_format(d)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 1076ebc036..03ec14ce87 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -1,4 +1,5 @@
-# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -89,6 +90,7 @@ class FederationClient(FederationBase):
self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
+ self._msc2403_enabled = hs.config.experimental.msc2403_enabled
self.hostname = hs.hostname
self.signing_key = hs.signing_key
@@ -620,6 +622,11 @@ class FederationClient(FederationBase):
no servers successfully handle the request.
"""
valid_memberships = {Membership.JOIN, Membership.LEAVE}
+
+ # Allow knocking if the feature is enabled
+ if self._msc2403_enabled:
+ valid_memberships.add(Membership.KNOCK)
+
if membership not in valid_memberships:
raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s"
@@ -638,6 +645,13 @@ class FederationClient(FederationBase):
if not room_version:
raise UnsupportedRoomVersionError()
+ if not room_version.msc2403_knocking and membership == Membership.KNOCK:
+ raise SynapseError(
+ 400,
+ "This room version does not support knocking",
+ errcode=Codes.FORBIDDEN,
+ )
+
pdu_dict = ret.get("event", None)
if not isinstance(pdu_dict, dict):
raise InvalidResponseError("Bad 'event' field in response")
@@ -946,6 +960,62 @@ class FederationClient(FederationBase):
# content.
return resp[1]
+ async def send_knock(self, destinations: List[str], pdu: EventBase) -> JsonDict:
+ """Attempts to send a knock event to given a list of servers. Iterates
+ through the list until one attempt succeeds.
+
+ Doing so will cause the remote server to add the event to the graph,
+ and send the event out to the rest of the federation.
+
+ Args:
+ destinations: A list of candidate homeservers which are likely to be
+ participating in the room.
+ pdu: The event to be sent.
+
+ Returns:
+ The remote homeserver return some state from the room. The response
+ dictionary is in the form:
+
+ {"knock_state_events": [<state event dict>, ...]}
+
+ The list of state events may be empty.
+
+ Raises:
+ SynapseError: If the chosen remote server returns a 3xx/4xx code.
+ RuntimeError: If no servers were reachable.
+ """
+
+ async def send_request(destination: str) -> JsonDict:
+ return await self._do_send_knock(destination, pdu)
+
+ return await self._try_destination_list(
+ "xyz.amorgan.knock/send_knock", destinations, send_request
+ )
+
+ async def _do_send_knock(self, destination: str, pdu: EventBase) -> JsonDict:
+ """Send a knock event to a remote homeserver.
+
+ Args:
+ destination: The homeserver to send to.
+ pdu: The event to send.
+
+ Returns:
+ The remote homeserver can optionally return some state from the room. The response
+ dictionary is in the form:
+
+ {"knock_state_events": [<state event dict>, ...]}
+
+ The list of state events may be empty.
+ """
+ time_now = self._clock.time_msec()
+
+ return await self.transport_layer.send_knock_v1(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+
async def get_public_rooms(
self,
remote_server: str,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index ace30aa450..2b07f18529 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -129,7 +129,7 @@ class FederationServer(FederationBase):
# come in waves.
self._state_resp_cache = ResponseCache(
hs.get_clock(), "state_resp", timeout_ms=30000
- ) # type: ResponseCache[Tuple[str, str]]
+ ) # type: ResponseCache[Tuple[str, Optional[str]]]
self._state_ids_resp_cache = ResponseCache(
hs.get_clock(), "state_ids_resp", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
@@ -138,6 +138,8 @@ class FederationServer(FederationBase):
hs.config.federation.federation_metrics_domains
)
+ self._room_prejoin_state_types = hs.config.api.room_prejoin_state
+
async def on_backfill_request(
self, origin: str, room_id: str, versions: List[str], limit: int
) -> Tuple[int, Dict[str, Any]]:
@@ -406,7 +408,7 @@ class FederationServer(FederationBase):
)
async def on_room_state_request(
- self, origin: str, room_id: str, event_id: str
+ self, origin: str, room_id: str, event_id: Optional[str]
) -> Tuple[int, Dict[str, Any]]:
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
@@ -463,7 +465,7 @@ class FederationServer(FederationBase):
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
async def _on_context_state_request_compute(
- self, room_id: str, event_id: str
+ self, room_id: str, event_id: Optional[str]
) -> Dict[str, list]:
if event_id:
pdus = await self.handler.get_state_for_pdu(
@@ -586,6 +588,103 @@ class FederationServer(FederationBase):
await self.handler.on_send_leave_request(origin, pdu)
return {}
+ async def on_make_knock_request(
+ self, origin: str, room_id: str, user_id: str, supported_versions: List[str]
+ ) -> Dict[str, Union[EventBase, str]]:
+ """We've received a /make_knock/ request, so we create a partial knock
+ event for the room and hand that back, along with the room version, to the knocking
+ homeserver. We do *not* persist or process this event until the other server has
+ signed it and sent it back.
+
+ Args:
+ origin: The (verified) server name of the requesting server.
+ room_id: The room to create the knock event in.
+ user_id: The user to create the knock for.
+ supported_versions: The room versions supported by the requesting server.
+
+ Returns:
+ The partial knock event.
+ """
+ origin_host, _ = parse_server_name(origin)
+ await self.check_server_matches_acl(origin_host, room_id)
+
+ room_version = await self.store.get_room_version(room_id)
+
+ # Check that this room version is supported by the remote homeserver
+ if room_version.identifier not in supported_versions:
+ logger.warning(
+ "Room version %s not in %s", room_version.identifier, supported_versions
+ )
+ raise IncompatibleRoomVersionError(room_version=room_version.identifier)
+
+ # Check that this room supports knocking as defined by its room version
+ if not room_version.msc2403_knocking:
+ raise SynapseError(
+ 403,
+ "This room version does not support knocking",
+ errcode=Codes.FORBIDDEN,
+ )
+
+ pdu = await self.handler.on_make_knock_request(origin, room_id, user_id)
+ time_now = self._clock.time_msec()
+ return {
+ "event": pdu.get_pdu_json(time_now),
+ "room_version": room_version.identifier,
+ }
+
+ async def on_send_knock_request(
+ self,
+ origin: str,
+ content: JsonDict,
+ room_id: str,
+ ) -> Dict[str, List[JsonDict]]:
+ """
+ We have received a knock event for a room. Verify and send the event into the room
+ on the knocking homeserver's behalf. Then reply with some stripped state from the
+ room for the knockee.
+
+ Args:
+ origin: The remote homeserver of the knocking user.
+ content: The content of the request.
+ room_id: The ID of the room to knock on.
+
+ Returns:
+ The stripped room state.
+ """
+ logger.debug("on_send_knock_request: content: %s", content)
+
+ room_version = await self.store.get_room_version(room_id)
+
+ # Check that this room supports knocking as defined by its room version
+ if not room_version.msc2403_knocking:
+ raise SynapseError(
+ 403,
+ "This room version does not support knocking",
+ errcode=Codes.FORBIDDEN,
+ )
+
+ pdu = event_from_pdu_json(content, room_version)
+
+ origin_host, _ = parse_server_name(origin)
+ await self.check_server_matches_acl(origin_host, pdu.room_id)
+
+ logger.debug("on_send_knock_request: pdu sigs: %s", pdu.signatures)
+
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
+
+ # Handle the event, and retrieve the EventContext
+ event_context = await self.handler.on_send_knock_request(origin, pdu)
+
+ # Retrieve stripped state events from the room and send them back to the remote
+ # server. This will allow the remote server's clients to display information
+ # related to the room while the knock request is pending.
+ stripped_room_state = (
+ await self.store.get_stripped_room_state_from_event_context(
+ event_context, self._room_prejoin_state_types
+ )
+ )
+ return {"knock_state_events": stripped_room_state}
+
async def on_event_auth(
self, origin: str, room_id: str, event_id: str
) -> Tuple[int, Dict[str, Any]]:
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 5b4f5d17f7..af0c679ed9 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -1,5 +1,5 @@
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -47,6 +47,7 @@ class TransportLayerClient:
def __init__(self, hs):
self.server_name = hs.hostname
self.client = hs.get_federation_http_client()
+ self._msc2403_enabled = hs.config.experimental.msc2403_enabled
@log_function
def get_room_state_ids(self, destination, room_id, event_id):
@@ -221,12 +222,28 @@ class TransportLayerClient:
is not in our federation whitelist
"""
valid_memberships = {Membership.JOIN, Membership.LEAVE}
+
+ # Allow knocking if the feature is enabled
+ if self._msc2403_enabled:
+ valid_memberships.add(Membership.KNOCK)
+
if membership not in valid_memberships:
raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s"
% (membership, ",".join(valid_memberships))
)
- path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id)
+
+ # Knock currently uses an unstable prefix
+ if membership == Membership.KNOCK:
+ # Create a path in the form of /unstable/xyz.amorgan.knock/make_knock/...
+ path = _create_path(
+ FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock",
+ "/make_knock/%s/%s",
+ room_id,
+ user_id,
+ )
+ else:
+ path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id)
ignore_backoff = False
retry_on_dns_fail = False
@@ -322,6 +339,45 @@ class TransportLayerClient:
return response
@log_function
+ async def send_knock_v1(
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ content: JsonDict,
+ ) -> JsonDict:
+ """
+ Sends a signed knock membership event to a remote server. This is the second
+ step for knocking after make_knock.
+
+ Args:
+ destination: The remote homeserver.
+ room_id: The ID of the room to knock on.
+ event_id: The ID of the knock membership event that we're sending.
+ content: The knock membership event that we're sending. Note that this is not the
+ `content` field of the membership event, but the entire signed membership event
+ itself represented as a JSON dict.
+
+ Returns:
+ The remote homeserver can optionally return some state from the room. The response
+ dictionary is in the form:
+
+ {"knock_state_events": [<state event dict>, ...]}
+
+ The list of state events may be empty.
+ """
+ path = _create_path(
+ FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock",
+ "/send_knock/%s/%s",
+ room_id,
+ event_id,
+ )
+
+ return await self.client.put_json(
+ destination=destination, path=path, data=content
+ )
+
+ @log_function
async def send_invite_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 5756fcb551..fe5fb6bee7 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -1,6 +1,5 @@
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import functools
import logging
import re
@@ -28,12 +26,14 @@ from synapse.api.urls import (
FEDERATION_V1_PREFIX,
FEDERATION_V2_PREFIX,
)
+from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.http.server import HttpServer, JsonResource
from synapse.http.servlet import (
parse_boolean_from_args,
parse_integer_from_args,
parse_json_object_from_request,
parse_string_from_args,
+ parse_strings_from_args,
)
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
@@ -275,10 +275,17 @@ class BaseFederationServlet:
RATELIMIT = True # Whether to rate limit requests or not
- def __init__(self, handler, authenticator, ratelimiter, server_name):
- self.handler = handler
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ self.hs = hs
self.authenticator = authenticator
self.ratelimiter = ratelimiter
+ self.server_name = server_name
def _wrap(self, func):
authenticator = self.authenticator
@@ -375,17 +382,30 @@ class BaseFederationServlet:
)
-class FederationSendServlet(BaseFederationServlet):
+class BaseFederationServerServlet(BaseFederationServlet):
+ """Abstract base class for federation servlet classes which provides a federation server handler.
+
+ See BaseFederationServlet for more information.
+ """
+
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.handler = hs.get_federation_server()
+
+
+class FederationSendServlet(BaseFederationServerServlet):
PATH = "/send/(?P<transaction_id>[^/]*)/?"
# We ratelimit manually in the handler as we queue up the requests and we
# don't want to fill up the ratelimiter with blocked requests.
RATELIMIT = False
- def __init__(self, handler, server_name, **kwargs):
- super().__init__(handler, server_name=server_name, **kwargs)
- self.server_name = server_name
-
# This is when someone is trying to send us a bunch of data.
async def on_PUT(self, origin, content, query, transaction_id):
"""Called on PUT /send/<transaction_id>/
@@ -434,7 +454,7 @@ class FederationSendServlet(BaseFederationServlet):
return code, response
-class FederationEventServlet(BaseFederationServlet):
+class FederationEventServlet(BaseFederationServerServlet):
PATH = "/event/(?P<event_id>[^/]*)/?"
# This is when someone asks for a data item for a given server data_id pair.
@@ -442,7 +462,7 @@ class FederationEventServlet(BaseFederationServlet):
return await self.handler.on_pdu_request(origin, event_id)
-class FederationStateV1Servlet(BaseFederationServlet):
+class FederationStateV1Servlet(BaseFederationServerServlet):
PATH = "/state/(?P<room_id>[^/]*)/?"
# This is when someone asks for all data for a given room.
@@ -454,7 +474,7 @@ class FederationStateV1Servlet(BaseFederationServlet):
)
-class FederationStateIdsServlet(BaseFederationServlet):
+class FederationStateIdsServlet(BaseFederationServerServlet):
PATH = "/state_ids/(?P<room_id>[^/]*)/?"
async def on_GET(self, origin, content, query, room_id):
@@ -465,7 +485,7 @@ class FederationStateIdsServlet(BaseFederationServlet):
)
-class FederationBackfillServlet(BaseFederationServlet):
+class FederationBackfillServlet(BaseFederationServerServlet):
PATH = "/backfill/(?P<room_id>[^/]*)/?"
async def on_GET(self, origin, content, query, room_id):
@@ -478,7 +498,7 @@ class FederationBackfillServlet(BaseFederationServlet):
return await self.handler.on_backfill_request(origin, room_id, versions, limit)
-class FederationQueryServlet(BaseFederationServlet):
+class FederationQueryServlet(BaseFederationServerServlet):
PATH = "/query/(?P<query_type>[^/]*)"
# This is when we receive a server-server Query
@@ -488,7 +508,7 @@ class FederationQueryServlet(BaseFederationServlet):
return await self.handler.on_query_request(query_type, args)
-class FederationMakeJoinServlet(BaseFederationServlet):
+class FederationMakeJoinServlet(BaseFederationServerServlet):
PATH = "/make_join/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
async def on_GET(self, origin, _content, query, room_id, user_id):
@@ -518,7 +538,7 @@ class FederationMakeJoinServlet(BaseFederationServlet):
return 200, content
-class FederationMakeLeaveServlet(BaseFederationServlet):
+class FederationMakeLeaveServlet(BaseFederationServerServlet):
PATH = "/make_leave/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
async def on_GET(self, origin, content, query, room_id, user_id):
@@ -526,7 +546,7 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
return 200, content
-class FederationV1SendLeaveServlet(BaseFederationServlet):
+class FederationV1SendLeaveServlet(BaseFederationServerServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id, event_id):
@@ -534,7 +554,7 @@ class FederationV1SendLeaveServlet(BaseFederationServlet):
return 200, (200, content)
-class FederationV2SendLeaveServlet(BaseFederationServlet):
+class FederationV2SendLeaveServlet(BaseFederationServerServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
@@ -544,14 +564,42 @@ class FederationV2SendLeaveServlet(BaseFederationServlet):
return 200, content
-class FederationEventAuthServlet(BaseFederationServlet):
+class FederationMakeKnockServlet(BaseFederationServerServlet):
+ PATH = "/make_knock/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
+
+ PREFIX = FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock"
+
+ async def on_GET(self, origin, content, query, room_id, user_id):
+ try:
+ # Retrieve the room versions the remote homeserver claims to support
+ supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8")
+ except KeyError:
+ raise SynapseError(400, "Missing required query parameter 'ver'")
+
+ content = await self.handler.on_make_knock_request(
+ origin, room_id, user_id, supported_versions=supported_versions
+ )
+ return 200, content
+
+
+class FederationV1SendKnockServlet(BaseFederationServerServlet):
+ PATH = "/send_knock/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
+
+ PREFIX = FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock"
+
+ async def on_PUT(self, origin, content, query, room_id, event_id):
+ content = await self.handler.on_send_knock_request(origin, content, room_id)
+ return 200, content
+
+
+class FederationEventAuthServlet(BaseFederationServerServlet):
PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_GET(self, origin, content, query, room_id, event_id):
return await self.handler.on_event_auth(origin, room_id, event_id)
-class FederationV1SendJoinServlet(BaseFederationServlet):
+class FederationV1SendJoinServlet(BaseFederationServerServlet):
PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id, event_id):
@@ -561,7 +609,7 @@ class FederationV1SendJoinServlet(BaseFederationServlet):
return 200, (200, content)
-class FederationV2SendJoinServlet(BaseFederationServlet):
+class FederationV2SendJoinServlet(BaseFederationServerServlet):
PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
@@ -573,7 +621,7 @@ class FederationV2SendJoinServlet(BaseFederationServlet):
return 200, content
-class FederationV1InviteServlet(BaseFederationServlet):
+class FederationV1InviteServlet(BaseFederationServerServlet):
PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id, event_id):
@@ -590,7 +638,7 @@ class FederationV1InviteServlet(BaseFederationServlet):
return 200, (200, content)
-class FederationV2InviteServlet(BaseFederationServlet):
+class FederationV2InviteServlet(BaseFederationServerServlet):
PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
@@ -614,7 +662,7 @@ class FederationV2InviteServlet(BaseFederationServlet):
return 200, content
-class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
+class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet):
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id):
@@ -622,21 +670,21 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
return 200, {}
-class FederationClientKeysQueryServlet(BaseFederationServlet):
+class FederationClientKeysQueryServlet(BaseFederationServerServlet):
PATH = "/user/keys/query"
async def on_POST(self, origin, content, query):
return await self.handler.on_query_client_keys(origin, content)
-class FederationUserDevicesQueryServlet(BaseFederationServlet):
+class FederationUserDevicesQueryServlet(BaseFederationServerServlet):
PATH = "/user/devices/(?P<user_id>[^/]*)"
async def on_GET(self, origin, content, query, user_id):
return await self.handler.on_query_user_devices(origin, user_id)
-class FederationClientKeysClaimServlet(BaseFederationServlet):
+class FederationClientKeysClaimServlet(BaseFederationServerServlet):
PATH = "/user/keys/claim"
async def on_POST(self, origin, content, query):
@@ -644,7 +692,7 @@ class FederationClientKeysClaimServlet(BaseFederationServlet):
return 200, response
-class FederationGetMissingEventsServlet(BaseFederationServlet):
+class FederationGetMissingEventsServlet(BaseFederationServerServlet):
# TODO(paul): Why does this path alone end with "/?" optional?
PATH = "/get_missing_events/(?P<room_id>[^/]*)/?"
@@ -664,7 +712,7 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
return 200, content
-class On3pidBindServlet(BaseFederationServlet):
+class On3pidBindServlet(BaseFederationServerServlet):
PATH = "/3pid/onbind"
REQUIRE_AUTH = False
@@ -694,7 +742,7 @@ class On3pidBindServlet(BaseFederationServlet):
return 200, {}
-class OpenIdUserInfo(BaseFederationServlet):
+class OpenIdUserInfo(BaseFederationServerServlet):
"""
Exchange a bearer token for information about a user.
@@ -770,8 +818,16 @@ class PublicRoomList(BaseFederationServlet):
PATH = "/publicRooms"
- def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access):
- super().__init__(handler, authenticator, ratelimiter, server_name)
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ allow_access: bool,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.handler = hs.get_room_list_handler()
self.allow_access = allow_access
async def on_GET(self, origin, content, query):
@@ -856,7 +912,24 @@ class FederationVersionServlet(BaseFederationServlet):
)
-class FederationGroupsProfileServlet(BaseFederationServlet):
+class BaseGroupsServerServlet(BaseFederationServlet):
+ """Abstract base class for federation servlet classes which provides a groups server handler.
+
+ See BaseFederationServlet for more information.
+ """
+
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.handler = hs.get_groups_server_handler()
+
+
+class FederationGroupsProfileServlet(BaseGroupsServerServlet):
"""Get/set the basic profile of a group on behalf of a user"""
PATH = "/groups/(?P<group_id>[^/]*)/profile"
@@ -882,7 +955,7 @@ class FederationGroupsProfileServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsSummaryServlet(BaseFederationServlet):
+class FederationGroupsSummaryServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/summary"
async def on_GET(self, origin, content, query, group_id):
@@ -895,7 +968,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsRoomsServlet(BaseFederationServlet):
+class FederationGroupsRoomsServlet(BaseGroupsServerServlet):
"""Get the rooms in a group on behalf of a user"""
PATH = "/groups/(?P<group_id>[^/]*)/rooms"
@@ -910,7 +983,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsAddRoomsServlet(BaseFederationServlet):
+class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet):
"""Add/remove room from group"""
PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
@@ -938,7 +1011,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
+class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet):
"""Update room config in group"""
PATH = (
@@ -958,7 +1031,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
return 200, result
-class FederationGroupsUsersServlet(BaseFederationServlet):
+class FederationGroupsUsersServlet(BaseGroupsServerServlet):
"""Get the users in a group on behalf of a user"""
PATH = "/groups/(?P<group_id>[^/]*)/users"
@@ -973,7 +1046,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
+class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet):
"""Get the users that have been invited to a group"""
PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
@@ -990,7 +1063,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsInviteServlet(BaseFederationServlet):
+class FederationGroupsInviteServlet(BaseGroupsServerServlet):
"""Ask a group server to invite someone to the group"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
@@ -1007,7 +1080,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
+class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet):
"""Accept an invitation from the group server"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
@@ -1021,7 +1094,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsJoinServlet(BaseFederationServlet):
+class FederationGroupsJoinServlet(BaseGroupsServerServlet):
"""Attempt to join a group"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
@@ -1035,7 +1108,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsRemoveUserServlet(BaseFederationServlet):
+class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet):
"""Leave or kick a user from the group"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
@@ -1052,7 +1125,24 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsLocalInviteServlet(BaseFederationServlet):
+class BaseGroupsLocalServlet(BaseFederationServlet):
+ """Abstract base class for federation servlet classes which provides a groups local handler.
+
+ See BaseFederationServlet for more information.
+ """
+
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.handler = hs.get_groups_local_handler()
+
+
+class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet):
"""A group server has invited a local user"""
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
@@ -1061,12 +1151,16 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet):
if get_domain_from_id(group_id) != origin:
raise SynapseError(403, "group_id doesn't match origin")
+ assert isinstance(
+ self.handler, GroupsLocalHandler
+ ), "Workers cannot handle group invites."
+
new_content = await self.handler.on_invite(group_id, user_id, content)
return 200, new_content
-class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
+class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet):
"""A group server has removed a local user"""
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
@@ -1075,6 +1169,10 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
if get_domain_from_id(group_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
+ assert isinstance(
+ self.handler, GroupsLocalHandler
+ ), "Workers cannot handle group removals."
+
new_content = await self.handler.user_removed_from_group(
group_id, user_id, content
)
@@ -1087,6 +1185,16 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)"
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.handler = hs.get_groups_attestation_renewer()
+
async def on_POST(self, origin, content, query, group_id, user_id):
# We don't need to check auth here as we check the attestation signatures
@@ -1097,7 +1205,7 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
+class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet):
"""Add/remove a room from the group summary, with optional category.
Matches both:
@@ -1154,7 +1262,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
return 200, resp
-class FederationGroupsCategoriesServlet(BaseFederationServlet):
+class FederationGroupsCategoriesServlet(BaseGroupsServerServlet):
"""Get all categories for a group"""
PATH = "/groups/(?P<group_id>[^/]*)/categories/?"
@@ -1169,7 +1277,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
return 200, resp
-class FederationGroupsCategoryServlet(BaseFederationServlet):
+class FederationGroupsCategoryServlet(BaseGroupsServerServlet):
"""Add/remove/get a category in a group"""
PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
@@ -1222,7 +1330,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
return 200, resp
-class FederationGroupsRolesServlet(BaseFederationServlet):
+class FederationGroupsRolesServlet(BaseGroupsServerServlet):
"""Get roles in a group"""
PATH = "/groups/(?P<group_id>[^/]*)/roles/?"
@@ -1237,7 +1345,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
return 200, resp
-class FederationGroupsRoleServlet(BaseFederationServlet):
+class FederationGroupsRoleServlet(BaseGroupsServerServlet):
"""Add/remove/get a role in a group"""
PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
@@ -1290,7 +1398,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
return 200, resp
-class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
+class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet):
"""Add/remove a user from the group summary, with optional role.
Matches both:
@@ -1345,7 +1453,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
return 200, resp
-class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
+class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet):
"""Get roles in a group"""
PATH = "/get_groups_publicised"
@@ -1358,7 +1466,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
return 200, resp
-class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
+class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet):
"""Sets whether a group is joinable without an invite or knock"""
PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"
@@ -1379,6 +1487,16 @@ class FederationSpaceSummaryServlet(BaseFederationServlet):
PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946"
PATH = "/spaces/(?P<room_id>[^/]*)"
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.handler = hs.get_space_summary_handler()
+
async def on_GET(
self,
origin: str,
@@ -1444,16 +1562,25 @@ class RoomComplexityServlet(BaseFederationServlet):
PATH = "/rooms/(?P<room_id>[^/]*)/complexity"
PREFIX = FEDERATION_UNSTABLE_PREFIX
- async def on_GET(self, origin, content, query, room_id):
-
- store = self.handler.hs.get_datastore()
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self._store = self.hs.get_datastore()
- is_public = await store.is_room_world_readable_or_publicly_joinable(room_id)
+ async def on_GET(self, origin, content, query, room_id):
+ is_public = await self._store.is_room_world_readable_or_publicly_joinable(
+ room_id
+ )
if not is_public:
raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM)
- complexity = await store.get_room_complexity(room_id)
+ complexity = await self._store.get_room_complexity(room_id)
return 200, complexity
@@ -1482,6 +1609,7 @@ FEDERATION_SERVLET_CLASSES = (
On3pidBindServlet,
FederationVersionServlet,
RoomComplexityServlet,
+ FederationSpaceSummaryServlet,
) # type: Tuple[Type[BaseFederationServlet], ...]
OPENID_SERVLET_CLASSES = (
@@ -1523,6 +1651,13 @@ GROUP_ATTESTATION_SERVLET_CLASSES = (
FederationGroupsRenewAttestaionServlet,
) # type: Tuple[Type[BaseFederationServlet], ...]
+
+MSC2403_SERVLET_CLASSES = (
+ FederationV1SendKnockServlet,
+ FederationMakeKnockServlet,
+)
+
+
DEFAULT_SERVLET_GROUPS = (
"federation",
"room_list",
@@ -1559,23 +1694,26 @@ def register_servlets(
if "federation" in servlet_groups:
for servletclass in FEDERATION_SERVLET_CLASSES:
servletclass(
- handler=hs.get_federation_server(),
+ hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)
- FederationSpaceSummaryServlet(
- handler=hs.get_space_summary_handler(),
- authenticator=authenticator,
- ratelimiter=ratelimiter,
- server_name=hs.hostname,
- ).register(resource)
+ # Register msc2403 (knocking) servlets if the feature is enabled
+ if hs.config.experimental.msc2403_enabled:
+ for servletclass in MSC2403_SERVLET_CLASSES:
+ servletclass(
+ hs=hs,
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
if "openid" in servlet_groups:
for servletclass in OPENID_SERVLET_CLASSES:
servletclass(
- handler=hs.get_federation_server(),
+ hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
@@ -1584,7 +1722,7 @@ def register_servlets(
if "room_list" in servlet_groups:
for servletclass in ROOM_LIST_CLASSES:
servletclass(
- handler=hs.get_room_list_handler(),
+ hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
@@ -1594,7 +1732,7 @@ def register_servlets(
if "group_server" in servlet_groups:
for servletclass in GROUP_SERVER_SERVLET_CLASSES:
servletclass(
- handler=hs.get_groups_server_handler(),
+ hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
@@ -1603,7 +1741,7 @@ def register_servlets(
if "group_local" in servlet_groups:
for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
servletclass(
- handler=hs.get_groups_local_handler(),
+ hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
@@ -1612,7 +1750,7 @@ def register_servlets(
if "group_attestation" in servlet_groups:
for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
servletclass(
- handler=hs.get_groups_attestation_renewer(),
+ hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 974487800d..3972849d4d 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -79,9 +79,15 @@ class E2eKeysHandler:
"client_keys", self.on_federation_query_client_keys
)
+ # Limit the number of in-flight requests from a single device.
+ self._query_devices_linearizer = Linearizer(
+ name="query_devices",
+ max_count=10,
+ )
+
@trace
async def query_devices(
- self, query_body: JsonDict, timeout: int, from_user_id: str
+ self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
) -> JsonDict:
"""Handle a device key query from a client
@@ -105,191 +111,197 @@ class E2eKeysHandler:
from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users
can see.
+ from_device_id: the device making the query. This is used to limit
+ the number of in-flight queries at a time.
"""
-
- device_keys_query = query_body.get(
- "device_keys", {}
- ) # type: Dict[str, Iterable[str]]
-
- # separate users by domain.
- # make a map from domain to user_id to device_ids
- local_query = {}
- remote_queries = {}
-
- for user_id, device_ids in device_keys_query.items():
- # we use UserID.from_string to catch invalid user ids
- if self.is_mine(UserID.from_string(user_id)):
- local_query[user_id] = device_ids
- else:
- remote_queries[user_id] = device_ids
-
- set_tag("local_key_query", local_query)
- set_tag("remote_key_query", remote_queries)
-
- # First get local devices.
- # A map of destination -> failure response.
- failures = {} # type: Dict[str, JsonDict]
- results = {}
- if local_query:
- local_result = await self.query_local_devices(local_query)
- for user_id, keys in local_result.items():
- if user_id in local_query:
- results[user_id] = keys
-
- # Get cached cross-signing keys
- cross_signing_keys = await self.get_cross_signing_keys_from_cache(
- device_keys_query, from_user_id
- )
-
- # Now attempt to get any remote devices from our local cache.
- # A map of destination -> user ID -> device IDs.
- remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]]
- if remote_queries:
- query_list = [] # type: List[Tuple[str, Optional[str]]]
- for user_id, device_ids in remote_queries.items():
- if device_ids:
- query_list.extend((user_id, device_id) for device_id in device_ids)
+ with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
+ device_keys_query = query_body.get(
+ "device_keys", {}
+ ) # type: Dict[str, Iterable[str]]
+
+ # separate users by domain.
+ # make a map from domain to user_id to device_ids
+ local_query = {}
+ remote_queries = {}
+
+ for user_id, device_ids in device_keys_query.items():
+ # we use UserID.from_string to catch invalid user ids
+ if self.is_mine(UserID.from_string(user_id)):
+ local_query[user_id] = device_ids
else:
- query_list.append((user_id, None))
-
- (
- user_ids_not_in_cache,
- remote_results,
- ) = await self.store.get_user_devices_from_cache(query_list)
- for user_id, devices in remote_results.items():
- user_devices = results.setdefault(user_id, {})
- for device_id, device in devices.items():
- keys = device.get("keys", None)
- device_display_name = device.get("device_display_name", None)
- if keys:
- result = dict(keys)
- unsigned = result.setdefault("unsigned", {})
- if device_display_name:
- unsigned["device_display_name"] = device_display_name
- user_devices[device_id] = result
-
- # check for missing cross-signing keys.
- for user_id in remote_queries.keys():
- cached_cross_master = user_id in cross_signing_keys["master_keys"]
- cached_cross_selfsigning = (
- user_id in cross_signing_keys["self_signing_keys"]
- )
-
- # check if we are missing only one of cross-signing master or
- # self-signing key, but the other one is cached.
- # as we need both, this will issue a federation request.
- # if we don't have any of the keys, either the user doesn't have
- # cross-signing set up, or the cached device list
- # is not (yet) updated.
- if cached_cross_master ^ cached_cross_selfsigning:
- user_ids_not_in_cache.add(user_id)
-
- # add those users to the list to fetch over federation.
- for user_id in user_ids_not_in_cache:
- domain = get_domain_from_id(user_id)
- r = remote_queries_not_in_cache.setdefault(domain, {})
- r[user_id] = remote_queries[user_id]
-
- # Now fetch any devices that we don't have in our cache
- @trace
- async def do_remote_query(destination):
- """This is called when we are querying the device list of a user on
- a remote homeserver and their device list is not in the device list
- cache. If we share a room with this user and we're not querying for
- specific user we will update the cache with their device list.
- """
-
- destination_query = remote_queries_not_in_cache[destination]
-
- # We first consider whether we wish to update the device list cache with
- # the users device list. We want to track a user's devices when the
- # authenticated user shares a room with the queried user and the query
- # has not specified a particular device.
- # If we update the cache for the queried user we remove them from further
- # queries. We use the more efficient batched query_client_keys for all
- # remaining users
- user_ids_updated = []
- for (user_id, device_list) in destination_query.items():
- if user_id in user_ids_updated:
- continue
-
- if device_list:
- continue
+ remote_queries[user_id] = device_ids
+
+ set_tag("local_key_query", local_query)
+ set_tag("remote_key_query", remote_queries)
+
+ # First get local devices.
+ # A map of destination -> failure response.
+ failures = {} # type: Dict[str, JsonDict]
+ results = {}
+ if local_query:
+ local_result = await self.query_local_devices(local_query)
+ for user_id, keys in local_result.items():
+ if user_id in local_query:
+ results[user_id] = keys
- room_ids = await self.store.get_rooms_for_user(user_id)
- if not room_ids:
- continue
+ # Get cached cross-signing keys
+ cross_signing_keys = await self.get_cross_signing_keys_from_cache(
+ device_keys_query, from_user_id
+ )
- # We've decided we're sharing a room with this user and should
- # probably be tracking their device lists. However, we haven't
- # done an initial sync on the device list so we do it now.
- try:
- if self._is_master:
- user_devices = await self.device_handler.device_list_updater.user_device_resync(
- user_id
+ # Now attempt to get any remote devices from our local cache.
+ # A map of destination -> user ID -> device IDs.
+ remote_queries_not_in_cache = (
+ {}
+ ) # type: Dict[str, Dict[str, Iterable[str]]]
+ if remote_queries:
+ query_list = [] # type: List[Tuple[str, Optional[str]]]
+ for user_id, device_ids in remote_queries.items():
+ if device_ids:
+ query_list.extend(
+ (user_id, device_id) for device_id in device_ids
)
else:
- user_devices = await self._user_device_resync_client(
- user_id=user_id
- )
-
- user_devices = user_devices["devices"]
- user_results = results.setdefault(user_id, {})
- for device in user_devices:
- user_results[device["device_id"]] = device["keys"]
- user_ids_updated.append(user_id)
- except Exception as e:
- failures[destination] = _exception_to_failure(e)
-
- if len(destination_query) == len(user_ids_updated):
- # We've updated all the users in the query and we do not need to
- # make any further remote calls.
- return
+ query_list.append((user_id, None))
- # Remove all the users from the query which we have updated
- for user_id in user_ids_updated:
- destination_query.pop(user_id)
+ (
+ user_ids_not_in_cache,
+ remote_results,
+ ) = await self.store.get_user_devices_from_cache(query_list)
+ for user_id, devices in remote_results.items():
+ user_devices = results.setdefault(user_id, {})
+ for device_id, device in devices.items():
+ keys = device.get("keys", None)
+ device_display_name = device.get("device_display_name", None)
+ if keys:
+ result = dict(keys)
+ unsigned = result.setdefault("unsigned", {})
+ if device_display_name:
+ unsigned["device_display_name"] = device_display_name
+ user_devices[device_id] = result
+
+ # check for missing cross-signing keys.
+ for user_id in remote_queries.keys():
+ cached_cross_master = user_id in cross_signing_keys["master_keys"]
+ cached_cross_selfsigning = (
+ user_id in cross_signing_keys["self_signing_keys"]
+ )
- try:
- remote_result = await self.federation.query_client_keys(
- destination, {"device_keys": destination_query}, timeout=timeout
- )
+ # check if we are missing only one of cross-signing master or
+ # self-signing key, but the other one is cached.
+ # as we need both, this will issue a federation request.
+ # if we don't have any of the keys, either the user doesn't have
+ # cross-signing set up, or the cached device list
+ # is not (yet) updated.
+ if cached_cross_master ^ cached_cross_selfsigning:
+ user_ids_not_in_cache.add(user_id)
+
+ # add those users to the list to fetch over federation.
+ for user_id in user_ids_not_in_cache:
+ domain = get_domain_from_id(user_id)
+ r = remote_queries_not_in_cache.setdefault(domain, {})
+ r[user_id] = remote_queries[user_id]
+
+ # Now fetch any devices that we don't have in our cache
+ @trace
+ async def do_remote_query(destination):
+ """This is called when we are querying the device list of a user on
+ a remote homeserver and their device list is not in the device list
+ cache. If we share a room with this user and we're not querying for
+ specific user we will update the cache with their device list.
+ """
+
+ destination_query = remote_queries_not_in_cache[destination]
+
+ # We first consider whether we wish to update the device list cache with
+ # the users device list. We want to track a user's devices when the
+ # authenticated user shares a room with the queried user and the query
+ # has not specified a particular device.
+ # If we update the cache for the queried user we remove them from further
+ # queries. We use the more efficient batched query_client_keys for all
+ # remaining users
+ user_ids_updated = []
+ for (user_id, device_list) in destination_query.items():
+ if user_id in user_ids_updated:
+ continue
+
+ if device_list:
+ continue
+
+ room_ids = await self.store.get_rooms_for_user(user_id)
+ if not room_ids:
+ continue
+
+ # We've decided we're sharing a room with this user and should
+ # probably be tracking their device lists. However, we haven't
+ # done an initial sync on the device list so we do it now.
+ try:
+ if self._is_master:
+ user_devices = await self.device_handler.device_list_updater.user_device_resync(
+ user_id
+ )
+ else:
+ user_devices = await self._user_device_resync_client(
+ user_id=user_id
+ )
+
+ user_devices = user_devices["devices"]
+ user_results = results.setdefault(user_id, {})
+ for device in user_devices:
+ user_results[device["device_id"]] = device["keys"]
+ user_ids_updated.append(user_id)
+ except Exception as e:
+ failures[destination] = _exception_to_failure(e)
+
+ if len(destination_query) == len(user_ids_updated):
+ # We've updated all the users in the query and we do not need to
+ # make any further remote calls.
+ return
+
+ # Remove all the users from the query which we have updated
+ for user_id in user_ids_updated:
+ destination_query.pop(user_id)
- for user_id, keys in remote_result["device_keys"].items():
- if user_id in destination_query:
- results[user_id] = keys
+ try:
+ remote_result = await self.federation.query_client_keys(
+ destination, {"device_keys": destination_query}, timeout=timeout
+ )
- if "master_keys" in remote_result:
- for user_id, key in remote_result["master_keys"].items():
+ for user_id, keys in remote_result["device_keys"].items():
if user_id in destination_query:
- cross_signing_keys["master_keys"][user_id] = key
+ results[user_id] = keys
- if "self_signing_keys" in remote_result:
- for user_id, key in remote_result["self_signing_keys"].items():
- if user_id in destination_query:
- cross_signing_keys["self_signing_keys"][user_id] = key
+ if "master_keys" in remote_result:
+ for user_id, key in remote_result["master_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["master_keys"][user_id] = key
- except Exception as e:
- failure = _exception_to_failure(e)
- failures[destination] = failure
- set_tag("error", True)
- set_tag("reason", failure)
+ if "self_signing_keys" in remote_result:
+ for user_id, key in remote_result["self_signing_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["self_signing_keys"][user_id] = key
- await make_deferred_yieldable(
- defer.gatherResults(
- [
- run_in_background(do_remote_query, destination)
- for destination in remote_queries_not_in_cache
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError)
- )
+ except Exception as e:
+ failure = _exception_to_failure(e)
+ failures[destination] = failure
+ set_tag("error", True)
+ set_tag("reason", failure)
+
+ await make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(do_remote_query, destination)
+ for destination in remote_queries_not_in_cache
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+ )
- ret = {"device_keys": results, "failures": failures}
+ ret = {"device_keys": results, "failures": failures}
- ret.update(cross_signing_keys)
+ ret.update(cross_signing_keys)
- return ret
+ return ret
async def get_cross_signing_keys_from_cache(
self, query: Iterable[str], from_user_id: Optional[str]
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index abbb71424d..6647063485 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1,6 +1,5 @@
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -34,6 +33,7 @@ from typing import (
)
import attr
+from prometheus_client import Counter
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
@@ -102,6 +102,11 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+soft_failed_event_counter = Counter(
+ "synapse_federation_soft_failed_events_total",
+ "Events received over federation that we marked as soft_failed",
+)
+
@attr.s(slots=True)
class _NewEventInfo:
@@ -1550,6 +1555,77 @@ class FederationHandler(BaseHandler):
run_in_background(self._handle_queued_pdus, room_queue)
+ @log_function
+ async def do_knock(
+ self,
+ target_hosts: List[str],
+ room_id: str,
+ knockee: str,
+ content: JsonDict,
+ ) -> Tuple[str, int]:
+ """Sends the knock to the remote server.
+
+ This first triggers a make_knock request that returns a partial
+ event that we can fill out and sign. This is then sent to the
+ remote server via send_knock.
+
+ Knock events must be signed by the knockee's server before distributing.
+
+ Args:
+ target_hosts: A list of hosts that we want to try knocking through.
+ room_id: The ID of the room to knock on.
+ knockee: The ID of the user who is knocking.
+ content: The content of the knock event.
+
+ Returns:
+ A tuple of (event ID, stream ID).
+
+ Raises:
+ SynapseError: If the chosen remote server returns a 3xx/4xx code.
+ RuntimeError: If no servers were reachable.
+ """
+ logger.debug("Knocking on room %s on behalf of user %s", room_id, knockee)
+
+ # Inform the remote server of the room versions we support
+ supported_room_versions = list(KNOWN_ROOM_VERSIONS.keys())
+
+ # Ask the remote server to create a valid knock event for us. Once received,
+ # we sign the event
+ params = {"ver": supported_room_versions} # type: Dict[str, Iterable[str]]
+ origin, event, event_format_version = await self._make_and_verify_event(
+ target_hosts, room_id, knockee, Membership.KNOCK, content, params=params
+ )
+
+ # Record the room ID and its version so that we have a record of the room
+ await self._maybe_store_room_on_outlier_membership(
+ room_id=event.room_id, room_version=event_format_version
+ )
+
+ # Initially try the host that we successfully called /make_knock on
+ try:
+ target_hosts.remove(origin)
+ target_hosts.insert(0, origin)
+ except ValueError:
+ pass
+
+ # Send the signed event back to the room, and potentially receive some
+ # further information about the room in the form of partial state events
+ stripped_room_state = await self.federation_client.send_knock(
+ target_hosts, event
+ )
+
+ # Store any stripped room state events in the "unsigned" key of the event.
+ # This is a bit of a hack and is cribbing off of invites. Basically we
+ # store the room state here and retrieve it again when this event appears
+ # in the invitee's sync stream. It is stripped out for all other local users.
+ event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
+
+ context = await self.state_handler.compute_event_context(event)
+ stream_id = await self.persist_events_and_notify(
+ event.room_id, [(event, context)]
+ )
+ return event.event_id, stream_id
+
async def _handle_queued_pdus(
self, room_queue: List[Tuple[EventBase, str]]
) -> None:
@@ -1915,6 +1991,116 @@ class FederationHandler(BaseHandler):
return None
+ @log_function
+ async def on_make_knock_request(
+ self, origin: str, room_id: str, user_id: str
+ ) -> EventBase:
+ """We've received a make_knock request, so we create a partial
+ knock event for the room and return that. We do *not* persist or
+ process it until the other server has signed it and sent it back.
+
+ Args:
+ origin: The (verified) server name of the requesting server.
+ room_id: The room to create the knock event in.
+ user_id: The user to create the knock for.
+
+ Returns:
+ The partial knock event.
+ """
+ if get_domain_from_id(user_id) != origin:
+ logger.info(
+ "Get /xyz.amorgan.knock/make_knock request for user %r"
+ "from different origin %s, ignoring",
+ user_id,
+ origin,
+ )
+ raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
+
+ room_version = await self.store.get_room_version_id(room_id)
+
+ builder = self.event_builder_factory.new(
+ room_version,
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": Membership.KNOCK},
+ "room_id": room_id,
+ "sender": user_id,
+ "state_key": user_id,
+ },
+ )
+
+ event, context = await self.event_creation_handler.create_new_client_event(
+ builder=builder
+ )
+
+ event_allowed = await self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ logger.warning("Creation of knock %s forbidden by third-party rules", event)
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
+ try:
+ # The remote hasn't signed it yet, obviously. We'll do the full checks
+ # when we get the event back in `on_send_knock_request`
+ await self.auth.check_from_context(
+ room_version, event, context, do_sig_check=False
+ )
+ except AuthError as e:
+ logger.warning("Failed to create new knock %r because %s", event, e)
+ raise e
+
+ return event
+
+ @log_function
+ async def on_send_knock_request(
+ self, origin: str, event: EventBase
+ ) -> EventContext:
+ """
+ We have received a knock event for a room. Verify that event and send it into the room
+ on the knocking homeserver's behalf.
+
+ Args:
+ origin: The remote homeserver of the knocking user.
+ event: The knocking member event that has been signed by the remote homeserver.
+
+ Returns:
+ The context of the event after inserting it into the room graph.
+ """
+ logger.debug(
+ "on_send_knock_request: Got event: %s, signatures: %s",
+ event.event_id,
+ event.signatures,
+ )
+
+ if get_domain_from_id(event.sender) != origin:
+ logger.info(
+ "Got /xyz.amorgan.knock/send_knock request for user %r "
+ "from different origin %s",
+ event.sender,
+ origin,
+ )
+ raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
+
+ event.internal_metadata.outlier = False
+
+ context = await self.state_handler.compute_event_context(event)
+
+ await self._auth_and_persist_event(origin, event, context)
+
+ event_allowed = await self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ logger.info("Sending of knock %s forbidden by third-party rules", event)
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
+ return context
+
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
"""Returns the state at the event. i.e. not including said event."""
@@ -2318,6 +2504,7 @@ class FederationHandler(BaseHandler):
event_auth.check(room_version_obj, event, auth_events=current_auth_events)
except AuthError as e:
logger.warning("Soft-failing %r because %s", event, e)
+ soft_failed_event_counter.inc()
event.internal_metadata.soft_failed = True
async def on_get_missing_events(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9f365eb5ad..4d2255bdf1 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1,6 +1,7 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+# Copyrignt 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -398,13 +399,14 @@ class EventCreationHandler:
self._events_shard_config = self.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
- self.room_invite_state_types = self.hs.config.api.room_prejoin_state
+ self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
- self.membership_types_to_include_profile_data_in = (
- {Membership.JOIN, Membership.INVITE}
- if self.hs.config.include_profile_data_on_invite
- else {Membership.JOIN}
- )
+ self.membership_types_to_include_profile_data_in = {
+ Membership.JOIN,
+ Membership.KNOCK,
+ }
+ if self.hs.config.include_profile_data_on_invite:
+ self.membership_types_to_include_profile_data_in.add(Membership.INVITE)
self.send_event = ReplicationSendEventRestServlet.make_client(hs)
@@ -961,8 +963,8 @@ class EventCreationHandler:
room_version = await self.store.get_room_version_id(event.room_id)
if event.internal_metadata.is_out_of_band_membership():
- # the only sort of out-of-band-membership events we expect to see here
- # are invite rejections we have generated ourselves.
+ # the only sort of out-of-band-membership events we expect to see here are
+ # invite rejections and rescinded knocks that we have generated ourselves.
assert event.type == EventTypes.Member
assert event.content["membership"] == Membership.LEAVE
else:
@@ -1239,7 +1241,7 @@ class EventCreationHandler:
"invite_room_state"
] = await self.store.get_stripped_room_state_from_event_context(
context,
- self.room_invite_state_types,
+ self.room_prejoin_state_types,
membership_user_id=event.sender,
)
@@ -1257,6 +1259,14 @@ class EventCreationHandler:
# TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures)
+ if event.content["membership"] == Membership.KNOCK:
+ event.unsigned[
+ "knock_room_state"
+ ] = await self.store.get_stripped_room_state_from_event_context(
+ context,
+ self.room_prejoin_state_types,
+ )
+
if event.type == EventTypes.Redaction:
original_event = await self.store.get_event(
event.redacts,
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 141c9c0444..5e3ef7ce3a 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -44,7 +44,7 @@ class RoomListHandler(BaseHandler):
self.enable_room_list_search = hs.config.enable_room_list_search
self.response_cache = ResponseCache(
hs.get_clock(), "room_list"
- ) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]
+ ) # type: ResponseCache[Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]]
self.remote_response_cache = ResponseCache(
hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000
) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
@@ -54,7 +54,7 @@ class RoomListHandler(BaseHandler):
limit: Optional[int] = None,
since_token: Optional[str] = None,
search_filter: Optional[dict] = None,
- network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
+ network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID,
from_federation: bool = False,
) -> JsonDict:
"""Generate a local public room list.
@@ -111,7 +111,7 @@ class RoomListHandler(BaseHandler):
limit: Optional[int] = None,
since_token: Optional[str] = None,
search_filter: Optional[dict] = None,
- network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
+ network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID,
from_federation: bool = False,
) -> JsonDict:
"""Generate a public room list.
@@ -169,6 +169,7 @@ class RoomListHandler(BaseHandler):
"world_readable": room["history_visibility"]
== HistoryVisibility.WORLD_READABLE,
"guest_can_join": room["guest_access"] == "can_join",
+ "join_rule": room["join_rules"],
}
# Filter out Nones – rather omit the field altogether
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index d6fc43e798..c26963b1e1 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1,4 +1,5 @@
# Copyright 2016-2020 The Matrix.org Foundation C.I.C.
+# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import abc
import logging
import random
@@ -30,7 +30,15 @@ from synapse.api.errors import (
from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
+from synapse.types import (
+ JsonDict,
+ Requester,
+ RoomAlias,
+ RoomID,
+ StateMap,
+ UserID,
+ get_domain_from_id,
+)
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_left_room
@@ -126,6 +134,24 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
+ async def remote_knock(
+ self,
+ remote_room_hosts: List[str],
+ room_id: str,
+ user: UserID,
+ content: dict,
+ ) -> Tuple[str, int]:
+ """Try and knock on a room that this server is not in
+
+ Args:
+ remote_room_hosts: List of servers that can be used to knock via.
+ room_id: Room that we are trying to knock on.
+ user: User who is trying to knock.
+ content: A dict that should be used as the content of the knock event.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
async def remote_reject_invite(
self,
invite_event_id: str,
@@ -149,6 +175,27 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
+ async def remote_rescind_knock(
+ self,
+ knock_event_id: str,
+ txn_id: Optional[str],
+ requester: Requester,
+ content: JsonDict,
+ ) -> Tuple[str, int]:
+ """Rescind a local knock made on a remote room.
+
+ Args:
+ knock_event_id: The ID of the knock event to rescind.
+ txn_id: An optional transaction ID supplied by the client.
+ requester: The user making the request, according to the access token.
+ content: The content of the generated leave event.
+
+ Returns:
+ A tuple containing (event_id, stream_id of the leave event).
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Notifies distributor on master process that the user has left the
room.
@@ -603,53 +650,82 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
- # perhaps we've been invited
+ # Figure out the user's current membership state for the room
(
current_membership_type,
current_membership_event_id,
) = await self.store.get_local_current_membership_for_user_in_room(
target.to_string(), room_id
)
- if (
- current_membership_type != Membership.INVITE
- or not current_membership_event_id
- ):
+ if not current_membership_type or not current_membership_event_id:
logger.info(
"%s sent a leave request to %s, but that is not an active room "
- "on this server, and there is no pending invite",
+ "on this server, or there is no pending invite or knock",
target,
room_id,
)
raise SynapseError(404, "Not a known room")
- invite = await self.store.get_event(current_membership_event_id)
- logger.info(
- "%s rejects invite to %s from %s", target, room_id, invite.sender
- )
+ # perhaps we've been invited
+ if current_membership_type == Membership.INVITE:
+ invite = await self.store.get_event(current_membership_event_id)
+ logger.info(
+ "%s rejects invite to %s from %s",
+ target,
+ room_id,
+ invite.sender,
+ )
- if not self.hs.is_mine_id(invite.sender):
- # send the rejection to the inviter's HS (with fallback to
- # local event)
- return await self.remote_reject_invite(
- invite.event_id,
- txn_id,
- requester,
- content,
+ if not self.hs.is_mine_id(invite.sender):
+ # send the rejection to the inviter's HS (with fallback to
+ # local event)
+ return await self.remote_reject_invite(
+ invite.event_id,
+ txn_id,
+ requester,
+ content,
+ )
+
+ # the inviter was on our server, but has now left. Carry on
+ # with the normal rejection codepath, which will also send the
+ # rejection out to any other servers we believe are still in the room.
+
+ # thanks to overzealous cleaning up of event_forward_extremities in
+ # `delete_old_current_state_events`, it's possible to end up with no
+ # forward extremities here. If that happens, let's just hang the
+ # rejection off the invite event.
+ #
+ # see: https://github.com/matrix-org/synapse/issues/7139
+ if len(latest_event_ids) == 0:
+ latest_event_ids = [invite.event_id]
+
+ # or perhaps this is a remote room that a local user has knocked on
+ elif current_membership_type == Membership.KNOCK:
+ knock = await self.store.get_event(current_membership_event_id)
+ return await self.remote_rescind_knock(
+ knock.event_id, txn_id, requester, content
)
- # the inviter was on our server, but has now left. Carry on
- # with the normal rejection codepath, which will also send the
- # rejection out to any other servers we believe are still in the room.
+ elif (
+ self.config.experimental.msc2403_enabled
+ and effective_membership_state == Membership.KNOCK
+ ):
+ if not is_host_in_room:
+ # The knock needs to be sent over federation instead
+ remote_room_hosts.append(get_domain_from_id(room_id))
- # thanks to overzealous cleaning up of event_forward_extremities in
- # `delete_old_current_state_events`, it's possible to end up with no
- # forward extremities here. If that happens, let's just hang the
- # rejection off the invite event.
- #
- # see: https://github.com/matrix-org/synapse/issues/7139
- if len(latest_event_ids) == 0:
- latest_event_ids = [invite.event_id]
+ content["membership"] = Membership.KNOCK
+
+ profile = self.profile_handler
+ if "displayname" not in content:
+ content["displayname"] = await profile.get_displayname(target)
+ if "avatar_url" not in content:
+ content["avatar_url"] = await profile.get_avatar_url(target)
+
+ return await self.remote_knock(
+ remote_room_hosts, room_id, target, content
+ )
return await self._local_membership_update(
requester=requester,
@@ -1209,6 +1285,35 @@ class RoomMemberMasterHandler(RoomMemberHandler):
invite_event, txn_id, requester, content
)
+ async def remote_rescind_knock(
+ self,
+ knock_event_id: str,
+ txn_id: Optional[str],
+ requester: Requester,
+ content: JsonDict,
+ ) -> Tuple[str, int]:
+ """
+ Rescinds a local knock made on a remote room
+
+ Args:
+ knock_event_id: The ID of the knock event to rescind.
+ txn_id: The transaction ID to use.
+ requester: The originator of the request.
+ content: The content of the leave event.
+
+ Implements RoomMemberHandler.remote_rescind_knock
+ """
+ # TODO: We don't yet support rescinding knocks over federation
+ # as we don't know which homeserver to send it to. An obvious
+ # candidate is the remote homeserver we originally knocked through,
+ # however we don't currently store that information.
+
+ # Just rescind the knock locally
+ knock_event = await self.store.get_event(knock_event_id)
+ return await self._generate_local_out_of_band_leave(
+ knock_event, txn_id, requester, content
+ )
+
async def _generate_local_out_of_band_leave(
self,
previous_membership_event: EventBase,
@@ -1272,6 +1377,36 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return result_event.event_id, result_event.internal_metadata.stream_ordering
+ async def remote_knock(
+ self,
+ remote_room_hosts: List[str],
+ room_id: str,
+ user: UserID,
+ content: dict,
+ ) -> Tuple[str, int]:
+ """Sends a knock to a room. Attempts to do so via one remote out of a given list.
+
+ Args:
+ remote_room_hosts: A list of homeservers to try knocking through.
+ room_id: The ID of the room to knock on.
+ user: The user to knock on behalf of.
+ content: The content of the knock event.
+
+ Returns:
+ A tuple of (event ID, stream ID).
+ """
+ # filter ourselves out of remote_room_hosts
+ remote_room_hosts = [
+ host for host in remote_room_hosts if host != self.hs.hostname
+ ]
+
+ if len(remote_room_hosts) == 0:
+ raise SynapseError(404, "No known servers")
+
+ return await self.federation_handler.do_knock(
+ remote_room_hosts, room_id, user.to_string(), content=content
+ )
+
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room"""
user_left_room(self.distributor, target, room_id)
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 3e89dd2315..221552a2a6 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -1,4 +1,4 @@
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,10 +19,12 @@ from synapse.api.errors import SynapseError
from synapse.handlers.room_member import RoomMemberHandler
from synapse.replication.http.membership import (
ReplicationRemoteJoinRestServlet as ReplRemoteJoin,
+ ReplicationRemoteKnockRestServlet as ReplRemoteKnock,
ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite,
+ ReplicationRemoteRescindKnockRestServlet as ReplRescindKnock,
ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft,
)
-from synapse.types import Requester, UserID
+from synapse.types import JsonDict, Requester, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -35,7 +37,9 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
super().__init__(hs)
self._remote_join_client = ReplRemoteJoin.make_client(hs)
+ self._remote_knock_client = ReplRemoteKnock.make_client(hs)
self._remote_reject_client = ReplRejectInvite.make_client(hs)
+ self._remote_rescind_client = ReplRescindKnock.make_client(hs)
self._notify_change_client = ReplJoinedLeft.make_client(hs)
async def _remote_join(
@@ -80,6 +84,53 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
)
return ret["event_id"], ret["stream_id"]
+ async def remote_rescind_knock(
+ self,
+ knock_event_id: str,
+ txn_id: Optional[str],
+ requester: Requester,
+ content: JsonDict,
+ ) -> Tuple[str, int]:
+ """
+ Rescinds a local knock made on a remote room
+
+ Args:
+ knock_event_id: the knock event
+ txn_id: optional transaction ID supplied by the client
+ requester: user making the request, according to the access token
+ content: additional content to include in the leave event.
+ Normally an empty dict.
+
+ Returns:
+ A tuple containing (event_id, stream_id of the leave event)
+ """
+ ret = await self._remote_rescind_client(
+ knock_event_id=knock_event_id,
+ txn_id=txn_id,
+ requester=requester,
+ content=content,
+ )
+ return ret["event_id"], ret["stream_id"]
+
+ async def remote_knock(
+ self,
+ remote_room_hosts: List[str],
+ room_id: str,
+ user: UserID,
+ content: dict,
+ ) -> Tuple[str, int]:
+ """Sends a knock to a room.
+
+ Implements RoomMemberHandler.remote_knock
+ """
+ ret = await self._remote_knock_client(
+ remote_room_hosts=remote_room_hosts,
+ room_id=room_id,
+ user=user,
+ content=content,
+ )
+ return ret["event_id"], ret["stream_id"]
+
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room"""
await self._notify_change_client(
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 383e34026e..4e45d1da57 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -1,4 +1,5 @@
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -230,6 +231,8 @@ class StatsHandler:
room_stats_delta["left_members"] -= 1
elif prev_membership == Membership.BAN:
room_stats_delta["banned_members"] -= 1
+ elif prev_membership == Membership.KNOCK:
+ room_stats_delta["knocked_members"] -= 1
else:
raise ValueError(
"%r is not a valid prev_membership" % (prev_membership,)
@@ -251,6 +254,8 @@ class StatsHandler:
room_stats_delta["left_members"] += 1
elif membership == Membership.BAN:
room_stats_delta["banned_members"] += 1
+ elif membership == Membership.KNOCK:
+ room_stats_delta["knocked_members"] += 1
else:
raise ValueError("%r is not a valid membership" % (membership,))
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b1c58ffdc8..7f2138d804 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -160,6 +160,16 @@ class InvitedSyncResult:
@attr.s(slots=True, frozen=True)
+class KnockedSyncResult:
+ room_id = attr.ib(type=str)
+ knock = attr.ib(type=EventBase)
+
+ def __bool__(self) -> bool:
+ """Knocked rooms should always be reported to the client"""
+ return True
+
+
+@attr.s(slots=True, frozen=True)
class GroupsSyncResult:
join = attr.ib(type=JsonDict)
invite = attr.ib(type=JsonDict)
@@ -192,6 +202,7 @@ class _RoomChanges:
room_entries = attr.ib(type=List["RoomSyncResultBuilder"])
invited = attr.ib(type=List[InvitedSyncResult])
+ knocked = attr.ib(type=List[KnockedSyncResult])
newly_joined_rooms = attr.ib(type=List[str])
newly_left_rooms = attr.ib(type=List[str])
@@ -205,6 +216,7 @@ class SyncResult:
account_data: List of account_data events for the user.
joined: JoinedSyncResult for each joined room.
invited: InvitedSyncResult for each invited room.
+ knocked: KnockedSyncResult for each knocked on room.
archived: ArchivedSyncResult for each archived room.
to_device: List of direct messages for the device.
device_lists: List of user_ids whose devices have changed
@@ -220,6 +232,7 @@ class SyncResult:
account_data = attr.ib(type=List[JsonDict])
joined = attr.ib(type=List[JoinedSyncResult])
invited = attr.ib(type=List[InvitedSyncResult])
+ knocked = attr.ib(type=List[KnockedSyncResult])
archived = attr.ib(type=List[ArchivedSyncResult])
to_device = attr.ib(type=List[JsonDict])
device_lists = attr.ib(type=DeviceLists)
@@ -236,6 +249,7 @@ class SyncResult:
self.presence
or self.joined
or self.invited
+ or self.knocked
or self.archived
or self.account_data
or self.to_device
@@ -1031,7 +1045,7 @@ class SyncHandler:
res = await self._generate_sync_entry_for_rooms(
sync_result_builder, account_data_by_room
)
- newly_joined_rooms, newly_joined_or_invited_users, _, _ = res
+ newly_joined_rooms, newly_joined_or_invited_or_knocked_users, _, _ = res
_, _, newly_left_rooms, newly_left_users = res
block_all_presence_data = (
@@ -1040,7 +1054,9 @@ class SyncHandler:
if self.hs_config.use_presence and not block_all_presence_data:
logger.debug("Fetching presence data")
await self._generate_sync_entry_for_presence(
- sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
+ sync_result_builder,
+ newly_joined_rooms,
+ newly_joined_or_invited_or_knocked_users,
)
logger.debug("Fetching to-device data")
@@ -1049,7 +1065,7 @@ class SyncHandler:
device_lists = await self._generate_sync_entry_for_device_list(
sync_result_builder,
newly_joined_rooms=newly_joined_rooms,
- newly_joined_or_invited_users=newly_joined_or_invited_users,
+ newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
newly_left_rooms=newly_left_rooms,
newly_left_users=newly_left_users,
)
@@ -1083,6 +1099,7 @@ class SyncHandler:
account_data=sync_result_builder.account_data,
joined=sync_result_builder.joined,
invited=sync_result_builder.invited,
+ knocked=sync_result_builder.knocked,
archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
device_lists=device_lists,
@@ -1142,7 +1159,7 @@ class SyncHandler:
self,
sync_result_builder: "SyncResultBuilder",
newly_joined_rooms: Set[str],
- newly_joined_or_invited_users: Set[str],
+ newly_joined_or_invited_or_knocked_users: Set[str],
newly_left_rooms: Set[str],
newly_left_users: Set[str],
) -> DeviceLists:
@@ -1151,8 +1168,9 @@ class SyncHandler:
Args:
sync_result_builder
newly_joined_rooms: Set of rooms user has joined since previous sync
- newly_joined_or_invited_users: Set of users that have joined or
- been invited to a room since previous sync.
+ newly_joined_or_invited_or_knocked_users: Set of users that have joined,
+ been invited to a room or are knocking on a room since
+ previous sync.
newly_left_rooms: Set of rooms user has left since previous sync
newly_left_users: Set of users that have left a room we're in since
previous sync
@@ -1163,7 +1181,9 @@ class SyncHandler:
# We're going to mutate these fields, so lets copy them rather than
# assume they won't get used later.
- newly_joined_or_invited_users = set(newly_joined_or_invited_users)
+ newly_joined_or_invited_or_knocked_users = set(
+ newly_joined_or_invited_or_knocked_users
+ )
newly_left_users = set(newly_left_users)
if since_token and since_token.device_list_key:
@@ -1202,11 +1222,11 @@ class SyncHandler:
# Step 1b, check for newly joined rooms
for room_id in newly_joined_rooms:
joined_users = await self.store.get_users_in_room(room_id)
- newly_joined_or_invited_users.update(joined_users)
+ newly_joined_or_invited_or_knocked_users.update(joined_users)
# TODO: Check that these users are actually new, i.e. either they
# weren't in the previous sync *or* they left and rejoined.
- users_that_have_changed.update(newly_joined_or_invited_users)
+ users_that_have_changed.update(newly_joined_or_invited_or_knocked_users)
user_signatures_changed = (
await self.store.get_users_whose_signatures_changed(
@@ -1452,6 +1472,7 @@ class SyncHandler:
room_entries = room_changes.room_entries
invited = room_changes.invited
+ knocked = room_changes.knocked
newly_joined_rooms = room_changes.newly_joined_rooms
newly_left_rooms = room_changes.newly_left_rooms
@@ -1472,9 +1493,10 @@ class SyncHandler:
await concurrently_execute(handle_room_entries, room_entries, 10)
sync_result_builder.invited.extend(invited)
+ sync_result_builder.knocked.extend(knocked)
- # Now we want to get any newly joined or invited users
- newly_joined_or_invited_users = set()
+ # Now we want to get any newly joined, invited or knocking users
+ newly_joined_or_invited_or_knocked_users = set()
newly_left_users = set()
if since_token:
for joined_sync in sync_result_builder.joined:
@@ -1486,19 +1508,22 @@ class SyncHandler:
if (
event.membership == Membership.JOIN
or event.membership == Membership.INVITE
+ or event.membership == Membership.KNOCK
):
- newly_joined_or_invited_users.add(event.state_key)
+ newly_joined_or_invited_or_knocked_users.add(
+ event.state_key
+ )
else:
prev_content = event.unsigned.get("prev_content", {})
prev_membership = prev_content.get("membership", None)
if prev_membership == Membership.JOIN:
newly_left_users.add(event.state_key)
- newly_left_users -= newly_joined_or_invited_users
+ newly_left_users -= newly_joined_or_invited_or_knocked_users
return (
set(newly_joined_rooms),
- newly_joined_or_invited_users,
+ newly_joined_or_invited_or_knocked_users,
set(newly_left_rooms),
newly_left_users,
)
@@ -1553,6 +1578,7 @@ class SyncHandler:
newly_left_rooms = []
room_entries = []
invited = []
+ knocked = []
for room_id, events in mem_change_events_by_room_id.items():
logger.debug(
"Membership changes in %s: [%s]",
@@ -1632,9 +1658,17 @@ class SyncHandler:
should_invite = non_joins[-1].membership == Membership.INVITE
if should_invite:
if event.sender not in ignored_users:
- room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
- if room_sync:
- invited.append(room_sync)
+ invite_room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
+ if invite_room_sync:
+ invited.append(invite_room_sync)
+
+ # Only bother if our latest membership in the room is knock (and we haven't
+ # been accepted/rejected in the meantime).
+ should_knock = non_joins[-1].membership == Membership.KNOCK
+ if should_knock:
+ knock_room_sync = KnockedSyncResult(room_id, knock=non_joins[-1])
+ if knock_room_sync:
+ knocked.append(knock_room_sync)
# Always include leave/ban events. Just take the last one.
# TODO: How do we handle ban -> leave in same batch?
@@ -1738,7 +1772,13 @@ class SyncHandler:
)
room_entries.append(entry)
- return _RoomChanges(room_entries, invited, newly_joined_rooms, newly_left_rooms)
+ return _RoomChanges(
+ room_entries,
+ invited,
+ knocked,
+ newly_joined_rooms,
+ newly_left_rooms,
+ )
async def _get_all_rooms(
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
@@ -1758,6 +1798,7 @@ class SyncHandler:
membership_list = (
Membership.INVITE,
+ Membership.KNOCK,
Membership.JOIN,
Membership.LEAVE,
Membership.BAN,
@@ -1769,6 +1810,7 @@ class SyncHandler:
room_entries = []
invited = []
+ knocked = []
for event in room_list:
if event.membership == Membership.JOIN:
@@ -1788,8 +1830,11 @@ class SyncHandler:
continue
invite = await self.store.get_event(event.event_id)
invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite))
+ elif event.membership == Membership.KNOCK:
+ knock = await self.store.get_event(event.event_id)
+ knocked.append(KnockedSyncResult(room_id=event.room_id, knock=knock))
elif event.membership in (Membership.LEAVE, Membership.BAN):
- # Always send down rooms we were banned or kicked from.
+ # Always send down rooms we were banned from or kicked from.
if not sync_config.filter_collection.include_leave:
if event.membership == Membership.LEAVE:
if user_id == event.sender:
@@ -1810,7 +1855,7 @@ class SyncHandler:
)
)
- return _RoomChanges(room_entries, invited, [], [])
+ return _RoomChanges(room_entries, invited, knocked, [], [])
async def _generate_room_entry(
self,
@@ -2101,6 +2146,7 @@ class SyncResultBuilder:
account_data (list)
joined (list[JoinedSyncResult])
invited (list[InvitedSyncResult])
+ knocked (list[KnockedSyncResult])
archived (list[ArchivedSyncResult])
groups (GroupsSyncResult|None)
to_device (list)
@@ -2116,6 +2162,7 @@ class SyncResultBuilder:
account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list))
joined = attr.ib(type=List[JoinedSyncResult], default=attr.Factory(list))
invited = attr.ib(type=List[InvitedSyncResult], default=attr.Factory(list))
+ knocked = attr.ib(type=List[KnockedSyncResult], default=attr.Factory(list))
archived = attr.ib(type=List[ArchivedSyncResult], default=attr.Factory(list))
groups = attr.ib(type=Optional[GroupsSyncResult], default=None)
to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list))
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 1998990a14..629373fc47 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -65,13 +65,9 @@ from synapse.http.client import (
read_body_with_max_size,
)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
+from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable
-from synapse.logging.opentracing import (
- inject_active_span_byte_dict,
- set_tag,
- start_active_span,
- tags,
-)
+from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import ISynapseReactor, JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
@@ -497,7 +493,7 @@ class MatrixFederationHttpClient:
# Inject the span into the headers
headers_dict = {} # type: Dict[bytes, List[bytes]]
- inject_active_span_byte_dict(headers_dict, request.destination)
+ opentracing.inject_header_dict(headers_dict, request.destination)
headers_dict[b"User-Agent"] = [self.version_string_bytes]
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index d61563d39b..3c43f32586 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -13,7 +13,6 @@
# limitations under the License.
""" This module contains base REST classes for constructing REST servlets. """
-
import logging
from typing import Dict, Iterable, List, Optional, overload
@@ -295,6 +294,30 @@ def parse_strings_from_args(
return default
+@overload
+def parse_string_from_args(
+ args: Dict[bytes, List[bytes]],
+ name: str,
+ default: Optional[str] = None,
+ required: Literal[True] = True,
+ allowed_values: Optional[Iterable[str]] = None,
+ encoding: str = "ascii",
+) -> str:
+ ...
+
+
+@overload
+def parse_string_from_args(
+ args: Dict[bytes, List[bytes]],
+ name: str,
+ default: Optional[str] = None,
+ required: bool = False,
+ allowed_values: Optional[Iterable[str]] = None,
+ encoding: str = "ascii",
+) -> Optional[str]:
+ ...
+
+
def parse_string_from_args(
args: Dict[bytes, List[bytes]],
name: str,
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index dd9377340e..5b4725e035 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -168,7 +168,7 @@ import inspect
import logging
import re
from functools import wraps
-from typing import TYPE_CHECKING, Dict, Optional, Pattern, Type
+from typing import TYPE_CHECKING, Dict, List, Optional, Pattern, Type
import attr
@@ -574,59 +574,22 @@ def set_operation_name(operation_name):
# Injection and extraction
-@ensure_active_span("inject the span into a header")
-def inject_active_span_twisted_headers(headers, destination, check_destination=True):
+@ensure_active_span("inject the span into a header dict")
+def inject_header_dict(
+ headers: Dict[bytes, List[bytes]],
+ destination: Optional[str] = None,
+ check_destination: bool = True,
+) -> None:
"""
- Injects a span context into twisted headers in-place
+ Injects a span context into a dict of HTTP headers
Args:
- headers (twisted.web.http_headers.Headers)
- destination (str): address of entity receiving the span context. If check_destination
- is true the context will only be injected if the destination matches the
- opentracing whitelist
+ headers: the dict to inject headers into
+ destination: address of entity receiving the span context. Must be given unless
+ check_destination is False. The context will only be injected if the
+ destination matches the opentracing whitelist
check_destination (bool): If false, destination will be ignored and the context
will always be injected.
- span (opentracing.Span)
-
- Returns:
- In-place modification of headers
-
- Note:
- The headers set by the tracer are custom to the tracer implementation which
- should be unique enough that they don't interfere with any headers set by
- synapse or twisted. If we're still using jaeger these headers would be those
- here:
- https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
- """
-
- if check_destination and not whitelisted_homeserver(destination):
- return
-
- span = opentracing.tracer.active_span
- carrier = {} # type: Dict[str, str]
- opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier)
-
- for key, value in carrier.items():
- headers.addRawHeaders(key, value)
-
-
-@ensure_active_span("inject the span into a byte dict")
-def inject_active_span_byte_dict(headers, destination, check_destination=True):
- """
- Injects a span context into a dict where the headers are encoded as byte
- strings
-
- Args:
- headers (dict)
- destination (str): address of entity receiving the span context. If check_destination
- is true the context will only be injected if the destination matches the
- opentracing whitelist
- check_destination (bool): If false, destination will be ignored and the context
- will always be injected.
- span (opentracing.Span)
-
- Returns:
- In-place modification of headers
Note:
The headers set by the tracer are custom to the tracer implementation which
@@ -635,8 +598,13 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
here:
https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
"""
- if check_destination and not whitelisted_homeserver(destination):
- return
+ if check_destination:
+ if destination is None:
+ raise ValueError(
+ "destination must be given unless check_destination is False"
+ )
+ if not whitelisted_homeserver(destination):
+ return
span = opentracing.tracer.active_span
@@ -647,38 +615,6 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
headers[key.encode()] = [value.encode()]
-@ensure_active_span("inject the span into a text map")
-def inject_active_span_text_map(carrier, destination, check_destination=True):
- """
- Injects a span context into a dict
-
- Args:
- carrier (dict)
- destination (str): address of entity receiving the span context. If check_destination
- is true the context will only be injected if the destination matches the
- opentracing whitelist
- check_destination (bool): If false, destination will be ignored and the context
- will always be injected.
-
- Returns:
- In-place modification of carrier
-
- Note:
- The headers set by the tracer are custom to the tracer implementation which
- should be unique enough that they don't interfere with any headers set by
- synapse or twisted. If we're still using jaeger these headers would be those
- here:
- https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
- """
-
- if check_destination and not whitelisted_homeserver(destination):
- return
-
- opentracing.tracer.inject(
- opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
- )
-
-
@ensure_active_span("get the active span context as a dict", ret={})
def get_active_span_text_map(destination=None):
"""
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 5685cf2121..2a13026e9a 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -23,7 +23,8 @@ from prometheus_client import Counter, Gauge
from synapse.api.errors import HttpResponseException, SynapseError
from synapse.http import RequestTimedOutError
-from synapse.logging.opentracing import inject_active_span_byte_dict, trace
+from synapse.logging import opentracing
+from synapse.logging.opentracing import trace
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
@@ -235,7 +236,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# Add an authorization header, if configured.
if replication_secret:
headers[b"Authorization"] = [b"Bearer " + replication_secret]
- inject_active_span_byte_dict(headers, None, check_destination=False)
+ opentracing.inject_header_dict(headers, check_destination=False)
try:
result = await request_func(uri, data, headers=headers)
break
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 289a397d68..043c25f63d 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -97,6 +97,76 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
return 200, {"event_id": event_id, "stream_id": stream_id}
+class ReplicationRemoteKnockRestServlet(ReplicationEndpoint):
+ """Perform a remote knock for the given user on the given room
+
+ Request format:
+
+ POST /_synapse/replication/remote_knock/:room_id/:user_id
+
+ {
+ "requester": ...,
+ "remote_room_hosts": [...],
+ "content": { ... }
+ }
+ """
+
+ NAME = "remote_knock"
+ PATH_ARGS = ("room_id", "user_id")
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ self.federation_handler = hs.get_federation_handler()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ async def _serialize_payload( # type: ignore
+ requester: Requester,
+ room_id: str,
+ user_id: str,
+ remote_room_hosts: List[str],
+ content: JsonDict,
+ ):
+ """
+ Args:
+ requester: The user making the request, according to the access token.
+ room_id: The ID of the room to knock on.
+ user_id: The ID of the knocking user.
+ remote_room_hosts: Servers to try and send the knock via.
+ content: The event content to use for the knock event.
+ """
+ return {
+ "requester": requester.serialize(),
+ "remote_room_hosts": remote_room_hosts,
+ "content": content,
+ }
+
+ async def _handle_request( # type: ignore
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ user_id: str,
+ ):
+ content = parse_json_object_from_request(request)
+
+ remote_room_hosts = content["remote_room_hosts"]
+ event_content = content["content"]
+
+ requester = Requester.deserialize(self.store, content["requester"])
+
+ request.requester = requester
+
+ logger.debug("remote_knock: %s on room: %s", user_id, room_id)
+
+ event_id, stream_id = await self.federation_handler.do_knock(
+ remote_room_hosts, room_id, user_id, event_content
+ )
+
+ return 200, {"event_id": event_id, "stream_id": stream_id}
+
+
class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
"""Rejects an out-of-band invite we have received from a remote server
@@ -167,6 +237,75 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
return 200, {"event_id": event_id, "stream_id": stream_id}
+class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint):
+ """Rescinds a local knock made on a remote room
+
+ Request format:
+
+ POST /_synapse/replication/remote_rescind_knock/:event_id
+
+ {
+ "txn_id": ...,
+ "requester": ...,
+ "content": { ... }
+ }
+ """
+
+ NAME = "remote_rescind_knock"
+ PATH_ARGS = ("knock_event_id",)
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self.member_handler = hs.get_room_member_handler()
+
+ @staticmethod
+ async def _serialize_payload( # type: ignore
+ knock_event_id: str,
+ txn_id: Optional[str],
+ requester: Requester,
+ content: JsonDict,
+ ):
+ """
+ Args:
+ knock_event_id: The ID of the knock to be rescinded.
+ txn_id: An optional transaction ID supplied by the client.
+ requester: The user making the rescind request, according to the access token.
+ content: The content to include in the rescind event.
+ """
+ return {
+ "txn_id": txn_id,
+ "requester": requester.serialize(),
+ "content": content,
+ }
+
+ async def _handle_request( # type: ignore
+ self,
+ request: SynapseRequest,
+ knock_event_id: str,
+ ):
+ content = parse_json_object_from_request(request)
+
+ txn_id = content["txn_id"]
+ event_content = content["content"]
+
+ requester = Requester.deserialize(self.store, content["requester"])
+
+ request.requester = requester
+
+ # hopefully we're now on the master, so this won't recurse!
+ event_id, stream_id = await self.member_handler.remote_rescind_knock(
+ knock_event_id,
+ txn_id,
+ requester,
+ event_content,
+ )
+
+ return 200, {"event_id": event_id, "stream_id": stream_id}
+
+
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
"""Notifies that a user has joined or left the room
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 79d52d2dcb..138411ad19 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -38,6 +38,7 @@ from synapse.rest.client.v2_alpha import (
filter,
groups,
keys,
+ knock,
notifications,
openid,
password_policy,
@@ -121,6 +122,10 @@ class ClientRestResource(JsonResource):
relations.register_servlets(hs, client_resource)
password_policy.register_servlets(hs, client_resource)
+ # Register msc2403 (knocking) servlets if the feature is enabled
+ if hs.config.experimental.msc2403_enabled:
+ knock.register_servlets(hs, client_resource)
+
# moving to /_synapse/admin
admin.register_servlets_for_client_rest_resource(hs, client_resource)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 122105854a..16d087ea60 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -14,10 +14,9 @@
# limitations under the License.
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
-
import logging
import re
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from urllib import parse as urlparse
from synapse.api.constants import EventTypes, Membership
@@ -38,6 +37,7 @@ from synapse.http.servlet import (
parse_integer,
parse_json_object_from_request,
parse_string,
+ parse_strings_from_args,
)
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import set_tag
@@ -278,7 +278,12 @@ class JoinRoomAliasServlet(TransactionRestServlet):
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
- async def on_POST(self, request, room_identifier, txn_id=None):
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ room_identifier: str,
+ txn_id: Optional[str] = None,
+ ):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
try:
@@ -290,17 +295,18 @@ class JoinRoomAliasServlet(TransactionRestServlet):
if RoomID.is_valid(room_identifier):
room_id = room_identifier
- try:
- remote_room_hosts = [
- x.decode("ascii") for x in request.args[b"server_name"]
- ] # type: Optional[List[str]]
- except Exception:
- remote_room_hosts = None
+
+ # twisted.web.server.Request.args is incorrectly defined as Optional[Any]
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
+
+ remote_room_hosts = parse_strings_from_args(
+ args, "server_name", required=False
+ )
elif RoomAlias.is_valid(room_identifier):
handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
- room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
- room_id = room_id.to_string()
+ room_id_obj, remote_room_hosts = await handler.lookup_room_alias(room_alias)
+ room_id = room_id_obj.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index a57ccbb5e5..4a28f2c072 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -160,9 +160,12 @@ class KeyQueryServlet(RestServlet):
async def on_POST(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
+ device_id = requester.device_id
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
- result = await self.e2e_keys_handler.query_devices(body, timeout, user_id)
+ result = await self.e2e_keys_handler.query_devices(
+ body, timeout, user_id, device_id
+ )
return 200, result
diff --git a/synapse/rest/client/v2_alpha/knock.py b/synapse/rest/client/v2_alpha/knock.py
new file mode 100644
index 0000000000..f046bf9cb3
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/knock.py
@@ -0,0 +1,109 @@
+# Copyright 2020 Sorunome
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+
+from twisted.web.server import Request
+
+from synapse.api.constants import Membership
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ parse_json_object_from_request,
+ parse_strings_from_args,
+)
+from synapse.http.site import SynapseRequest
+from synapse.logging.opentracing import set_tag
+from synapse.rest.client.transactions import HttpTransactionCache
+from synapse.types import JsonDict, RoomAlias, RoomID
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
+from ._base import client_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class KnockRoomAliasServlet(RestServlet):
+ """
+ POST /xyz.amorgan.knock/{roomIdOrAlias}
+ """
+
+ PATTERNS = client_patterns(
+ "/xyz.amorgan.knock/(?P<room_identifier>[^/]*)", releases=()
+ )
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self.txns = HttpTransactionCache(hs)
+ self.room_member_handler = hs.get_room_member_handler()
+ self.auth = hs.get_auth()
+
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ room_identifier: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+
+ content = parse_json_object_from_request(request)
+ event_content = None
+ if "reason" in content:
+ event_content = {"reason": content["reason"]}
+
+ if RoomID.is_valid(room_identifier):
+ room_id = room_identifier
+
+ # twisted.web.server.Request.args is incorrectly defined as Optional[Any]
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
+
+ remote_room_hosts = parse_strings_from_args(
+ args, "server_name", required=False
+ )
+ elif RoomAlias.is_valid(room_identifier):
+ handler = self.room_member_handler
+ room_alias = RoomAlias.from_string(room_identifier)
+ room_id_obj, remote_room_hosts = await handler.lookup_room_alias(room_alias)
+ room_id = room_id_obj.to_string()
+ else:
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ )
+
+ await self.room_member_handler.update_membership(
+ requester=requester,
+ target=requester.user,
+ room_id=room_id,
+ action=Membership.KNOCK,
+ txn_id=txn_id,
+ third_party_signed=None,
+ remote_room_hosts=remote_room_hosts,
+ content=event_content,
+ )
+
+ return 200, {"room_id": room_id}
+
+ def on_PUT(self, request: Request, room_identifier: str, txn_id: str):
+ set_tag("txn_id", txn_id)
+
+ return self.txns.fetch_or_execute_request(
+ request, self.on_POST, request, room_identifier, txn_id
+ )
+
+
+def register_servlets(hs, http_server):
+ KnockRoomAliasServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 95ee3f1b84..042e1788b6 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -11,12 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import itertools
import logging
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
-from synapse.api.constants import PresenceState
+from synapse.api.constants import Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.events.utils import (
@@ -24,7 +23,7 @@ from synapse.events.utils import (
format_event_raw,
)
from synapse.handlers.presence import format_user_presence_state
-from synapse.handlers.sync import SyncConfig
+from synapse.handlers.sync import KnockedSyncResult, SyncConfig
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, StreamToken
@@ -220,6 +219,10 @@ class SyncRestServlet(RestServlet):
sync_result.invited, time_now, access_token_id, event_formatter
)
+ knocked = await self.encode_knocked(
+ sync_result.knocked, time_now, access_token_id, event_formatter
+ )
+
archived = await self.encode_archived(
sync_result.archived,
time_now,
@@ -237,11 +240,16 @@ class SyncRestServlet(RestServlet):
"left": list(sync_result.device_lists.left),
},
"presence": SyncRestServlet.encode_presence(sync_result.presence, time_now),
- "rooms": {"join": joined, "invite": invited, "leave": archived},
+ "rooms": {
+ Membership.JOIN: joined,
+ Membership.INVITE: invited,
+ Membership.KNOCK: knocked,
+ Membership.LEAVE: archived,
+ },
"groups": {
- "join": sync_result.groups.join,
- "invite": sync_result.groups.invite,
- "leave": sync_result.groups.leave,
+ Membership.JOIN: sync_result.groups.join,
+ Membership.INVITE: sync_result.groups.invite,
+ Membership.LEAVE: sync_result.groups.leave,
},
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
"org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types,
@@ -303,7 +311,7 @@ class SyncRestServlet(RestServlet):
Args:
rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of
- sync results for rooms this user is joined to
+ sync results for rooms this user is invited to
time_now(int): current time - used as a baseline for age
calculations
token_id(int): ID of the user's auth token - used for namespacing
@@ -322,7 +330,7 @@ class SyncRestServlet(RestServlet):
time_now,
token_id=token_id,
event_format=event_formatter,
- is_invite=True,
+ include_stripped_room_state=True,
)
unsigned = dict(invite.get("unsigned", {}))
invite["unsigned"] = unsigned
@@ -332,6 +340,60 @@ class SyncRestServlet(RestServlet):
return invited
+ async def encode_knocked(
+ self,
+ rooms: List[KnockedSyncResult],
+ time_now: int,
+ token_id: int,
+ event_formatter: Callable[[Dict], Dict],
+ ) -> Dict[str, Dict[str, Any]]:
+ """
+ Encode the rooms we've knocked on in a sync result.
+
+ Args:
+ rooms: list of sync results for rooms this user is knocking on
+ time_now: current time - used as a baseline for age calculations
+ token_id: ID of the user's auth token - used for namespacing of transaction IDs
+ event_formatter: function to convert from federation format to client format
+
+ Returns:
+ The list of rooms the user has knocked on, in our response format.
+ """
+ knocked = {}
+ for room in rooms:
+ knock = await self._event_serializer.serialize_event(
+ room.knock,
+ time_now,
+ token_id=token_id,
+ event_format=event_formatter,
+ include_stripped_room_state=True,
+ )
+
+ # Extract the `unsigned` key from the knock event.
+ # This is where we (cheekily) store the knock state events
+ unsigned = knock.setdefault("unsigned", {})
+
+ # Duplicate the dictionary in order to avoid modifying the original
+ unsigned = dict(unsigned)
+
+ # Extract the stripped room state from the unsigned dict
+ # This is for clients to get a little bit of information about
+ # the room they've knocked on, without revealing any sensitive information
+ knocked_state = list(unsigned.pop("knock_room_state", []))
+
+ # Append the actual knock membership event itself as well. This provides
+ # the client with:
+ #
+ # * A knock state event that they can use for easier internal tracking
+ # * The rough timestamp of when the knock occurred contained within the event
+ knocked_state.append(knock)
+
+ # Build the `knock_state` dictionary, which will contain the state of the
+ # room that the client has knocked on
+ knocked[room.room_id] = {"knock_state": {"events": knocked_state}}
+
+ return knocked
+
async def encode_archived(
self, rooms, time_now, token_id, event_fields, event_formatter
):
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 2a96bcd314..9f0d64a325 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -19,7 +19,7 @@ from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -177,11 +177,13 @@ class RoomWorkerStore(SQLBaseStore):
INNER JOIN room_stats_current USING (room_id)
WHERE
(
- join_rules = 'public' OR history_visibility = 'world_readable'
+ join_rules = 'public' OR join_rules = '%(knock_join_rule)s'
+ OR history_visibility = 'world_readable'
)
AND joined_members > 0
""" % {
- "published_sql": published_sql
+ "published_sql": published_sql,
+ "knock_join_rule": JoinRules.KNOCK,
}
txn.execute(sql, query_args)
@@ -303,7 +305,7 @@ class RoomWorkerStore(SQLBaseStore):
sql = """
SELECT
room_id, name, topic, canonical_alias, joined_members,
- avatar, history_visibility, joined_members, guest_access
+ avatar, history_visibility, guest_access, join_rules
FROM (
%(published_sql)s
) published
@@ -311,7 +313,8 @@ class RoomWorkerStore(SQLBaseStore):
INNER JOIN room_stats_current USING (room_id)
WHERE
(
- join_rules = 'public' OR history_visibility = 'world_readable'
+ join_rules = 'public' OR join_rules = '%(knock_join_rule)s'
+ OR history_visibility = 'world_readable'
)
AND joined_members > 0
%(where_clause)s
@@ -320,6 +323,7 @@ class RoomWorkerStore(SQLBaseStore):
"published_sql": published_sql,
"where_clause": where_clause,
"dir": "DESC" if forwards else "ASC",
+ "knock_join_rule": JoinRules.KNOCK,
}
if limit is not None:
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index ae9f880965..82a1833509 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -41,6 +41,7 @@ ABSOLUTE_STATS_FIELDS = {
"current_state_events",
"joined_members",
"invited_members",
+ "knocked_members",
"left_members",
"banned_members",
"local_users_in_room",
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 3799d46734..683e5e3b90 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -1,5 +1,4 @@
-# Copyright 2014 - 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2014 - 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -26,7 +25,7 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.engines.postgres import PostgresEngine
-from synapse.storage.schema import SCHEMA_VERSION
+from synapse.storage.schema import SCHEMA_COMPAT_VERSION, SCHEMA_VERSION
from synapse.storage.types import Cursor
logger = logging.getLogger(__name__)
@@ -59,6 +58,28 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = (
)
+@attr.s
+class _SchemaState:
+ current_version: int = attr.ib()
+ """The current schema version of the database"""
+
+ compat_version: Optional[int] = attr.ib()
+ """The SCHEMA_VERSION of the oldest version of Synapse for this database
+
+ If this is None, we have an old version of the database without the necessary
+ table.
+ """
+
+ applied_deltas: Collection[str] = attr.ib(factory=tuple)
+ """Any delta files for `current_version` which have already been applied"""
+
+ upgraded: bool = attr.ib(default=False)
+ """Whether the current state was reached by applying deltas.
+
+ If False, we have run the full schema for `current_version`, and have applied no
+ deltas since. If True, we have run some deltas since the original creation."""
+
+
def prepare_database(
db_conn: LoggingDatabaseConnection,
database_engine: BaseDatabaseEngine,
@@ -96,12 +117,11 @@ def prepare_database(
version_info = _get_or_create_schema_state(cur, database_engine)
if version_info:
- user_version, delta_files, upgraded = version_info
logger.info(
"%r: Existing schema is %i (+%i deltas)",
databases,
- user_version,
- len(delta_files),
+ version_info.current_version,
+ len(version_info.applied_deltas),
)
# config should only be None when we are preparing an in-memory SQLite db,
@@ -113,16 +133,18 @@ def prepare_database(
# if it's a worker app, refuse to upgrade the database, to avoid multiple
# workers doing it at once.
- if config.worker_app is not None and user_version != SCHEMA_VERSION:
+ if (
+ config.worker_app is not None
+ and version_info.current_version != SCHEMA_VERSION
+ ):
raise UpgradeDatabaseException(
- OUTDATED_SCHEMA_ON_WORKER_ERROR % (SCHEMA_VERSION, user_version)
+ OUTDATED_SCHEMA_ON_WORKER_ERROR
+ % (SCHEMA_VERSION, version_info.current_version)
)
_upgrade_existing_database(
cur,
- user_version,
- delta_files,
- upgraded,
+ version_info,
database_engine,
config,
databases=databases,
@@ -261,9 +283,7 @@ def _setup_new_database(
_upgrade_existing_database(
cur,
- current_version=max_current_ver,
- applied_delta_files=[],
- upgraded=False,
+ _SchemaState(current_version=max_current_ver, compat_version=None),
database_engine=database_engine,
config=None,
databases=databases,
@@ -273,9 +293,7 @@ def _setup_new_database(
def _upgrade_existing_database(
cur: Cursor,
- current_version: int,
- applied_delta_files: List[str],
- upgraded: bool,
+ current_schema_state: _SchemaState,
database_engine: BaseDatabaseEngine,
config: Optional[HomeServerConfig],
databases: Collection[str],
@@ -321,12 +339,8 @@ def _upgrade_existing_database(
Args:
cur
- current_version: The current version of the schema.
- applied_delta_files: A list of deltas that have already been applied.
- upgraded: Whether the current version was generated by having
- applied deltas or from full schema file. If `True` the function
- will never apply delta files for the given `current_version`, since
- the current_version wasn't generated by applying those delta files.
+ current_schema_state: The current version of the schema, as
+ returned by _get_or_create_schema_state
database_engine
config:
None if we are initialising a blank database, otherwise the application
@@ -337,13 +351,16 @@ def _upgrade_existing_database(
upgrade portions of the delta scripts.
"""
if is_empty:
- assert not applied_delta_files
+ assert not current_schema_state.applied_deltas
else:
assert config
is_worker = config and config.worker_app is not None
- if current_version > SCHEMA_VERSION:
+ if (
+ current_schema_state.compat_version is not None
+ and current_schema_state.compat_version > SCHEMA_VERSION
+ ):
raise ValueError(
"Cannot use this database as it is too "
+ "new for the server to understand"
@@ -357,14 +374,26 @@ def _upgrade_existing_database(
assert config is not None
check_database_before_upgrade(cur, database_engine, config)
- start_ver = current_version
+ # update schema_compat_version before we run any upgrades, so that if synapse
+ # gets downgraded again, it won't try to run against the upgraded database.
+ if (
+ current_schema_state.compat_version is None
+ or current_schema_state.compat_version < SCHEMA_COMPAT_VERSION
+ ):
+ cur.execute("DELETE FROM schema_compat_version")
+ cur.execute(
+ "INSERT INTO schema_compat_version(compat_version) VALUES (?)",
+ (SCHEMA_COMPAT_VERSION,),
+ )
+
+ start_ver = current_schema_state.current_version
# if we got to this schema version by running a full_schema rather than a series
# of deltas, we should not run the deltas for this version.
- if not upgraded:
+ if not current_schema_state.upgraded:
start_ver += 1
- logger.debug("applied_delta_files: %s", applied_delta_files)
+ logger.debug("applied_delta_files: %s", current_schema_state.applied_deltas)
if isinstance(database_engine, PostgresEngine):
specific_engine_extension = ".postgres"
@@ -440,7 +469,7 @@ def _upgrade_existing_database(
absolute_path = entry.absolute_path
logger.debug("Found file: %s (%s)", relative_path, absolute_path)
- if relative_path in applied_delta_files:
+ if relative_path in current_schema_state.applied_deltas:
continue
root_name, ext = os.path.splitext(file_name)
@@ -621,7 +650,7 @@ def execute_statements_from_stream(cur: Cursor, f: TextIO) -> None:
def _get_or_create_schema_state(
txn: Cursor, database_engine: BaseDatabaseEngine
-) -> Optional[Tuple[int, List[str], bool]]:
+) -> Optional[_SchemaState]:
# Bluntly try creating the schema_version tables.
sql_path = os.path.join(schema_path, "common", "schema_version.sql")
executescript(txn, sql_path)
@@ -629,17 +658,31 @@ def _get_or_create_schema_state(
txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone()
+ if row is None:
+ # new database
+ return None
+
+ current_version = int(row[0])
+ upgraded = bool(row[1])
+
+ compat_version: Optional[int] = None
+ txn.execute("SELECT compat_version FROM schema_compat_version")
+ row = txn.fetchone()
if row is not None:
- current_version = int(row[0])
- txn.execute(
- "SELECT file FROM applied_schema_deltas WHERE version >= ?",
- (current_version,),
- )
- applied_deltas = [d for d, in txn]
- upgraded = bool(row[1])
- return current_version, applied_deltas, upgraded
+ compat_version = int(row[0])
+
+ txn.execute(
+ "SELECT file FROM applied_schema_deltas WHERE version >= ?",
+ (current_version,),
+ )
+ applied_deltas = tuple(d for d, in txn)
- return None
+ return _SchemaState(
+ current_version=current_version,
+ compat_version=compat_version,
+ applied_deltas=applied_deltas,
+ upgraded=upgraded,
+ )
@attr.s(slots=True)
diff --git a/synapse/storage/schema/README.md b/synapse/storage/schema/README.md
index 030153db64..729f44ea6c 100644
--- a/synapse/storage/schema/README.md
+++ b/synapse/storage/schema/README.md
@@ -1,37 +1,4 @@
# Synapse Database Schemas
-This directory contains the schema files used to build Synapse databases.
-
-Synapse supports splitting its datastore across multiple physical databases (which can
-be useful for large installations), and the schema files are therefore split according
-to the logical database they are apply to.
-
-At the time of writing, the following "logical" databases are supported:
-
-* `state` - used to store Matrix room state (more specifically, `state_groups`,
- their relationships and contents.)
-* `main` - stores everything else.
-
-Addionally, the `common` directory contains schema files for tables which must be
-present on *all* physical databases.
-
-## Full schema dumps
-
-In the `full_schemas` directories, only the most recently-numbered snapshot is useful
-(`54` at the time of writing). Older snapshots (eg, `16`) are present for historical
-reference only.
-
-## Building full schema dumps
-
-If you want to recreate these schemas, they need to be made from a database that
-has had all background updates run.
-
-To do so, use `scripts-dev/make_full_schema.sh`. This will produce new
-`full.sql.postgres` and `full.sql.sqlite` files.
-
-Ensure postgres is installed, then run:
-
- ./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/
-
-NB at the time of writing, this script predates the split into separate `state`/`main`
-databases so will require updates to handle that correctly.
+This directory contains the schema files used to build Synapse databases. For more
+information, see /docs/development/database_schema.md.
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index f0d9f23167..d36ba1d773 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,6 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# Remember to update this number every time a change is made to database
-# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 59
+"""Represents the expectations made by the codebase about the database schema
+
+This should be incremented whenever the codebase changes its requirements on the
+shape of the database schema (even if those requirements are backwards-compatible with
+older versions of Synapse).
+
+See `README.md <synapse/storage/schema/README.md>`_ for more information on how this
+works.
+"""
+
+
+SCHEMA_COMPAT_VERSION = 59
+"""Limit on how far the synapse codebase can be rolled back without breaking db compat
+
+This value is stored in the database, and checked on startup. If the value in the
+database is greater than SCHEMA_VERSION, then Synapse will refuse to start.
+"""
diff --git a/synapse/storage/schema/common/schema_version.sql b/synapse/storage/schema/common/schema_version.sql
index 42e5cb6df5..f41fde5d2d 100644
--- a/synapse/storage/schema/common/schema_version.sql
+++ b/synapse/storage/schema/common/schema_version.sql
@@ -20,6 +20,13 @@ CREATE TABLE IF NOT EXISTS schema_version(
CHECK (Lock='X')
);
+CREATE TABLE IF NOT EXISTS schema_compat_version(
+ Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
+ -- The SCHEMA_VERSION of the oldest synapse this database can be used with
+ compat_version INTEGER NOT NULL,
+ CHECK (Lock='X')
+);
+
CREATE TABLE IF NOT EXISTS applied_schema_deltas(
version INTEGER NOT NULL,
file TEXT NOT NULL,
diff --git a/synapse/storage/schema/main/delta/59/11add_knock_members_to_stats.sql b/synapse/storage/schema/main/delta/59/11add_knock_members_to_stats.sql
new file mode 100644
index 0000000000..56c0ad0003
--- /dev/null
+++ b/synapse/storage/schema/main/delta/59/11add_knock_members_to_stats.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 Sorunome
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE room_stats_current ADD COLUMN knocked_members INT NOT NULL DEFAULT '0';
+ALTER TABLE room_stats_historical ADD COLUMN knocked_members BIGINT NOT NULL DEFAULT '0';
\ No newline at end of file
|