diff options
Diffstat (limited to 'synapse/federation/federation_server.py')
-rw-r--r-- | synapse/federation/federation_server.py | 212 |
1 files changed, 125 insertions, 87 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 2a589524a4..aba19639c7 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -21,10 +21,11 @@ from .units import Transaction, Edu from synapse.util.async import Linearizer from synapse.util.logutils import log_function +from synapse.util.caches.response_cache import ResponseCache from synapse.events import FrozenEvent import synapse.metrics -from synapse.api.errors import FederationError, SynapseError +from synapse.api.errors import AuthError, FederationError, SynapseError from synapse.crypto.event_signing import compute_event_signature @@ -48,7 +49,14 @@ class FederationServer(FederationBase): def __init__(self, hs): super(FederationServer, self).__init__(hs) + self.auth = hs.get_auth() + self._room_pdu_linearizer = Linearizer() + self._server_linearizer = Linearizer() + + # We cache responses to state queries, as they take a while and often + # come in waves. + self._state_resp_cache = ResponseCache(hs, timeout_ms=30000) def set_handler(self, handler): """Sets the handler that the replication layer will use to communicate @@ -89,11 +97,14 @@ class FederationServer(FederationBase): @defer.inlineCallbacks @log_function def on_backfill_request(self, origin, room_id, versions, limit): - pdus = yield self.handler.on_backfill_request( - origin, room_id, versions, limit - ) + with (yield self._server_linearizer.queue((origin, room_id))): + pdus = yield self.handler.on_backfill_request( + origin, room_id, versions, limit + ) - defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) + res = self._transaction_from_pdus(pdus).get_dict() + + defer.returnValue((200, res)) @defer.inlineCallbacks @log_function @@ -184,32 +195,71 @@ class FederationServer(FederationBase): @defer.inlineCallbacks @log_function def on_context_state_request(self, origin, room_id, event_id): - if event_id: - pdus = yield self.handler.get_state_for_pdu( - origin, room_id, event_id, - ) - auth_chain = yield self.store.get_auth_chain( - [pdu.event_id for pdu in pdus] - ) + if not event_id: + raise NotImplementedError("Specify an event") - for event in auth_chain: - # We sign these again because there was a bug where we - # incorrectly signed things the first time round - if self.hs.is_mine_id(event.event_id): - event.signatures.update( - compute_event_signature( - event, - self.hs.hostname, - self.hs.config.signing_key[0] - ) - ) + in_room = yield self.auth.check_host_in_room(room_id, origin) + if not in_room: + raise AuthError(403, "Host not in room.") + + result = self._state_resp_cache.get((room_id, event_id)) + if not result: + with (yield self._server_linearizer.queue((origin, room_id))): + resp = yield self._state_resp_cache.set( + (room_id, event_id), + self._on_context_state_request_compute(room_id, event_id) + ) else: + resp = yield result + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_state_ids_request(self, origin, room_id, event_id): + if not event_id: raise NotImplementedError("Specify an event") + in_room = yield self.auth.check_host_in_room(room_id, origin) + if not in_room: + raise AuthError(403, "Host not in room.") + + pdus = yield self.handler.get_state_for_pdu( + room_id, event_id, + ) + auth_chain = yield self.store.get_auth_chain( + [pdu.event_id for pdu in pdus] + ) + defer.returnValue((200, { + "pdu_ids": [pdu.event_id for pdu in pdus], + "auth_chain_ids": [pdu.event_id for pdu in auth_chain], + })) + + @defer.inlineCallbacks + def _on_context_state_request_compute(self, room_id, event_id): + pdus = yield self.handler.get_state_for_pdu( + room_id, event_id, + ) + auth_chain = yield self.store.get_auth_chain( + [pdu.event_id for pdu in pdus] + ) + + for event in auth_chain: + # We sign these again because there was a bug where we + # incorrectly signed things the first time round + if self.hs.is_mine_id(event.event_id): + event.signatures.update( + compute_event_signature( + event, + self.hs.hostname, + self.hs.config.signing_key[0] + ) + ) + + defer.returnValue({ "pdus": [pdu.get_pdu_json() for pdu in pdus], "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], - })) + }) @defer.inlineCallbacks @log_function @@ -283,14 +333,16 @@ class FederationServer(FederationBase): @defer.inlineCallbacks def on_event_auth(self, origin, room_id, event_id): - time_now = self._clock.time_msec() - auth_pdus = yield self.handler.on_event_auth(event_id) - defer.returnValue((200, { - "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], - })) + with (yield self._server_linearizer.queue((origin, room_id))): + time_now = self._clock.time_msec() + auth_pdus = yield self.handler.on_event_auth(event_id) + res = { + "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], + } + defer.returnValue((200, res)) @defer.inlineCallbacks - def on_query_auth_request(self, origin, content, event_id): + def on_query_auth_request(self, origin, content, room_id, event_id): """ Content is a dict with keys:: auth_chain (list): A list of events that give the auth chain. @@ -309,58 +361,41 @@ class FederationServer(FederationBase): Returns: Deferred: Results in `dict` with the same format as `content` """ - auth_chain = [ - self.event_from_pdu_json(e) - for e in content["auth_chain"] - ] - - signed_auth = yield self._check_sigs_and_hash_and_fetch( - origin, auth_chain, outlier=True - ) + with (yield self._server_linearizer.queue((origin, room_id))): + auth_chain = [ + self.event_from_pdu_json(e) + for e in content["auth_chain"] + ] + + signed_auth = yield self._check_sigs_and_hash_and_fetch( + origin, auth_chain, outlier=True + ) - ret = yield self.handler.on_query_auth( - origin, - event_id, - signed_auth, - content.get("rejects", []), - content.get("missing", []), - ) + ret = yield self.handler.on_query_auth( + origin, + event_id, + signed_auth, + content.get("rejects", []), + content.get("missing", []), + ) - time_now = self._clock.time_msec() - send_content = { - "auth_chain": [ - e.get_pdu_json(time_now) - for e in ret["auth_chain"] - ], - "rejects": ret.get("rejects", []), - "missing": ret.get("missing", []), - } + time_now = self._clock.time_msec() + send_content = { + "auth_chain": [ + e.get_pdu_json(time_now) + for e in ret["auth_chain"] + ], + "rejects": ret.get("rejects", []), + "missing": ret.get("missing", []), + } defer.returnValue( (200, send_content) ) - @defer.inlineCallbacks @log_function def on_query_client_keys(self, origin, content): - query = [] - for user_id, device_ids in content.get("device_keys", {}).items(): - if not device_ids: - query.append((user_id, None)) - else: - for device_id in device_ids: - query.append((user_id, device_id)) - - results = yield self.store.get_e2e_device_keys(query) - - json_result = {} - for user_id, device_keys in results.items(): - for device_id, json_bytes in device_keys.items(): - json_result.setdefault(user_id, {})[device_id] = json.loads( - json_bytes - ) - - defer.returnValue({"device_keys": json_result}) + return self.on_query_request("client_keys", content) @defer.inlineCallbacks @log_function @@ -386,21 +421,24 @@ class FederationServer(FederationBase): @log_function def on_get_missing_events(self, origin, room_id, earliest_events, latest_events, limit, min_depth): - logger.info( - "on_get_missing_events: earliest_events: %r, latest_events: %r," - " limit: %d, min_depth: %d", - earliest_events, latest_events, limit, min_depth - ) - missing_events = yield self.handler.on_get_missing_events( - origin, room_id, earliest_events, latest_events, limit, min_depth - ) + with (yield self._server_linearizer.queue((origin, room_id))): + logger.info( + "on_get_missing_events: earliest_events: %r, latest_events: %r," + " limit: %d, min_depth: %d", + earliest_events, latest_events, limit, min_depth + ) + missing_events = yield self.handler.on_get_missing_events( + origin, room_id, earliest_events, latest_events, limit, min_depth + ) - if len(missing_events) < 5: - logger.info("Returning %d events: %r", len(missing_events), missing_events) - else: - logger.info("Returning %d events", len(missing_events)) + if len(missing_events) < 5: + logger.info( + "Returning %d events: %r", len(missing_events), missing_events + ) + else: + logger.info("Returning %d events", len(missing_events)) - time_now = self._clock.time_msec() + time_now = self._clock.time_msec() defer.returnValue({ "events": [ev.get_pdu_json(time_now) for ev in missing_events], @@ -567,7 +605,7 @@ class FederationServer(FederationBase): origin, pdu.room_id, pdu.event_id, ) except: - logger.warn("Failed to get state for event: %s", pdu.event_id) + logger.exception("Failed to get state for event: %s", pdu.event_id) yield self.handler.on_receive_pdu( origin, |