diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/api/auth.py | 38 | ||||
-rw-r--r-- | synapse/api/constants.py | 6 | ||||
-rw-r--r-- | synapse/crypto/keyclient.py | 7 | ||||
-rw-r--r-- | synapse/events/__init__.py | 2 | ||||
-rw-r--r-- | synapse/events/builder.py | 5 | ||||
-rw-r--r-- | synapse/events/snapshot.py | 1 | ||||
-rw-r--r-- | synapse/events/utils.py | 11 | ||||
-rw-r--r-- | synapse/federation/federation_client.py | 414 | ||||
-rw-r--r-- | synapse/federation/federation_server.py | 444 | ||||
-rw-r--r-- | synapse/federation/replication.py | 897 | ||||
-rw-r--r-- | synapse/federation/transaction_queue.py | 317 | ||||
-rw-r--r-- | synapse/federation/transport/client.py | 16 | ||||
-rw-r--r-- | synapse/federation/transport/server.py | 21 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 394 | ||||
-rw-r--r-- | synapse/http/matrixfederationclient.py | 37 | ||||
-rw-r--r-- | synapse/python_dependencies.py | 1 | ||||
-rw-r--r-- | synapse/state.py | 127 | ||||
-rw-r--r-- | synapse/storage/__init__.py | 38 | ||||
-rw-r--r-- | synapse/storage/_base.py | 21 | ||||
-rw-r--r-- | synapse/storage/rejections.py | 43 | ||||
-rw-r--r-- | synapse/storage/schema/delta/v12.sql | 14 | ||||
-rw-r--r-- | synapse/storage/schema/im.sql | 8 |
22 files changed, 1798 insertions, 1064 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 9c03024512..3471afd7e7 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -358,9 +358,23 @@ class Auth(object): def add_auth_events(self, builder, context): yield run_on_reactor() - if builder.type == EventTypes.Create: - builder.auth_events = [] - return + auth_ids = self.compute_auth_events(builder, context) + + auth_events_entries = yield self.store.add_event_hashes( + auth_ids + ) + + builder.auth_events = auth_events_entries + + context.auth_events = { + k: v + for k, v in context.current_state.items() + if v.event_id in auth_ids + } + + def compute_auth_events(self, event, context): + if event.type == EventTypes.Create: + return [] auth_ids = [] @@ -373,7 +387,7 @@ class Auth(object): key = (EventTypes.JoinRules, "", ) join_rule_event = context.current_state.get(key) - key = (EventTypes.Member, builder.user_id, ) + key = (EventTypes.Member, event.user_id, ) member_event = context.current_state.get(key) key = (EventTypes.Create, "", ) @@ -387,8 +401,8 @@ class Auth(object): else: is_public = False - if builder.type == EventTypes.Member: - e_type = builder.content["membership"] + if event.type == EventTypes.Member: + e_type = event.content["membership"] if e_type in [Membership.JOIN, Membership.INVITE]: if join_rule_event: auth_ids.append(join_rule_event.event_id) @@ -403,17 +417,7 @@ class Auth(object): if member_event.content["membership"] == Membership.JOIN: auth_ids.append(member_event.event_id) - auth_events_entries = yield self.store.add_event_hashes( - auth_ids - ) - - builder.auth_events = auth_events_entries - - context.auth_events = { - k: v - for k, v in context.current_state.items() - if v.event_id in auth_ids - } + return auth_ids @log_function def _can_send_event(self, event, auth_events): diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 7ee6dcc46e..0d3fc629af 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -74,3 +74,9 @@ class EventTypes(object): Message = "m.room.message" Topic = "m.room.topic" Name = "m.room.name" + + +class RejectedReason(object): + AUTH_ERROR = "auth_error" + REPLACED = "replaced" + NOT_ANCESTOR = "not_ancestor" diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index 9c910fa3fc..cdb6279764 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -61,9 +61,11 @@ class SynapseKeyClientProtocol(HTTPClient): def __init__(self): self.remote_key = defer.Deferred() + self.host = None def connectionMade(self): - logger.debug("Connected to %s", self.transport.getHost()) + self.host = self.transport.getHost() + logger.debug("Connected to %s", self.host) self.sendCommand(b"GET", b"/_matrix/key/v1/") self.endHeaders() self.timer = reactor.callLater( @@ -92,8 +94,7 @@ class SynapseKeyClientProtocol(HTTPClient): self.timer.cancel() def on_timeout(self): - logger.debug("Timeout waiting for response from %s", - self.transport.getHost()) + logger.debug("Timeout waiting for response from %s", self.host) self.remote_key.errback(IOError("Timeout waiting for response")) self.transport.abortConnection() diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 4252e5ab5c..bf07951027 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -18,7 +18,7 @@ from synapse.util.frozenutils import freeze, unfreeze class _EventInternalMetadata(object): def __init__(self, internal_metadata_dict): - self.__dict__ = internal_metadata_dict + self.__dict__ = dict(internal_metadata_dict) def get_dict(self): return dict(self.__dict__) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index a9b1b99a10..9d45bdb892 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -23,14 +23,15 @@ import copy class EventBuilder(EventBase): - def __init__(self, key_values={}): + def __init__(self, key_values={}, internal_metadata_dict={}): signatures = copy.deepcopy(key_values.pop("signatures", {})) unsigned = copy.deepcopy(key_values.pop("unsigned", {})) super(EventBuilder, self).__init__( key_values, signatures=signatures, - unsigned=unsigned + unsigned=unsigned, + internal_metadata_dict=internal_metadata_dict, ) def build(self): diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 6bbba8d6ba..7e98bdef28 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -20,3 +20,4 @@ class EventContext(object): self.current_state = current_state self.auth_events = auth_events self.state_group = None + self.rejected = False diff --git a/synapse/events/utils.py b/synapse/events/utils.py index e391aca4cc..7ae5d42b96 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -45,12 +45,14 @@ def prune_event(event): "membership", ] + event_dict = event.get_dict() + new_content = {} def add_fields(*fields): for field in fields: if field in event.content: - new_content[field] = event.content[field] + new_content[field] = event_dict["content"][field] if event_type == EventTypes.Member: add_fields("membership") @@ -75,7 +77,7 @@ def prune_event(event): allowed_fields = { k: v - for k, v in event.get_dict().items() + for k, v in event_dict.items() if k in allowed_keys } @@ -86,7 +88,10 @@ def prune_event(event): if "age_ts" in event.unsigned: allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"] - return type(event)(allowed_fields) + return type(event)( + allowed_fields, + internal_metadata_dict=event.internal_metadata.get_dict() + ) def serialize_event(e, time_now_ms, client_event=True): diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py new file mode 100644 index 0000000000..1173ca817b --- /dev/null +++ b/synapse/federation/federation_client.py @@ -0,0 +1,414 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from .units import Edu + +from synapse.util.logutils import log_function +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event + +from syutil.jsonutil import encode_canonical_json + +from synapse.crypto.event_signing import check_event_content_hash + +from synapse.api.errors import SynapseError + +import logging + + +logger = logging.getLogger(__name__) + + +class FederationClient(object): + @log_function + def send_pdu(self, pdu, destinations): + """Informs the replication layer about a new PDU generated within the + home server that should be transmitted to others. + + TODO: Figure out when we should actually resolve the deferred. + + Args: + pdu (Pdu): The new Pdu. + + Returns: + Deferred: Completes when we have successfully processed the PDU + and replicated it to any interested remote home servers. + """ + order = self._order + self._order += 1 + + logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id) + + # TODO, add errback, etc. + self._transaction_queue.enqueue_pdu(pdu, destinations, order) + + logger.debug( + "[%s] transaction_layer.enqueue_pdu... done", + pdu.event_id + ) + + @log_function + def send_edu(self, destination, edu_type, content): + edu = Edu( + origin=self.server_name, + destination=destination, + edu_type=edu_type, + content=content, + ) + + # TODO, add errback, etc. + self._transaction_queue.enqueue_edu(edu) + return defer.succeed(None) + + @log_function + def send_failure(self, failure, destination): + self._transaction_queue.enqueue_failure(failure, destination) + return defer.succeed(None) + + @log_function + def make_query(self, destination, query_type, args, + retry_on_dns_fail=True): + """Sends a federation Query to a remote homeserver of the given type + and arguments. + + Args: + destination (str): Domain name of the remote homeserver + query_type (str): Category of the query type; should match the + handler name used in register_query_handler(). + args (dict): Mapping of strings to strings containing the details + of the query request. + + Returns: + a Deferred which will eventually yield a JSON object from the + response + """ + return self.transport_layer.make_query( + destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail + ) + + @defer.inlineCallbacks + @log_function + def backfill(self, dest, context, limit, extremities): + """Requests some more historic PDUs for the given context from the + given destination server. + + Args: + dest (str): The remote home server to ask. + context (str): The context to backfill. + limit (int): The maximum number of PDUs to return. + extremities (list): List of PDU id and origins of the first pdus + we have seen from the context + + Returns: + Deferred: Results in the received PDUs. + """ + logger.debug("backfill extrem=%s", extremities) + + # If there are no extremeties then we've (probably) reached the start. + if not extremities: + return + + transaction_data = yield self.transport_layer.backfill( + dest, context, extremities, limit) + + logger.debug("backfill transaction_data=%s", repr(transaction_data)) + + pdus = [ + self.event_from_pdu_json(p, outlier=False) + for p in transaction_data["pdus"] + ] + + for i, pdu in enumerate(pdus): + pdus[i] = yield self._check_sigs_and_hash(pdu) + + # FIXME: We should handle signature failures more gracefully. + + defer.returnValue(pdus) + + @defer.inlineCallbacks + @log_function + def get_pdu(self, destinations, event_id, outlier=False): + """Requests the PDU with given origin and ID from the remote home + servers. + + Will attempt to get the PDU from each destination in the list until + one succeeds. + + This will persist the PDU locally upon receipt. + + Args: + destinations (list): Which home servers to query + pdu_origin (str): The home server that originally sent the pdu. + event_id (str) + outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if + it's from an arbitary point in the context as opposed to part + of the current block of PDUs. Defaults to `False` + + Returns: + Deferred: Results in the requested PDU. + """ + + # TODO: Rate limit the number of times we try and get the same event. + + pdu = None + for destination in destinations: + try: + transaction_data = yield self.transport_layer.get_event( + destination, event_id + ) + + logger.debug("transaction_data %r", transaction_data) + + pdu_list = [ + self.event_from_pdu_json(p, outlier=outlier) + for p in transaction_data["pdus"] + ] + + if pdu_list: + pdu = pdu_list[0] + + # Check signatures are correct. + pdu = yield self._check_sigs_and_hash(pdu) + + break + + except Exception as e: + logger.info( + "Failed to get PDU %s from %s because %s", + event_id, destination, e, + ) + continue + + defer.returnValue(pdu) + + @defer.inlineCallbacks + @log_function + def get_state_for_room(self, destination, room_id, event_id): + """Requests all of the `current` state PDUs for a given room from + a remote home server. + + Args: + destination (str): The remote homeserver to query for the state. + room_id (str): The id of the room we're interested in. + event_id (str): The id of the event we want the state at. + + Returns: + Deferred: Results in a list of PDUs. + """ + + result = yield self.transport_layer.get_room_state( + destination, room_id, event_id=event_id, + ) + + pdus = [ + self.event_from_pdu_json(p, outlier=True) for p in result["pdus"] + ] + + auth_chain = [ + self.event_from_pdu_json(p, outlier=True) + for p in result.get("auth_chain", []) + ] + + for i, pdu in enumerate(pdus): + pdus[i] = yield self._check_sigs_and_hash(pdu) + + # FIXME: We should handle signature failures more gracefully. + + for i, pdu in enumerate(auth_chain): + auth_chain[i] = yield self._check_sigs_and_hash(pdu) + + # FIXME: We should handle signature failures more gracefully. + + defer.returnValue((pdus, auth_chain)) + + @defer.inlineCallbacks + @log_function + def get_event_auth(self, destination, room_id, event_id): + res = yield self.transport_layer.get_event_auth( + destination, room_id, event_id, + ) + + auth_chain = [ + self.event_from_pdu_json(p, outlier=True) + for p in res["auth_chain"] + ] + + for i, pdu in enumerate(auth_chain): + auth_chain[i] = yield self._check_sigs_and_hash(pdu) + + # FIXME: We should handle signature failures more gracefully. + + auth_chain.sort(key=lambda e: e.depth) + + defer.returnValue(auth_chain) + + @defer.inlineCallbacks + def make_join(self, destination, room_id, user_id): + ret = yield self.transport_layer.make_join( + destination, room_id, user_id + ) + + pdu_dict = ret["event"] + + logger.debug("Got response to make_join: %s", pdu_dict) + + defer.returnValue(self.event_from_pdu_json(pdu_dict)) + + @defer.inlineCallbacks + def send_join(self, destination, pdu): + time_now = self._clock.time_msec() + _, content = yield self.transport_layer.send_join( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + + logger.debug("Got content: %s", content) + + state = [ + self.event_from_pdu_json(p, outlier=True) + for p in content.get("state", []) + ] + + auth_chain = [ + self.event_from_pdu_json(p, outlier=True) + for p in content.get("auth_chain", []) + ] + + for i, pdu in enumerate(state): + state[i] = yield self._check_sigs_and_hash(pdu) + + # FIXME: We should handle signature failures more gracefully. + + for i, pdu in enumerate(auth_chain): + auth_chain[i] = yield self._check_sigs_and_hash(pdu) + + # FIXME: We should handle signature failures more gracefully. + + auth_chain.sort(key=lambda e: e.depth) + + defer.returnValue({ + "state": state, + "auth_chain": auth_chain, + }) + + @defer.inlineCallbacks + def send_invite(self, destination, room_id, event_id, pdu): + time_now = self._clock.time_msec() + code, content = yield self.transport_layer.send_invite( + destination=destination, + room_id=room_id, + event_id=event_id, + content=pdu.get_pdu_json(time_now), + ) + + pdu_dict = content["event"] + + logger.debug("Got response to send_invite: %s", pdu_dict) + + pdu = self.event_from_pdu_json(pdu_dict) + + # Check signatures are correct. + pdu = yield self._check_sigs_and_hash(pdu) + + # FIXME: We should handle signature failures more gracefully. + + defer.returnValue(pdu) + + @defer.inlineCallbacks + def query_auth(self, destination, room_id, event_id, local_auth): + """ + Params: + destination (str) + event_it (str) + local_auth (list) + """ + time_now = self._clock.time_msec() + + send_content = { + "auth_chain": [e.get_pdu_json(time_now) for e in local_auth], + } + + code, content = yield self.transport_layer.send_query_auth( + destination=destination, + room_id=room_id, + event_id=event_id, + content=send_content, + ) + + auth_chain = [ + (yield self._check_sigs_and_hash(self.event_from_pdu_json(e))) + for e in content["auth_chain"] + ] + + missing = [ + (yield self._check_sigs_and_hash(self.event_from_pdu_json(e))) + for e in content.get("missing", []) + ] + + ret = { + "auth_chain": auth_chain, + "rejects": content.get("rejects", []), + "missing": missing, + } + + defer.returnValue(ret) + + def event_from_pdu_json(self, pdu_json, outlier=False): + event = FrozenEvent( + pdu_json + ) + + event.internal_metadata.outlier = outlier + + return event + + @defer.inlineCallbacks + def _check_sigs_and_hash(self, pdu): + """Throws a SynapseError if the PDU does not have the correct + signatures. + + Returns: + FrozenEvent: Either the given event or it redacted if it failed the + content hash check. + """ + # Check signatures are correct. + redacted_event = prune_event(pdu) + redacted_pdu_json = redacted_event.get_pdu_json() + + try: + yield self.keyring.verify_json_for_server( + pdu.origin, redacted_pdu_json + ) + except SynapseError: + logger.warn( + "Signature check failed for %s redacted to %s", + encode_canonical_json(pdu.get_pdu_json()), + encode_canonical_json(redacted_pdu_json), + ) + raise + + if not check_event_content_hash(pdu): + logger.warn( + "Event content has been tampered, redacting %s, %s", + pdu.event_id, encode_canonical_json(pdu.get_dict()) + ) + defer.returnValue(redacted_event) + + defer.returnValue(pdu) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py new file mode 100644 index 0000000000..845a07a3a3 --- /dev/null +++ b/synapse/federation/federation_server.py @@ -0,0 +1,444 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from .units import Transaction, Edu + +from synapse.util.logutils import log_function +from synapse.util.logcontext import PreserveLoggingContext +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event + +from syutil.jsonutil import encode_canonical_json + +from synapse.crypto.event_signing import check_event_content_hash + +from synapse.api.errors import FederationError, SynapseError + +import logging + + +logger = logging.getLogger(__name__) + + +class FederationServer(object): + def set_handler(self, handler): + """Sets the handler that the replication layer will use to communicate + receipt of new PDUs from other home servers. The required methods are + documented on :py:class:`.ReplicationHandler`. + """ + self.handler = handler + + def register_edu_handler(self, edu_type, handler): + if edu_type in self.edu_handlers: + raise KeyError("Already have an EDU handler for %s" % (edu_type,)) + + self.edu_handlers[edu_type] = handler + + def register_query_handler(self, query_type, handler): + """Sets the handler callable that will be used to handle an incoming + federation Query of the given type. + + Args: + query_type (str): Category name of the query, which should match + the string used by make_query. + handler (callable): Invoked to handle incoming queries of this type + + handler is invoked as: + result = handler(args) + + where 'args' is a dict mapping strings to strings of the query + arguments. It should return a Deferred that will eventually yield an + object to encode as JSON. + """ + if query_type in self.query_handlers: + raise KeyError( + "Already have a Query handler for %s" % (query_type,) + ) + + self.query_handlers[query_type] = handler + + @defer.inlineCallbacks + @log_function + def on_backfill_request(self, origin, room_id, versions, limit): + pdus = yield self.handler.on_backfill_request( + origin, room_id, versions, limit + ) + + defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) + + @defer.inlineCallbacks + @log_function + def on_incoming_transaction(self, transaction_data): + transaction = Transaction(**transaction_data) + + for p in transaction.pdus: + if "unsigned" in p: + unsigned = p["unsigned"] + if "age" in unsigned: + p["age"] = unsigned["age"] + if "age" in p: + p["age_ts"] = int(self._clock.time_msec()) - int(p["age"]) + del p["age"] + + pdu_list = [ + self.event_from_pdu_json(p) for p in transaction.pdus + ] + + logger.debug("[%s] Got transaction", transaction.transaction_id) + + response = yield self.transaction_actions.have_responded(transaction) + + if response: + logger.debug( + "[%s] We've already responed to this request", + transaction.transaction_id + ) + defer.returnValue(response) + return + + logger.debug("[%s] Transaction is new", transaction.transaction_id) + + with PreserveLoggingContext(): + dl = [] + for pdu in pdu_list: + dl.append(self._handle_new_pdu(transaction.origin, pdu)) + + if hasattr(transaction, "edus"): + for edu in [Edu(**x) for x in transaction.edus]: + self.received_edu( + transaction.origin, + edu.edu_type, + edu.content + ) + + results = yield defer.DeferredList(dl) + + ret = [] + for r in results: + if r[0]: + ret.append({}) + else: + logger.exception(r[1]) + ret.append({"error": str(r[1])}) + + logger.debug("Returning: %s", str(ret)) + + yield self.transaction_actions.set_response( + transaction, + 200, response + ) + defer.returnValue((200, response)) + + def received_edu(self, origin, edu_type, content): + if edu_type in self.edu_handlers: + self.edu_handlers[edu_type](origin, content) + else: + logger.warn("Received EDU of type %s with no handler", edu_type) + + @defer.inlineCallbacks + @log_function + def on_context_state_request(self, origin, room_id, event_id): + if event_id: + pdus = yield self.handler.get_state_for_pdu( + origin, room_id, event_id, + ) + auth_chain = yield self.store.get_auth_chain( + [pdu.event_id for pdu in pdus] + ) + else: + raise NotImplementedError("Specify an event") + + defer.returnValue((200, { + "pdus": [pdu.get_pdu_json() for pdu in pdus], + "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], + })) + + @defer.inlineCallbacks + @log_function + def on_pdu_request(self, origin, event_id): + pdu = yield self._get_persisted_pdu(origin, event_id) + + if pdu: + defer.returnValue( + (200, self._transaction_from_pdus([pdu]).get_dict()) + ) + else: + defer.returnValue((404, "")) + + @defer.inlineCallbacks + @log_function + def on_pull_request(self, origin, versions): + raise NotImplementedError("Pull transactions not implemented") + + @defer.inlineCallbacks + def on_query_request(self, query_type, args): + if query_type in self.query_handlers: + response = yield self.query_handlers[query_type](args) + defer.returnValue((200, response)) + else: + defer.returnValue( + (404, "No handler for Query type '%s'" % (query_type,)) + ) + + @defer.inlineCallbacks + def on_make_join_request(self, room_id, user_id): + pdu = yield self.handler.on_make_join_request(room_id, user_id) + time_now = self._clock.time_msec() + defer.returnValue({"event": pdu.get_pdu_json(time_now)}) + + @defer.inlineCallbacks + def on_invite_request(self, origin, content): + pdu = self.event_from_pdu_json(content) + ret_pdu = yield self.handler.on_invite_request(origin, pdu) + time_now = self._clock.time_msec() + defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)})) + + @defer.inlineCallbacks + def on_send_join_request(self, origin, content): + logger.debug("on_send_join_request: content: %s", content) + pdu = self.event_from_pdu_json(content) + logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) + res_pdus = yield self.handler.on_send_join_request(origin, pdu) + time_now = self._clock.time_msec() + defer.returnValue((200, { + "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], + "auth_chain": [ + p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] + ], + })) + + @defer.inlineCallbacks + def on_event_auth(self, origin, room_id, event_id): + time_now = self._clock.time_msec() + auth_pdus = yield self.handler.on_event_auth(event_id) + defer.returnValue((200, { + "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], + })) + + @defer.inlineCallbacks + def on_query_auth_request(self, origin, content, event_id): + auth_chain = [ + (yield self._check_sigs_and_hash(self.event_from_pdu_json(e))) + for e in content["auth_chain"] + ] + + missing = [ + (yield self._check_sigs_and_hash(self.event_from_pdu_json(e))) + for e in content.get("missing", []) + ] + + ret = yield self.handler.on_query_auth( + origin, event_id, auth_chain, content.get("rejects", []), missing + ) + + time_now = self._clock.time_msec() + send_content = { + "auth_chain": [ + e.get_pdu_json(time_now) + for e in ret["auth_chain"] + ], + "rejects": content.get("rejects", []), + "missing": [ + e.get_pdu_json(time_now) + for e in ret.get("missing", []) + ], + } + + defer.returnValue( + (200, send_content) + ) + + @log_function + def _get_persisted_pdu(self, origin, event_id, do_auth=True): + """ Get a PDU from the database with given origin and id. + + Returns: + Deferred: Results in a `Pdu`. + """ + return self.handler.get_persisted_pdu( + origin, event_id, do_auth=do_auth + ) + + def _transaction_from_pdus(self, pdu_list): + """Returns a new Transaction containing the given PDUs suitable for + transmission. + """ + time_now = self._clock.time_msec() + pdus = [p.get_pdu_json(time_now) for p in pdu_list] + return Transaction( + origin=self.server_name, + pdus=pdus, + origin_server_ts=int(time_now), + destination=None, + ) + + @defer.inlineCallbacks + @log_function + def _handle_new_pdu(self, origin, pdu, max_recursion=10): + # We reprocess pdus when we have seen them only as outliers + existing = yield self._get_persisted_pdu( + origin, pdu.event_id, do_auth=False + ) + + # FIXME: Currently we fetch an event again when we already have it + # if it has been marked as an outlier. + + already_seen = ( + existing and ( + not existing.internal_metadata.is_outlier() + or pdu.internal_metadata.is_outlier() + ) + ) + if already_seen: + logger.debug("Already seen pdu %s", pdu.event_id) + defer.returnValue({}) + return + + # Check signature. + try: + pdu = yield self._check_sigs_and_hash(pdu) + except SynapseError as e: + raise FederationError( + "ERROR", + e.code, + e.msg, + affected=pdu.event_id, + ) + + state = None + + auth_chain = [] + + have_seen = yield self.store.have_events( + [ev for ev, _ in pdu.prev_events] + ) + + fetch_state = False + + # Get missing pdus if necessary. + if not pdu.internal_metadata.is_outlier(): + # We only backfill backwards to the min depth. + min_depth = yield self.handler.get_min_depth_for_context( + pdu.room_id + ) + + logger.debug( + "_handle_new_pdu min_depth for %s: %d", + pdu.room_id, min_depth + ) + + if min_depth and pdu.depth > min_depth and max_recursion > 0: + for event_id, hashes in pdu.prev_events: + if event_id not in have_seen: + logger.debug( + "_handle_new_pdu requesting pdu %s", + event_id + ) + + try: + new_pdu = yield self.federation_client.get_pdu( + [origin, pdu.origin], + event_id=event_id, + ) + + if new_pdu: + yield self._handle_new_pdu( + origin, + new_pdu, + max_recursion=max_recursion-1 + ) + + logger.debug("Processed pdu %s", event_id) + else: + logger.warn("Failed to get PDU %s", event_id) + fetch_state = True + except: + # TODO(erikj): Do some more intelligent retries. + logger.exception("Failed to get PDU") + fetch_state = True + else: + fetch_state = True + else: + fetch_state = True + + if fetch_state: + # We need to get the state at this event, since we haven't + # processed all the prev events. + logger.debug( + "_handle_new_pdu getting state for %s", + pdu.room_id + ) + state, auth_chain = yield self.get_state_for_room( + origin, pdu.room_id, pdu.event_id, + ) + + ret = yield self.handler.on_receive_pdu( + origin, + pdu, + backfilled=False, + state=state, + auth_chain=auth_chain, + ) + + defer.returnValue(ret) + + def __str__(self): + return "<ReplicationLayer(%s)>" % self.server_name + + def event_from_pdu_json(self, pdu_json, outlier=False): + event = FrozenEvent( + pdu_json + ) + + event.internal_metadata.outlier = outlier + + return event + + @defer.inlineCallbacks + def _check_sigs_and_hash(self, pdu): + """Throws a SynapseError if the PDU does not have the correct + signatures. + + Returns: + FrozenEvent: Either the given event or it redacted if it failed the + content hash check. + """ + # Check signatures are correct. + redacted_event = prune_event(pdu) + redacted_pdu_json = redacted_event.get_pdu_json() + + try: + yield self.keyring.verify_json_for_server( + pdu.origin, redacted_pdu_json + ) + except SynapseError: + logger.warn( + "Signature check failed for %s redacted to %s", + encode_canonical_json(pdu.get_pdu_json()), + encode_canonical_json(redacted_pdu_json), + ) + raise + + if not check_event_content_hash(pdu): + logger.warn( + "Event content has been tampered, redacting %s, %s", + pdu.event_id, encode_canonical_json(pdu.get_dict()) + ) + defer.returnValue(redacted_event) + + defer.returnValue(pdu) diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index 6620532a60..e442c6c5d5 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -17,23 +17,20 @@ a given transport. """ -from twisted.internet import defer +from .federation_client import FederationClient +from .federation_server import FederationServer -from .units import Transaction, Edu +from .transaction_queue import TransactionQueue from .persistence import TransactionActions -from synapse.util.logutils import log_function -from synapse.util.logcontext import PreserveLoggingContext -from synapse.events import FrozenEvent - import logging logger = logging.getLogger(__name__) -class ReplicationLayer(object): +class ReplicationLayer(FederationClient, FederationServer): """This layer is responsible for replicating with remote home servers over the given transport. I.e., does the sending and receiving of PDUs to remote home servers. @@ -54,898 +51,26 @@ class ReplicationLayer(object): def __init__(self, hs, transport_layer): self.server_name = hs.hostname + self.keyring = hs.get_keyring() + self.transport_layer = transport_layer self.transport_layer.register_received_handler(self) self.transport_layer.register_request_handler(self) - self.store = hs.get_datastore() - # self.pdu_actions = PduActions(self.store) - self.transaction_actions = TransactionActions(self.store) + self.federation_client = self - self._transaction_queue = _TransactionQueue( - hs, self.transaction_actions, transport_layer - ) + self.store = hs.get_datastore() self.handler = None self.edu_handlers = {} self.query_handlers = {} - self._order = 0 - self._clock = hs.get_clock() - self.event_builder_factory = hs.get_event_builder_factory() - - def set_handler(self, handler): - """Sets the handler that the replication layer will use to communicate - receipt of new PDUs from other home servers. The required methods are - documented on :py:class:`.ReplicationHandler`. - """ - self.handler = handler - - def register_edu_handler(self, edu_type, handler): - if edu_type in self.edu_handlers: - raise KeyError("Already have an EDU handler for %s" % (edu_type,)) - - self.edu_handlers[edu_type] = handler - - def register_query_handler(self, query_type, handler): - """Sets the handler callable that will be used to handle an incoming - federation Query of the given type. - - Args: - query_type (str): Category name of the query, which should match - the string used by make_query. - handler (callable): Invoked to handle incoming queries of this type - - handler is invoked as: - result = handler(args) - - where 'args' is a dict mapping strings to strings of the query - arguments. It should return a Deferred that will eventually yield an - object to encode as JSON. - """ - if query_type in self.query_handlers: - raise KeyError( - "Already have a Query handler for %s" % (query_type,) - ) - - self.query_handlers[query_type] = handler - - @log_function - def send_pdu(self, pdu, destinations): - """Informs the replication layer about a new PDU generated within the - home server that should be transmitted to others. - - TODO: Figure out when we should actually resolve the deferred. - - Args: - pdu (Pdu): The new Pdu. - - Returns: - Deferred: Completes when we have successfully processed the PDU - and replicated it to any interested remote home servers. - """ - order = self._order - self._order += 1 - - logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id) - - # TODO, add errback, etc. - self._transaction_queue.enqueue_pdu(pdu, destinations, order) - - logger.debug( - "[%s] transaction_layer.enqueue_pdu... done", - pdu.event_id - ) - - @log_function - def send_edu(self, destination, edu_type, content): - edu = Edu( - origin=self.server_name, - destination=destination, - edu_type=edu_type, - content=content, - ) - - # TODO, add errback, etc. - self._transaction_queue.enqueue_edu(edu) - return defer.succeed(None) - - @log_function - def send_failure(self, failure, destination): - self._transaction_queue.enqueue_failure(failure, destination) - return defer.succeed(None) - - @log_function - def make_query(self, destination, query_type, args, - retry_on_dns_fail=True): - """Sends a federation Query to a remote homeserver of the given type - and arguments. - - Args: - destination (str): Domain name of the remote homeserver - query_type (str): Category of the query type; should match the - handler name used in register_query_handler(). - args (dict): Mapping of strings to strings containing the details - of the query request. - - Returns: - a Deferred which will eventually yield a JSON object from the - response - """ - return self.transport_layer.make_query( - destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail - ) - - @defer.inlineCallbacks - @log_function - def backfill(self, dest, context, limit, extremities): - """Requests some more historic PDUs for the given context from the - given destination server. - - Args: - dest (str): The remote home server to ask. - context (str): The context to backfill. - limit (int): The maximum number of PDUs to return. - extremities (list): List of PDU id and origins of the first pdus - we have seen from the context - - Returns: - Deferred: Results in the received PDUs. - """ - logger.debug("backfill extrem=%s", extremities) - - # If there are no extremeties then we've (probably) reached the start. - if not extremities: - return - - transaction_data = yield self.transport_layer.backfill( - dest, context, extremities, limit) - - logger.debug("backfill transaction_data=%s", repr(transaction_data)) - - transaction = Transaction(**transaction_data) - - pdus = [ - self.event_from_pdu_json(p, outlier=False) - for p in transaction.pdus - ] - for pdu in pdus: - yield self._handle_new_pdu(dest, pdu, backfilled=True) - - defer.returnValue(pdus) - - @defer.inlineCallbacks - @log_function - def get_pdu(self, destination, event_id, outlier=False): - """Requests the PDU with given origin and ID from the remote home - server. - - This will persist the PDU locally upon receipt. - - Args: - destination (str): Which home server to query - pdu_origin (str): The home server that originally sent the pdu. - event_id (str) - outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if - it's from an arbitary point in the context as opposed to part - of the current block of PDUs. Defaults to `False` - - Returns: - Deferred: Results in the requested PDU. - """ - - transaction_data = yield self.transport_layer.get_event( - destination, event_id - ) - - transaction = Transaction(**transaction_data) - - pdu_list = [ - self.event_from_pdu_json(p, outlier=outlier) - for p in transaction.pdus - ] - - pdu = None - if pdu_list: - pdu = pdu_list[0] - yield self._handle_new_pdu(destination, pdu) - - defer.returnValue(pdu) - - @defer.inlineCallbacks - @log_function - def get_state_for_room(self, destination, room_id, event_id): - """Requests all of the `current` state PDUs for a given room from - a remote home server. - - Args: - destination (str): The remote homeserver to query for the state. - room_id (str): The id of the room we're interested in. - event_id (str): The id of the event we want the state at. - - Returns: - Deferred: Results in a list of PDUs. - """ - - result = yield self.transport_layer.get_room_state( - destination, room_id, event_id=event_id, - ) - - pdus = [ - self.event_from_pdu_json(p, outlier=True) for p in result["pdus"] - ] - - auth_chain = [ - self.event_from_pdu_json(p, outlier=True) - for p in result.get("auth_chain", []) - ] - - defer.returnValue((pdus, auth_chain)) - - @defer.inlineCallbacks - @log_function - def get_event_auth(self, destination, room_id, event_id): - res = yield self.transport_layer.get_event_auth( - destination, room_id, event_id, - ) - - auth_chain = [ - self.event_from_pdu_json(p, outlier=True) - for p in res["auth_chain"] - ] - - auth_chain.sort(key=lambda e: e.depth) - - defer.returnValue(auth_chain) - - @defer.inlineCallbacks - @log_function - def on_backfill_request(self, origin, room_id, versions, limit): - pdus = yield self.handler.on_backfill_request( - origin, room_id, versions, limit - ) - - defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) - - @defer.inlineCallbacks - @log_function - def on_incoming_transaction(self, transaction_data): - transaction = Transaction(**transaction_data) - - for p in transaction.pdus: - if "unsigned" in p: - unsigned = p["unsigned"] - if "age" in unsigned: - p["age"] = unsigned["age"] - if "age" in p: - p["age_ts"] = int(self._clock.time_msec()) - int(p["age"]) - del p["age"] - - pdu_list = [ - self.event_from_pdu_json(p) for p in transaction.pdus - ] - - logger.debug("[%s] Got transaction", transaction.transaction_id) - - response = yield self.transaction_actions.have_responded(transaction) - - if response: - logger.debug("[%s] We've already responed to this request", - transaction.transaction_id) - defer.returnValue(response) - return - - logger.debug("[%s] Transaction is new", transaction.transaction_id) - - with PreserveLoggingContext(): - dl = [] - for pdu in pdu_list: - dl.append(self._handle_new_pdu(transaction.origin, pdu)) - - if hasattr(transaction, "edus"): - for edu in [Edu(**x) for x in transaction.edus]: - self.received_edu( - transaction.origin, - edu.edu_type, - edu.content - ) - - results = yield defer.DeferredList(dl) - - ret = [] - for r in results: - if r[0]: - ret.append({}) - else: - logger.exception(r[1]) - ret.append({"error": str(r[1])}) - - logger.debug("Returning: %s", str(ret)) - - yield self.transaction_actions.set_response( - transaction, - 200, response - ) - defer.returnValue((200, response)) - - def received_edu(self, origin, edu_type, content): - if edu_type in self.edu_handlers: - self.edu_handlers[edu_type](origin, content) - else: - logger.warn("Received EDU of type %s with no handler", edu_type) - - @defer.inlineCallbacks - @log_function - def on_context_state_request(self, origin, room_id, event_id): - if event_id: - pdus = yield self.handler.get_state_for_pdu( - origin, room_id, event_id, - ) - auth_chain = yield self.store.get_auth_chain( - [pdu.event_id for pdu in pdus] - ) - else: - raise NotImplementedError("Specify an event") - - defer.returnValue((200, { - "pdus": [pdu.get_pdu_json() for pdu in pdus], - "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], - })) - - @defer.inlineCallbacks - @log_function - def on_pdu_request(self, origin, event_id): - pdu = yield self._get_persisted_pdu(origin, event_id) - - if pdu: - defer.returnValue( - (200, self._transaction_from_pdus([pdu]).get_dict()) - ) - else: - defer.returnValue((404, "")) - - @defer.inlineCallbacks - @log_function - def on_pull_request(self, origin, versions): - raise NotImplementedError("Pull transactions not implemented") - - @defer.inlineCallbacks - def on_query_request(self, query_type, args): - if query_type in self.query_handlers: - response = yield self.query_handlers[query_type](args) - defer.returnValue((200, response)) - else: - defer.returnValue( - (404, "No handler for Query type '%s'" % (query_type,)) - ) - - @defer.inlineCallbacks - def on_make_join_request(self, room_id, user_id): - pdu = yield self.handler.on_make_join_request(room_id, user_id) - time_now = self._clock.time_msec() - defer.returnValue({"event": pdu.get_pdu_json(time_now)}) - - @defer.inlineCallbacks - def on_invite_request(self, origin, content): - pdu = self.event_from_pdu_json(content) - ret_pdu = yield self.handler.on_invite_request(origin, pdu) - time_now = self._clock.time_msec() - defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)})) - - @defer.inlineCallbacks - def on_send_join_request(self, origin, content): - logger.debug("on_send_join_request: content: %s", content) - pdu = self.event_from_pdu_json(content) - logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) - res_pdus = yield self.handler.on_send_join_request(origin, pdu) - time_now = self._clock.time_msec() - defer.returnValue((200, { - "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], - "auth_chain": [ - p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] - ], - })) - - @defer.inlineCallbacks - def on_event_auth(self, origin, room_id, event_id): - time_now = self._clock.time_msec() - auth_pdus = yield self.handler.on_event_auth(event_id) - defer.returnValue((200, { - "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], - })) - - @defer.inlineCallbacks - def make_join(self, destination, room_id, user_id): - ret = yield self.transport_layer.make_join( - destination, room_id, user_id - ) - - pdu_dict = ret["event"] - - logger.debug("Got response to make_join: %s", pdu_dict) - - defer.returnValue(self.event_from_pdu_json(pdu_dict)) - - @defer.inlineCallbacks - def send_join(self, destination, pdu): - time_now = self._clock.time_msec() - _, content = yield self.transport_layer.send_join( - destination=destination, - room_id=pdu.room_id, - event_id=pdu.event_id, - content=pdu.get_pdu_json(time_now), - ) - - logger.debug("Got content: %s", content) - - state = [ - self.event_from_pdu_json(p, outlier=True) - for p in content.get("state", []) - ] - - auth_chain = [ - self.event_from_pdu_json(p, outlier=True) - for p in content.get("auth_chain", []) - ] - - auth_chain.sort(key=lambda e: e.depth) - - defer.returnValue({ - "state": state, - "auth_chain": auth_chain, - }) - - @defer.inlineCallbacks - def send_invite(self, destination, room_id, event_id, pdu): - time_now = self._clock.time_msec() - code, content = yield self.transport_layer.send_invite( - destination=destination, - room_id=room_id, - event_id=event_id, - content=pdu.get_pdu_json(time_now), - ) - - pdu_dict = content["event"] - - logger.debug("Got response to send_invite: %s", pdu_dict) - - defer.returnValue(self.event_from_pdu_json(pdu_dict)) - - @log_function - def _get_persisted_pdu(self, origin, event_id, do_auth=True): - """ Get a PDU from the database with given origin and id. - - Returns: - Deferred: Results in a `Pdu`. - """ - return self.handler.get_persisted_pdu( - origin, event_id, do_auth=do_auth - ) - - def _transaction_from_pdus(self, pdu_list): - """Returns a new Transaction containing the given PDUs suitable for - transmission. - """ - time_now = self._clock.time_msec() - pdus = [p.get_pdu_json(time_now) for p in pdu_list] - return Transaction( - origin=self.server_name, - pdus=pdus, - origin_server_ts=int(time_now), - destination=None, - ) - - @defer.inlineCallbacks - @log_function - def _handle_new_pdu(self, origin, pdu, backfilled=False): - # We reprocess pdus when we have seen them only as outliers - existing = yield self._get_persisted_pdu( - origin, pdu.event_id, do_auth=False - ) - - already_seen = ( - existing and ( - not existing.internal_metadata.is_outlier() - or pdu.internal_metadata.is_outlier() - ) - ) - if already_seen: - logger.debug("Already seen pdu %s", pdu.event_id) - defer.returnValue({}) - return - - state = None - - auth_chain = [] - - # We need to make sure we have all the auth events. - # for e_id, _ in pdu.auth_events: - # exists = yield self._get_persisted_pdu( - # origin, - # e_id, - # do_auth=False - # ) - # - # if not exists: - # try: - # logger.debug( - # "_handle_new_pdu fetch missing auth event %s from %s", - # e_id, - # origin, - # ) - # - # yield self.get_pdu( - # origin, - # event_id=e_id, - # outlier=True, - # ) - # - # logger.debug("Processed pdu %s", e_id) - # except: - # logger.warn( - # "Failed to get auth event %s from %s", - # e_id, - # origin - # ) - - # Get missing pdus if necessary. - if not pdu.internal_metadata.is_outlier(): - # We only backfill backwards to the min depth. - min_depth = yield self.handler.get_min_depth_for_context( - pdu.room_id - ) - - logger.debug( - "_handle_new_pdu min_depth for %s: %d", - pdu.room_id, min_depth - ) - - if min_depth and pdu.depth > min_depth: - for event_id, hashes in pdu.prev_events: - exists = yield self._get_persisted_pdu( - origin, - event_id, - do_auth=False - ) - - if not exists: - logger.debug( - "_handle_new_pdu requesting pdu %s", - event_id - ) - - try: - yield self.get_pdu( - origin, - event_id=event_id, - ) - logger.debug("Processed pdu %s", event_id) - except: - # TODO(erikj): Do some more intelligent retries. - logger.exception("Failed to get PDU") - else: - # We need to get the state at this event, since we have reached - # a backward extremity edge. - logger.debug( - "_handle_new_pdu getting state for %s", - pdu.room_id - ) - state, auth_chain = yield self.get_state_for_room( - origin, pdu.room_id, pdu.event_id, - ) - - if not backfilled: - ret = yield self.handler.on_receive_pdu( - origin, - pdu, - backfilled=backfilled, - state=state, - auth_chain=auth_chain, - ) - else: - ret = None - - # yield self.pdu_actions.mark_as_processed(pdu) + self.transaction_actions = TransactionActions(self.store) + self._transaction_queue = TransactionQueue(hs, transport_layer) - defer.returnValue(ret) + self._order = 0 def __str__(self): return "<ReplicationLayer(%s)>" % self.server_name - - def event_from_pdu_json(self, pdu_json, outlier=False): - event = FrozenEvent( - pdu_json - ) - - event.internal_metadata.outlier = outlier - - return event - - -class _TransactionQueue(object): - """This class makes sure we only have one transaction in flight at - a time for a given destination. - - It batches pending PDUs into single transactions. - """ - - def __init__(self, hs, transaction_actions, transport_layer): - self.server_name = hs.hostname - self.transaction_actions = transaction_actions - self.transport_layer = transport_layer - - self._clock = hs.get_clock() - self.store = hs.get_datastore() - - # Is a mapping from destinations -> deferreds. Used to keep track - # of which destinations have transactions in flight and when they are - # done - self.pending_transactions = {} - - # Is a mapping from destination -> list of - # tuple(pending pdus, deferred, order) - self.pending_pdus_by_dest = {} - # destination -> list of tuple(edu, deferred) - self.pending_edus_by_dest = {} - - # destination -> list of tuple(failure, deferred) - self.pending_failures_by_dest = {} - - # HACK to get unique tx id - self._next_txn_id = int(self._clock.time_msec()) - - @defer.inlineCallbacks - @log_function - def enqueue_pdu(self, pdu, destinations, order): - # We loop through all destinations to see whether we already have - # a transaction in progress. If we do, stick it in the pending_pdus - # table and we'll get back to it later. - - destinations = set(destinations) - destinations.discard(self.server_name) - destinations.discard("localhost") - - logger.debug("Sending to: %s", str(destinations)) - - if not destinations: - return - - deferreds = [] - - for destination in destinations: - deferred = defer.Deferred() - self.pending_pdus_by_dest.setdefault(destination, []).append( - (pdu, deferred, order) - ) - - def eb(failure): - if not deferred.called: - deferred.errback(failure) - else: - logger.warn("Failed to send pdu", failure) - - with PreserveLoggingContext(): - self._attempt_new_transaction(destination).addErrback(eb) - - deferreds.append(deferred) - - yield defer.DeferredList(deferreds) - - # NO inlineCallbacks - def enqueue_edu(self, edu): - destination = edu.destination - - if destination == self.server_name: - return - - deferred = defer.Deferred() - self.pending_edus_by_dest.setdefault(destination, []).append( - (edu, deferred) - ) - - def eb(failure): - if not deferred.called: - deferred.errback(failure) - else: - logger.warn("Failed to send edu", failure) - - with PreserveLoggingContext(): - self._attempt_new_transaction(destination).addErrback(eb) - - return deferred - - @defer.inlineCallbacks - def enqueue_failure(self, failure, destination): - deferred = defer.Deferred() - - self.pending_failures_by_dest.setdefault( - destination, [] - ).append( - (failure, deferred) - ) - - yield deferred - - @defer.inlineCallbacks - @log_function - def _attempt_new_transaction(self, destination): - - (retry_last_ts, retry_interval) = (0, 0) - retry_timings = yield self.store.get_destination_retry_timings( - destination - ) - if retry_timings: - (retry_last_ts, retry_interval) = ( - retry_timings.retry_last_ts, retry_timings.retry_interval - ) - if retry_last_ts + retry_interval > int(self._clock.time_msec()): - logger.info( - "TX [%s] not ready for retry yet - " - "dropping transaction for now", - destination, - ) - return - else: - logger.info("TX [%s] is ready for retry", destination) - - logger.info("TX [%s] _attempt_new_transaction", destination) - - if destination in self.pending_transactions: - # XXX: pending_transactions can get stuck on by a never-ending - # request at which point pending_pdus_by_dest just keeps growing. - # we need application-layer timeouts of some flavour of these - # requests - return - - # list of (pending_pdu, deferred, order) - pending_pdus = self.pending_pdus_by_dest.pop(destination, []) - pending_edus = self.pending_edus_by_dest.pop(destination, []) - pending_failures = self.pending_failures_by_dest.pop(destination, []) - - if pending_pdus: - logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d", - destination, len(pending_pdus)) - - if not pending_pdus and not pending_edus and not pending_failures: - return - - logger.debug( - "TX [%s] Attempting new transaction" - " (pdus: %d, edus: %d, failures: %d)", - destination, - len(pending_pdus), - len(pending_edus), - len(pending_failures) - ) - - # Sort based on the order field - pending_pdus.sort(key=lambda t: t[2]) - - pdus = [x[0] for x in pending_pdus] - edus = [x[0] for x in pending_edus] - failures = [x[0].get_dict() for x in pending_failures] - deferreds = [ - x[1] - for x in pending_pdus + pending_edus + pending_failures - ] - - try: - self.pending_transactions[destination] = 1 - - logger.debug("TX [%s] Persisting transaction...", destination) - - transaction = Transaction.create_new( - origin_server_ts=int(self._clock.time_msec()), - transaction_id=str(self._next_txn_id), - origin=self.server_name, - destination=destination, - pdus=pdus, - edus=edus, - pdu_failures=failures, - ) - - self._next_txn_id += 1 - - yield self.transaction_actions.prepare_to_send(transaction) - - logger.debug("TX [%s] Persisted transaction", destination) - logger.info( - "TX [%s] Sending transaction [%s]", - destination, - transaction.transaction_id, - ) - - # Actually send the transaction - - # FIXME (erikj): This is a bit of a hack to make the Pdu age - # keys work - def json_data_cb(): - data = transaction.get_dict() - now = int(self._clock.time_msec()) - if "pdus" in data: - for p in data["pdus"]: - if "age_ts" in p: - unsigned = p.setdefault("unsigned", {}) - unsigned["age"] = now - int(p["age_ts"]) - del p["age_ts"] - return data - - code, response = yield self.transport_layer.send_transaction( - transaction, json_data_cb - ) - - logger.info("TX [%s] got %d response", destination, code) - - logger.debug("TX [%s] Sent transaction", destination) - logger.debug("TX [%s] Marking as delivered...", destination) - - yield self.transaction_actions.delivered( - transaction, code, response - ) - - logger.debug("TX [%s] Marked as delivered", destination) - logger.debug("TX [%s] Yielding to callbacks...", destination) - - for deferred in deferreds: - if code == 200: - if retry_last_ts: - # this host is alive! reset retry schedule - yield self.store.set_destination_retry_timings( - destination, 0, 0 - ) - deferred.callback(None) - else: - self.set_retrying(destination, retry_interval) - deferred.errback(RuntimeError("Got status %d" % code)) - - # Ensures we don't continue until all callbacks on that - # deferred have fired - try: - yield deferred - except: - pass - - logger.debug("TX [%s] Yielded to callbacks", destination) - - except Exception as e: - # We capture this here as there as nothing actually listens - # for this finishing functions deferred. - logger.warn( - "TX [%s] Problem in _attempt_transaction: %s", - destination, - e, - ) - - self.set_retrying(destination, retry_interval) - - for deferred in deferreds: - if not deferred.called: - deferred.errback(e) - - finally: - # We want to be *very* sure we delete this after we stop processing - self.pending_transactions.pop(destination, None) - - # Check to see if there is anything else to send. - self._attempt_new_transaction(destination) - - @defer.inlineCallbacks - def set_retrying(self, destination, retry_interval): - # track that this destination is having problems and we should - # give it a chance to recover before trying it again - - if retry_interval: - retry_interval *= 2 - # plateau at hourly retries for now - if retry_interval >= 60 * 60 * 1000: - retry_interval = 60 * 60 * 1000 - else: - retry_interval = 2000 # try again at first after 2 seconds - - yield self.store.set_destination_retry_timings( - destination, - int(self._clock.time_msec()), - retry_interval - ) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py new file mode 100644 index 0000000000..9d4f2c09a2 --- /dev/null +++ b/synapse/federation/transaction_queue.py @@ -0,0 +1,317 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from .persistence import TransactionActions +from .units import Transaction + +from synapse.util.logutils import log_function +from synapse.util.logcontext import PreserveLoggingContext + +import logging + + +logger = logging.getLogger(__name__) + + +class TransactionQueue(object): + """This class makes sure we only have one transaction in flight at + a time for a given destination. + + It batches pending PDUs into single transactions. + """ + + def __init__(self, hs, transport_layer): + self.server_name = hs.hostname + + self.store = hs.get_datastore() + self.transaction_actions = TransactionActions(self.store) + + self.transport_layer = transport_layer + + self._clock = hs.get_clock() + + # Is a mapping from destinations -> deferreds. Used to keep track + # of which destinations have transactions in flight and when they are + # done + self.pending_transactions = {} + + # Is a mapping from destination -> list of + # tuple(pending pdus, deferred, order) + self.pending_pdus_by_dest = {} + # destination -> list of tuple(edu, deferred) + self.pending_edus_by_dest = {} + + # destination -> list of tuple(failure, deferred) + self.pending_failures_by_dest = {} + + # HACK to get unique tx id + self._next_txn_id = int(self._clock.time_msec()) + + @defer.inlineCallbacks + @log_function + def enqueue_pdu(self, pdu, destinations, order): + # We loop through all destinations to see whether we already have + # a transaction in progress. If we do, stick it in the pending_pdus + # table and we'll get back to it later. + + destinations = set(destinations) + destinations.discard(self.server_name) + destinations.discard("localhost") + + logger.debug("Sending to: %s", str(destinations)) + + if not destinations: + return + + deferreds = [] + + for destination in destinations: + deferred = defer.Deferred() + self.pending_pdus_by_dest.setdefault(destination, []).append( + (pdu, deferred, order) + ) + + def eb(failure): + if not deferred.called: + deferred.errback(failure) + else: + logger.warn("Failed to send pdu", failure) + + with PreserveLoggingContext(): + self._attempt_new_transaction(destination).addErrback(eb) + + deferreds.append(deferred) + + yield defer.DeferredList(deferreds) + + # NO inlineCallbacks + def enqueue_edu(self, edu): + destination = edu.destination + + if destination == self.server_name: + return + + deferred = defer.Deferred() + self.pending_edus_by_dest.setdefault(destination, []).append( + (edu, deferred) + ) + + def eb(failure): + if not deferred.called: + deferred.errback(failure) + else: + logger.warn("Failed to send edu", failure) + + with PreserveLoggingContext(): + self._attempt_new_transaction(destination).addErrback(eb) + + return deferred + + @defer.inlineCallbacks + def enqueue_failure(self, failure, destination): + deferred = defer.Deferred() + + self.pending_failures_by_dest.setdefault( + destination, [] + ).append( + (failure, deferred) + ) + + yield deferred + + @defer.inlineCallbacks + @log_function + def _attempt_new_transaction(self, destination): + + (retry_last_ts, retry_interval) = (0, 0) + retry_timings = yield self.store.get_destination_retry_timings( + destination + ) + if retry_timings: + (retry_last_ts, retry_interval) = ( + retry_timings.retry_last_ts, retry_timings.retry_interval + ) + if retry_last_ts + retry_interval > int(self._clock.time_msec()): + logger.info( + "TX [%s] not ready for retry yet - " + "dropping transaction for now", + destination, + ) + return + else: + logger.info("TX [%s] is ready for retry", destination) + + logger.info("TX [%s] _attempt_new_transaction", destination) + + if destination in self.pending_transactions: + # XXX: pending_transactions can get stuck on by a never-ending + # request at which point pending_pdus_by_dest just keeps growing. + # we need application-layer timeouts of some flavour of these + # requests + return + + # list of (pending_pdu, deferred, order) + pending_pdus = self.pending_pdus_by_dest.pop(destination, []) + pending_edus = self.pending_edus_by_dest.pop(destination, []) + pending_failures = self.pending_failures_by_dest.pop(destination, []) + + if pending_pdus: + logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d", + destination, len(pending_pdus)) + + if not pending_pdus and not pending_edus and not pending_failures: + return + + logger.debug( + "TX [%s] Attempting new transaction" + " (pdus: %d, edus: %d, failures: %d)", + destination, + len(pending_pdus), + len(pending_edus), + len(pending_failures) + ) + + # Sort based on the order field + pending_pdus.sort(key=lambda t: t[2]) + + pdus = [x[0] for x in pending_pdus] + edus = [x[0] for x in pending_edus] + failures = [x[0].get_dict() for x in pending_failures] + deferreds = [ + x[1] + for x in pending_pdus + pending_edus + pending_failures + ] + + try: + self.pending_transactions[destination] = 1 + + logger.debug("TX [%s] Persisting transaction...", destination) + + transaction = Transaction.create_new( + origin_server_ts=int(self._clock.time_msec()), + transaction_id=str(self._next_txn_id), + origin=self.server_name, + destination=destination, + pdus=pdus, + edus=edus, + pdu_failures=failures, + ) + + self._next_txn_id += 1 + + yield self.transaction_actions.prepare_to_send(transaction) + + logger.debug("TX [%s] Persisted transaction", destination) + logger.info( + "TX [%s] Sending transaction [%s]", + destination, + transaction.transaction_id, + ) + + # Actually send the transaction + + # FIXME (erikj): This is a bit of a hack to make the Pdu age + # keys work + def json_data_cb(): + data = transaction.get_dict() + now = int(self._clock.time_msec()) + if "pdus" in data: + for p in data["pdus"]: + if "age_ts" in p: + unsigned = p.setdefault("unsigned", {}) + unsigned["age"] = now - int(p["age_ts"]) + del p["age_ts"] + return data + + code, response = yield self.transport_layer.send_transaction( + transaction, json_data_cb + ) + + logger.info("TX [%s] got %d response", destination, code) + + logger.debug("TX [%s] Sent transaction", destination) + logger.debug("TX [%s] Marking as delivered...", destination) + + yield self.transaction_actions.delivered( + transaction, code, response + ) + + logger.debug("TX [%s] Marked as delivered", destination) + logger.debug("TX [%s] Yielding to callbacks...", destination) + + for deferred in deferreds: + if code == 200: + if retry_last_ts: + # this host is alive! reset retry schedule + yield self.store.set_destination_retry_timings( + destination, 0, 0 + ) + deferred.callback(None) + else: + self.set_retrying(destination, retry_interval) + deferred.errback(RuntimeError("Got status %d" % code)) + + # Ensures we don't continue until all callbacks on that + # deferred have fired + try: + yield deferred + except: + pass + + logger.debug("TX [%s] Yielded to callbacks", destination) + + except Exception as e: + # We capture this here as there as nothing actually listens + # for this finishing functions deferred. + logger.warn( + "TX [%s] Problem in _attempt_transaction: %s", + destination, + e, + ) + + self.set_retrying(destination, retry_interval) + + for deferred in deferreds: + if not deferred.called: + deferred.errback(e) + + finally: + # We want to be *very* sure we delete this after we stop processing + self.pending_transactions.pop(destination, None) + + # Check to see if there is anything else to send. + self._attempt_new_transaction(destination) + + @defer.inlineCallbacks + def set_retrying(self, destination, retry_interval): + # track that this destination is having problems and we should + # give it a chance to recover before trying it again + + if retry_interval: + retry_interval *= 2 + # plateau at hourly retries for now + if retry_interval >= 60 * 60 * 1000: + retry_interval = 60 * 60 * 1000 + else: + retry_interval = 2000 # try again at first after 2 seconds + + yield self.store.set_destination_retry_timings( + destination, + int(self._clock.time_msec()), + retry_interval + ) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index e634a3a213..4cb1dea2de 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -213,3 +213,19 @@ class TransportLayerClient(object): ) defer.returnValue(response) + + @defer.inlineCallbacks + @log_function + def send_query_auth(self, destination, room_id, event_id, content): + path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id) + + code, content = yield self.client.post_json( + destination=destination, + path=path, + data=content, + ) + + if not 200 <= code < 300: + raise RuntimeError("Got %d from send_invite", code) + + defer.returnValue(json.loads(content)) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index a380a6910b..9c9f8d525b 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -42,7 +42,7 @@ class TransportLayerServer(object): content = None origin = None - if request.method == "PUT": + if request.method in ["PUT", "POST"]: # TODO: Handle other method types? other content types? try: content_bytes = request.content.read() @@ -234,6 +234,16 @@ class TransportLayerServer(object): ) ) ) + self.server.register_path( + "POST", + re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"), + self._with_authentication( + lambda origin, content, query, context, event_id: + self._on_query_auth_request( + origin, content, event_id, + ) + ) + ) @defer.inlineCallbacks @log_function @@ -325,3 +335,12 @@ class TransportLayerServer(object): ) defer.returnValue((200, content)) + + @defer.inlineCallbacks + @log_function + def _on_query_auth_request(self, origin, content, event_id): + new_content = yield self.request_handler.on_query_auth_request( + origin, content, event_id + ) + + defer.returnValue((200, new_content)) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index bcdcc90a18..cc22f21cd1 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -17,19 +17,16 @@ from ._base import BaseHandler -from synapse.events.utils import prune_event from synapse.api.errors import ( - AuthError, FederationError, SynapseError, StoreError, + AuthError, FederationError, StoreError, ) -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor from synapse.crypto.event_signing import ( - compute_event_signature, check_event_content_hash, - add_hashes_and_signatures, + compute_event_signature, add_hashes_and_signatures, ) from synapse.types import UserID -from syutil.jsonutil import encode_canonical_json from twisted.internet import defer @@ -113,33 +110,6 @@ class FederationHandler(BaseHandler): logger.debug("Processing event: %s", event.event_id) - redacted_event = prune_event(event) - - redacted_pdu_json = redacted_event.get_pdu_json() - try: - yield self.keyring.verify_json_for_server( - event.origin, redacted_pdu_json - ) - except SynapseError as e: - logger.warn( - "Signature check failed for %s redacted to %s", - encode_canonical_json(pdu.get_pdu_json()), - encode_canonical_json(redacted_pdu_json), - ) - raise FederationError( - "ERROR", - e.code, - e.msg, - affected=event.event_id, - ) - - if not check_event_content_hash(event): - logger.warn( - "Event content has been tampered, redacting %s, %s", - event.event_id, encode_canonical_json(event.get_dict()) - ) - event = redacted_event - logger.debug("Event: %s", event) # FIXME (erikj): Awful hack to make the case where we are not currently @@ -149,14 +119,14 @@ class FederationHandler(BaseHandler): event.room_id, self.server_name ) - if not is_in_room and not event.internal_metadata.outlier: + if not is_in_room and not event.internal_metadata.is_outlier(): logger.debug("Got event for room we're not in.") replication = self.replication_layer if not state: state, auth_chain = yield replication.get_state_for_room( - origin, context=event.room_id, event_id=event.event_id, + origin, room_id=event.room_id, event_id=event.event_id, ) if not auth_chain: @@ -169,7 +139,7 @@ class FederationHandler(BaseHandler): for e in auth_chain: e.internal_metadata.outlier = True try: - yield self._handle_new_event(e, fetch_auth_from=origin) + yield self._handle_new_event(origin, e) except: logger.exception( "Failed to handle auth event %s", @@ -180,10 +150,9 @@ class FederationHandler(BaseHandler): if state: for e in state: - logging.info("A :) %r", e) e.internal_metadata.outlier = True try: - yield self._handle_new_event(e) + yield self._handle_new_event(origin, e) except: logger.exception( "Failed to handle state event %s", @@ -192,6 +161,7 @@ class FederationHandler(BaseHandler): try: yield self._handle_new_event( + origin, event, state=state, backfilled=backfilled, @@ -394,7 +364,14 @@ class FederationHandler(BaseHandler): for e in auth_chain: e.internal_metadata.outlier = True try: - yield self._handle_new_event(e) + auth_ids = [e_id for e_id, _ in e.auth_events] + auth = { + (e.type, e.state_key): e for e in auth_chain + if e.event_id in auth_ids + } + yield self._handle_new_event( + target_host, e, auth_events=auth + ) except: logger.exception( "Failed to handle auth event %s", @@ -405,8 +382,13 @@ class FederationHandler(BaseHandler): # FIXME: Auth these. e.internal_metadata.outlier = True try: + auth_ids = [e_id for e_id, _ in e.auth_events] + auth = { + (e.type, e.state_key): e for e in auth_chain + if e.event_id in auth_ids + } yield self._handle_new_event( - e, fetch_auth_from=target_host + target_host, e, auth_events=auth ) except: logger.exception( @@ -415,6 +397,7 @@ class FederationHandler(BaseHandler): ) yield self._handle_new_event( + target_host, new_event, state=state, current_state=state, @@ -481,7 +464,7 @@ class FederationHandler(BaseHandler): event.internal_metadata.outlier = False - context = yield self._handle_new_event(event) + context = yield self._handle_new_event(origin, event) logger.debug( "on_send_join_request: After _handle_new_event: %s, sigs: %s", @@ -682,11 +665,12 @@ class FederationHandler(BaseHandler): waiters.pop().callback(None) @defer.inlineCallbacks - def _handle_new_event(self, event, state=None, backfilled=False, - current_state=None, fetch_auth_from=None): + @log_function + def _handle_new_event(self, origin, event, state=None, backfilled=False, + current_state=None, auth_events=None): logger.debug( - "_handle_new_event: Before annotate: %s, sigs: %s", + "_handle_new_event: %s, sigs: %s", event.event_id, event.signatures, ) @@ -694,65 +678,44 @@ class FederationHandler(BaseHandler): event, old_state=state ) + if not auth_events: + auth_events = context.auth_events + logger.debug( - "_handle_new_event: Before auth fetch: %s, sigs: %s", - event.event_id, event.signatures, + "_handle_new_event: %s, auth_events: %s", + event.event_id, auth_events, ) is_new_state = not event.internal_metadata.is_outlier() - known_ids = set( - [s.event_id for s in context.auth_events.values()] - ) - - for e_id, _ in event.auth_events: - if e_id not in known_ids: - e = yield self.store.get_event(e_id, allow_none=True) - - if not e and fetch_auth_from is not None: - # Grab the auth_chain over federation if we are missing - # auth events. - auth_chain = yield self.replication_layer.get_event_auth( - fetch_auth_from, event.event_id, event.room_id - ) - for auth_event in auth_chain: - yield self._handle_new_event(auth_event) - e = yield self.store.get_event(e_id, allow_none=True) - - if not e: - # TODO: Do some conflict res to make sure that we're - # not the ones who are wrong. - logger.info( - "Rejecting %s as %s not in db or %s", - event.event_id, e_id, known_ids, - ) - # FIXME: How does raising AuthError work with federation? - raise AuthError(403, "Cannot find auth event") - - context.auth_events[(e.type, e.state_key)] = e - - logger.debug( - "_handle_new_event: Before hack: %s, sigs: %s", - event.event_id, event.signatures, - ) - + # This is a hack to fix some old rooms where the initial join event + # didn't reference the create event in its auth events. if event.type == EventTypes.Member and not event.auth_events: if len(event.prev_events) == 1: c = yield self.store.get_event(event.prev_events[0][0]) if c.type == EventTypes.Create: - context.auth_events[(c.type, c.state_key)] = c + auth_events[(c.type, c.state_key)] = c - logger.debug( - "_handle_new_event: Before auth check: %s, sigs: %s", - event.event_id, event.signatures, - ) + try: + yield self.do_auth( + origin, event, context, auth_events=auth_events + ) + except AuthError as e: + logger.warn( + "Rejecting %s because %s", + event.event_id, e.msg + ) - self.auth.check(event, auth_events=context.auth_events) + context.rejected = RejectedReason.AUTH_ERROR - logger.debug( - "_handle_new_event: Before persist_event: %s, sigs: %s", - event.event_id, event.signatures, - ) + yield self.store.persist_event( + event, + context=context, + backfilled=backfilled, + is_new_state=False, + current_state=current_state, + ) + raise yield self.store.persist_event( event, @@ -762,9 +725,250 @@ class FederationHandler(BaseHandler): current_state=current_state, ) - logger.debug( - "_handle_new_event: After persist_event: %s, sigs: %s", - event.event_id, event.signatures, + defer.returnValue(context) + + @defer.inlineCallbacks + def on_query_auth(self, origin, event_id, remote_auth_chain, rejects, + missing): + # Just go through and process each event in `remote_auth_chain`. We + # don't want to fall into the trap of `missing` being wrong. + for e in remote_auth_chain: + try: + yield self._handle_new_event(origin, e) + except AuthError: + pass + + # Now get the current auth_chain for the event. + local_auth_chain = yield self.store.get_auth_chain([event_id]) + + # TODO: Check if we would now reject event_id. If so we need to tell + # everyone. + + ret = yield self.construct_auth_difference( + local_auth_chain, remote_auth_chain ) - defer.returnValue(context) + logger.debug("on_query_auth reutrning: %s", ret) + + defer.returnValue(ret) + + @defer.inlineCallbacks + @log_function + def do_auth(self, origin, event, context, auth_events): + # Check if we have all the auth events. + res = yield self.store.have_events( + [e_id for e_id, _ in event.auth_events] + ) + + event_auth_events = set(e_id for e_id, _ in event.auth_events) + seen_events = set(res.keys()) + + missing_auth = event_auth_events - seen_events + + if missing_auth: + logger.debug("Missing auth: %s", missing_auth) + # If we don't have all the auth events, we need to get them. + remote_auth_chain = yield self.replication_layer.get_event_auth( + origin, event.room_id, event.event_id + ) + + for e in remote_auth_chain: + try: + auth_ids = [e_id for e_id, _ in e.auth_events] + auth = { + (e.type, e.state_key): e for e in remote_auth_chain + if e.event_id in auth_ids + } + e.internal_metadata.outlier = True + yield self._handle_new_event( + origin, e, auth_events=auth + ) + auth_events[(e.type, e.state_key)] = e + except AuthError: + pass + + # FIXME: Assumes we have and stored all the state for all the + # prev_events + current_state = set(e.event_id for e in auth_events.values()) + different_auth = event_auth_events - current_state + + if different_auth and not event.internal_metadata.is_outlier(): + # Do auth conflict res. + logger.debug("Different auth: %s", different_auth) + + # 1. Get what we think is the auth chain. + auth_ids = self.auth.compute_auth_events(event, context) + local_auth_chain = yield self.store.get_auth_chain(auth_ids) + + # 2. Get remote difference. + result = yield self.replication_layer.query_auth( + origin, + event.room_id, + event.event_id, + local_auth_chain, + ) + + # 3. Process any remote auth chain events we haven't seen. + for e in result.get("missing", []): + try: + auth_ids = [e_id for e_id, _ in e.auth_events] + auth = { + (e.type, e.state_key): e for e in result["auth_chain"] + if e.event_id in auth_ids + } + e.internal_metadata.outlier = True + yield self._handle_new_event( + origin, e, auth_events=auth + ) + auth_events[(e.type, e.state_key)] = e + except AuthError: + pass + + # 4. Look at rejects and their proofs. + # TODO. + + context.current_state.update(auth_events) + context.state_group = None + + try: + self.auth.check(event, auth_events=auth_events) + except AuthError: + raise + + @defer.inlineCallbacks + def construct_auth_difference(self, local_auth, remote_auth): + """ Given a local and remote auth chain, find the differences. This + assumes that we have already processed all events in remote_auth + + Params: + local_auth (list) + remote_auth (list) + + Returns: + dict + """ + + logger.debug("construct_auth_difference Start!") + + # TODO: Make sure we are OK with local_auth or remote_auth having more + # auth events in them than strictly necessary. + + def sort_fun(ev): + return ev.depth, ev.event_id + + logger.debug("construct_auth_difference after sort_fun!") + + # We find the differences by starting at the "bottom" of each list + # and iterating up on both lists. The lists are ordered by depth and + # then event_id, we iterate up both lists until we find the event ids + # don't match. Then we look at depth/event_id to see which side is + # missing that event, and iterate only up that list. Repeat. + + remote_list = list(remote_auth) + remote_list.sort(key=sort_fun) + + local_list = list(local_auth) + local_list.sort(key=sort_fun) + + local_iter = iter(local_list) + remote_iter = iter(remote_list) + + logger.debug("construct_auth_difference before get_next!") + + def get_next(it, opt=None): + try: + return it.next() + except: + return opt + + current_local = get_next(local_iter) + current_remote = get_next(remote_iter) + + logger.debug("construct_auth_difference before while") + + missing_remotes = [] + missing_locals = [] + while current_local or current_remote: + if current_remote is None: + missing_locals.append(current_local) + current_local = get_next(local_iter) + continue + + if current_local is None: + missing_remotes.append(current_remote) + current_remote = get_next(remote_iter) + continue + + if current_local.event_id == current_remote.event_id: + current_local = get_next(local_iter) + current_remote = get_next(remote_iter) + continue + + if current_local.depth < current_remote.depth: + missing_locals.append(current_local) + current_local = get_next(local_iter) + continue + + if current_local.depth > current_remote.depth: + missing_remotes.append(current_remote) + current_remote = get_next(remote_iter) + continue + + # They have the same depth, so we fall back to the event_id order + if current_local.event_id < current_remote.event_id: + missing_locals.append(current_local) + current_local = get_next(local_iter) + + if current_local.event_id > current_remote.event_id: + missing_remotes.append(current_remote) + current_remote = get_next(remote_iter) + continue + + logger.debug("construct_auth_difference after while") + + # missing locals should be sent to the server + # We should find why we are missing remotes, as they will have been + # rejected. + + # Remove events from missing_remotes if they are referencing a missing + # remote. We only care about the "root" rejected ones. + missing_remote_ids = [e.event_id for e in missing_remotes] + base_remote_rejected = list(missing_remotes) + for e in missing_remotes: + for e_id, _ in e.auth_events: + if e_id in missing_remote_ids: + base_remote_rejected.remove(e) + + reason_map = {} + + for e in base_remote_rejected: + reason = yield self.store.get_rejection_reason(e.event_id) + if reason is None: + # FIXME: ERRR?! + logger.warn("Could not find reason for %s", e.event_id) + raise RuntimeError("") + + reason_map[e.event_id] = reason + + if reason == RejectedReason.AUTH_ERROR: + pass + elif reason == RejectedReason.REPLACED: + # TODO: Get proof + pass + elif reason == RejectedReason.NOT_ANCESTOR: + # TODO: Get proof. + pass + + logger.debug("construct_auth_difference returning") + + defer.returnValue({ + "auth_chain": local_auth, + "rejects": { + e.event_id: { + "reason": reason_map[e.event_id], + "proof": None, + } + for e in base_remote_rejected + }, + "missing": missing_locals, + }) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 1dda3ba2c7..c7bf1b47b8 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -245,6 +245,43 @@ class MatrixFederationHttpClient(object): defer.returnValue((response.code, body)) @defer.inlineCallbacks + def post_json(self, destination, path, data={}): + """ Sends the specifed json data using POST + + Args: + destination (str): The remote server to send the HTTP request + to. + path (str): The HTTP path. + data (dict): A dict containing the data that will be used as + the request body. This will be encoded as JSON. + + Returns: + Deferred: Succeeds when we get a 2xx HTTP response. The result + will be the decoded JSON body. On a 4xx or 5xx error response a + CodeMessageException is raised. + """ + + def body_callback(method, url_bytes, headers_dict): + self.sign_request( + destination, method, url_bytes, headers_dict, data + ) + return _JsonProducer(data) + + response = yield self._create_request( + destination.encode("ascii"), + "POST", + path.encode("ascii"), + body_callback=body_callback, + headers_dict={"Content-Type": ["application/json"]}, + ) + + logger.debug("Getting resp body") + body = yield readBody(response) + logger.debug("Got resp body") + + defer.returnValue((response.code, body)) + + @defer.inlineCallbacks def get_json(self, destination, path, args={}, retry_on_dns_fail=True): """ GETs some json from the given host homeserver and path diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 4182ad990f..ba9308803c 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) REQUIREMENTS = { "syutil==0.0.2": ["syutil"], "matrix_angular_sdk==0.6.0": ["syweb>=0.6.0"], + "matrix_angular_sdk>=0.6.0": ["syweb>=0.6.0"], "Twisted>=14.0.0": ["twisted>=14.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"], "pyopenssl>=0.14": ["OpenSSL>=0.14"], diff --git a/synapse/state.py b/synapse/state.py index 8144fa02b4..d9fdfb34be 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor from synapse.api.constants import EventTypes +from synapse.api.errors import AuthError from synapse.events.snapshot import EventContext from collections import namedtuple @@ -42,6 +43,8 @@ class StateHandler(object): def __init__(self, hs): self.store = hs.get_datastore() + # self.auth = hs.get_auth() + self.hs = hs @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key=""): @@ -210,64 +213,96 @@ class StateHandler(object): else: prev_states = [] + auth_events = { + k: e for k, e in unconflicted_state.items() + if k[0] in (EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,) + } + try: - new_state = {} - new_state.update(unconflicted_state) - for key, events in conflicted_state.items(): - new_state[key] = self._resolve_state_events(events) + resolved_state = self._resolve_state_events( + conflicted_state, auth_events + ) except: logger.exception("Failed to resolve state") raise + new_state = unconflicted_state + new_state.update(resolved_state) + defer.returnValue((None, new_state, prev_states)) - def _get_power_level_from_event_state(self, event, user_id): - if hasattr(event, "old_state_events") and event.old_state_events: - key = (EventTypes.PowerLevels, "", ) - power_level_event = event.old_state_events.get(key) - level = None - if power_level_event: - level = power_level_event.content.get("users", {}).get( - user_id + @log_function + def _resolve_state_events(self, conflicted_state, auth_events): + """ This is where we actually decide which of the conflicted state to + use. + + We resolve conflicts in the following order: + 1. power levels + 2. memberships + 3. other events. + + :param conflicted_state: + :param auth_events: + :return: + """ + resolved_state = {} + power_key = (EventTypes.PowerLevels, "") + if power_key in conflicted_state.items(): + power_levels = conflicted_state[power_key] + resolved_state[power_key] = self._resolve_auth_events(power_levels) + + auth_events.update(resolved_state) + + for key, events in conflicted_state.items(): + if key[0] == EventTypes.Member: + resolved_state[key] = self._resolve_auth_events( + events, + auth_events ) - if not level: - level = power_level_event.content.get("users_default", 0) - return level - else: - return 0 + auth_events.update(resolved_state) - @log_function - def _resolve_state_events(self, events): - curr_events = events + for key, events in conflicted_state.items(): + if key not in resolved_state: + resolved_state[key] = self._resolve_normal_events( + events, auth_events + ) - new_powers = [ - self._get_power_level_from_event_state(e, e.user_id) - for e in curr_events - ] + return resolved_state - new_powers = [ - int(p) if p else 0 for p in new_powers - ] + def _resolve_auth_events(self, events, auth_events): + reverse = [i for i in reversed(self._ordered_events(events))] - max_power = max(new_powers) + auth_events = dict(auth_events) - curr_events = [ - z[0] for z in zip(curr_events, new_powers) - if z[1] == max_power - ] + prev_event = reverse[0] + for event in reverse[1:]: + auth_events[(prev_event.type, prev_event.state_key)] = prev_event + try: + # FIXME: hs.get_auth() is bad style, but we need to do it to + # get around circular deps. + self.hs.get_auth().check(event, auth_events) + prev_event = event + except AuthError: + return prev_event - if not curr_events: - raise RuntimeError("Max didn't get a max?") - elif len(curr_events) == 1: - return curr_events[0] - - # TODO: For now, just choose the one with the largest event_id. - return ( - sorted( - curr_events, - key=lambda e: hashlib.sha1( - e.event_id + e.user_id + e.room_id + e.type - ).hexdigest() - )[0] - ) + return event + + def _resolve_normal_events(self, events, auth_events): + for event in self._ordered_events(events): + try: + # FIXME: hs.get_auth() is bad style, but we need to do it to + # get around circular deps. + self.hs.get_auth().check(event, auth_events) + return event + except AuthError: + pass + + # Oh dear. + return event + + def _ordered_events(self, events): + def key_func(e): + return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() + + return sorted(events, key=key_func) \ No newline at end of file diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 277581b4e2..adcb038020 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -32,6 +32,7 @@ from .event_federation import EventFederationStore from .pusher import PusherStore from .push_rule import PushRuleStore from .media_repository import MediaRepositoryStore +from .rejections import RejectionsStore from .state import StateStore from .signatures import SignatureStore @@ -85,6 +86,7 @@ class DataStore(RoomMemberStore, RoomStore, DirectoryStore, KeyStore, StateStore, SignatureStore, EventFederationStore, MediaRepositoryStore, + RejectionsStore, PusherStore, PushRuleStore ): @@ -229,6 +231,9 @@ class DataStore(RoomMemberStore, RoomStore, if not outlier: self._store_state_groups_txn(txn, event, context) + if context.rejected: + self._store_rejections_txn(txn, event.event_id, context.rejected) + if current_state: txn.execute( "DELETE FROM current_state_events WHERE room_id = ?", @@ -267,7 +272,7 @@ class DataStore(RoomMemberStore, RoomStore, or_replace=True, ) - if is_new_state: + if is_new_state and not context.rejected: self._simple_insert_txn( txn, "current_state_events", @@ -293,7 +298,7 @@ class DataStore(RoomMemberStore, RoomStore, or_ignore=True, ) - if not backfilled: + if not backfilled and not context.rejected: self._simple_insert_txn( txn, table="state_forward_extremities", @@ -457,6 +462,35 @@ class DataStore(RoomMemberStore, RoomStore, ], ) + def have_events(self, event_ids): + """Given a list of event ids, check if we have already processed them. + + Returns: + dict: Has an entry for each event id we already have seen. Maps to + the rejected reason string if we rejected the event, else maps to + None. + """ + def f(txn): + sql = ( + "SELECT e.event_id, reason FROM events as e " + "LEFT JOIN rejections as r ON e.event_id = r.event_id " + "WHERE e.event_id = ?" + ) + + res = {} + for event_id in event_ids: + txn.execute(sql, (event_id,)) + row = txn.fetchone() + if row: + _, rejected = row + res[event_id] = rejected + + return res + + return self.runInteraction( + "have_events", f, + ) + def schema_path(schema): """ Get a filesystem path for the named database schema diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 4e8bd3faa9..1f5e74a16a 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -502,10 +502,12 @@ class SQLBaseStore(object): return [e for e in events if e] def _get_event_txn(self, txn, event_id, check_redacted=True, - get_prev_content=False): + get_prev_content=False, allow_rejected=False): sql = ( - "SELECT internal_metadata, json, r.event_id FROM event_json as e " + "SELECT internal_metadata, json, r.event_id, reason " + "FROM event_json as e " "LEFT JOIN redactions as r ON e.event_id = r.redacts " + "LEFT JOIN rejections as rej on rej.event_id = e.event_id " "WHERE e.event_id = ? " "LIMIT 1 " ) @@ -517,13 +519,16 @@ class SQLBaseStore(object): if not res: return None - internal_metadata, js, redacted = res + internal_metadata, js, redacted, rejected_reason = res - return self._get_event_from_row_txn( - txn, internal_metadata, js, redacted, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - ) + if allow_rejected or not rejected_reason: + return self._get_event_from_row_txn( + txn, internal_metadata, js, redacted, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + ) + else: + return None def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted, check_redacted=True, get_prev_content=False): diff --git a/synapse/storage/rejections.py b/synapse/storage/rejections.py new file mode 100644 index 0000000000..4e1a9a2783 --- /dev/null +++ b/synapse/storage/rejections.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore + +import logging + +logger = logging.getLogger(__name__) + + +class RejectionsStore(SQLBaseStore): + def _store_rejections_txn(self, txn, event_id, reason): + self._simple_insert_txn( + txn, + table="rejections", + values={ + "event_id": event_id, + "reason": reason, + "last_check": self._clock.time_msec(), + } + ) + + def get_rejection_reason(self, event_id): + return self._simple_select_one_onecol( + table="rejections", + retcol="reason", + keyvalues={ + "event_id": event_id, + }, + allow_none=True, + ) diff --git a/synapse/storage/schema/delta/v12.sql b/synapse/storage/schema/delta/v12.sql index 8c4dfd5c1b..d83c3b049e 100644 --- a/synapse/storage/schema/delta/v12.sql +++ b/synapse/storage/schema/delta/v12.sql @@ -1,4 +1,8 @@ +<<<<<<< HEAD +/* Copyright 2015 OpenMarket Ltd +======= /* Copyright 2014 OpenMarket Ltd +>>>>>>> fc946f3b8da8c7f71a9c25bf542c04472147bc5b * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -12,6 +16,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +CREATE TABLE IF NOT EXISTS rejections( + event_id TEXT NOT NULL, + reason TEXT NOT NULL, + last_check TEXT NOT NULL, + root_rejected TEXT, + CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE +); + -- Push notification endpoints that users have configured CREATE TABLE IF NOT EXISTS pushers ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -44,3 +57,4 @@ CREATE TABLE IF NOT EXISTS push_rules ( ); CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name); + diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql index dd00c1cd2f..5866a387f6 100644 --- a/synapse/storage/schema/im.sql +++ b/synapse/storage/schema/im.sql @@ -123,3 +123,11 @@ CREATE TABLE IF NOT EXISTS room_hosts( ); CREATE INDEX IF NOT EXISTS room_hosts_room_id ON room_hosts (room_id); + +CREATE TABLE IF NOT EXISTS rejections( + event_id TEXT NOT NULL, + reason TEXT NOT NULL, + last_check TEXT NOT NULL, + root_rejected TEXT, + CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE +); |