diff options
-rw-r--r-- | synapse/federation/transport/server.py | 339 | ||||
-rw-r--r-- | synapse/handlers/events.py | 2 | ||||
-rw-r--r-- | synapse/handlers/presence.py | 6 | ||||
-rw-r--r-- | synapse/handlers/profile.py | 3 | ||||
-rw-r--r-- | synapse/handlers/room.py | 6 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 6 | ||||
-rw-r--r-- | tests/handlers/test_presence.py | 4 | ||||
-rw-r--r-- | tests/handlers/test_presencelike.py | 8 | ||||
-rw-r--r-- | tests/rest/client/v1/test_presence.py | 8 |
9 files changed, 179 insertions, 203 deletions
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index ece6dbcf62..6c624977d7 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -19,6 +19,7 @@ from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.errors import Codes, SynapseError from synapse.util.logutils import log_function +import functools import logging import simplejson as json import re @@ -30,8 +31,9 @@ logger = logging.getLogger(__name__) class TransportLayerServer(object): """Handles incoming federation HTTP requests""" + # A method just so we can pass 'self' as the authenticator to the Servlets @defer.inlineCallbacks - def _authenticate_request(self, request): + def authenticate_request(self, request): json_request = { "method": request.method, "uri": request.uri, @@ -93,28 +95,6 @@ class TransportLayerServer(object): defer.returnValue((origin, content)) - def _with_authentication(self, handler): - @defer.inlineCallbacks - def new_handler(request, *args, **kwargs): - try: - (origin, content) = yield self._authenticate_request(request) - with self.ratelimiter.ratelimit(origin) as d: - yield d - response = yield handler( - origin, content, request.args, *args, **kwargs - ) - except: - logger.exception("_authenticate_request failed") - raise - defer.returnValue(response) - return new_handler - - def rate_limit_origin(self, handler): - def new_handler(origin, *args, **kwargs): - response = yield handler(origin, *args, **kwargs) - defer.returnValue(response) - return new_handler() - @log_function def register_received_handler(self, handler): """ Register a handler that will be fired when we receive data. @@ -122,14 +102,12 @@ class TransportLayerServer(object): Args: handler (TransportReceivedHandler) """ - self.received_handler = handler - - # This is when someone is trying to send us a bunch of data. - self.server.register_path( - "PUT", - re.compile("^" + PREFIX + "/send/([^/]*)/$"), - self._with_authentication(self._on_send_request) - ) + FederationSendServlet( + handler, + authenticator=self, + ratelimiter=self.ratelimiter, + server_name=self.server_name, + ).register(self.server) @log_function def register_request_handler(self, handler): @@ -138,136 +116,61 @@ class TransportLayerServer(object): Args: handler (TransportRequestHandler) """ - self.request_handler = handler - - # This is for when someone asks us for everything since version X - self.server.register_path( - "GET", - re.compile("^" + PREFIX + "/pull/$"), - self._with_authentication( - lambda origin, content, query: - handler.on_pull_request(query["origin"][0], query["v"]) - ) - ) + for servletclass in SERVLET_CLASSES: + servletclass( + handler, + authenticator=self, + ratelimiter=self.ratelimiter, + ).register(self.server) - # This is when someone asks for a data item for a given server - # data_id pair. - self.server.register_path( - "GET", - re.compile("^" + PREFIX + "/event/([^/]*)/$"), - self._with_authentication( - lambda origin, content, query, event_id: - handler.on_pdu_request(origin, event_id) - ) - ) - # This is when someone asks for all data for a given context. - self.server.register_path( - "GET", - re.compile("^" + PREFIX + "/state/([^/]*)/$"), - self._with_authentication( - lambda origin, content, query, context: - handler.on_context_state_request( - origin, - context, - query.get("event_id", [None])[0], - ) - ) - ) +class BaseFederationServlet(object): + def __init__(self, handler, authenticator, ratelimiter): + self.handler = handler + self.authenticator = authenticator + self.ratelimiter = ratelimiter - self.server.register_path( - "GET", - re.compile("^" + PREFIX + "/backfill/([^/]*)/$"), - self._with_authentication( - lambda origin, content, query, context: - self._on_backfill_request( - origin, context, query["v"], query["limit"] - ) - ) - ) + def _wrap(self, code): + authenticator = self.authenticator + ratelimiter = self.ratelimiter - # This is when we receive a server-server Query - self.server.register_path( - "GET", - re.compile("^" + PREFIX + "/query/([^/]*)$"), - self._with_authentication( - lambda origin, content, query, query_type: - handler.on_query_request( - query_type, - {k: v[0].decode("utf-8") for k, v in query.items()} - ) - ) - ) + @defer.inlineCallbacks + @functools.wraps(code) + def new_code(request, *args, **kwargs): + try: + (origin, content) = yield authenticator.authenticate_request(request) + with ratelimiter.ratelimit(origin) as d: + yield d + response = yield code( + origin, content, request.args, *args, **kwargs + ) + except: + logger.exception("authenticate_request failed") + raise + defer.returnValue(response) + return new_code - self.server.register_path( - "GET", - re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"), - self._with_authentication( - lambda origin, content, query, context, user_id: - self._on_make_join_request( - origin, content, query, context, user_id - ) - ) - ) + def register(self, server): + pattern = re.compile("^" + PREFIX + self.PATH + "$") - self.server.register_path( - "GET", - re.compile("^" + PREFIX + "/event_auth/([^/]*)/([^/]*)$"), - self._with_authentication( - lambda origin, content, query, context, event_id: - handler.on_event_auth( - origin, context, event_id, - ) - ) - ) + for method in ("GET", "PUT", "POST"): + code = getattr(self, "on_%s" % (method), None) + if code is None: + continue - self.server.register_path( - "PUT", - re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"), - self._with_authentication( - lambda origin, content, query, context, event_id: - self._on_send_join_request( - origin, content, query, - ) - ) - ) + server.register_path(method, pattern, self._wrap(code)) - self.server.register_path( - "PUT", - re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)$"), - self._with_authentication( - lambda origin, content, query, context, event_id: - self._on_invite_request( - origin, content, query, - ) - ) - ) - self.server.register_path( - "POST", - re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"), - self._with_authentication( - lambda origin, content, query, context, event_id: - self._on_query_auth_request( - origin, content, event_id, - ) - ) - ) +class FederationSendServlet(BaseFederationServlet): + PATH = "/send/([^/]*)/" - self.server.register_path( - "POST", - re.compile("^" + PREFIX + "/get_missing_events/([^/]*)/?$"), - self._with_authentication( - lambda origin, content, query, room_id: - self._get_missing_events( - origin, content, room_id, - ) - ) - ) + def __init__(self, handler, server_name, **kwargs): + super(FederationSendServlet, self).__init__(handler, **kwargs) + self.server_name = server_name + # This is when someone is trying to send us a bunch of data. @defer.inlineCallbacks - @log_function - def _on_send_request(self, origin, content, query, transaction_id): + def on_PUT(self, origin, content, query, transaction_id): """ Called on PUT /send/<transaction_id>/ Args: @@ -305,8 +208,7 @@ class TransportLayerServer(object): return try: - handler = self.received_handler - code, response = yield handler.on_incoming_transaction( + code, response = yield self.handler.on_incoming_transaction( transaction_data ) except: @@ -315,65 +217,123 @@ class TransportLayerServer(object): defer.returnValue((code, response)) - @log_function - def _on_backfill_request(self, origin, context, v_list, limits): + +class FederationPullServlet(BaseFederationServlet): + PATH = "/pull/" + + # This is for when someone asks us for everything since version X + def on_GET(self, origin, content, query): + return self.handler.on_pull_request(query["origin"][0], query["v"]) + + +class FederationEventServlet(BaseFederationServlet): + PATH = "/event/([^/]*)/" + + # This is when someone asks for a data item for a given server data_id pair. + def on_GET(self, origin, content, query, event_id): + return self.handler.on_pdu_request(origin, event_id) + + +class FederationStateServlet(BaseFederationServlet): + PATH = "/state/([^/]*)/" + + # This is when someone asks for all data for a given context. + def on_GET(self, origin, content, query, context): + return self.handler.on_context_state_request( + origin, + context, + query.get("event_id", [None])[0], + ) + + +class FederationBackfillServlet(BaseFederationServlet): + PATH = "/backfill/([^/]*)/" + + def on_GET(self, origin, content, query, context): + versions = query["v"] + limits = query["limit"] + if not limits: - return defer.succeed( - (400, {"error": "Did not include limit param"}) - ) + return defer.succeed((400, {"error": "Did not include limit param"})) limit = int(limits[-1]) - versions = v_list + return self.handler.on_backfill_request(origin, context, versions, limit) - return self.request_handler.on_backfill_request( - origin, context, versions, limit + +class FederationQueryServlet(BaseFederationServlet): + PATH = "/query/([^/]*)" + + # This is when we receive a server-server Query + def on_GET(self, origin, content, query, query_type): + return self.handler.on_query_request( + query_type, + {k: v[0].decode("utf-8") for k, v in query.items()} ) + +class FederationMakeJoinServlet(BaseFederationServlet): + PATH = "/make_join/([^/]*)/([^/]*)" + @defer.inlineCallbacks - @log_function - def _on_make_join_request(self, origin, content, query, context, user_id): - content = yield self.request_handler.on_make_join_request( - context, user_id, - ) + def on_GET(self, origin, content, query, context, user_id): + content = yield self.handler.on_make_join_request(context, user_id) defer.returnValue((200, content)) - @defer.inlineCallbacks - @log_function - def _on_send_join_request(self, origin, content, query): - content = yield self.request_handler.on_send_join_request( - origin, content, - ) - defer.returnValue((200, content)) +class FederationEventAuthServlet(BaseFederationServlet): + PATH = "/event_auth/([^/]*)/([^/]*)" + + def on_GET(self, origin, content, query, context, event_id): + return self.handler.on_event_auth(origin, context, event_id) + + +class FederationSendJoinServlet(BaseFederationServlet): + PATH = "/send_join/([^/]*)/([^/]*)" @defer.inlineCallbacks - @log_function - def _on_invite_request(self, origin, content, query): - content = yield self.request_handler.on_invite_request( - origin, content, - ) + def on_PUT(self, origin, content, query, context, event_id): + # TODO(paul): assert that context/event_id parsed from path actually + # match those given in content + content = yield self.handler.on_send_join_request(origin, content) + defer.returnValue((200, content)) + + +class FederationInviteServlet(BaseFederationServlet): + PATH = "/invite/([^/]*)/([^/]*)" + @defer.inlineCallbacks + def on_PUT(self, origin, content, query, context, event_id): + # TODO(paul): assert that context/event_id parsed from path actually + # match those given in content + content = yield self.handler.on_invite_request(origin, content) defer.returnValue((200, content)) + +class FederationQueryAuthServlet(BaseFederationServlet): + PATH = "/query_auth/([^/]*)/([^/]*)" + @defer.inlineCallbacks - @log_function - def _on_query_auth_request(self, origin, content, event_id): - new_content = yield self.request_handler.on_query_auth_request( + def on_POST(self, origin, content, query, context, event_id): + new_content = yield self.handler.on_query_auth_request( origin, content, event_id ) defer.returnValue((200, new_content)) + +class FederationGetMissingEventsServlet(BaseFederationServlet): + # TODO(paul): Why does this path alone end with "/?" optional? + PATH = "/get_missing_events/([^/]*)/?" + @defer.inlineCallbacks - @log_function - def _get_missing_events(self, origin, content, room_id): + def on_POST(self, origin, content, query, room_id): limit = int(content.get("limit", 10)) min_depth = int(content.get("min_depth", 0)) earliest_events = content.get("earliest_events", []) latest_events = content.get("latest_events", []) - content = yield self.request_handler.on_get_missing_events( + content = yield self.handler.on_get_missing_events( origin, room_id=room_id, earliest_events=earliest_events, @@ -383,3 +343,18 @@ class TransportLayerServer(object): ) defer.returnValue((200, content)) + + +SERVLET_CLASSES = ( + FederationPullServlet, + FederationEventServlet, + FederationStateServlet, + FederationBackfillServlet, + FederationQueryServlet, + FederationMakeJoinServlet, + FederationEventServlet, + FederationSendJoinServlet, + FederationInviteServlet, + FederationQueryAuthServlet, + FederationGetMissingEventsServlet, +) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index d3297b7292..f9f855213b 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -71,7 +71,7 @@ class EventStreamHandler(BaseHandler): self._streams_per_user[auth_user] += 1 rm_handler = self.hs.get_handlers().room_member_handler - room_ids = yield rm_handler.get_rooms_for_user(auth_user) + room_ids = yield rm_handler.get_joined_rooms_for_user(auth_user) if timeout: # If they've set a timeout set a minimum limit. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 8ef248ecf2..28e922f79b 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -452,7 +452,7 @@ class PresenceHandler(BaseHandler): # Also include people in all my rooms rm_handler = self.homeserver.get_handlers().room_member_handler - room_ids = yield rm_handler.get_rooms_for_user(user) + room_ids = yield rm_handler.get_joined_rooms_for_user(user) if state is None: state = yield self.store.get_presence_state(user.localpart) @@ -596,7 +596,7 @@ class PresenceHandler(BaseHandler): localusers.add(user) rm_handler = self.homeserver.get_handlers().room_member_handler - room_ids = yield rm_handler.get_rooms_for_user(user) + room_ids = yield rm_handler.get_joined_rooms_for_user(user) if not localusers and not room_ids: defer.returnValue(None) @@ -663,7 +663,7 @@ class PresenceHandler(BaseHandler): ) rm_handler = self.homeserver.get_handlers().room_member_handler - room_ids = yield rm_handler.get_rooms_for_user(user) + room_ids = yield rm_handler.get_joined_rooms_for_user(user) if room_ids: logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 2ddf9d5378..ee2732b848 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -197,9 +197,8 @@ class ProfileHandler(BaseHandler): self.ratelimit(user.to_string()) - joins = yield self.store.get_rooms_for_user_where_membership_is( + joins = yield self.store.get_rooms_for_user( user.to_string(), - [Membership.JOIN], ) for j in joins: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 80f7ee3f12..823affc380 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -507,7 +507,7 @@ class RoomMemberHandler(BaseHandler): defer.returnValue((is_remote_invite_join, room_host)) @defer.inlineCallbacks - def get_rooms_for_user(self, user, membership_list=[Membership.JOIN]): + def get_joined_rooms_for_user(self, user): """Returns a list of roomids that the user has any of the given membership states in.""" @@ -517,8 +517,8 @@ class RoomMemberHandler(BaseHandler): if app_service: rooms = yield self.store.get_app_service_rooms(app_service) else: - rooms = yield self.store.get_rooms_for_user_where_membership_is( - user_id=user.to_string(), membership_list=membership_list + rooms = yield self.store.get_rooms_for_user( + user.to_string(), ) # For some reason the list of events contains duplicates diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 7883bbd834..35a62fda47 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -96,7 +96,9 @@ class SyncHandler(BaseHandler): return self.current_sync_for_user(sync_config, since_token) rm_handler = self.hs.get_handlers().room_member_handler - room_ids = yield rm_handler.get_rooms_for_user(sync_config.user) + room_ids = yield rm_handler.get_joined_rooms_for_user( + sync_config.user + ) result = yield self.notifier.wait_for_events( sync_config.user, room_ids, sync_config.filter, timeout, current_sync_callback @@ -227,7 +229,7 @@ class SyncHandler(BaseHandler): logger.debug("Typing %r", typing_by_room) rm_handler = self.hs.get_handlers().room_member_handler - room_ids = yield rm_handler.get_rooms_for_user(sync_config.user) + room_ids = yield rm_handler.get_joined_rooms_for_user(sync_config.user) # TODO (mjark): Does public mean "published"? published_rooms = yield self.store.get_rooms(is_public=True) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 6ffc3c99cc..04eba4289e 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -100,7 +100,7 @@ class PresenceTestCase(unittest.TestCase): self.room_members = [] room_member_handler = handlers.room_member_handler = Mock(spec=[ - "get_rooms_for_user", + "get_joined_rooms_for_user", "get_room_members", "fetch_room_distributions_into", ]) @@ -111,7 +111,7 @@ class PresenceTestCase(unittest.TestCase): return defer.succeed([self.room_id]) else: return defer.succeed([]) - room_member_handler.get_rooms_for_user = get_rooms_for_user + room_member_handler.get_joined_rooms_for_user = get_rooms_for_user def get_room_members(room_id): if room_id == self.room_id: diff --git a/tests/handlers/test_presencelike.py b/tests/handlers/test_presencelike.py index 18cac9a846..977e832da7 100644 --- a/tests/handlers/test_presencelike.py +++ b/tests/handlers/test_presencelike.py @@ -64,7 +64,7 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase): "set_presence_state", "is_presence_visible", "set_profile_displayname", - "get_rooms_for_user_where_membership_is", + "get_rooms_for_user", ]), handlers=None, resource_for_federation=Mock(), @@ -124,9 +124,9 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase): self.mock_update_client) hs.handlers.room_member_handler = Mock(spec=[ - "get_rooms_for_user", + "get_joined_rooms_for_user", ]) - hs.handlers.room_member_handler.get_rooms_for_user = ( + hs.handlers.room_member_handler.get_joined_rooms_for_user = ( lambda u: defer.succeed([])) # Some local users to test with @@ -138,7 +138,7 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase): self.u_potato = UserID.from_string("@potato:remote") self.mock_get_joined = ( - self.datastore.get_rooms_for_user_where_membership_is + self.datastore.get_rooms_for_user ) @defer.inlineCallbacks diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 5f2ef64efc..b9c03383a2 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -79,13 +79,13 @@ class PresenceStateTestCase(unittest.TestCase): room_member_handler = hs.handlers.room_member_handler = Mock( spec=[ - "get_rooms_for_user", + "get_joined_rooms_for_user", ] ) def get_rooms_for_user(user): return defer.succeed([]) - room_member_handler.get_rooms_for_user = get_rooms_for_user + room_member_handler.get_joined_rooms_for_user = get_rooms_for_user presence.register_servlets(hs, self.mock_resource) @@ -166,7 +166,7 @@ class PresenceListTestCase(unittest.TestCase): hs.handlers.room_member_handler = Mock( spec=[ - "get_rooms_for_user", + "get_joined_rooms_for_user", ] ) @@ -291,7 +291,7 @@ class PresenceEventStreamTestCase(unittest.TestCase): return ["a-room"] else: return [] - hs.handlers.room_member_handler.get_rooms_for_user = get_rooms_for_user + hs.handlers.room_member_handler.get_joined_rooms_for_user = get_rooms_for_user self.mock_datastore = hs.get_datastore() self.mock_datastore.get_app_service_by_token = Mock(return_value=None) |