diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index edba5a9808..8e2be218e2 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -21,7 +21,7 @@ from six.moves import urllib
from twisted.internet import defer
from synapse.api.constants import Membership
-from synapse.api.urls import FEDERATION_PREFIX as PREFIX
+from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX
from synapse.util.logutils import log_function
logger = logging.getLogger(__name__)
@@ -51,7 +51,7 @@ class TransportLayerClient(object):
logger.debug("get_room_state dest=%s, room=%s",
destination, room_id)
- path = _create_path(PREFIX, "/state/%s/", room_id)
+ path = _create_v1_path("/state/%s/", room_id)
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
)
@@ -73,7 +73,7 @@ class TransportLayerClient(object):
logger.debug("get_room_state_ids dest=%s, room=%s",
destination, room_id)
- path = _create_path(PREFIX, "/state_ids/%s/", room_id)
+ path = _create_v1_path("/state_ids/%s/", room_id)
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
)
@@ -95,7 +95,7 @@ class TransportLayerClient(object):
logger.debug("get_pdu dest=%s, event_id=%s",
destination, event_id)
- path = _create_path(PREFIX, "/event/%s/", event_id)
+ path = _create_v1_path("/event/%s/", event_id)
return self.client.get_json(destination, path=path, timeout=timeout)
@log_function
@@ -121,7 +121,7 @@ class TransportLayerClient(object):
# TODO: raise?
return
- path = _create_path(PREFIX, "/backfill/%s/", room_id)
+ path = _create_v1_path("/backfill/%s/", room_id)
args = {
"v": event_tuples,
@@ -167,7 +167,7 @@ class TransportLayerClient(object):
# generated by the json_data_callback.
json_data = transaction.get_dict()
- path = _create_path(PREFIX, "/send/%s/", transaction.transaction_id)
+ path = _create_v1_path("/send/%s/", transaction.transaction_id)
response = yield self.client.put_json(
transaction.destination,
@@ -184,7 +184,7 @@ class TransportLayerClient(object):
@log_function
def make_query(self, destination, query_type, args, retry_on_dns_fail,
ignore_backoff=False):
- path = _create_path(PREFIX, "/query/%s", query_type)
+ path = _create_v1_path("/query/%s", query_type)
content = yield self.client.get_json(
destination=destination,
@@ -231,7 +231,7 @@ class TransportLayerClient(object):
"make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships))
)
- path = _create_path(PREFIX, "/make_%s/%s/%s", membership, room_id, user_id)
+ path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id)
ignore_backoff = False
retry_on_dns_fail = False
@@ -258,7 +258,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_join(self, destination, room_id, event_id, content):
- path = _create_path(PREFIX, "/send_join/%s/%s", room_id, event_id)
+ path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
@@ -271,7 +271,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_leave(self, destination, room_id, event_id, content):
- path = _create_path(PREFIX, "/send_leave/%s/%s", room_id, event_id)
+ path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
@@ -289,8 +289,22 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
- def send_invite(self, destination, room_id, event_id, content):
- path = _create_path(PREFIX, "/invite/%s/%s", room_id, event_id)
+ def send_invite_v1(self, destination, room_id, event_id, content):
+ path = _create_v1_path("/invite/%s/%s", room_id, event_id)
+
+ response = yield self.client.put_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ defer.returnValue(response)
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_invite_v2(self, destination, room_id, event_id, content):
+ path = _create_v2_path("/invite/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
@@ -306,7 +320,7 @@ class TransportLayerClient(object):
def get_public_rooms(self, remote_server, limit, since_token,
search_filter=None, include_all_networks=False,
third_party_instance_id=None):
- path = PREFIX + "/publicRooms"
+ path = _create_v1_path("/publicRooms")
args = {
"include_all_networks": "true" if include_all_networks else "false",
@@ -332,7 +346,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(self, destination, room_id, event_dict):
- path = _create_path(PREFIX, "/exchange_third_party_invite/%s", room_id,)
+ path = _create_v1_path("/exchange_third_party_invite/%s", room_id,)
response = yield self.client.put_json(
destination=destination,
@@ -345,7 +359,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
- path = _create_path(PREFIX, "/event_auth/%s/%s", room_id, event_id)
+ path = _create_v1_path("/event_auth/%s/%s", room_id, event_id)
content = yield self.client.get_json(
destination=destination,
@@ -357,7 +371,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_query_auth(self, destination, room_id, event_id, content):
- path = _create_path(PREFIX, "/query_auth/%s/%s", room_id, event_id)
+ path = _create_v1_path("/query_auth/%s/%s", room_id, event_id)
content = yield self.client.post_json(
destination=destination,
@@ -392,7 +406,7 @@ class TransportLayerClient(object):
Returns:
A dict containg the device keys.
"""
- path = PREFIX + "/user/keys/query"
+ path = _create_v1_path("/user/keys/query")
content = yield self.client.post_json(
destination=destination,
@@ -419,7 +433,7 @@ class TransportLayerClient(object):
Returns:
A dict containg the device keys.
"""
- path = _create_path(PREFIX, "/user/devices/%s", user_id)
+ path = _create_v1_path("/user/devices/%s", user_id)
content = yield self.client.get_json(
destination=destination,
@@ -455,7 +469,7 @@ class TransportLayerClient(object):
A dict containg the one-time keys.
"""
- path = PREFIX + "/user/keys/claim"
+ path = _create_v1_path("/user/keys/claim")
content = yield self.client.post_json(
destination=destination,
@@ -469,7 +483,7 @@ class TransportLayerClient(object):
@log_function
def get_missing_events(self, destination, room_id, earliest_events,
latest_events, limit, min_depth, timeout):
- path = _create_path(PREFIX, "/get_missing_events/%s", room_id,)
+ path = _create_v1_path("/get_missing_events/%s", room_id,)
content = yield self.client.post_json(
destination=destination,
@@ -489,7 +503,7 @@ class TransportLayerClient(object):
def get_group_profile(self, destination, group_id, requester_user_id):
"""Get a group profile
"""
- path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
+ path = _create_v1_path("/groups/%s/profile", group_id,)
return self.client.get_json(
destination=destination,
@@ -508,7 +522,7 @@ class TransportLayerClient(object):
requester_user_id (str)
content (dict): The new profile of the group
"""
- path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
+ path = _create_v1_path("/groups/%s/profile", group_id,)
return self.client.post_json(
destination=destination,
@@ -522,7 +536,7 @@ class TransportLayerClient(object):
def get_group_summary(self, destination, group_id, requester_user_id):
"""Get a group summary
"""
- path = _create_path(PREFIX, "/groups/%s/summary", group_id,)
+ path = _create_v1_path("/groups/%s/summary", group_id,)
return self.client.get_json(
destination=destination,
@@ -535,7 +549,7 @@ class TransportLayerClient(object):
def get_rooms_in_group(self, destination, group_id, requester_user_id):
"""Get all rooms in a group
"""
- path = _create_path(PREFIX, "/groups/%s/rooms", group_id,)
+ path = _create_v1_path("/groups/%s/rooms", group_id,)
return self.client.get_json(
destination=destination,
@@ -548,7 +562,7 @@ class TransportLayerClient(object):
content):
"""Add a room to a group
"""
- path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
+ path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,)
return self.client.post_json(
destination=destination,
@@ -562,8 +576,8 @@ class TransportLayerClient(object):
config_key, content):
"""Update room in group
"""
- path = _create_path(
- PREFIX, "/groups/%s/room/%s/config/%s",
+ path = _create_v1_path(
+ "/groups/%s/room/%s/config/%s",
group_id, room_id, config_key,
)
@@ -578,7 +592,7 @@ class TransportLayerClient(object):
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
"""Remove a room from a group
"""
- path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
+ path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,)
return self.client.delete_json(
destination=destination,
@@ -591,7 +605,7 @@ class TransportLayerClient(object):
def get_users_in_group(self, destination, group_id, requester_user_id):
"""Get users in a group
"""
- path = _create_path(PREFIX, "/groups/%s/users", group_id,)
+ path = _create_v1_path("/groups/%s/users", group_id,)
return self.client.get_json(
destination=destination,
@@ -604,7 +618,7 @@ class TransportLayerClient(object):
def get_invited_users_in_group(self, destination, group_id, requester_user_id):
"""Get users that have been invited to a group
"""
- path = _create_path(PREFIX, "/groups/%s/invited_users", group_id,)
+ path = _create_v1_path("/groups/%s/invited_users", group_id,)
return self.client.get_json(
destination=destination,
@@ -617,8 +631,8 @@ class TransportLayerClient(object):
def accept_group_invite(self, destination, group_id, user_id, content):
"""Accept a group invite
"""
- path = _create_path(
- PREFIX, "/groups/%s/users/%s/accept_invite",
+ path = _create_v1_path(
+ "/groups/%s/users/%s/accept_invite",
group_id, user_id,
)
@@ -633,7 +647,7 @@ class TransportLayerClient(object):
def join_group(self, destination, group_id, user_id, content):
"""Attempts to join a group
"""
- path = _create_path(PREFIX, "/groups/%s/users/%s/join", group_id, user_id)
+ path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
return self.client.post_json(
destination=destination,
@@ -646,7 +660,7 @@ class TransportLayerClient(object):
def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
"""Invite a user to a group
"""
- path = _create_path(PREFIX, "/groups/%s/users/%s/invite", group_id, user_id)
+ path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id)
return self.client.post_json(
destination=destination,
@@ -662,7 +676,7 @@ class TransportLayerClient(object):
invited.
"""
- path = _create_path(PREFIX, "/groups/local/%s/users/%s/invite", group_id, user_id)
+ path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id)
return self.client.post_json(
destination=destination,
@@ -676,7 +690,7 @@ class TransportLayerClient(object):
user_id, content):
"""Remove a user fron a group
"""
- path = _create_path(PREFIX, "/groups/%s/users/%s/remove", group_id, user_id)
+ path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
return self.client.post_json(
destination=destination,
@@ -693,7 +707,7 @@ class TransportLayerClient(object):
kicked from the group.
"""
- path = _create_path(PREFIX, "/groups/local/%s/users/%s/remove", group_id, user_id)
+ path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id)
return self.client.post_json(
destination=destination,
@@ -708,7 +722,7 @@ class TransportLayerClient(object):
the attestations
"""
- path = _create_path(PREFIX, "/groups/%s/renew_attestation/%s", group_id, user_id)
+ path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id)
return self.client.post_json(
destination=destination,
@@ -723,12 +737,12 @@ class TransportLayerClient(object):
"""Update a room entry in a group summary
"""
if category_id:
- path = _create_path(
- PREFIX, "/groups/%s/summary/categories/%s/rooms/%s",
+ path = _create_v1_path(
+ "/groups/%s/summary/categories/%s/rooms/%s",
group_id, category_id, room_id,
)
else:
- path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
+ path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,)
return self.client.post_json(
destination=destination,
@@ -744,12 +758,12 @@ class TransportLayerClient(object):
"""Delete a room entry in a group summary
"""
if category_id:
- path = _create_path(
- PREFIX + "/groups/%s/summary/categories/%s/rooms/%s",
+ path = _create_v1_path(
+ "/groups/%s/summary/categories/%s/rooms/%s",
group_id, category_id, room_id,
)
else:
- path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
+ path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,)
return self.client.delete_json(
destination=destination,
@@ -762,7 +776,7 @@ class TransportLayerClient(object):
def get_group_categories(self, destination, group_id, requester_user_id):
"""Get all categories in a group
"""
- path = _create_path(PREFIX, "/groups/%s/categories", group_id,)
+ path = _create_v1_path("/groups/%s/categories", group_id,)
return self.client.get_json(
destination=destination,
@@ -775,7 +789,7 @@ class TransportLayerClient(object):
def get_group_category(self, destination, group_id, requester_user_id, category_id):
"""Get category info in a group
"""
- path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
+ path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
return self.client.get_json(
destination=destination,
@@ -789,7 +803,7 @@ class TransportLayerClient(object):
content):
"""Update a category in a group
"""
- path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
+ path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
return self.client.post_json(
destination=destination,
@@ -804,7 +818,7 @@ class TransportLayerClient(object):
category_id):
"""Delete a category in a group
"""
- path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
+ path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
return self.client.delete_json(
destination=destination,
@@ -817,7 +831,7 @@ class TransportLayerClient(object):
def get_group_roles(self, destination, group_id, requester_user_id):
"""Get all roles in a group
"""
- path = _create_path(PREFIX, "/groups/%s/roles", group_id,)
+ path = _create_v1_path("/groups/%s/roles", group_id,)
return self.client.get_json(
destination=destination,
@@ -830,7 +844,7 @@ class TransportLayerClient(object):
def get_group_role(self, destination, group_id, requester_user_id, role_id):
"""Get a roles info
"""
- path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
+ path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
return self.client.get_json(
destination=destination,
@@ -844,7 +858,7 @@ class TransportLayerClient(object):
content):
"""Update a role in a group
"""
- path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
+ path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
return self.client.post_json(
destination=destination,
@@ -858,7 +872,7 @@ class TransportLayerClient(object):
def delete_group_role(self, destination, group_id, requester_user_id, role_id):
"""Delete a role in a group
"""
- path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
+ path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
return self.client.delete_json(
destination=destination,
@@ -873,12 +887,12 @@ class TransportLayerClient(object):
"""Update a users entry in a group
"""
if role_id:
- path = _create_path(
- PREFIX, "/groups/%s/summary/roles/%s/users/%s",
+ path = _create_v1_path(
+ "/groups/%s/summary/roles/%s/users/%s",
group_id, role_id, user_id,
)
else:
- path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
+ path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,)
return self.client.post_json(
destination=destination,
@@ -893,7 +907,7 @@ class TransportLayerClient(object):
content):
"""Sets the join policy for a group
"""
- path = _create_path(PREFIX, "/groups/%s/settings/m.join_policy", group_id,)
+ path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id,)
return self.client.put_json(
destination=destination,
@@ -909,12 +923,12 @@ class TransportLayerClient(object):
"""Delete a users entry in a group
"""
if role_id:
- path = _create_path(
- PREFIX, "/groups/%s/summary/roles/%s/users/%s",
+ path = _create_v1_path(
+ "/groups/%s/summary/roles/%s/users/%s",
group_id, role_id, user_id,
)
else:
- path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
+ path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,)
return self.client.delete_json(
destination=destination,
@@ -927,7 +941,7 @@ class TransportLayerClient(object):
"""Get the groups a list of users are publicising
"""
- path = PREFIX + "/get_groups_publicised"
+ path = _create_v1_path("/get_groups_publicised")
content = {"user_ids": user_ids}
@@ -939,20 +953,43 @@ class TransportLayerClient(object):
)
-def _create_path(prefix, path, *args):
- """Creates a path from the prefix, path template and args. Ensures that
- all args are url encoded.
+def _create_v1_path(path, *args):
+ """Creates a path against V1 federation API from the path template and
+ args. Ensures that all args are url encoded.
+
+ Example:
+
+ _create_v1_path("/event/%s/", event_id)
+
+ Args:
+ path (str): String template for the path
+ args: ([str]): Args to insert into path. Each arg will be url encoded
+
+ Returns:
+ str
+ """
+ return (
+ FEDERATION_V1_PREFIX
+ + path % tuple(urllib.parse.quote(arg, "") for arg in args)
+ )
+
+
+def _create_v2_path(path, *args):
+ """Creates a path against V2 federation API from the path template and
+ args. Ensures that all args are url encoded.
Example:
- _create_path(PREFIX, "/event/%s/", event_id)
+ _create_v2_path("/event/%s/", event_id)
Args:
- prefix (str)
path (str): String template for the path
args: ([str]): Args to insert into path. Each arg will be url encoded
Returns:
str
"""
- return prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args)
+ return (
+ FEDERATION_V2_PREFIX
+ + path % tuple(urllib.parse.quote(arg, "") for arg in args)
+ )
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 6d4a26f595..5ba94be2ec 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -21,8 +21,9 @@ import re
from twisted.internet import defer
import synapse
+from synapse.api.constants import RoomVersions
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
-from synapse.api.urls import FEDERATION_PREFIX as PREFIX
+from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
@@ -42,9 +43,20 @@ logger = logging.getLogger(__name__)
class TransportLayerServer(JsonResource):
"""Handles incoming federation HTTP requests"""
- def __init__(self, hs):
+ def __init__(self, hs, servlet_groups=None):
+ """Initialize the TransportLayerServer
+
+ Will by default register all servlets. For custom behaviour, pass in
+ a list of servlet_groups to register.
+
+ Args:
+ hs (synapse.server.HomeServer): homeserver
+ servlet_groups (list[str], optional): List of servlet groups to register.
+ Defaults to ``DEFAULT_SERVLET_GROUPS``.
+ """
self.hs = hs
self.clock = hs.get_clock()
+ self.servlet_groups = servlet_groups
super(TransportLayerServer, self).__init__(hs, canonical_json=False)
@@ -66,6 +78,7 @@ class TransportLayerServer(JsonResource):
resource=self,
ratelimiter=self.ratelimiter,
authenticator=self.authenticator,
+ servlet_groups=self.servlet_groups,
)
@@ -227,6 +240,8 @@ class BaseFederationServlet(object):
"""
REQUIRE_AUTH = True
+ PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version
+
def __init__(self, handler, authenticator, ratelimiter, server_name):
self.handler = handler
self.authenticator = authenticator
@@ -286,7 +301,7 @@ class BaseFederationServlet(object):
return new_func
def register(self, server):
- pattern = re.compile("^" + PREFIX + self.PATH + "$")
+ pattern = re.compile("^" + self.PREFIX + self.PATH + "$")
for method in ("GET", "PUT", "POST"):
code = getattr(self, "on_%s" % (method), None)
@@ -362,14 +377,6 @@ class FederationSendServlet(BaseFederationServlet):
defer.returnValue((code, response))
-class FederationPullServlet(BaseFederationServlet):
- PATH = "/pull/"
-
- # This is for when someone asks us for everything since version X
- def on_GET(self, origin, content, query):
- return self.handler.on_pull_request(query["origin"][0], query["v"])
-
-
class FederationEventServlet(BaseFederationServlet):
PATH = "/event/(?P<event_id>[^/]*)/"
@@ -474,7 +481,7 @@ class FederationSendLeaveServlet(BaseFederationServlet):
@defer.inlineCallbacks
def on_PUT(self, origin, content, query, room_id, event_id):
- content = yield self.handler.on_send_leave_request(origin, content)
+ content = yield self.handler.on_send_leave_request(origin, content, room_id)
defer.returnValue((200, content))
@@ -492,18 +499,50 @@ class FederationSendJoinServlet(BaseFederationServlet):
def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
# match those given in content
- content = yield self.handler.on_send_join_request(origin, content)
+ content = yield self.handler.on_send_join_request(origin, content, context)
defer.returnValue((200, content))
-class FederationInviteServlet(BaseFederationServlet):
+class FederationV1InviteServlet(BaseFederationServlet):
PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
@defer.inlineCallbacks
def on_PUT(self, origin, content, query, context, event_id):
+ # We don't get a room version, so we have to assume its EITHER v1 or
+ # v2. This is "fine" as the only difference between V1 and V2 is the
+ # state resolution algorithm, and we don't use that for processing
+ # invites
+ content = yield self.handler.on_invite_request(
+ origin, content, room_version=RoomVersions.V1,
+ )
+
+ # V1 federation API is defined to return a content of `[200, {...}]`
+ # due to a historical bug.
+ defer.returnValue((200, (200, content)))
+
+
+class FederationV2InviteServlet(BaseFederationServlet):
+ PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+
+ PREFIX = FEDERATION_V2_PREFIX
+
+ @defer.inlineCallbacks
+ def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
# match those given in content
- content = yield self.handler.on_invite_request(origin, content)
+
+ room_version = content["room_version"]
+ event = content["event"]
+ invite_room_state = content["invite_room_state"]
+
+ # Synapse expects invite_room_state to be in unsigned, as it is in v1
+ # API
+
+ event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state
+
+ content = yield self.handler.on_invite_request(
+ origin, event, room_version=room_version,
+ )
defer.returnValue((200, content))
@@ -1262,7 +1301,6 @@ class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
FEDERATION_SERVLET_CLASSES = (
FederationSendServlet,
- FederationPullServlet,
FederationEventServlet,
FederationStateServlet,
FederationStateIdsServlet,
@@ -1273,7 +1311,8 @@ FEDERATION_SERVLET_CLASSES = (
FederationEventServlet,
FederationSendJoinServlet,
FederationSendLeaveServlet,
- FederationInviteServlet,
+ FederationV1InviteServlet,
+ FederationV2InviteServlet,
FederationQueryAuthServlet,
FederationGetMissingEventsServlet,
FederationEventAuthServlet,
@@ -1282,10 +1321,12 @@ FEDERATION_SERVLET_CLASSES = (
FederationClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
- OpenIdUserInfo,
FederationVersionServlet,
)
+OPENID_SERVLET_CLASSES = (
+ OpenIdUserInfo,
+)
ROOM_LIST_CLASSES = (
PublicRoomList,
@@ -1324,44 +1365,83 @@ GROUP_ATTESTATION_SERVLET_CLASSES = (
FederationGroupsRenewAttestaionServlet,
)
+DEFAULT_SERVLET_GROUPS = (
+ "federation",
+ "room_list",
+ "group_server",
+ "group_local",
+ "group_attestation",
+ "openid",
+)
+
-def register_servlets(hs, resource, authenticator, ratelimiter):
- for servletclass in FEDERATION_SERVLET_CLASSES:
- servletclass(
- handler=hs.get_federation_server(),
- authenticator=authenticator,
- ratelimiter=ratelimiter,
- server_name=hs.hostname,
- ).register(resource)
-
- for servletclass in ROOM_LIST_CLASSES:
- servletclass(
- handler=hs.get_room_list_handler(),
- authenticator=authenticator,
- ratelimiter=ratelimiter,
- server_name=hs.hostname,
- ).register(resource)
-
- for servletclass in GROUP_SERVER_SERVLET_CLASSES:
- servletclass(
- handler=hs.get_groups_server_handler(),
- authenticator=authenticator,
- ratelimiter=ratelimiter,
- server_name=hs.hostname,
- ).register(resource)
-
- for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
- servletclass(
- handler=hs.get_groups_local_handler(),
- authenticator=authenticator,
- ratelimiter=ratelimiter,
- server_name=hs.hostname,
- ).register(resource)
-
- for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
- servletclass(
- handler=hs.get_groups_attestation_renewer(),
- authenticator=authenticator,
- ratelimiter=ratelimiter,
- server_name=hs.hostname,
- ).register(resource)
+def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=None):
+ """Initialize and register servlet classes.
+
+ Will by default register all servlets. For custom behaviour, pass in
+ a list of servlet_groups to register.
+
+ Args:
+ hs (synapse.server.HomeServer): homeserver
+ resource (TransportLayerServer): resource class to register to
+ authenticator (Authenticator): authenticator to use
+ ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
+ servlet_groups (list[str], optional): List of servlet groups to register.
+ Defaults to ``DEFAULT_SERVLET_GROUPS``.
+ """
+ if not servlet_groups:
+ servlet_groups = DEFAULT_SERVLET_GROUPS
+
+ if "federation" in servlet_groups:
+ for servletclass in FEDERATION_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_federation_server(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
+
+ if "openid" in servlet_groups:
+ for servletclass in OPENID_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_federation_server(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
+
+ if "room_list" in servlet_groups:
+ for servletclass in ROOM_LIST_CLASSES:
+ servletclass(
+ handler=hs.get_room_list_handler(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
+
+ if "group_server" in servlet_groups:
+ for servletclass in GROUP_SERVER_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_groups_server_handler(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
+
+ if "group_local" in servlet_groups:
+ for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_groups_local_handler(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
+
+ if "group_attestation" in servlet_groups:
+ for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_groups_attestation_renewer(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
|