diff --git a/setup.py b/setup.py
index 043cd044a7..733cfa8318 100755
--- a/setup.py
+++ b/setup.py
@@ -32,7 +32,7 @@ setup(
description="Reference Synapse Home Server",
install_requires=[
"syutil==0.0.2",
- "matrix_angular_sdk==0.6.0",
+ "matrix_angular_sdk==0.6.1",
"Twisted>=14.0.0",
"service_identity>=1.0.0",
"pyopenssl>=0.14",
@@ -47,7 +47,7 @@ setup(
dependency_links=[
"https://github.com/matrix-org/syutil/tarball/v0.0.2#egg=syutil-0.0.2",
"https://github.com/pyca/pynacl/tarball/d4d3175589b892f6ea7c22f466e0e223853516fa#egg=pynacl-0.3.0",
- "https://github.com/matrix-org/matrix-angular-sdk/tarball/v0.6.0/#egg=matrix_angular_sdk-0.6.0",
+ "https://github.com/matrix-org/matrix-angular-sdk/tarball/v0.6.0/#egg=matrix_angular_sdk-0.6.1",
],
setup_requires=[
"setuptools_trial",
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 9c03024512..3471afd7e7 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -358,9 +358,23 @@ class Auth(object):
def add_auth_events(self, builder, context):
yield run_on_reactor()
- if builder.type == EventTypes.Create:
- builder.auth_events = []
- return
+ auth_ids = self.compute_auth_events(builder, context)
+
+ auth_events_entries = yield self.store.add_event_hashes(
+ auth_ids
+ )
+
+ builder.auth_events = auth_events_entries
+
+ context.auth_events = {
+ k: v
+ for k, v in context.current_state.items()
+ if v.event_id in auth_ids
+ }
+
+ def compute_auth_events(self, event, context):
+ if event.type == EventTypes.Create:
+ return []
auth_ids = []
@@ -373,7 +387,7 @@ class Auth(object):
key = (EventTypes.JoinRules, "", )
join_rule_event = context.current_state.get(key)
- key = (EventTypes.Member, builder.user_id, )
+ key = (EventTypes.Member, event.user_id, )
member_event = context.current_state.get(key)
key = (EventTypes.Create, "", )
@@ -387,8 +401,8 @@ class Auth(object):
else:
is_public = False
- if builder.type == EventTypes.Member:
- e_type = builder.content["membership"]
+ if event.type == EventTypes.Member:
+ e_type = event.content["membership"]
if e_type in [Membership.JOIN, Membership.INVITE]:
if join_rule_event:
auth_ids.append(join_rule_event.event_id)
@@ -403,17 +417,7 @@ class Auth(object):
if member_event.content["membership"] == Membership.JOIN:
auth_ids.append(member_event.event_id)
- auth_events_entries = yield self.store.add_event_hashes(
- auth_ids
- )
-
- builder.auth_events = auth_events_entries
-
- context.auth_events = {
- k: v
- for k, v in context.current_state.items()
- if v.event_id in auth_ids
- }
+ return auth_ids
@log_function
def _can_send_event(self, event, auth_events):
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 7ee6dcc46e..0d3fc629af 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -74,3 +74,9 @@ class EventTypes(object):
Message = "m.room.message"
Topic = "m.room.topic"
Name = "m.room.name"
+
+
+class RejectedReason(object):
+ AUTH_ERROR = "auth_error"
+ REPLACED = "replaced"
+ NOT_ANCESTOR = "not_ancestor"
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index 9c910fa3fc..cdb6279764 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -61,9 +61,11 @@ class SynapseKeyClientProtocol(HTTPClient):
def __init__(self):
self.remote_key = defer.Deferred()
+ self.host = None
def connectionMade(self):
- logger.debug("Connected to %s", self.transport.getHost())
+ self.host = self.transport.getHost()
+ logger.debug("Connected to %s", self.host)
self.sendCommand(b"GET", b"/_matrix/key/v1/")
self.endHeaders()
self.timer = reactor.callLater(
@@ -92,8 +94,7 @@ class SynapseKeyClientProtocol(HTTPClient):
self.timer.cancel()
def on_timeout(self):
- logger.debug("Timeout waiting for response from %s",
- self.transport.getHost())
+ logger.debug("Timeout waiting for response from %s", self.host)
self.remote_key.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection()
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 4252e5ab5c..bf07951027 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -18,7 +18,7 @@ from synapse.util.frozenutils import freeze, unfreeze
class _EventInternalMetadata(object):
def __init__(self, internal_metadata_dict):
- self.__dict__ = internal_metadata_dict
+ self.__dict__ = dict(internal_metadata_dict)
def get_dict(self):
return dict(self.__dict__)
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index a9b1b99a10..9d45bdb892 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -23,14 +23,15 @@ import copy
class EventBuilder(EventBase):
- def __init__(self, key_values={}):
+ def __init__(self, key_values={}, internal_metadata_dict={}):
signatures = copy.deepcopy(key_values.pop("signatures", {}))
unsigned = copy.deepcopy(key_values.pop("unsigned", {}))
super(EventBuilder, self).__init__(
key_values,
signatures=signatures,
- unsigned=unsigned
+ unsigned=unsigned,
+ internal_metadata_dict=internal_metadata_dict,
)
def build(self):
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 6bbba8d6ba..7e98bdef28 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -20,3 +20,4 @@ class EventContext(object):
self.current_state = current_state
self.auth_events = auth_events
self.state_group = None
+ self.rejected = False
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index e391aca4cc..7ae5d42b96 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -45,12 +45,14 @@ def prune_event(event):
"membership",
]
+ event_dict = event.get_dict()
+
new_content = {}
def add_fields(*fields):
for field in fields:
if field in event.content:
- new_content[field] = event.content[field]
+ new_content[field] = event_dict["content"][field]
if event_type == EventTypes.Member:
add_fields("membership")
@@ -75,7 +77,7 @@ def prune_event(event):
allowed_fields = {
k: v
- for k, v in event.get_dict().items()
+ for k, v in event_dict.items()
if k in allowed_keys
}
@@ -86,7 +88,10 @@ def prune_event(event):
if "age_ts" in event.unsigned:
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
- return type(event)(allowed_fields)
+ return type(event)(
+ allowed_fields,
+ internal_metadata_dict=event.internal_metadata.get_dict()
+ )
def serialize_event(e, time_now_ms, client_event=True):
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
new file mode 100644
index 0000000000..1173ca817b
--- /dev/null
+++ b/synapse/federation/federation_client.py
@@ -0,0 +1,414 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer
+
+from .units import Edu
+
+from synapse.util.logutils import log_function
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
+
+from syutil.jsonutil import encode_canonical_json
+
+from synapse.crypto.event_signing import check_event_content_hash
+
+from synapse.api.errors import SynapseError
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class FederationClient(object):
+ @log_function
+ def send_pdu(self, pdu, destinations):
+ """Informs the replication layer about a new PDU generated within the
+ home server that should be transmitted to others.
+
+ TODO: Figure out when we should actually resolve the deferred.
+
+ Args:
+ pdu (Pdu): The new Pdu.
+
+ Returns:
+ Deferred: Completes when we have successfully processed the PDU
+ and replicated it to any interested remote home servers.
+ """
+ order = self._order
+ self._order += 1
+
+ logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
+
+ # TODO, add errback, etc.
+ self._transaction_queue.enqueue_pdu(pdu, destinations, order)
+
+ logger.debug(
+ "[%s] transaction_layer.enqueue_pdu... done",
+ pdu.event_id
+ )
+
+ @log_function
+ def send_edu(self, destination, edu_type, content):
+ edu = Edu(
+ origin=self.server_name,
+ destination=destination,
+ edu_type=edu_type,
+ content=content,
+ )
+
+ # TODO, add errback, etc.
+ self._transaction_queue.enqueue_edu(edu)
+ return defer.succeed(None)
+
+ @log_function
+ def send_failure(self, failure, destination):
+ self._transaction_queue.enqueue_failure(failure, destination)
+ return defer.succeed(None)
+
+ @log_function
+ def make_query(self, destination, query_type, args,
+ retry_on_dns_fail=True):
+ """Sends a federation Query to a remote homeserver of the given type
+ and arguments.
+
+ Args:
+ destination (str): Domain name of the remote homeserver
+ query_type (str): Category of the query type; should match the
+ handler name used in register_query_handler().
+ args (dict): Mapping of strings to strings containing the details
+ of the query request.
+
+ Returns:
+ a Deferred which will eventually yield a JSON object from the
+ response
+ """
+ return self.transport_layer.make_query(
+ destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
+ )
+
+ @defer.inlineCallbacks
+ @log_function
+ def backfill(self, dest, context, limit, extremities):
+ """Requests some more historic PDUs for the given context from the
+ given destination server.
+
+ Args:
+ dest (str): The remote home server to ask.
+ context (str): The context to backfill.
+ limit (int): The maximum number of PDUs to return.
+ extremities (list): List of PDU id and origins of the first pdus
+ we have seen from the context
+
+ Returns:
+ Deferred: Results in the received PDUs.
+ """
+ logger.debug("backfill extrem=%s", extremities)
+
+ # If there are no extremeties then we've (probably) reached the start.
+ if not extremities:
+ return
+
+ transaction_data = yield self.transport_layer.backfill(
+ dest, context, extremities, limit)
+
+ logger.debug("backfill transaction_data=%s", repr(transaction_data))
+
+ pdus = [
+ self.event_from_pdu_json(p, outlier=False)
+ for p in transaction_data["pdus"]
+ ]
+
+ for i, pdu in enumerate(pdus):
+ pdus[i] = yield self._check_sigs_and_hash(pdu)
+
+ # FIXME: We should handle signature failures more gracefully.
+
+ defer.returnValue(pdus)
+
+ @defer.inlineCallbacks
+ @log_function
+ def get_pdu(self, destinations, event_id, outlier=False):
+ """Requests the PDU with given origin and ID from the remote home
+ servers.
+
+ Will attempt to get the PDU from each destination in the list until
+ one succeeds.
+
+ This will persist the PDU locally upon receipt.
+
+ Args:
+ destinations (list): Which home servers to query
+ pdu_origin (str): The home server that originally sent the pdu.
+ event_id (str)
+ outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
+ it's from an arbitary point in the context as opposed to part
+ of the current block of PDUs. Defaults to `False`
+
+ Returns:
+ Deferred: Results in the requested PDU.
+ """
+
+ # TODO: Rate limit the number of times we try and get the same event.
+
+ pdu = None
+ for destination in destinations:
+ try:
+ transaction_data = yield self.transport_layer.get_event(
+ destination, event_id
+ )
+
+ logger.debug("transaction_data %r", transaction_data)
+
+ pdu_list = [
+ self.event_from_pdu_json(p, outlier=outlier)
+ for p in transaction_data["pdus"]
+ ]
+
+ if pdu_list:
+ pdu = pdu_list[0]
+
+ # Check signatures are correct.
+ pdu = yield self._check_sigs_and_hash(pdu)
+
+ break
+
+ except Exception as e:
+ logger.info(
+ "Failed to get PDU %s from %s because %s",
+ event_id, destination, e,
+ )
+ continue
+
+ defer.returnValue(pdu)
+
+ @defer.inlineCallbacks
+ @log_function
+ def get_state_for_room(self, destination, room_id, event_id):
+ """Requests all of the `current` state PDUs for a given room from
+ a remote home server.
+
+ Args:
+ destination (str): The remote homeserver to query for the state.
+ room_id (str): The id of the room we're interested in.
+ event_id (str): The id of the event we want the state at.
+
+ Returns:
+ Deferred: Results in a list of PDUs.
+ """
+
+ result = yield self.transport_layer.get_room_state(
+ destination, room_id, event_id=event_id,
+ )
+
+ pdus = [
+ self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
+ ]
+
+ auth_chain = [
+ self.event_from_pdu_json(p, outlier=True)
+ for p in result.get("auth_chain", [])
+ ]
+
+ for i, pdu in enumerate(pdus):
+ pdus[i] = yield self._check_sigs_and_hash(pdu)
+
+ # FIXME: We should handle signature failures more gracefully.
+
+ for i, pdu in enumerate(auth_chain):
+ auth_chain[i] = yield self._check_sigs_and_hash(pdu)
+
+ # FIXME: We should handle signature failures more gracefully.
+
+ defer.returnValue((pdus, auth_chain))
+
+ @defer.inlineCallbacks
+ @log_function
+ def get_event_auth(self, destination, room_id, event_id):
+ res = yield self.transport_layer.get_event_auth(
+ destination, room_id, event_id,
+ )
+
+ auth_chain = [
+ self.event_from_pdu_json(p, outlier=True)
+ for p in res["auth_chain"]
+ ]
+
+ for i, pdu in enumerate(auth_chain):
+ auth_chain[i] = yield self._check_sigs_and_hash(pdu)
+
+ # FIXME: We should handle signature failures more gracefully.
+
+ auth_chain.sort(key=lambda e: e.depth)
+
+ defer.returnValue(auth_chain)
+
+ @defer.inlineCallbacks
+ def make_join(self, destination, room_id, user_id):
+ ret = yield self.transport_layer.make_join(
+ destination, room_id, user_id
+ )
+
+ pdu_dict = ret["event"]
+
+ logger.debug("Got response to make_join: %s", pdu_dict)
+
+ defer.returnValue(self.event_from_pdu_json(pdu_dict))
+
+ @defer.inlineCallbacks
+ def send_join(self, destination, pdu):
+ time_now = self._clock.time_msec()
+ _, content = yield self.transport_layer.send_join(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+
+ logger.debug("Got content: %s", content)
+
+ state = [
+ self.event_from_pdu_json(p, outlier=True)
+ for p in content.get("state", [])
+ ]
+
+ auth_chain = [
+ self.event_from_pdu_json(p, outlier=True)
+ for p in content.get("auth_chain", [])
+ ]
+
+ for i, pdu in enumerate(state):
+ state[i] = yield self._check_sigs_and_hash(pdu)
+
+ # FIXME: We should handle signature failures more gracefully.
+
+ for i, pdu in enumerate(auth_chain):
+ auth_chain[i] = yield self._check_sigs_and_hash(pdu)
+
+ # FIXME: We should handle signature failures more gracefully.
+
+ auth_chain.sort(key=lambda e: e.depth)
+
+ defer.returnValue({
+ "state": state,
+ "auth_chain": auth_chain,
+ })
+
+ @defer.inlineCallbacks
+ def send_invite(self, destination, room_id, event_id, pdu):
+ time_now = self._clock.time_msec()
+ code, content = yield self.transport_layer.send_invite(
+ destination=destination,
+ room_id=room_id,
+ event_id=event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+
+ pdu_dict = content["event"]
+
+ logger.debug("Got response to send_invite: %s", pdu_dict)
+
+ pdu = self.event_from_pdu_json(pdu_dict)
+
+ # Check signatures are correct.
+ pdu = yield self._check_sigs_and_hash(pdu)
+
+ # FIXME: We should handle signature failures more gracefully.
+
+ defer.returnValue(pdu)
+
+ @defer.inlineCallbacks
+ def query_auth(self, destination, room_id, event_id, local_auth):
+ """
+ Params:
+ destination (str)
+ event_it (str)
+ local_auth (list)
+ """
+ time_now = self._clock.time_msec()
+
+ send_content = {
+ "auth_chain": [e.get_pdu_json(time_now) for e in local_auth],
+ }
+
+ code, content = yield self.transport_layer.send_query_auth(
+ destination=destination,
+ room_id=room_id,
+ event_id=event_id,
+ content=send_content,
+ )
+
+ auth_chain = [
+ (yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
+ for e in content["auth_chain"]
+ ]
+
+ missing = [
+ (yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
+ for e in content.get("missing", [])
+ ]
+
+ ret = {
+ "auth_chain": auth_chain,
+ "rejects": content.get("rejects", []),
+ "missing": missing,
+ }
+
+ defer.returnValue(ret)
+
+ def event_from_pdu_json(self, pdu_json, outlier=False):
+ event = FrozenEvent(
+ pdu_json
+ )
+
+ event.internal_metadata.outlier = outlier
+
+ return event
+
+ @defer.inlineCallbacks
+ def _check_sigs_and_hash(self, pdu):
+ """Throws a SynapseError if the PDU does not have the correct
+ signatures.
+
+ Returns:
+ FrozenEvent: Either the given event or it redacted if it failed the
+ content hash check.
+ """
+ # Check signatures are correct.
+ redacted_event = prune_event(pdu)
+ redacted_pdu_json = redacted_event.get_pdu_json()
+
+ try:
+ yield self.keyring.verify_json_for_server(
+ pdu.origin, redacted_pdu_json
+ )
+ except SynapseError:
+ logger.warn(
+ "Signature check failed for %s redacted to %s",
+ encode_canonical_json(pdu.get_pdu_json()),
+ encode_canonical_json(redacted_pdu_json),
+ )
+ raise
+
+ if not check_event_content_hash(pdu):
+ logger.warn(
+ "Event content has been tampered, redacting %s, %s",
+ pdu.event_id, encode_canonical_json(pdu.get_dict())
+ )
+ defer.returnValue(redacted_event)
+
+ defer.returnValue(pdu)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
new file mode 100644
index 0000000000..845a07a3a3
--- /dev/null
+++ b/synapse/federation/federation_server.py
@@ -0,0 +1,444 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer
+
+from .units import Transaction, Edu
+
+from synapse.util.logutils import log_function
+from synapse.util.logcontext import PreserveLoggingContext
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
+
+from syutil.jsonutil import encode_canonical_json
+
+from synapse.crypto.event_signing import check_event_content_hash
+
+from synapse.api.errors import FederationError, SynapseError
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class FederationServer(object):
+ def set_handler(self, handler):
+ """Sets the handler that the replication layer will use to communicate
+ receipt of new PDUs from other home servers. The required methods are
+ documented on :py:class:`.ReplicationHandler`.
+ """
+ self.handler = handler
+
+ def register_edu_handler(self, edu_type, handler):
+ if edu_type in self.edu_handlers:
+ raise KeyError("Already have an EDU handler for %s" % (edu_type,))
+
+ self.edu_handlers[edu_type] = handler
+
+ def register_query_handler(self, query_type, handler):
+ """Sets the handler callable that will be used to handle an incoming
+ federation Query of the given type.
+
+ Args:
+ query_type (str): Category name of the query, which should match
+ the string used by make_query.
+ handler (callable): Invoked to handle incoming queries of this type
+
+ handler is invoked as:
+ result = handler(args)
+
+ where 'args' is a dict mapping strings to strings of the query
+ arguments. It should return a Deferred that will eventually yield an
+ object to encode as JSON.
+ """
+ if query_type in self.query_handlers:
+ raise KeyError(
+ "Already have a Query handler for %s" % (query_type,)
+ )
+
+ self.query_handlers[query_type] = handler
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_backfill_request(self, origin, room_id, versions, limit):
+ pdus = yield self.handler.on_backfill_request(
+ origin, room_id, versions, limit
+ )
+
+ defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_incoming_transaction(self, transaction_data):
+ transaction = Transaction(**transaction_data)
+
+ for p in transaction.pdus:
+ if "unsigned" in p:
+ unsigned = p["unsigned"]
+ if "age" in unsigned:
+ p["age"] = unsigned["age"]
+ if "age" in p:
+ p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
+ del p["age"]
+
+ pdu_list = [
+ self.event_from_pdu_json(p) for p in transaction.pdus
+ ]
+
+ logger.debug("[%s] Got transaction", transaction.transaction_id)
+
+ response = yield self.transaction_actions.have_responded(transaction)
+
+ if response:
+ logger.debug(
+ "[%s] We've already responed to this request",
+ transaction.transaction_id
+ )
+ defer.returnValue(response)
+ return
+
+ logger.debug("[%s] Transaction is new", transaction.transaction_id)
+
+ with PreserveLoggingContext():
+ dl = []
+ for pdu in pdu_list:
+ dl.append(self._handle_new_pdu(transaction.origin, pdu))
+
+ if hasattr(transaction, "edus"):
+ for edu in [Edu(**x) for x in transaction.edus]:
+ self.received_edu(
+ transaction.origin,
+ edu.edu_type,
+ edu.content
+ )
+
+ results = yield defer.DeferredList(dl)
+
+ ret = []
+ for r in results:
+ if r[0]:
+ ret.append({})
+ else:
+ logger.exception(r[1])
+ ret.append({"error": str(r[1])})
+
+ logger.debug("Returning: %s", str(ret))
+
+ yield self.transaction_actions.set_response(
+ transaction,
+ 200, response
+ )
+ defer.returnValue((200, response))
+
+ def received_edu(self, origin, edu_type, content):
+ if edu_type in self.edu_handlers:
+ self.edu_handlers[edu_type](origin, content)
+ else:
+ logger.warn("Received EDU of type %s with no handler", edu_type)
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_context_state_request(self, origin, room_id, event_id):
+ if event_id:
+ pdus = yield self.handler.get_state_for_pdu(
+ origin, room_id, event_id,
+ )
+ auth_chain = yield self.store.get_auth_chain(
+ [pdu.event_id for pdu in pdus]
+ )
+ else:
+ raise NotImplementedError("Specify an event")
+
+ defer.returnValue((200, {
+ "pdus": [pdu.get_pdu_json() for pdu in pdus],
+ "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
+ }))
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_pdu_request(self, origin, event_id):
+ pdu = yield self._get_persisted_pdu(origin, event_id)
+
+ if pdu:
+ defer.returnValue(
+ (200, self._transaction_from_pdus([pdu]).get_dict())
+ )
+ else:
+ defer.returnValue((404, ""))
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_pull_request(self, origin, versions):
+ raise NotImplementedError("Pull transactions not implemented")
+
+ @defer.inlineCallbacks
+ def on_query_request(self, query_type, args):
+ if query_type in self.query_handlers:
+ response = yield self.query_handlers[query_type](args)
+ defer.returnValue((200, response))
+ else:
+ defer.returnValue(
+ (404, "No handler for Query type '%s'" % (query_type,))
+ )
+
+ @defer.inlineCallbacks
+ def on_make_join_request(self, room_id, user_id):
+ pdu = yield self.handler.on_make_join_request(room_id, user_id)
+ time_now = self._clock.time_msec()
+ defer.returnValue({"event": pdu.get_pdu_json(time_now)})
+
+ @defer.inlineCallbacks
+ def on_invite_request(self, origin, content):
+ pdu = self.event_from_pdu_json(content)
+ ret_pdu = yield self.handler.on_invite_request(origin, pdu)
+ time_now = self._clock.time_msec()
+ defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
+
+ @defer.inlineCallbacks
+ def on_send_join_request(self, origin, content):
+ logger.debug("on_send_join_request: content: %s", content)
+ pdu = self.event_from_pdu_json(content)
+ logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
+ res_pdus = yield self.handler.on_send_join_request(origin, pdu)
+ time_now = self._clock.time_msec()
+ defer.returnValue((200, {
+ "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
+ "auth_chain": [
+ p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
+ ],
+ }))
+
+ @defer.inlineCallbacks
+ def on_event_auth(self, origin, room_id, event_id):
+ time_now = self._clock.time_msec()
+ auth_pdus = yield self.handler.on_event_auth(event_id)
+ defer.returnValue((200, {
+ "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
+ }))
+
+ @defer.inlineCallbacks
+ def on_query_auth_request(self, origin, content, event_id):
+ auth_chain = [
+ (yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
+ for e in content["auth_chain"]
+ ]
+
+ missing = [
+ (yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
+ for e in content.get("missing", [])
+ ]
+
+ ret = yield self.handler.on_query_auth(
+ origin, event_id, auth_chain, content.get("rejects", []), missing
+ )
+
+ time_now = self._clock.time_msec()
+ send_content = {
+ "auth_chain": [
+ e.get_pdu_json(time_now)
+ for e in ret["auth_chain"]
+ ],
+ "rejects": content.get("rejects", []),
+ "missing": [
+ e.get_pdu_json(time_now)
+ for e in ret.get("missing", [])
+ ],
+ }
+
+ defer.returnValue(
+ (200, send_content)
+ )
+
+ @log_function
+ def _get_persisted_pdu(self, origin, event_id, do_auth=True):
+ """ Get a PDU from the database with given origin and id.
+
+ Returns:
+ Deferred: Results in a `Pdu`.
+ """
+ return self.handler.get_persisted_pdu(
+ origin, event_id, do_auth=do_auth
+ )
+
+ def _transaction_from_pdus(self, pdu_list):
+ """Returns a new Transaction containing the given PDUs suitable for
+ transmission.
+ """
+ time_now = self._clock.time_msec()
+ pdus = [p.get_pdu_json(time_now) for p in pdu_list]
+ return Transaction(
+ origin=self.server_name,
+ pdus=pdus,
+ origin_server_ts=int(time_now),
+ destination=None,
+ )
+
+ @defer.inlineCallbacks
+ @log_function
+ def _handle_new_pdu(self, origin, pdu, max_recursion=10):
+ # We reprocess pdus when we have seen them only as outliers
+ existing = yield self._get_persisted_pdu(
+ origin, pdu.event_id, do_auth=False
+ )
+
+ # FIXME: Currently we fetch an event again when we already have it
+ # if it has been marked as an outlier.
+
+ already_seen = (
+ existing and (
+ not existing.internal_metadata.is_outlier()
+ or pdu.internal_metadata.is_outlier()
+ )
+ )
+ if already_seen:
+ logger.debug("Already seen pdu %s", pdu.event_id)
+ defer.returnValue({})
+ return
+
+ # Check signature.
+ try:
+ pdu = yield self._check_sigs_and_hash(pdu)
+ except SynapseError as e:
+ raise FederationError(
+ "ERROR",
+ e.code,
+ e.msg,
+ affected=pdu.event_id,
+ )
+
+ state = None
+
+ auth_chain = []
+
+ have_seen = yield self.store.have_events(
+ [ev for ev, _ in pdu.prev_events]
+ )
+
+ fetch_state = False
+
+ # Get missing pdus if necessary.
+ if not pdu.internal_metadata.is_outlier():
+ # We only backfill backwards to the min depth.
+ min_depth = yield self.handler.get_min_depth_for_context(
+ pdu.room_id
+ )
+
+ logger.debug(
+ "_handle_new_pdu min_depth for %s: %d",
+ pdu.room_id, min_depth
+ )
+
+ if min_depth and pdu.depth > min_depth and max_recursion > 0:
+ for event_id, hashes in pdu.prev_events:
+ if event_id not in have_seen:
+ logger.debug(
+ "_handle_new_pdu requesting pdu %s",
+ event_id
+ )
+
+ try:
+ new_pdu = yield self.federation_client.get_pdu(
+ [origin, pdu.origin],
+ event_id=event_id,
+ )
+
+ if new_pdu:
+ yield self._handle_new_pdu(
+ origin,
+ new_pdu,
+ max_recursion=max_recursion-1
+ )
+
+ logger.debug("Processed pdu %s", event_id)
+ else:
+ logger.warn("Failed to get PDU %s", event_id)
+ fetch_state = True
+ except:
+ # TODO(erikj): Do some more intelligent retries.
+ logger.exception("Failed to get PDU")
+ fetch_state = True
+ else:
+ fetch_state = True
+ else:
+ fetch_state = True
+
+ if fetch_state:
+ # We need to get the state at this event, since we haven't
+ # processed all the prev events.
+ logger.debug(
+ "_handle_new_pdu getting state for %s",
+ pdu.room_id
+ )
+ state, auth_chain = yield self.get_state_for_room(
+ origin, pdu.room_id, pdu.event_id,
+ )
+
+ ret = yield self.handler.on_receive_pdu(
+ origin,
+ pdu,
+ backfilled=False,
+ state=state,
+ auth_chain=auth_chain,
+ )
+
+ defer.returnValue(ret)
+
+ def __str__(self):
+ return "<ReplicationLayer(%s)>" % self.server_name
+
+ def event_from_pdu_json(self, pdu_json, outlier=False):
+ event = FrozenEvent(
+ pdu_json
+ )
+
+ event.internal_metadata.outlier = outlier
+
+ return event
+
+ @defer.inlineCallbacks
+ def _check_sigs_and_hash(self, pdu):
+ """Throws a SynapseError if the PDU does not have the correct
+ signatures.
+
+ Returns:
+ FrozenEvent: Either the given event or it redacted if it failed the
+ content hash check.
+ """
+ # Check signatures are correct.
+ redacted_event = prune_event(pdu)
+ redacted_pdu_json = redacted_event.get_pdu_json()
+
+ try:
+ yield self.keyring.verify_json_for_server(
+ pdu.origin, redacted_pdu_json
+ )
+ except SynapseError:
+ logger.warn(
+ "Signature check failed for %s redacted to %s",
+ encode_canonical_json(pdu.get_pdu_json()),
+ encode_canonical_json(redacted_pdu_json),
+ )
+ raise
+
+ if not check_event_content_hash(pdu):
+ logger.warn(
+ "Event content has been tampered, redacting %s, %s",
+ pdu.event_id, encode_canonical_json(pdu.get_dict())
+ )
+ defer.returnValue(redacted_event)
+
+ defer.returnValue(pdu)
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 6620532a60..e442c6c5d5 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -17,23 +17,20 @@
a given transport.
"""
-from twisted.internet import defer
+from .federation_client import FederationClient
+from .federation_server import FederationServer
-from .units import Transaction, Edu
+from .transaction_queue import TransactionQueue
from .persistence import TransactionActions
-from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext
-from synapse.events import FrozenEvent
-
import logging
logger = logging.getLogger(__name__)
-class ReplicationLayer(object):
+class ReplicationLayer(FederationClient, FederationServer):
"""This layer is responsible for replicating with remote home servers over
the given transport. I.e., does the sending and receiving of PDUs to
remote home servers.
@@ -54,898 +51,26 @@ class ReplicationLayer(object):
def __init__(self, hs, transport_layer):
self.server_name = hs.hostname
+ self.keyring = hs.get_keyring()
+
self.transport_layer = transport_layer
self.transport_layer.register_received_handler(self)
self.transport_layer.register_request_handler(self)
- self.store = hs.get_datastore()
- # self.pdu_actions = PduActions(self.store)
- self.transaction_actions = TransactionActions(self.store)
+ self.federation_client = self
- self._transaction_queue = _TransactionQueue(
- hs, self.transaction_actions, transport_layer
- )
+ self.store = hs.get_datastore()
self.handler = None
self.edu_handlers = {}
self.query_handlers = {}
- self._order = 0
-
self._clock = hs.get_clock()
- self.event_builder_factory = hs.get_event_builder_factory()
-
- def set_handler(self, handler):
- """Sets the handler that the replication layer will use to communicate
- receipt of new PDUs from other home servers. The required methods are
- documented on :py:class:`.ReplicationHandler`.
- """
- self.handler = handler
-
- def register_edu_handler(self, edu_type, handler):
- if edu_type in self.edu_handlers:
- raise KeyError("Already have an EDU handler for %s" % (edu_type,))
-
- self.edu_handlers[edu_type] = handler
-
- def register_query_handler(self, query_type, handler):
- """Sets the handler callable that will be used to handle an incoming
- federation Query of the given type.
-
- Args:
- query_type (str): Category name of the query, which should match
- the string used by make_query.
- handler (callable): Invoked to handle incoming queries of this type
-
- handler is invoked as:
- result = handler(args)
-
- where 'args' is a dict mapping strings to strings of the query
- arguments. It should return a Deferred that will eventually yield an
- object to encode as JSON.
- """
- if query_type in self.query_handlers:
- raise KeyError(
- "Already have a Query handler for %s" % (query_type,)
- )
-
- self.query_handlers[query_type] = handler
-
- @log_function
- def send_pdu(self, pdu, destinations):
- """Informs the replication layer about a new PDU generated within the
- home server that should be transmitted to others.
-
- TODO: Figure out when we should actually resolve the deferred.
-
- Args:
- pdu (Pdu): The new Pdu.
-
- Returns:
- Deferred: Completes when we have successfully processed the PDU
- and replicated it to any interested remote home servers.
- """
- order = self._order
- self._order += 1
-
- logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
-
- # TODO, add errback, etc.
- self._transaction_queue.enqueue_pdu(pdu, destinations, order)
-
- logger.debug(
- "[%s] transaction_layer.enqueue_pdu... done",
- pdu.event_id
- )
-
- @log_function
- def send_edu(self, destination, edu_type, content):
- edu = Edu(
- origin=self.server_name,
- destination=destination,
- edu_type=edu_type,
- content=content,
- )
-
- # TODO, add errback, etc.
- self._transaction_queue.enqueue_edu(edu)
- return defer.succeed(None)
-
- @log_function
- def send_failure(self, failure, destination):
- self._transaction_queue.enqueue_failure(failure, destination)
- return defer.succeed(None)
-
- @log_function
- def make_query(self, destination, query_type, args,
- retry_on_dns_fail=True):
- """Sends a federation Query to a remote homeserver of the given type
- and arguments.
-
- Args:
- destination (str): Domain name of the remote homeserver
- query_type (str): Category of the query type; should match the
- handler name used in register_query_handler().
- args (dict): Mapping of strings to strings containing the details
- of the query request.
-
- Returns:
- a Deferred which will eventually yield a JSON object from the
- response
- """
- return self.transport_layer.make_query(
- destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
- )
-
- @defer.inlineCallbacks
- @log_function
- def backfill(self, dest, context, limit, extremities):
- """Requests some more historic PDUs for the given context from the
- given destination server.
-
- Args:
- dest (str): The remote home server to ask.
- context (str): The context to backfill.
- limit (int): The maximum number of PDUs to return.
- extremities (list): List of PDU id and origins of the first pdus
- we have seen from the context
-
- Returns:
- Deferred: Results in the received PDUs.
- """
- logger.debug("backfill extrem=%s", extremities)
-
- # If there are no extremeties then we've (probably) reached the start.
- if not extremities:
- return
-
- transaction_data = yield self.transport_layer.backfill(
- dest, context, extremities, limit)
-
- logger.debug("backfill transaction_data=%s", repr(transaction_data))
-
- transaction = Transaction(**transaction_data)
-
- pdus = [
- self.event_from_pdu_json(p, outlier=False)
- for p in transaction.pdus
- ]
- for pdu in pdus:
- yield self._handle_new_pdu(dest, pdu, backfilled=True)
-
- defer.returnValue(pdus)
-
- @defer.inlineCallbacks
- @log_function
- def get_pdu(self, destination, event_id, outlier=False):
- """Requests the PDU with given origin and ID from the remote home
- server.
-
- This will persist the PDU locally upon receipt.
-
- Args:
- destination (str): Which home server to query
- pdu_origin (str): The home server that originally sent the pdu.
- event_id (str)
- outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
- it's from an arbitary point in the context as opposed to part
- of the current block of PDUs. Defaults to `False`
-
- Returns:
- Deferred: Results in the requested PDU.
- """
-
- transaction_data = yield self.transport_layer.get_event(
- destination, event_id
- )
-
- transaction = Transaction(**transaction_data)
-
- pdu_list = [
- self.event_from_pdu_json(p, outlier=outlier)
- for p in transaction.pdus
- ]
-
- pdu = None
- if pdu_list:
- pdu = pdu_list[0]
- yield self._handle_new_pdu(destination, pdu)
-
- defer.returnValue(pdu)
-
- @defer.inlineCallbacks
- @log_function
- def get_state_for_room(self, destination, room_id, event_id):
- """Requests all of the `current` state PDUs for a given room from
- a remote home server.
-
- Args:
- destination (str): The remote homeserver to query for the state.
- room_id (str): The id of the room we're interested in.
- event_id (str): The id of the event we want the state at.
-
- Returns:
- Deferred: Results in a list of PDUs.
- """
-
- result = yield self.transport_layer.get_room_state(
- destination, room_id, event_id=event_id,
- )
-
- pdus = [
- self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
- ]
-
- auth_chain = [
- self.event_from_pdu_json(p, outlier=True)
- for p in result.get("auth_chain", [])
- ]
-
- defer.returnValue((pdus, auth_chain))
-
- @defer.inlineCallbacks
- @log_function
- def get_event_auth(self, destination, room_id, event_id):
- res = yield self.transport_layer.get_event_auth(
- destination, room_id, event_id,
- )
-
- auth_chain = [
- self.event_from_pdu_json(p, outlier=True)
- for p in res["auth_chain"]
- ]
-
- auth_chain.sort(key=lambda e: e.depth)
-
- defer.returnValue(auth_chain)
-
- @defer.inlineCallbacks
- @log_function
- def on_backfill_request(self, origin, room_id, versions, limit):
- pdus = yield self.handler.on_backfill_request(
- origin, room_id, versions, limit
- )
-
- defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
-
- @defer.inlineCallbacks
- @log_function
- def on_incoming_transaction(self, transaction_data):
- transaction = Transaction(**transaction_data)
-
- for p in transaction.pdus:
- if "unsigned" in p:
- unsigned = p["unsigned"]
- if "age" in unsigned:
- p["age"] = unsigned["age"]
- if "age" in p:
- p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
- del p["age"]
-
- pdu_list = [
- self.event_from_pdu_json(p) for p in transaction.pdus
- ]
-
- logger.debug("[%s] Got transaction", transaction.transaction_id)
-
- response = yield self.transaction_actions.have_responded(transaction)
-
- if response:
- logger.debug("[%s] We've already responed to this request",
- transaction.transaction_id)
- defer.returnValue(response)
- return
-
- logger.debug("[%s] Transaction is new", transaction.transaction_id)
-
- with PreserveLoggingContext():
- dl = []
- for pdu in pdu_list:
- dl.append(self._handle_new_pdu(transaction.origin, pdu))
-
- if hasattr(transaction, "edus"):
- for edu in [Edu(**x) for x in transaction.edus]:
- self.received_edu(
- transaction.origin,
- edu.edu_type,
- edu.content
- )
-
- results = yield defer.DeferredList(dl)
-
- ret = []
- for r in results:
- if r[0]:
- ret.append({})
- else:
- logger.exception(r[1])
- ret.append({"error": str(r[1])})
-
- logger.debug("Returning: %s", str(ret))
-
- yield self.transaction_actions.set_response(
- transaction,
- 200, response
- )
- defer.returnValue((200, response))
-
- def received_edu(self, origin, edu_type, content):
- if edu_type in self.edu_handlers:
- self.edu_handlers[edu_type](origin, content)
- else:
- logger.warn("Received EDU of type %s with no handler", edu_type)
-
- @defer.inlineCallbacks
- @log_function
- def on_context_state_request(self, origin, room_id, event_id):
- if event_id:
- pdus = yield self.handler.get_state_for_pdu(
- origin, room_id, event_id,
- )
- auth_chain = yield self.store.get_auth_chain(
- [pdu.event_id for pdu in pdus]
- )
- else:
- raise NotImplementedError("Specify an event")
-
- defer.returnValue((200, {
- "pdus": [pdu.get_pdu_json() for pdu in pdus],
- "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
- }))
-
- @defer.inlineCallbacks
- @log_function
- def on_pdu_request(self, origin, event_id):
- pdu = yield self._get_persisted_pdu(origin, event_id)
-
- if pdu:
- defer.returnValue(
- (200, self._transaction_from_pdus([pdu]).get_dict())
- )
- else:
- defer.returnValue((404, ""))
-
- @defer.inlineCallbacks
- @log_function
- def on_pull_request(self, origin, versions):
- raise NotImplementedError("Pull transactions not implemented")
-
- @defer.inlineCallbacks
- def on_query_request(self, query_type, args):
- if query_type in self.query_handlers:
- response = yield self.query_handlers[query_type](args)
- defer.returnValue((200, response))
- else:
- defer.returnValue(
- (404, "No handler for Query type '%s'" % (query_type,))
- )
-
- @defer.inlineCallbacks
- def on_make_join_request(self, room_id, user_id):
- pdu = yield self.handler.on_make_join_request(room_id, user_id)
- time_now = self._clock.time_msec()
- defer.returnValue({"event": pdu.get_pdu_json(time_now)})
-
- @defer.inlineCallbacks
- def on_invite_request(self, origin, content):
- pdu = self.event_from_pdu_json(content)
- ret_pdu = yield self.handler.on_invite_request(origin, pdu)
- time_now = self._clock.time_msec()
- defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
-
- @defer.inlineCallbacks
- def on_send_join_request(self, origin, content):
- logger.debug("on_send_join_request: content: %s", content)
- pdu = self.event_from_pdu_json(content)
- logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
- res_pdus = yield self.handler.on_send_join_request(origin, pdu)
- time_now = self._clock.time_msec()
- defer.returnValue((200, {
- "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
- "auth_chain": [
- p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
- ],
- }))
-
- @defer.inlineCallbacks
- def on_event_auth(self, origin, room_id, event_id):
- time_now = self._clock.time_msec()
- auth_pdus = yield self.handler.on_event_auth(event_id)
- defer.returnValue((200, {
- "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
- }))
-
- @defer.inlineCallbacks
- def make_join(self, destination, room_id, user_id):
- ret = yield self.transport_layer.make_join(
- destination, room_id, user_id
- )
-
- pdu_dict = ret["event"]
-
- logger.debug("Got response to make_join: %s", pdu_dict)
-
- defer.returnValue(self.event_from_pdu_json(pdu_dict))
-
- @defer.inlineCallbacks
- def send_join(self, destination, pdu):
- time_now = self._clock.time_msec()
- _, content = yield self.transport_layer.send_join(
- destination=destination,
- room_id=pdu.room_id,
- event_id=pdu.event_id,
- content=pdu.get_pdu_json(time_now),
- )
-
- logger.debug("Got content: %s", content)
-
- state = [
- self.event_from_pdu_json(p, outlier=True)
- for p in content.get("state", [])
- ]
-
- auth_chain = [
- self.event_from_pdu_json(p, outlier=True)
- for p in content.get("auth_chain", [])
- ]
-
- auth_chain.sort(key=lambda e: e.depth)
-
- defer.returnValue({
- "state": state,
- "auth_chain": auth_chain,
- })
-
- @defer.inlineCallbacks
- def send_invite(self, destination, room_id, event_id, pdu):
- time_now = self._clock.time_msec()
- code, content = yield self.transport_layer.send_invite(
- destination=destination,
- room_id=room_id,
- event_id=event_id,
- content=pdu.get_pdu_json(time_now),
- )
-
- pdu_dict = content["event"]
-
- logger.debug("Got response to send_invite: %s", pdu_dict)
-
- defer.returnValue(self.event_from_pdu_json(pdu_dict))
-
- @log_function
- def _get_persisted_pdu(self, origin, event_id, do_auth=True):
- """ Get a PDU from the database with given origin and id.
-
- Returns:
- Deferred: Results in a `Pdu`.
- """
- return self.handler.get_persisted_pdu(
- origin, event_id, do_auth=do_auth
- )
-
- def _transaction_from_pdus(self, pdu_list):
- """Returns a new Transaction containing the given PDUs suitable for
- transmission.
- """
- time_now = self._clock.time_msec()
- pdus = [p.get_pdu_json(time_now) for p in pdu_list]
- return Transaction(
- origin=self.server_name,
- pdus=pdus,
- origin_server_ts=int(time_now),
- destination=None,
- )
-
- @defer.inlineCallbacks
- @log_function
- def _handle_new_pdu(self, origin, pdu, backfilled=False):
- # We reprocess pdus when we have seen them only as outliers
- existing = yield self._get_persisted_pdu(
- origin, pdu.event_id, do_auth=False
- )
-
- already_seen = (
- existing and (
- not existing.internal_metadata.is_outlier()
- or pdu.internal_metadata.is_outlier()
- )
- )
- if already_seen:
- logger.debug("Already seen pdu %s", pdu.event_id)
- defer.returnValue({})
- return
-
- state = None
-
- auth_chain = []
-
- # We need to make sure we have all the auth events.
- # for e_id, _ in pdu.auth_events:
- # exists = yield self._get_persisted_pdu(
- # origin,
- # e_id,
- # do_auth=False
- # )
- #
- # if not exists:
- # try:
- # logger.debug(
- # "_handle_new_pdu fetch missing auth event %s from %s",
- # e_id,
- # origin,
- # )
- #
- # yield self.get_pdu(
- # origin,
- # event_id=e_id,
- # outlier=True,
- # )
- #
- # logger.debug("Processed pdu %s", e_id)
- # except:
- # logger.warn(
- # "Failed to get auth event %s from %s",
- # e_id,
- # origin
- # )
-
- # Get missing pdus if necessary.
- if not pdu.internal_metadata.is_outlier():
- # We only backfill backwards to the min depth.
- min_depth = yield self.handler.get_min_depth_for_context(
- pdu.room_id
- )
-
- logger.debug(
- "_handle_new_pdu min_depth for %s: %d",
- pdu.room_id, min_depth
- )
-
- if min_depth and pdu.depth > min_depth:
- for event_id, hashes in pdu.prev_events:
- exists = yield self._get_persisted_pdu(
- origin,
- event_id,
- do_auth=False
- )
-
- if not exists:
- logger.debug(
- "_handle_new_pdu requesting pdu %s",
- event_id
- )
-
- try:
- yield self.get_pdu(
- origin,
- event_id=event_id,
- )
- logger.debug("Processed pdu %s", event_id)
- except:
- # TODO(erikj): Do some more intelligent retries.
- logger.exception("Failed to get PDU")
- else:
- # We need to get the state at this event, since we have reached
- # a backward extremity edge.
- logger.debug(
- "_handle_new_pdu getting state for %s",
- pdu.room_id
- )
- state, auth_chain = yield self.get_state_for_room(
- origin, pdu.room_id, pdu.event_id,
- )
-
- if not backfilled:
- ret = yield self.handler.on_receive_pdu(
- origin,
- pdu,
- backfilled=backfilled,
- state=state,
- auth_chain=auth_chain,
- )
- else:
- ret = None
-
- # yield self.pdu_actions.mark_as_processed(pdu)
+ self.transaction_actions = TransactionActions(self.store)
+ self._transaction_queue = TransactionQueue(hs, transport_layer)
- defer.returnValue(ret)
+ self._order = 0
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
-
- def event_from_pdu_json(self, pdu_json, outlier=False):
- event = FrozenEvent(
- pdu_json
- )
-
- event.internal_metadata.outlier = outlier
-
- return event
-
-
-class _TransactionQueue(object):
- """This class makes sure we only have one transaction in flight at
- a time for a given destination.
-
- It batches pending PDUs into single transactions.
- """
-
- def __init__(self, hs, transaction_actions, transport_layer):
- self.server_name = hs.hostname
- self.transaction_actions = transaction_actions
- self.transport_layer = transport_layer
-
- self._clock = hs.get_clock()
- self.store = hs.get_datastore()
-
- # Is a mapping from destinations -> deferreds. Used to keep track
- # of which destinations have transactions in flight and when they are
- # done
- self.pending_transactions = {}
-
- # Is a mapping from destination -> list of
- # tuple(pending pdus, deferred, order)
- self.pending_pdus_by_dest = {}
- # destination -> list of tuple(edu, deferred)
- self.pending_edus_by_dest = {}
-
- # destination -> list of tuple(failure, deferred)
- self.pending_failures_by_dest = {}
-
- # HACK to get unique tx id
- self._next_txn_id = int(self._clock.time_msec())
-
- @defer.inlineCallbacks
- @log_function
- def enqueue_pdu(self, pdu, destinations, order):
- # We loop through all destinations to see whether we already have
- # a transaction in progress. If we do, stick it in the pending_pdus
- # table and we'll get back to it later.
-
- destinations = set(destinations)
- destinations.discard(self.server_name)
- destinations.discard("localhost")
-
- logger.debug("Sending to: %s", str(destinations))
-
- if not destinations:
- return
-
- deferreds = []
-
- for destination in destinations:
- deferred = defer.Deferred()
- self.pending_pdus_by_dest.setdefault(destination, []).append(
- (pdu, deferred, order)
- )
-
- def eb(failure):
- if not deferred.called:
- deferred.errback(failure)
- else:
- logger.warn("Failed to send pdu", failure)
-
- with PreserveLoggingContext():
- self._attempt_new_transaction(destination).addErrback(eb)
-
- deferreds.append(deferred)
-
- yield defer.DeferredList(deferreds)
-
- # NO inlineCallbacks
- def enqueue_edu(self, edu):
- destination = edu.destination
-
- if destination == self.server_name:
- return
-
- deferred = defer.Deferred()
- self.pending_edus_by_dest.setdefault(destination, []).append(
- (edu, deferred)
- )
-
- def eb(failure):
- if not deferred.called:
- deferred.errback(failure)
- else:
- logger.warn("Failed to send edu", failure)
-
- with PreserveLoggingContext():
- self._attempt_new_transaction(destination).addErrback(eb)
-
- return deferred
-
- @defer.inlineCallbacks
- def enqueue_failure(self, failure, destination):
- deferred = defer.Deferred()
-
- self.pending_failures_by_dest.setdefault(
- destination, []
- ).append(
- (failure, deferred)
- )
-
- yield deferred
-
- @defer.inlineCallbacks
- @log_function
- def _attempt_new_transaction(self, destination):
-
- (retry_last_ts, retry_interval) = (0, 0)
- retry_timings = yield self.store.get_destination_retry_timings(
- destination
- )
- if retry_timings:
- (retry_last_ts, retry_interval) = (
- retry_timings.retry_last_ts, retry_timings.retry_interval
- )
- if retry_last_ts + retry_interval > int(self._clock.time_msec()):
- logger.info(
- "TX [%s] not ready for retry yet - "
- "dropping transaction for now",
- destination,
- )
- return
- else:
- logger.info("TX [%s] is ready for retry", destination)
-
- logger.info("TX [%s] _attempt_new_transaction", destination)
-
- if destination in self.pending_transactions:
- # XXX: pending_transactions can get stuck on by a never-ending
- # request at which point pending_pdus_by_dest just keeps growing.
- # we need application-layer timeouts of some flavour of these
- # requests
- return
-
- # list of (pending_pdu, deferred, order)
- pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
- pending_edus = self.pending_edus_by_dest.pop(destination, [])
- pending_failures = self.pending_failures_by_dest.pop(destination, [])
-
- if pending_pdus:
- logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
- destination, len(pending_pdus))
-
- if not pending_pdus and not pending_edus and not pending_failures:
- return
-
- logger.debug(
- "TX [%s] Attempting new transaction"
- " (pdus: %d, edus: %d, failures: %d)",
- destination,
- len(pending_pdus),
- len(pending_edus),
- len(pending_failures)
- )
-
- # Sort based on the order field
- pending_pdus.sort(key=lambda t: t[2])
-
- pdus = [x[0] for x in pending_pdus]
- edus = [x[0] for x in pending_edus]
- failures = [x[0].get_dict() for x in pending_failures]
- deferreds = [
- x[1]
- for x in pending_pdus + pending_edus + pending_failures
- ]
-
- try:
- self.pending_transactions[destination] = 1
-
- logger.debug("TX [%s] Persisting transaction...", destination)
-
- transaction = Transaction.create_new(
- origin_server_ts=int(self._clock.time_msec()),
- transaction_id=str(self._next_txn_id),
- origin=self.server_name,
- destination=destination,
- pdus=pdus,
- edus=edus,
- pdu_failures=failures,
- )
-
- self._next_txn_id += 1
-
- yield self.transaction_actions.prepare_to_send(transaction)
-
- logger.debug("TX [%s] Persisted transaction", destination)
- logger.info(
- "TX [%s] Sending transaction [%s]",
- destination,
- transaction.transaction_id,
- )
-
- # Actually send the transaction
-
- # FIXME (erikj): This is a bit of a hack to make the Pdu age
- # keys work
- def json_data_cb():
- data = transaction.get_dict()
- now = int(self._clock.time_msec())
- if "pdus" in data:
- for p in data["pdus"]:
- if "age_ts" in p:
- unsigned = p.setdefault("unsigned", {})
- unsigned["age"] = now - int(p["age_ts"])
- del p["age_ts"]
- return data
-
- code, response = yield self.transport_layer.send_transaction(
- transaction, json_data_cb
- )
-
- logger.info("TX [%s] got %d response", destination, code)
-
- logger.debug("TX [%s] Sent transaction", destination)
- logger.debug("TX [%s] Marking as delivered...", destination)
-
- yield self.transaction_actions.delivered(
- transaction, code, response
- )
-
- logger.debug("TX [%s] Marked as delivered", destination)
- logger.debug("TX [%s] Yielding to callbacks...", destination)
-
- for deferred in deferreds:
- if code == 200:
- if retry_last_ts:
- # this host is alive! reset retry schedule
- yield self.store.set_destination_retry_timings(
- destination, 0, 0
- )
- deferred.callback(None)
- else:
- self.set_retrying(destination, retry_interval)
- deferred.errback(RuntimeError("Got status %d" % code))
-
- # Ensures we don't continue until all callbacks on that
- # deferred have fired
- try:
- yield deferred
- except:
- pass
-
- logger.debug("TX [%s] Yielded to callbacks", destination)
-
- except Exception as e:
- # We capture this here as there as nothing actually listens
- # for this finishing functions deferred.
- logger.warn(
- "TX [%s] Problem in _attempt_transaction: %s",
- destination,
- e,
- )
-
- self.set_retrying(destination, retry_interval)
-
- for deferred in deferreds:
- if not deferred.called:
- deferred.errback(e)
-
- finally:
- # We want to be *very* sure we delete this after we stop processing
- self.pending_transactions.pop(destination, None)
-
- # Check to see if there is anything else to send.
- self._attempt_new_transaction(destination)
-
- @defer.inlineCallbacks
- def set_retrying(self, destination, retry_interval):
- # track that this destination is having problems and we should
- # give it a chance to recover before trying it again
-
- if retry_interval:
- retry_interval *= 2
- # plateau at hourly retries for now
- if retry_interval >= 60 * 60 * 1000:
- retry_interval = 60 * 60 * 1000
- else:
- retry_interval = 2000 # try again at first after 2 seconds
-
- yield self.store.set_destination_retry_timings(
- destination,
- int(self._clock.time_msec()),
- retry_interval
- )
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
new file mode 100644
index 0000000000..9d4f2c09a2
--- /dev/null
+++ b/synapse/federation/transaction_queue.py
@@ -0,0 +1,317 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014, 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer
+
+from .persistence import TransactionActions
+from .units import Transaction
+
+from synapse.util.logutils import log_function
+from synapse.util.logcontext import PreserveLoggingContext
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class TransactionQueue(object):
+ """This class makes sure we only have one transaction in flight at
+ a time for a given destination.
+
+ It batches pending PDUs into single transactions.
+ """
+
+ def __init__(self, hs, transport_layer):
+ self.server_name = hs.hostname
+
+ self.store = hs.get_datastore()
+ self.transaction_actions = TransactionActions(self.store)
+
+ self.transport_layer = transport_layer
+
+ self._clock = hs.get_clock()
+
+ # Is a mapping from destinations -> deferreds. Used to keep track
+ # of which destinations have transactions in flight and when they are
+ # done
+ self.pending_transactions = {}
+
+ # Is a mapping from destination -> list of
+ # tuple(pending pdus, deferred, order)
+ self.pending_pdus_by_dest = {}
+ # destination -> list of tuple(edu, deferred)
+ self.pending_edus_by_dest = {}
+
+ # destination -> list of tuple(failure, deferred)
+ self.pending_failures_by_dest = {}
+
+ # HACK to get unique tx id
+ self._next_txn_id = int(self._clock.time_msec())
+
+ @defer.inlineCallbacks
+ @log_function
+ def enqueue_pdu(self, pdu, destinations, order):
+ # We loop through all destinations to see whether we already have
+ # a transaction in progress. If we do, stick it in the pending_pdus
+ # table and we'll get back to it later.
+
+ destinations = set(destinations)
+ destinations.discard(self.server_name)
+ destinations.discard("localhost")
+
+ logger.debug("Sending to: %s", str(destinations))
+
+ if not destinations:
+ return
+
+ deferreds = []
+
+ for destination in destinations:
+ deferred = defer.Deferred()
+ self.pending_pdus_by_dest.setdefault(destination, []).append(
+ (pdu, deferred, order)
+ )
+
+ def eb(failure):
+ if not deferred.called:
+ deferred.errback(failure)
+ else:
+ logger.warn("Failed to send pdu", failure)
+
+ with PreserveLoggingContext():
+ self._attempt_new_transaction(destination).addErrback(eb)
+
+ deferreds.append(deferred)
+
+ yield defer.DeferredList(deferreds)
+
+ # NO inlineCallbacks
+ def enqueue_edu(self, edu):
+ destination = edu.destination
+
+ if destination == self.server_name:
+ return
+
+ deferred = defer.Deferred()
+ self.pending_edus_by_dest.setdefault(destination, []).append(
+ (edu, deferred)
+ )
+
+ def eb(failure):
+ if not deferred.called:
+ deferred.errback(failure)
+ else:
+ logger.warn("Failed to send edu", failure)
+
+ with PreserveLoggingContext():
+ self._attempt_new_transaction(destination).addErrback(eb)
+
+ return deferred
+
+ @defer.inlineCallbacks
+ def enqueue_failure(self, failure, destination):
+ deferred = defer.Deferred()
+
+ self.pending_failures_by_dest.setdefault(
+ destination, []
+ ).append(
+ (failure, deferred)
+ )
+
+ yield deferred
+
+ @defer.inlineCallbacks
+ @log_function
+ def _attempt_new_transaction(self, destination):
+
+ (retry_last_ts, retry_interval) = (0, 0)
+ retry_timings = yield self.store.get_destination_retry_timings(
+ destination
+ )
+ if retry_timings:
+ (retry_last_ts, retry_interval) = (
+ retry_timings.retry_last_ts, retry_timings.retry_interval
+ )
+ if retry_last_ts + retry_interval > int(self._clock.time_msec()):
+ logger.info(
+ "TX [%s] not ready for retry yet - "
+ "dropping transaction for now",
+ destination,
+ )
+ return
+ else:
+ logger.info("TX [%s] is ready for retry", destination)
+
+ logger.info("TX [%s] _attempt_new_transaction", destination)
+
+ if destination in self.pending_transactions:
+ # XXX: pending_transactions can get stuck on by a never-ending
+ # request at which point pending_pdus_by_dest just keeps growing.
+ # we need application-layer timeouts of some flavour of these
+ # requests
+ return
+
+ # list of (pending_pdu, deferred, order)
+ pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
+ pending_edus = self.pending_edus_by_dest.pop(destination, [])
+ pending_failures = self.pending_failures_by_dest.pop(destination, [])
+
+ if pending_pdus:
+ logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+ destination, len(pending_pdus))
+
+ if not pending_pdus and not pending_edus and not pending_failures:
+ return
+
+ logger.debug(
+ "TX [%s] Attempting new transaction"
+ " (pdus: %d, edus: %d, failures: %d)",
+ destination,
+ len(pending_pdus),
+ len(pending_edus),
+ len(pending_failures)
+ )
+
+ # Sort based on the order field
+ pending_pdus.sort(key=lambda t: t[2])
+
+ pdus = [x[0] for x in pending_pdus]
+ edus = [x[0] for x in pending_edus]
+ failures = [x[0].get_dict() for x in pending_failures]
+ deferreds = [
+ x[1]
+ for x in pending_pdus + pending_edus + pending_failures
+ ]
+
+ try:
+ self.pending_transactions[destination] = 1
+
+ logger.debug("TX [%s] Persisting transaction...", destination)
+
+ transaction = Transaction.create_new(
+ origin_server_ts=int(self._clock.time_msec()),
+ transaction_id=str(self._next_txn_id),
+ origin=self.server_name,
+ destination=destination,
+ pdus=pdus,
+ edus=edus,
+ pdu_failures=failures,
+ )
+
+ self._next_txn_id += 1
+
+ yield self.transaction_actions.prepare_to_send(transaction)
+
+ logger.debug("TX [%s] Persisted transaction", destination)
+ logger.info(
+ "TX [%s] Sending transaction [%s]",
+ destination,
+ transaction.transaction_id,
+ )
+
+ # Actually send the transaction
+
+ # FIXME (erikj): This is a bit of a hack to make the Pdu age
+ # keys work
+ def json_data_cb():
+ data = transaction.get_dict()
+ now = int(self._clock.time_msec())
+ if "pdus" in data:
+ for p in data["pdus"]:
+ if "age_ts" in p:
+ unsigned = p.setdefault("unsigned", {})
+ unsigned["age"] = now - int(p["age_ts"])
+ del p["age_ts"]
+ return data
+
+ code, response = yield self.transport_layer.send_transaction(
+ transaction, json_data_cb
+ )
+
+ logger.info("TX [%s] got %d response", destination, code)
+
+ logger.debug("TX [%s] Sent transaction", destination)
+ logger.debug("TX [%s] Marking as delivered...", destination)
+
+ yield self.transaction_actions.delivered(
+ transaction, code, response
+ )
+
+ logger.debug("TX [%s] Marked as delivered", destination)
+ logger.debug("TX [%s] Yielding to callbacks...", destination)
+
+ for deferred in deferreds:
+ if code == 200:
+ if retry_last_ts:
+ # this host is alive! reset retry schedule
+ yield self.store.set_destination_retry_timings(
+ destination, 0, 0
+ )
+ deferred.callback(None)
+ else:
+ self.set_retrying(destination, retry_interval)
+ deferred.errback(RuntimeError("Got status %d" % code))
+
+ # Ensures we don't continue until all callbacks on that
+ # deferred have fired
+ try:
+ yield deferred
+ except:
+ pass
+
+ logger.debug("TX [%s] Yielded to callbacks", destination)
+
+ except Exception as e:
+ # We capture this here as there as nothing actually listens
+ # for this finishing functions deferred.
+ logger.warn(
+ "TX [%s] Problem in _attempt_transaction: %s",
+ destination,
+ e,
+ )
+
+ self.set_retrying(destination, retry_interval)
+
+ for deferred in deferreds:
+ if not deferred.called:
+ deferred.errback(e)
+
+ finally:
+ # We want to be *very* sure we delete this after we stop processing
+ self.pending_transactions.pop(destination, None)
+
+ # Check to see if there is anything else to send.
+ self._attempt_new_transaction(destination)
+
+ @defer.inlineCallbacks
+ def set_retrying(self, destination, retry_interval):
+ # track that this destination is having problems and we should
+ # give it a chance to recover before trying it again
+
+ if retry_interval:
+ retry_interval *= 2
+ # plateau at hourly retries for now
+ if retry_interval >= 60 * 60 * 1000:
+ retry_interval = 60 * 60 * 1000
+ else:
+ retry_interval = 2000 # try again at first after 2 seconds
+
+ yield self.store.set_destination_retry_timings(
+ destination,
+ int(self._clock.time_msec()),
+ retry_interval
+ )
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index e634a3a213..4cb1dea2de 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -213,3 +213,19 @@ class TransportLayerClient(object):
)
defer.returnValue(response)
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_query_auth(self, destination, room_id, event_id, content):
+ path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id)
+
+ code, content = yield self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ )
+
+ if not 200 <= code < 300:
+ raise RuntimeError("Got %d from send_invite", code)
+
+ defer.returnValue(json.loads(content))
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index a380a6910b..9c9f8d525b 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -42,7 +42,7 @@ class TransportLayerServer(object):
content = None
origin = None
- if request.method == "PUT":
+ if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
try:
content_bytes = request.content.read()
@@ -234,6 +234,16 @@ class TransportLayerServer(object):
)
)
)
+ self.server.register_path(
+ "POST",
+ re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"),
+ self._with_authentication(
+ lambda origin, content, query, context, event_id:
+ self._on_query_auth_request(
+ origin, content, event_id,
+ )
+ )
+ )
@defer.inlineCallbacks
@log_function
@@ -325,3 +335,12 @@ class TransportLayerServer(object):
)
defer.returnValue((200, content))
+
+ @defer.inlineCallbacks
+ @log_function
+ def _on_query_auth_request(self, origin, content, event_id):
+ new_content = yield self.request_handler.on_query_auth_request(
+ origin, content, event_id
+ )
+
+ defer.returnValue((200, new_content))
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bcdcc90a18..cc22f21cd1 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -17,19 +17,16 @@
from ._base import BaseHandler
-from synapse.events.utils import prune_event
from synapse.api.errors import (
- AuthError, FederationError, SynapseError, StoreError,
+ AuthError, FederationError, StoreError,
)
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import (
- compute_event_signature, check_event_content_hash,
- add_hashes_and_signatures,
+ compute_event_signature, add_hashes_and_signatures,
)
from synapse.types import UserID
-from syutil.jsonutil import encode_canonical_json
from twisted.internet import defer
@@ -113,33 +110,6 @@ class FederationHandler(BaseHandler):
logger.debug("Processing event: %s", event.event_id)
- redacted_event = prune_event(event)
-
- redacted_pdu_json = redacted_event.get_pdu_json()
- try:
- yield self.keyring.verify_json_for_server(
- event.origin, redacted_pdu_json
- )
- except SynapseError as e:
- logger.warn(
- "Signature check failed for %s redacted to %s",
- encode_canonical_json(pdu.get_pdu_json()),
- encode_canonical_json(redacted_pdu_json),
- )
- raise FederationError(
- "ERROR",
- e.code,
- e.msg,
- affected=event.event_id,
- )
-
- if not check_event_content_hash(event):
- logger.warn(
- "Event content has been tampered, redacting %s, %s",
- event.event_id, encode_canonical_json(event.get_dict())
- )
- event = redacted_event
-
logger.debug("Event: %s", event)
# FIXME (erikj): Awful hack to make the case where we are not currently
@@ -149,14 +119,14 @@ class FederationHandler(BaseHandler):
event.room_id,
self.server_name
)
- if not is_in_room and not event.internal_metadata.outlier:
+ if not is_in_room and not event.internal_metadata.is_outlier():
logger.debug("Got event for room we're not in.")
replication = self.replication_layer
if not state:
state, auth_chain = yield replication.get_state_for_room(
- origin, context=event.room_id, event_id=event.event_id,
+ origin, room_id=event.room_id, event_id=event.event_id,
)
if not auth_chain:
@@ -169,7 +139,7 @@ class FederationHandler(BaseHandler):
for e in auth_chain:
e.internal_metadata.outlier = True
try:
- yield self._handle_new_event(e, fetch_auth_from=origin)
+ yield self._handle_new_event(origin, e)
except:
logger.exception(
"Failed to handle auth event %s",
@@ -180,10 +150,9 @@ class FederationHandler(BaseHandler):
if state:
for e in state:
- logging.info("A :) %r", e)
e.internal_metadata.outlier = True
try:
- yield self._handle_new_event(e)
+ yield self._handle_new_event(origin, e)
except:
logger.exception(
"Failed to handle state event %s",
@@ -192,6 +161,7 @@ class FederationHandler(BaseHandler):
try:
yield self._handle_new_event(
+ origin,
event,
state=state,
backfilled=backfilled,
@@ -394,7 +364,14 @@ class FederationHandler(BaseHandler):
for e in auth_chain:
e.internal_metadata.outlier = True
try:
- yield self._handle_new_event(e)
+ auth_ids = [e_id for e_id, _ in e.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in auth_chain
+ if e.event_id in auth_ids
+ }
+ yield self._handle_new_event(
+ target_host, e, auth_events=auth
+ )
except:
logger.exception(
"Failed to handle auth event %s",
@@ -405,8 +382,13 @@ class FederationHandler(BaseHandler):
# FIXME: Auth these.
e.internal_metadata.outlier = True
try:
+ auth_ids = [e_id for e_id, _ in e.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in auth_chain
+ if e.event_id in auth_ids
+ }
yield self._handle_new_event(
- e, fetch_auth_from=target_host
+ target_host, e, auth_events=auth
)
except:
logger.exception(
@@ -415,6 +397,7 @@ class FederationHandler(BaseHandler):
)
yield self._handle_new_event(
+ target_host,
new_event,
state=state,
current_state=state,
@@ -481,7 +464,7 @@ class FederationHandler(BaseHandler):
event.internal_metadata.outlier = False
- context = yield self._handle_new_event(event)
+ context = yield self._handle_new_event(origin, event)
logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
@@ -682,11 +665,12 @@ class FederationHandler(BaseHandler):
waiters.pop().callback(None)
@defer.inlineCallbacks
- def _handle_new_event(self, event, state=None, backfilled=False,
- current_state=None, fetch_auth_from=None):
+ @log_function
+ def _handle_new_event(self, origin, event, state=None, backfilled=False,
+ current_state=None, auth_events=None):
logger.debug(
- "_handle_new_event: Before annotate: %s, sigs: %s",
+ "_handle_new_event: %s, sigs: %s",
event.event_id, event.signatures,
)
@@ -694,65 +678,44 @@ class FederationHandler(BaseHandler):
event, old_state=state
)
+ if not auth_events:
+ auth_events = context.auth_events
+
logger.debug(
- "_handle_new_event: Before auth fetch: %s, sigs: %s",
- event.event_id, event.signatures,
+ "_handle_new_event: %s, auth_events: %s",
+ event.event_id, auth_events,
)
is_new_state = not event.internal_metadata.is_outlier()
- known_ids = set(
- [s.event_id for s in context.auth_events.values()]
- )
-
- for e_id, _ in event.auth_events:
- if e_id not in known_ids:
- e = yield self.store.get_event(e_id, allow_none=True)
-
- if not e and fetch_auth_from is not None:
- # Grab the auth_chain over federation if we are missing
- # auth events.
- auth_chain = yield self.replication_layer.get_event_auth(
- fetch_auth_from, event.event_id, event.room_id
- )
- for auth_event in auth_chain:
- yield self._handle_new_event(auth_event)
- e = yield self.store.get_event(e_id, allow_none=True)
-
- if not e:
- # TODO: Do some conflict res to make sure that we're
- # not the ones who are wrong.
- logger.info(
- "Rejecting %s as %s not in db or %s",
- event.event_id, e_id, known_ids,
- )
- # FIXME: How does raising AuthError work with federation?
- raise AuthError(403, "Cannot find auth event")
-
- context.auth_events[(e.type, e.state_key)] = e
-
- logger.debug(
- "_handle_new_event: Before hack: %s, sigs: %s",
- event.event_id, event.signatures,
- )
-
+ # This is a hack to fix some old rooms where the initial join event
+ # didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_events:
if len(event.prev_events) == 1:
c = yield self.store.get_event(event.prev_events[0][0])
if c.type == EventTypes.Create:
- context.auth_events[(c.type, c.state_key)] = c
+ auth_events[(c.type, c.state_key)] = c
- logger.debug(
- "_handle_new_event: Before auth check: %s, sigs: %s",
- event.event_id, event.signatures,
- )
+ try:
+ yield self.do_auth(
+ origin, event, context, auth_events=auth_events
+ )
+ except AuthError as e:
+ logger.warn(
+ "Rejecting %s because %s",
+ event.event_id, e.msg
+ )
- self.auth.check(event, auth_events=context.auth_events)
+ context.rejected = RejectedReason.AUTH_ERROR
- logger.debug(
- "_handle_new_event: Before persist_event: %s, sigs: %s",
- event.event_id, event.signatures,
- )
+ yield self.store.persist_event(
+ event,
+ context=context,
+ backfilled=backfilled,
+ is_new_state=False,
+ current_state=current_state,
+ )
+ raise
yield self.store.persist_event(
event,
@@ -762,9 +725,250 @@ class FederationHandler(BaseHandler):
current_state=current_state,
)
- logger.debug(
- "_handle_new_event: After persist_event: %s, sigs: %s",
- event.event_id, event.signatures,
+ defer.returnValue(context)
+
+ @defer.inlineCallbacks
+ def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
+ missing):
+ # Just go through and process each event in `remote_auth_chain`. We
+ # don't want to fall into the trap of `missing` being wrong.
+ for e in remote_auth_chain:
+ try:
+ yield self._handle_new_event(origin, e)
+ except AuthError:
+ pass
+
+ # Now get the current auth_chain for the event.
+ local_auth_chain = yield self.store.get_auth_chain([event_id])
+
+ # TODO: Check if we would now reject event_id. If so we need to tell
+ # everyone.
+
+ ret = yield self.construct_auth_difference(
+ local_auth_chain, remote_auth_chain
)
- defer.returnValue(context)
+ logger.debug("on_query_auth reutrning: %s", ret)
+
+ defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ @log_function
+ def do_auth(self, origin, event, context, auth_events):
+ # Check if we have all the auth events.
+ res = yield self.store.have_events(
+ [e_id for e_id, _ in event.auth_events]
+ )
+
+ event_auth_events = set(e_id for e_id, _ in event.auth_events)
+ seen_events = set(res.keys())
+
+ missing_auth = event_auth_events - seen_events
+
+ if missing_auth:
+ logger.debug("Missing auth: %s", missing_auth)
+ # If we don't have all the auth events, we need to get them.
+ remote_auth_chain = yield self.replication_layer.get_event_auth(
+ origin, event.room_id, event.event_id
+ )
+
+ for e in remote_auth_chain:
+ try:
+ auth_ids = [e_id for e_id, _ in e.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in remote_auth_chain
+ if e.event_id in auth_ids
+ }
+ e.internal_metadata.outlier = True
+ yield self._handle_new_event(
+ origin, e, auth_events=auth
+ )
+ auth_events[(e.type, e.state_key)] = e
+ except AuthError:
+ pass
+
+ # FIXME: Assumes we have and stored all the state for all the
+ # prev_events
+ current_state = set(e.event_id for e in auth_events.values())
+ different_auth = event_auth_events - current_state
+
+ if different_auth and not event.internal_metadata.is_outlier():
+ # Do auth conflict res.
+ logger.debug("Different auth: %s", different_auth)
+
+ # 1. Get what we think is the auth chain.
+ auth_ids = self.auth.compute_auth_events(event, context)
+ local_auth_chain = yield self.store.get_auth_chain(auth_ids)
+
+ # 2. Get remote difference.
+ result = yield self.replication_layer.query_auth(
+ origin,
+ event.room_id,
+ event.event_id,
+ local_auth_chain,
+ )
+
+ # 3. Process any remote auth chain events we haven't seen.
+ for e in result.get("missing", []):
+ try:
+ auth_ids = [e_id for e_id, _ in e.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in result["auth_chain"]
+ if e.event_id in auth_ids
+ }
+ e.internal_metadata.outlier = True
+ yield self._handle_new_event(
+ origin, e, auth_events=auth
+ )
+ auth_events[(e.type, e.state_key)] = e
+ except AuthError:
+ pass
+
+ # 4. Look at rejects and their proofs.
+ # TODO.
+
+ context.current_state.update(auth_events)
+ context.state_group = None
+
+ try:
+ self.auth.check(event, auth_events=auth_events)
+ except AuthError:
+ raise
+
+ @defer.inlineCallbacks
+ def construct_auth_difference(self, local_auth, remote_auth):
+ """ Given a local and remote auth chain, find the differences. This
+ assumes that we have already processed all events in remote_auth
+
+ Params:
+ local_auth (list)
+ remote_auth (list)
+
+ Returns:
+ dict
+ """
+
+ logger.debug("construct_auth_difference Start!")
+
+ # TODO: Make sure we are OK with local_auth or remote_auth having more
+ # auth events in them than strictly necessary.
+
+ def sort_fun(ev):
+ return ev.depth, ev.event_id
+
+ logger.debug("construct_auth_difference after sort_fun!")
+
+ # We find the differences by starting at the "bottom" of each list
+ # and iterating up on both lists. The lists are ordered by depth and
+ # then event_id, we iterate up both lists until we find the event ids
+ # don't match. Then we look at depth/event_id to see which side is
+ # missing that event, and iterate only up that list. Repeat.
+
+ remote_list = list(remote_auth)
+ remote_list.sort(key=sort_fun)
+
+ local_list = list(local_auth)
+ local_list.sort(key=sort_fun)
+
+ local_iter = iter(local_list)
+ remote_iter = iter(remote_list)
+
+ logger.debug("construct_auth_difference before get_next!")
+
+ def get_next(it, opt=None):
+ try:
+ return it.next()
+ except:
+ return opt
+
+ current_local = get_next(local_iter)
+ current_remote = get_next(remote_iter)
+
+ logger.debug("construct_auth_difference before while")
+
+ missing_remotes = []
+ missing_locals = []
+ while current_local or current_remote:
+ if current_remote is None:
+ missing_locals.append(current_local)
+ current_local = get_next(local_iter)
+ continue
+
+ if current_local is None:
+ missing_remotes.append(current_remote)
+ current_remote = get_next(remote_iter)
+ continue
+
+ if current_local.event_id == current_remote.event_id:
+ current_local = get_next(local_iter)
+ current_remote = get_next(remote_iter)
+ continue
+
+ if current_local.depth < current_remote.depth:
+ missing_locals.append(current_local)
+ current_local = get_next(local_iter)
+ continue
+
+ if current_local.depth > current_remote.depth:
+ missing_remotes.append(current_remote)
+ current_remote = get_next(remote_iter)
+ continue
+
+ # They have the same depth, so we fall back to the event_id order
+ if current_local.event_id < current_remote.event_id:
+ missing_locals.append(current_local)
+ current_local = get_next(local_iter)
+
+ if current_local.event_id > current_remote.event_id:
+ missing_remotes.append(current_remote)
+ current_remote = get_next(remote_iter)
+ continue
+
+ logger.debug("construct_auth_difference after while")
+
+ # missing locals should be sent to the server
+ # We should find why we are missing remotes, as they will have been
+ # rejected.
+
+ # Remove events from missing_remotes if they are referencing a missing
+ # remote. We only care about the "root" rejected ones.
+ missing_remote_ids = [e.event_id for e in missing_remotes]
+ base_remote_rejected = list(missing_remotes)
+ for e in missing_remotes:
+ for e_id, _ in e.auth_events:
+ if e_id in missing_remote_ids:
+ base_remote_rejected.remove(e)
+
+ reason_map = {}
+
+ for e in base_remote_rejected:
+ reason = yield self.store.get_rejection_reason(e.event_id)
+ if reason is None:
+ # FIXME: ERRR?!
+ logger.warn("Could not find reason for %s", e.event_id)
+ raise RuntimeError("")
+
+ reason_map[e.event_id] = reason
+
+ if reason == RejectedReason.AUTH_ERROR:
+ pass
+ elif reason == RejectedReason.REPLACED:
+ # TODO: Get proof
+ pass
+ elif reason == RejectedReason.NOT_ANCESTOR:
+ # TODO: Get proof.
+ pass
+
+ logger.debug("construct_auth_difference returning")
+
+ defer.returnValue({
+ "auth_chain": local_auth,
+ "rejects": {
+ e.event_id: {
+ "reason": reason_map[e.event_id],
+ "proof": None,
+ }
+ for e in base_remote_rejected
+ },
+ "missing": missing_locals,
+ })
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 1dda3ba2c7..c7bf1b47b8 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -245,6 +245,43 @@ class MatrixFederationHttpClient(object):
defer.returnValue((response.code, body))
@defer.inlineCallbacks
+ def post_json(self, destination, path, data={}):
+ """ Sends the specifed json data using POST
+
+ Args:
+ destination (str): The remote server to send the HTTP request
+ to.
+ path (str): The HTTP path.
+ data (dict): A dict containing the data that will be used as
+ the request body. This will be encoded as JSON.
+
+ Returns:
+ Deferred: Succeeds when we get a 2xx HTTP response. The result
+ will be the decoded JSON body. On a 4xx or 5xx error response a
+ CodeMessageException is raised.
+ """
+
+ def body_callback(method, url_bytes, headers_dict):
+ self.sign_request(
+ destination, method, url_bytes, headers_dict, data
+ )
+ return _JsonProducer(data)
+
+ response = yield self._create_request(
+ destination.encode("ascii"),
+ "POST",
+ path.encode("ascii"),
+ body_callback=body_callback,
+ headers_dict={"Content-Type": ["application/json"]},
+ )
+
+ logger.debug("Getting resp body")
+ body = yield readBody(response)
+ logger.debug("Got resp body")
+
+ defer.returnValue((response.code, body))
+
+ @defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
""" GETs some json from the given host homeserver and path
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 4182ad990f..ba9308803c 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -6,6 +6,7 @@ logger = logging.getLogger(__name__)
REQUIREMENTS = {
"syutil==0.0.2": ["syutil"],
"matrix_angular_sdk==0.6.0": ["syweb>=0.6.0"],
+ "matrix_angular_sdk>=0.6.0": ["syweb>=0.6.0"],
"Twisted>=14.0.0": ["twisted>=14.0.0"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
diff --git a/synapse/state.py b/synapse/state.py
index 8144fa02b4..d9fdfb34be 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.api.constants import EventTypes
+from synapse.api.errors import AuthError
from synapse.events.snapshot import EventContext
from collections import namedtuple
@@ -42,6 +43,8 @@ class StateHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
+ # self.auth = hs.get_auth()
+ self.hs = hs
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""):
@@ -210,64 +213,96 @@ class StateHandler(object):
else:
prev_states = []
+ auth_events = {
+ k: e for k, e in unconflicted_state.items()
+ if k[0] in (EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,)
+ }
+
try:
- new_state = {}
- new_state.update(unconflicted_state)
- for key, events in conflicted_state.items():
- new_state[key] = self._resolve_state_events(events)
+ resolved_state = self._resolve_state_events(
+ conflicted_state, auth_events
+ )
except:
logger.exception("Failed to resolve state")
raise
+ new_state = unconflicted_state
+ new_state.update(resolved_state)
+
defer.returnValue((None, new_state, prev_states))
- def _get_power_level_from_event_state(self, event, user_id):
- if hasattr(event, "old_state_events") and event.old_state_events:
- key = (EventTypes.PowerLevels, "", )
- power_level_event = event.old_state_events.get(key)
- level = None
- if power_level_event:
- level = power_level_event.content.get("users", {}).get(
- user_id
+ @log_function
+ def _resolve_state_events(self, conflicted_state, auth_events):
+ """ This is where we actually decide which of the conflicted state to
+ use.
+
+ We resolve conflicts in the following order:
+ 1. power levels
+ 2. memberships
+ 3. other events.
+
+ :param conflicted_state:
+ :param auth_events:
+ :return:
+ """
+ resolved_state = {}
+ power_key = (EventTypes.PowerLevels, "")
+ if power_key in conflicted_state.items():
+ power_levels = conflicted_state[power_key]
+ resolved_state[power_key] = self._resolve_auth_events(power_levels)
+
+ auth_events.update(resolved_state)
+
+ for key, events in conflicted_state.items():
+ if key[0] == EventTypes.Member:
+ resolved_state[key] = self._resolve_auth_events(
+ events,
+ auth_events
)
- if not level:
- level = power_level_event.content.get("users_default", 0)
- return level
- else:
- return 0
+ auth_events.update(resolved_state)
- @log_function
- def _resolve_state_events(self, events):
- curr_events = events
+ for key, events in conflicted_state.items():
+ if key not in resolved_state:
+ resolved_state[key] = self._resolve_normal_events(
+ events, auth_events
+ )
- new_powers = [
- self._get_power_level_from_event_state(e, e.user_id)
- for e in curr_events
- ]
+ return resolved_state
- new_powers = [
- int(p) if p else 0 for p in new_powers
- ]
+ def _resolve_auth_events(self, events, auth_events):
+ reverse = [i for i in reversed(self._ordered_events(events))]
- max_power = max(new_powers)
+ auth_events = dict(auth_events)
- curr_events = [
- z[0] for z in zip(curr_events, new_powers)
- if z[1] == max_power
- ]
+ prev_event = reverse[0]
+ for event in reverse[1:]:
+ auth_events[(prev_event.type, prev_event.state_key)] = prev_event
+ try:
+ # FIXME: hs.get_auth() is bad style, but we need to do it to
+ # get around circular deps.
+ self.hs.get_auth().check(event, auth_events)
+ prev_event = event
+ except AuthError:
+ return prev_event
- if not curr_events:
- raise RuntimeError("Max didn't get a max?")
- elif len(curr_events) == 1:
- return curr_events[0]
-
- # TODO: For now, just choose the one with the largest event_id.
- return (
- sorted(
- curr_events,
- key=lambda e: hashlib.sha1(
- e.event_id + e.user_id + e.room_id + e.type
- ).hexdigest()
- )[0]
- )
+ return event
+
+ def _resolve_normal_events(self, events, auth_events):
+ for event in self._ordered_events(events):
+ try:
+ # FIXME: hs.get_auth() is bad style, but we need to do it to
+ # get around circular deps.
+ self.hs.get_auth().check(event, auth_events)
+ return event
+ except AuthError:
+ pass
+
+ # Oh dear.
+ return event
+
+ def _ordered_events(self, events):
+ def key_func(e):
+ return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
+
+ return sorted(events, key=key_func)
\ No newline at end of file
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 277581b4e2..adcb038020 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -32,6 +32,7 @@ from .event_federation import EventFederationStore
from .pusher import PusherStore
from .push_rule import PushRuleStore
from .media_repository import MediaRepositoryStore
+from .rejections import RejectionsStore
from .state import StateStore
from .signatures import SignatureStore
@@ -85,6 +86,7 @@ class DataStore(RoomMemberStore, RoomStore,
DirectoryStore, KeyStore, StateStore, SignatureStore,
EventFederationStore,
MediaRepositoryStore,
+ RejectionsStore,
PusherStore,
PushRuleStore
):
@@ -229,6 +231,9 @@ class DataStore(RoomMemberStore, RoomStore,
if not outlier:
self._store_state_groups_txn(txn, event, context)
+ if context.rejected:
+ self._store_rejections_txn(txn, event.event_id, context.rejected)
+
if current_state:
txn.execute(
"DELETE FROM current_state_events WHERE room_id = ?",
@@ -267,7 +272,7 @@ class DataStore(RoomMemberStore, RoomStore,
or_replace=True,
)
- if is_new_state:
+ if is_new_state and not context.rejected:
self._simple_insert_txn(
txn,
"current_state_events",
@@ -293,7 +298,7 @@ class DataStore(RoomMemberStore, RoomStore,
or_ignore=True,
)
- if not backfilled:
+ if not backfilled and not context.rejected:
self._simple_insert_txn(
txn,
table="state_forward_extremities",
@@ -457,6 +462,35 @@ class DataStore(RoomMemberStore, RoomStore,
],
)
+ def have_events(self, event_ids):
+ """Given a list of event ids, check if we have already processed them.
+
+ Returns:
+ dict: Has an entry for each event id we already have seen. Maps to
+ the rejected reason string if we rejected the event, else maps to
+ None.
+ """
+ def f(txn):
+ sql = (
+ "SELECT e.event_id, reason FROM events as e "
+ "LEFT JOIN rejections as r ON e.event_id = r.event_id "
+ "WHERE e.event_id = ?"
+ )
+
+ res = {}
+ for event_id in event_ids:
+ txn.execute(sql, (event_id,))
+ row = txn.fetchone()
+ if row:
+ _, rejected = row
+ res[event_id] = rejected
+
+ return res
+
+ return self.runInteraction(
+ "have_events", f,
+ )
+
def schema_path(schema):
""" Get a filesystem path for the named database schema
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 4e8bd3faa9..1f5e74a16a 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -502,10 +502,12 @@ class SQLBaseStore(object):
return [e for e in events if e]
def _get_event_txn(self, txn, event_id, check_redacted=True,
- get_prev_content=False):
+ get_prev_content=False, allow_rejected=False):
sql = (
- "SELECT internal_metadata, json, r.event_id FROM event_json as e "
+ "SELECT internal_metadata, json, r.event_id, reason "
+ "FROM event_json as e "
"LEFT JOIN redactions as r ON e.event_id = r.redacts "
+ "LEFT JOIN rejections as rej on rej.event_id = e.event_id "
"WHERE e.event_id = ? "
"LIMIT 1 "
)
@@ -517,13 +519,16 @@ class SQLBaseStore(object):
if not res:
return None
- internal_metadata, js, redacted = res
+ internal_metadata, js, redacted, rejected_reason = res
- return self._get_event_from_row_txn(
- txn, internal_metadata, js, redacted,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- )
+ if allow_rejected or not rejected_reason:
+ return self._get_event_from_row_txn(
+ txn, internal_metadata, js, redacted,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ )
+ else:
+ return None
def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
check_redacted=True, get_prev_content=False):
diff --git a/synapse/storage/rejections.py b/synapse/storage/rejections.py
new file mode 100644
index 0000000000..4e1a9a2783
--- /dev/null
+++ b/synapse/storage/rejections.py
@@ -0,0 +1,43 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014, 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import SQLBaseStore
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class RejectionsStore(SQLBaseStore):
+ def _store_rejections_txn(self, txn, event_id, reason):
+ self._simple_insert_txn(
+ txn,
+ table="rejections",
+ values={
+ "event_id": event_id,
+ "reason": reason,
+ "last_check": self._clock.time_msec(),
+ }
+ )
+
+ def get_rejection_reason(self, event_id):
+ return self._simple_select_one_onecol(
+ table="rejections",
+ retcol="reason",
+ keyvalues={
+ "event_id": event_id,
+ },
+ allow_none=True,
+ )
diff --git a/synapse/storage/schema/delta/v12.sql b/synapse/storage/schema/delta/v12.sql
index 8c4dfd5c1b..d83c3b049e 100644
--- a/synapse/storage/schema/delta/v12.sql
+++ b/synapse/storage/schema/delta/v12.sql
@@ -1,4 +1,8 @@
+<<<<<<< HEAD
+/* Copyright 2015 OpenMarket Ltd
+=======
/* Copyright 2014 OpenMarket Ltd
+>>>>>>> fc946f3b8da8c7f71a9c25bf542c04472147bc5b
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -12,6 +16,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
+CREATE TABLE IF NOT EXISTS rejections(
+ event_id TEXT NOT NULL,
+ reason TEXT NOT NULL,
+ last_check TEXT NOT NULL,
+ root_rejected TEXT,
+ CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE
+);
+
-- Push notification endpoints that users have configured
CREATE TABLE IF NOT EXISTS pushers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -44,3 +57,4 @@ CREATE TABLE IF NOT EXISTS push_rules (
);
CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name);
+
diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql
index dd00c1cd2f..5866a387f6 100644
--- a/synapse/storage/schema/im.sql
+++ b/synapse/storage/schema/im.sql
@@ -123,3 +123,11 @@ CREATE TABLE IF NOT EXISTS room_hosts(
);
CREATE INDEX IF NOT EXISTS room_hosts_room_id ON room_hosts (room_id);
+
+CREATE TABLE IF NOT EXISTS rejections(
+ event_id TEXT NOT NULL,
+ reason TEXT NOT NULL,
+ last_check TEXT NOT NULL,
+ root_rejected TEXT,
+ CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE
+);
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index ed21defd13..44dbce6bea 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -52,6 +52,7 @@ class FederationTestCase(unittest.TestCase):
"get_room",
"get_destination_retry_timings",
"set_destination_retry_timings",
+ "have_events",
]),
resource_for_federation=NonCallableMock(),
http_client=NonCallableMock(spec_set=[]),
@@ -90,6 +91,7 @@ class FederationTestCase(unittest.TestCase):
self.datastore.persist_event.return_value = defer.succeed(None)
self.datastore.get_room.return_value = defer.succeed(True)
self.auth.check_host_in_room.return_value = defer.succeed(True)
+ self.datastore.have_events.return_value = defer.succeed({})
def annotate(ev, old_state=None):
context = Mock()
diff --git a/tests/test_state.py b/tests/test_state.py
index 98ad9e54cd..019e794aa2 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -16,11 +16,120 @@
from tests import unittest
from twisted.internet import defer
+from synapse.events import FrozenEvent
+from synapse.api.auth import Auth
+from synapse.api.constants import EventTypes, Membership
from synapse.state import StateHandler
from mock import Mock
+_next_event_id = 1000
+
+
+def create_event(name=None, type=None, state_key=None, depth=2, event_id=None,
+ prev_events=[], **kwargs):
+ global _next_event_id
+
+ if not event_id:
+ _next_event_id += 1
+ event_id = str(_next_event_id)
+
+ if not name:
+ if state_key is not None:
+ name = "<%s-%s, %s>" % (type, state_key, event_id,)
+ else:
+ name = "<%s, %s>" % (type, event_id,)
+
+ d = {
+ "event_id": event_id,
+ "type": type,
+ "sender": "@user_id:example.com",
+ "room_id": "!room_id:example.com",
+ "depth": depth,
+ "prev_events": prev_events,
+ }
+
+ if state_key is not None:
+ d["state_key"] = state_key
+
+ d.update(kwargs)
+
+ event = FrozenEvent(d)
+
+ return event
+
+
+class StateGroupStore(object):
+ def __init__(self):
+ self._event_to_state_group = {}
+ self._group_to_state = {}
+
+ self._next_group = 1
+
+ def get_state_groups(self, event_ids):
+ groups = {}
+ for event_id in event_ids:
+ group = self._event_to_state_group.get(event_id)
+ if group:
+ groups[group] = self._group_to_state[group]
+
+ return defer.succeed(groups)
+
+ def store_state_groups(self, event, context):
+ if context.current_state is None:
+ return
+
+ state_events = context.current_state
+
+ if event.is_state():
+ state_events[(event.type, event.state_key)] = event
+
+ state_group = context.state_group
+ if not state_group:
+ state_group = self._next_group
+ self._next_group += 1
+
+ self._group_to_state[state_group] = state_events.values()
+
+ self._event_to_state_group[event.event_id] = state_group
+
+
+class DictObj(dict):
+ def __init__(self, **kwargs):
+ super(DictObj, self).__init__(kwargs)
+ self.__dict__ = self
+
+
+class Graph(object):
+ def __init__(self, nodes, edges):
+ events = {}
+ clobbered = set(events.keys())
+
+ for event_id, fields in nodes.items():
+ refs = edges.get(event_id)
+ if refs:
+ clobbered.difference_update(refs)
+ prev_events = [(r, {}) for r in refs]
+ else:
+ prev_events = []
+
+ events[event_id] = create_event(
+ event_id=event_id,
+ prev_events=prev_events,
+ **fields
+ )
+
+ self._leaves = clobbered
+ self._events = sorted(events.values(), key=lambda e: e.depth)
+
+ def walk(self):
+ return iter(self._events)
+
+ def get_leaves(self):
+ return (self._events[i] for i in self._leaves)
+
+
class StateTestCase(unittest.TestCase):
def setUp(self):
self.store = Mock(
@@ -29,20 +138,188 @@ class StateTestCase(unittest.TestCase):
"add_event_hashes",
]
)
- hs = Mock(spec=["get_datastore"])
+ hs = Mock(spec=["get_datastore", "get_auth", "get_state_handler"])
hs.get_datastore.return_value = self.store
+ hs.get_state_handler.return_value = None
+ hs.get_auth.return_value = Auth(hs)
self.state = StateHandler(hs)
self.event_id = 0
@defer.inlineCallbacks
+ def test_branch_no_conflict(self):
+ graph = Graph(
+ nodes={
+ "START": DictObj(
+ type=EventTypes.Create,
+ state_key="",
+ depth=1,
+ ),
+ "A": DictObj(
+ type=EventTypes.Message,
+ depth=2,
+ ),
+ "B": DictObj(
+ type=EventTypes.Message,
+ depth=3,
+ ),
+ "C": DictObj(
+ type=EventTypes.Name,
+ state_key="",
+ depth=3,
+ ),
+ "D": DictObj(
+ type=EventTypes.Message,
+ depth=4,
+ ),
+ },
+ edges={
+ "A": ["START"],
+ "B": ["A"],
+ "C": ["A"],
+ "D": ["B", "C"]
+ }
+ )
+
+ store = StateGroupStore()
+ self.store.get_state_groups.side_effect = store.get_state_groups
+
+ context_store = {}
+
+ for event in graph.walk():
+ context = yield self.state.compute_event_context(event)
+ store.store_state_groups(event, context)
+ context_store[event.event_id] = context
+
+ self.assertEqual(2, len(context_store["D"].current_state))
+
+ @defer.inlineCallbacks
+ def test_branch_basic_conflict(self):
+ graph = Graph(
+ nodes={
+ "START": DictObj(
+ type=EventTypes.Create,
+ state_key="creator",
+ content={"membership": "@user_id:example.com"},
+ depth=1,
+ ),
+ "A": DictObj(
+ type=EventTypes.Member,
+ state_key="@user_id:example.com",
+ content={"membership": Membership.JOIN},
+ membership=Membership.JOIN,
+ depth=2,
+ ),
+ "B": DictObj(
+ type=EventTypes.Name,
+ state_key="",
+ depth=3,
+ ),
+ "C": DictObj(
+ type=EventTypes.Name,
+ state_key="",
+ depth=4,
+ ),
+ "D": DictObj(
+ type=EventTypes.Message,
+ depth=5,
+ ),
+ },
+ edges={
+ "A": ["START"],
+ "B": ["A"],
+ "C": ["A"],
+ "D": ["B", "C"]
+ }
+ )
+
+ store = StateGroupStore()
+ self.store.get_state_groups.side_effect = store.get_state_groups
+
+ context_store = {}
+
+ for event in graph.walk():
+ context = yield self.state.compute_event_context(event)
+ store.store_state_groups(event, context)
+ context_store[event.event_id] = context
+
+ self.assertSetEqual(
+ {"START", "A", "C"},
+ {e.event_id for e in context_store["D"].current_state.values()}
+ )
+
+ @defer.inlineCallbacks
+ def test_branch_have_banned_conflict(self):
+ graph = Graph(
+ nodes={
+ "START": DictObj(
+ type=EventTypes.Create,
+ state_key="creator",
+ content={"membership": "@user_id:example.com"},
+ depth=1,
+ ),
+ "A": DictObj(
+ type=EventTypes.Member,
+ state_key="@user_id:example.com",
+ content={"membership": Membership.JOIN},
+ membership=Membership.JOIN,
+ depth=2,
+ ),
+ "B": DictObj(
+ type=EventTypes.Name,
+ state_key="",
+ depth=3,
+ ),
+ "C": DictObj(
+ type=EventTypes.Member,
+ state_key="@user_id_2:example.com",
+ content={"membership": Membership.BAN},
+ membership=Membership.BAN,
+ depth=4,
+ ),
+ "D": DictObj(
+ type=EventTypes.Name,
+ state_key="",
+ depth=4,
+ sender="@user_id_2:example.com",
+ ),
+ "E": DictObj(
+ type=EventTypes.Message,
+ depth=5,
+ ),
+ },
+ edges={
+ "A": ["START"],
+ "B": ["A"],
+ "C": ["B"],
+ "D": ["B"],
+ "E": ["C", "D"]
+ }
+ )
+
+ store = StateGroupStore()
+ self.store.get_state_groups.side_effect = store.get_state_groups
+
+ context_store = {}
+
+ for event in graph.walk():
+ context = yield self.state.compute_event_context(event)
+ store.store_state_groups(event, context)
+ context_store[event.event_id] = context
+
+ self.assertSetEqual(
+ {"START", "A", "B", "C"},
+ {e.event_id for e in context_store["E"].current_state.values()}
+ )
+
+ @defer.inlineCallbacks
def test_annotate_with_old_message(self):
- event = self.create_event(type="test_message", name="event")
+ event = create_event(type="test_message", name="event")
old_state = [
- self.create_event(type="test1", state_key="1"),
- self.create_event(type="test1", state_key="2"),
- self.create_event(type="test2", state_key=""),
+ create_event(type="test1", state_key="1"),
+ create_event(type="test1", state_key="2"),
+ create_event(type="test2", state_key=""),
]
context = yield self.state.compute_event_context(
@@ -62,12 +339,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_annotate_with_old_state(self):
- event = self.create_event(type="state", state_key="", name="event")
+ event = create_event(type="state", state_key="", name="event")
old_state = [
- self.create_event(type="test1", state_key="1"),
- self.create_event(type="test1", state_key="2"),
- self.create_event(type="test2", state_key=""),
+ create_event(type="test1", state_key="1"),
+ create_event(type="test1", state_key="2"),
+ create_event(type="test2", state_key=""),
]
context = yield self.state.compute_event_context(
@@ -88,13 +365,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_trivial_annotate_message(self):
- event = self.create_event(type="test_message", name="event")
- event.prev_events = []
+ event = create_event(type="test_message", name="event")
old_state = [
- self.create_event(type="test1", state_key="1"),
- self.create_event(type="test1", state_key="2"),
- self.create_event(type="test2", state_key=""),
+ create_event(type="test1", state_key="1"),
+ create_event(type="test1", state_key="2"),
+ create_event(type="test2", state_key=""),
]
group_name = "group_name_1"
@@ -119,13 +395,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_trivial_annotate_state(self):
- event = self.create_event(type="state", state_key="", name="event")
- event.prev_events = []
+ event = create_event(type="state", state_key="", name="event")
old_state = [
- self.create_event(type="test1", state_key="1"),
- self.create_event(type="test1", state_key="2"),
- self.create_event(type="test2", state_key=""),
+ create_event(type="test1", state_key="1"),
+ create_event(type="test1", state_key="2"),
+ create_event(type="test2", state_key=""),
]
group_name = "group_name_1"
@@ -150,30 +425,21 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_resolve_message_conflict(self):
- event = self.create_event(type="test_message", name="event")
- event.prev_events = []
+ event = create_event(type="test_message", name="event")
old_state_1 = [
- self.create_event(type="test1", state_key="1"),
- self.create_event(type="test1", state_key="2"),
- self.create_event(type="test2", state_key=""),
+ create_event(type="test1", state_key="1"),
+ create_event(type="test1", state_key="2"),
+ create_event(type="test2", state_key=""),
]
old_state_2 = [
- self.create_event(type="test1", state_key="1"),
- self.create_event(type="test3", state_key="2"),
- self.create_event(type="test4", state_key=""),
+ create_event(type="test1", state_key="1"),
+ create_event(type="test3", state_key="2"),
+ create_event(type="test4", state_key=""),
]
- group_name_1 = "group_name_1"
- group_name_2 = "group_name_2"
-
- self.store.get_state_groups.return_value = {
- group_name_1: old_state_1,
- group_name_2: old_state_2,
- }
-
- context = yield self.state.compute_event_context(event)
+ context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(len(context.current_state), 5)
@@ -181,56 +447,76 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_resolve_state_conflict(self):
- event = self.create_event(type="test4", state_key="", name="event")
- event.prev_events = []
+ event = create_event(type="test4", state_key="", name="event")
old_state_1 = [
- self.create_event(type="test1", state_key="1"),
- self.create_event(type="test1", state_key="2"),
- self.create_event(type="test2", state_key=""),
+ create_event(type="test1", state_key="1"),
+ create_event(type="test1", state_key="2"),
+ create_event(type="test2", state_key=""),
]
old_state_2 = [
- self.create_event(type="test1", state_key="1"),
- self.create_event(type="test3", state_key="2"),
- self.create_event(type="test4", state_key=""),
+ create_event(type="test1", state_key="1"),
+ create_event(type="test3", state_key="2"),
+ create_event(type="test4", state_key=""),
]
- group_name_1 = "group_name_1"
- group_name_2 = "group_name_2"
-
- self.store.get_state_groups.return_value = {
- group_name_1: old_state_1,
- group_name_2: old_state_2,
- }
-
- context = yield self.state.compute_event_context(event)
+ context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(len(context.current_state), 5)
self.assertIsNone(context.state_group)
- def create_event(self, name=None, type=None, state_key=None):
- self.event_id += 1
- event_id = str(self.event_id)
+ @defer.inlineCallbacks
+ def test_standard_depth_conflict(self):
+ event = create_event(type="test4", name="event")
+
+ member_event = create_event(
+ type=EventTypes.Member,
+ state_key="@user_id:example.com",
+ content={
+ "membership": Membership.JOIN,
+ }
+ )
- if not name:
- if state_key is not None:
- name = "<%s-%s>" % (type, state_key)
- else:
- name = "<%s>" % (type, )
+ old_state_1 = [
+ member_event,
+ create_event(type="test1", state_key="1", depth=1),
+ ]
+
+ old_state_2 = [
+ member_event,
+ create_event(type="test1", state_key="1", depth=2),
+ ]
- event = Mock(name=name, spec=[])
- event.type = type
+ context = yield self._get_context(event, old_state_1, old_state_2)
- if state_key is not None:
- event.state_key = state_key
- event.event_id = event_id
+ self.assertEqual(old_state_2[1], context.current_state[("test1", "1")])
+
+ # Reverse the depth to make sure we are actually using the depths
+ # during state resolution.
+
+ old_state_1 = [
+ member_event,
+ create_event(type="test1", state_key="1", depth=2),
+ ]
+
+ old_state_2 = [
+ member_event,
+ create_event(type="test1", state_key="1", depth=1),
+ ]
+
+ context = yield self._get_context(event, old_state_1, old_state_2)
+
+ self.assertEqual(old_state_1[1], context.current_state[("test1", "1")])
- event.is_state = lambda: (state_key is not None)
- event.unsigned = {}
+ def _get_context(self, event, old_state_1, old_state_2):
+ group_name_1 = "group_name_1"
+ group_name_2 = "group_name_2"
- event.user_id = "@user_id:example.com"
- event.room_id = "!room_id:example.com"
+ self.store.get_state_groups.return_value = {
+ group_name_1: old_state_1,
+ group_name_2: old_state_2,
+ }
- return event
+ return self.state.compute_event_context(event)
|