summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/federation/pdu_codec.py102
-rw-r--r--synapse/federation/persistence.py73
-rw-r--r--synapse/federation/replication.py354
-rw-r--r--synapse/federation/transport.py320
-rw-r--r--synapse/federation/units.py134
5 files changed, 422 insertions, 561 deletions
diff --git a/synapse/federation/pdu_codec.py b/synapse/federation/pdu_codec.py
deleted file mode 100644
index e8180d94fd..0000000000
--- a/synapse/federation/pdu_codec.py
+++ /dev/null
@@ -1,102 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from .units import Pdu
-
-import copy
-
-
-def decode_event_id(event_id, server_name):
-    parts = event_id.split("@")
-    if len(parts) < 2:
-        return (event_id, server_name)
-    else:
-        return (parts[0], "".join(parts[1:]))
-
-
-def encode_event_id(pdu_id, origin):
-    return "%s@%s" % (pdu_id, origin)
-
-
-class PduCodec(object):
-
-    def __init__(self, hs):
-        self.server_name = hs.hostname
-        self.event_factory = hs.get_event_factory()
-        self.clock = hs.get_clock()
-
-    def event_from_pdu(self, pdu):
-        kwargs = {}
-
-        kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin)
-        kwargs["room_id"] = pdu.context
-        kwargs["etype"] = pdu.pdu_type
-        kwargs["prev_events"] = [
-            encode_event_id(p[0], p[1]) for p in pdu.prev_pdus
-        ]
-
-        if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"):
-            kwargs["prev_state"] = encode_event_id(
-                pdu.prev_state_id, pdu.prev_state_origin
-            )
-
-        kwargs.update({
-            k: v
-            for k, v in pdu.get_full_dict().items()
-            if k not in [
-                "pdu_id",
-                "context",
-                "pdu_type",
-                "prev_pdus",
-                "prev_state_id",
-                "prev_state_origin",
-            ]
-        })
-
-        return self.event_factory.create_event(**kwargs)
-
-    def pdu_from_event(self, event):
-        d = event.get_full_dict()
-
-        d["pdu_id"], d["origin"] = decode_event_id(
-            event.event_id, self.server_name
-        )
-        d["context"] = event.room_id
-        d["pdu_type"] = event.type
-
-        if hasattr(event, "prev_events"):
-            d["prev_pdus"] = [
-                decode_event_id(e, self.server_name)
-                for e in event.prev_events
-            ]
-
-        if hasattr(event, "prev_state"):
-            d["prev_state_id"], d["prev_state_origin"] = (
-                decode_event_id(event.prev_state, self.server_name)
-            )
-
-        if hasattr(event, "state_key"):
-            d["is_state"] = True
-
-        kwargs = copy.deepcopy(event.unrecognized_keys)
-        kwargs.update({
-            k: v for k, v in d.items()
-            if k not in ["event_id", "room_id", "type", "prev_events"]
-        })
-
-        if "origin_server_ts" not in kwargs:
-            kwargs["origin_server_ts"] = int(self.clock.time_msec())
-
-        return Pdu(**kwargs)
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 7043fcc504..73dc844d59 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -21,8 +21,6 @@ These actions are mostly only used by the :py:mod:`.replication` module.
 
 from twisted.internet import defer
 
-from .units import Pdu
-
 from synapse.util.logutils import log_function
 
 import json
@@ -32,76 +30,6 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class PduActions(object):
-    """ Defines persistence actions that relate to handling PDUs.
-    """
-
-    def __init__(self, datastore):
-        self.store = datastore
-
-    @log_function
-    def mark_as_processed(self, pdu):
-        """ Persist the fact that we have fully processed the given `Pdu`
-
-        Returns:
-            Deferred
-        """
-        return self.store.mark_pdu_as_processed(pdu.pdu_id, pdu.origin)
-
-    @defer.inlineCallbacks
-    @log_function
-    def after_transaction(self, transaction_id, destination, origin):
-        """ Returns all `Pdu`s that we sent to the given remote home server
-        after a given transaction id.
-
-        Returns:
-            Deferred: Results in a list of `Pdu`s
-        """
-        results = yield self.store.get_pdus_after_transaction(
-            transaction_id,
-            destination
-        )
-
-        defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
-
-    @defer.inlineCallbacks
-    @log_function
-    def get_all_pdus_from_context(self, context):
-        results = yield self.store.get_all_pdus_from_context(context)
-        defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
-
-    @defer.inlineCallbacks
-    @log_function
-    def backfill(self, context, pdu_list, limit):
-        """ For a given list of PDU id and origins return the proceeding
-        `limit` `Pdu`s in the given `context`.
-
-        Returns:
-            Deferred: Results in a list of `Pdu`s.
-        """
-        results = yield self.store.get_backfill(
-            context, pdu_list, limit
-        )
-
-        defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
-
-    @log_function
-    def is_new(self, pdu):
-        """ When we receive a `Pdu` from a remote home server, we want to
-        figure out whether it is `new`, i.e. it is not some historic PDU that
-        we haven't seen simply because we haven't backfilled back that far.
-
-        Returns:
-            Deferred: Results in a `bool`
-        """
-        return self.store.is_pdu_new(
-            pdu_id=pdu.pdu_id,
-            origin=pdu.origin,
-            context=pdu.context,
-            depth=pdu.depth
-        )
-
-
 class TransactionActions(object):
     """ Defines persistence actions that relate to handling Transactions.
     """
@@ -158,7 +86,6 @@ class TransactionActions(object):
             transaction.transaction_id,
             transaction.destination,
             transaction.origin_server_ts,
-            [(p["pdu_id"], p["origin"]) for p in transaction.pdus]
         )
 
     @log_function
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 092411eaf9..8ee74de005 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -19,9 +19,9 @@ a given transport.
 
 from twisted.internet import defer
 
-from .units import Transaction, Pdu, Edu
+from .units import Transaction, Edu
 
-from .persistence import PduActions, TransactionActions
+from .persistence import TransactionActions
 
 from synapse.util.logutils import log_function
 
@@ -57,7 +57,7 @@ class ReplicationLayer(object):
         self.transport_layer.register_request_handler(self)
 
         self.store = hs.get_datastore()
-        self.pdu_actions = PduActions(self.store)
+        # self.pdu_actions = PduActions(self.store)
         self.transaction_actions = TransactionActions(self.store)
 
         self._transaction_queue = _TransactionQueue(
@@ -72,6 +72,8 @@ class ReplicationLayer(object):
 
         self._clock = hs.get_clock()
 
+        self.event_factory = hs.get_event_factory()
+
     def set_handler(self, handler):
         """Sets the handler that the replication layer will use to communicate
         receipt of new PDUs from other home servers. The required methods are
@@ -81,7 +83,7 @@ class ReplicationLayer(object):
 
     def register_edu_handler(self, edu_type, handler):
         if edu_type in self.edu_handlers:
-            raise KeyError("Already have an EDU handler for %s" % (edu_type))
+            raise KeyError("Already have an EDU handler for %s" % (edu_type,))
 
         self.edu_handlers[edu_type] = handler
 
@@ -102,24 +104,17 @@ class ReplicationLayer(object):
           object to encode as JSON.
         """
         if query_type in self.query_handlers:
-            raise KeyError("Already have a Query handler for %s" % (query_type))
+            raise KeyError(
+                "Already have a Query handler for %s" % (query_type,)
+            )
 
         self.query_handlers[query_type] = handler
 
-    @defer.inlineCallbacks
     @log_function
     def send_pdu(self, pdu):
         """Informs the replication layer about a new PDU generated within the
         home server that should be transmitted to others.
 
-        This will fill out various attributes on the PDU object, e.g. the
-        `prev_pdus` key.
-
-        *Note:* The home server should always call `send_pdu` even if it knows
-        that it does not need to be replicated to other home servers. This is
-        in case e.g. someone else joins via a remote home server and then
-        backfills.
-
         TODO: Figure out when we should actually resolve the deferred.
 
         Args:
@@ -132,18 +127,15 @@ class ReplicationLayer(object):
         order = self._order
         self._order += 1
 
-        logger.debug("[%s] Persisting PDU", pdu.pdu_id)
-
-        # Save *before* trying to send
-        yield self.store.persist_event(pdu=pdu)
-
-        logger.debug("[%s] Persisted PDU", pdu.pdu_id)
-        logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id)
+        logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
 
         # TODO, add errback, etc.
         self._transaction_queue.enqueue_pdu(pdu, order)
 
-        logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.pdu_id)
+        logger.debug(
+            "[%s] transaction_layer.enqueue_pdu... done",
+            pdu.event_id
+        )
 
     @log_function
     def send_edu(self, destination, edu_type, content):
@@ -159,6 +151,11 @@ class ReplicationLayer(object):
         return defer.succeed(None)
 
     @log_function
+    def send_failure(self, failure, destination):
+        self._transaction_queue.enqueue_failure(failure, destination)
+        return defer.succeed(None)
+
+    @log_function
     def make_query(self, destination, query_type, args,
                    retry_on_dns_fail=True):
         """Sends a federation Query to a remote homeserver of the given type
@@ -181,7 +178,7 @@ class ReplicationLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def backfill(self, dest, context, limit):
+    def backfill(self, dest, context, limit, extremities):
         """Requests some more historic PDUs for the given context from the
         given destination server.
 
@@ -189,12 +186,12 @@ class ReplicationLayer(object):
             dest (str): The remote home server to ask.
             context (str): The context to backfill.
             limit (int): The maximum number of PDUs to return.
+            extremities (list): List of PDU id and origins of the first pdus
+                we have seen from the context
 
         Returns:
             Deferred: Results in the received PDUs.
         """
-        extremities = yield self.store.get_oldest_pdus_in_context(context)
-
         logger.debug("backfill extrem=%s", extremities)
 
         # If there are no extremeties then we've (probably) reached the start.
@@ -208,15 +205,18 @@ class ReplicationLayer(object):
 
         transaction = Transaction(**transaction_data)
 
-        pdus = [Pdu(outlier=False, **p) for p in transaction.pdus]
+        pdus = [
+            self.event_from_pdu_json(p, outlier=False)
+            for p in transaction.pdus
+        ]
         for pdu in pdus:
-            yield self._handle_new_pdu(pdu, backfilled=True)
+            yield self._handle_new_pdu(dest, pdu, backfilled=True)
 
         defer.returnValue(pdus)
 
     @defer.inlineCallbacks
     @log_function
-    def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False):
+    def get_pdu(self, destination, event_id, outlier=False):
         """Requests the PDU with given origin and ID from the remote home
         server.
 
@@ -225,7 +225,7 @@ class ReplicationLayer(object):
         Args:
             destination (str): Which home server to query
             pdu_origin (str): The home server that originally sent the pdu.
-            pdu_id (str)
+            event_id (str)
             outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
                 it's from an arbitary point in the context as opposed to part
                 of the current block of PDUs. Defaults to `False`
@@ -234,23 +234,27 @@ class ReplicationLayer(object):
             Deferred: Results in the requested PDU.
         """
 
-        transaction_data = yield self.transport_layer.get_pdu(
-            destination, pdu_origin, pdu_id)
+        transaction_data = yield self.transport_layer.get_event(
+            destination, event_id
+        )
 
         transaction = Transaction(**transaction_data)
 
-        pdu_list = [Pdu(outlier=outlier, **p) for p in transaction.pdus]
+        pdu_list = [
+            self.event_from_pdu_json(p, outlier=outlier)
+            for p in transaction.pdus
+        ]
 
         pdu = None
         if pdu_list:
             pdu = pdu_list[0]
-            yield self._handle_new_pdu(pdu)
+            yield self._handle_new_pdu(destination, pdu)
 
         defer.returnValue(pdu)
 
     @defer.inlineCallbacks
     @log_function
-    def get_state_for_context(self, destination, context):
+    def get_state_for_context(self, destination, context, event_id=None):
         """Requests all of the `current` state PDUs for a given context from
         a remote home server.
 
@@ -263,29 +267,25 @@ class ReplicationLayer(object):
         """
 
         transaction_data = yield self.transport_layer.get_context_state(
-            destination, context)
+            destination,
+            context,
+            event_id=event_id,
+        )
 
         transaction = Transaction(**transaction_data)
-
-        pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
-        for pdu in pdus:
-            yield self._handle_new_pdu(pdu)
+        pdus = [
+            self.event_from_pdu_json(p, outlier=True)
+            for p in transaction.pdus
+        ]
 
         defer.returnValue(pdus)
 
     @defer.inlineCallbacks
     @log_function
-    def on_context_pdus_request(self, context):
-        pdus = yield self.pdu_actions.get_all_pdus_from_context(
-            context
+    def on_backfill_request(self, origin, context, versions, limit):
+        pdus = yield self.handler.on_backfill_request(
+            origin, context, versions, limit
         )
-        defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
-
-    @defer.inlineCallbacks
-    @log_function
-    def on_backfill_request(self, context, versions, limit):
-
-        pdus = yield self.pdu_actions.backfill(context, versions, limit)
 
         defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
 
@@ -295,11 +295,17 @@ class ReplicationLayer(object):
         transaction = Transaction(**transaction_data)
 
         for p in transaction.pdus:
+            if "unsigned" in p:
+                unsigned = p["unsigned"]
+                if "age" in unsigned:
+                    p["age"] = unsigned["age"]
             if "age" in p:
                 p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
                 del p["age"]
 
-        pdu_list = [Pdu(**p) for p in transaction.pdus]
+        pdu_list = [
+            self.event_from_pdu_json(p) for p in transaction.pdus
+        ]
 
         logger.debug("[%s] Got transaction", transaction.transaction_id)
 
@@ -315,11 +321,15 @@ class ReplicationLayer(object):
 
         dl = []
         for pdu in pdu_list:
-            dl.append(self._handle_new_pdu(pdu))
+            dl.append(self._handle_new_pdu(transaction.origin, pdu))
 
         if hasattr(transaction, "edus"):
             for edu in [Edu(**x) for x in transaction.edus]:
-                self.received_edu(transaction.origin, edu.edu_type, edu.content)
+                self.received_edu(
+                    transaction.origin,
+                    edu.edu_type,
+                    edu.content
+                )
 
         results = yield defer.DeferredList(dl)
 
@@ -347,20 +357,22 @@ class ReplicationLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def on_context_state_request(self, context):
-        results = yield self.store.get_current_state_for_context(
-            context
-        )
-
-        logger.debug("Context returning %d results", len(results))
+    def on_context_state_request(self, origin, context, event_id):
+        if event_id:
+            pdus = yield self.handler.get_state_for_pdu(
+                origin,
+                context,
+                event_id,
+            )
+        else:
+            raise NotImplementedError("Specify an event")
 
-        pdus = [Pdu.from_pdu_tuple(p) for p in results]
         defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
 
     @defer.inlineCallbacks
     @log_function
-    def on_pdu_request(self, pdu_origin, pdu_id):
-        pdu = yield self._get_persisted_pdu(pdu_id, pdu_origin)
+    def on_pdu_request(self, origin, event_id):
+        pdu = yield self._get_persisted_pdu(origin, event_id)
 
         if pdu:
             defer.returnValue(
@@ -372,116 +384,207 @@ class ReplicationLayer(object):
     @defer.inlineCallbacks
     @log_function
     def on_pull_request(self, origin, versions):
-        transaction_id = max([int(v) for v in versions])
+        raise NotImplementedError("Pull transacions not implemented")
 
-        response = yield self.pdu_actions.after_transaction(
-            transaction_id,
-            origin,
-            self.server_name
+    @defer.inlineCallbacks
+    def on_query_request(self, query_type, args):
+        if query_type in self.query_handlers:
+            response = yield self.query_handlers[query_type](args)
+            defer.returnValue((200, response))
+        else:
+            defer.returnValue(
+                (404, "No handler for Query type '%s'" % (query_type, ))
+            )
+
+    @defer.inlineCallbacks
+    def on_make_join_request(self, context, user_id):
+        pdu = yield self.handler.on_make_join_request(context, user_id)
+        defer.returnValue({
+            "event": pdu.get_pdu_json(),
+        })
+
+    @defer.inlineCallbacks
+    def on_invite_request(self, origin, content):
+        pdu = self.event_from_pdu_json(content)
+        ret_pdu = yield self.handler.on_invite_request(origin, pdu)
+        defer.returnValue(
+            (
+                200,
+                {
+                    "event": ret_pdu.get_pdu_json(),
+                }
+            )
         )
 
-        if not response:
-            response = []
+    @defer.inlineCallbacks
+    def on_send_join_request(self, origin, content):
+        pdu = self.event_from_pdu_json(content)
+        res_pdus = yield self.handler.on_send_join_request(origin, pdu)
 
+        defer.returnValue((200, {
+            "state": [p.get_pdu_json() for p in res_pdus["state"]],
+            "auth_chain": [p.get_pdu_json() for p in res_pdus["auth_chain"]],
+        }))
+
+    @defer.inlineCallbacks
+    def on_event_auth(self, origin, context, event_id):
+        auth_pdus = yield self.handler.on_event_auth(event_id)
         defer.returnValue(
-            (200, self._transaction_from_pdus(response).get_dict())
+            (
+                200,
+                {
+                    "auth_chain": [a.get_pdu_json() for a in auth_pdus],
+                }
+            )
         )
 
     @defer.inlineCallbacks
-    def on_query_request(self, query_type, args):
-        if query_type in self.query_handlers:
-            response = yield self.query_handlers[query_type](args)
-            defer.returnValue((200, response))
-        else:
-            defer.returnValue((404, "No handler for Query type '%s'"
-                % (query_type)
-            ))
+    def make_join(self, destination, context, user_id):
+        ret = yield self.transport_layer.make_join(
+            destination=destination,
+            context=context,
+            user_id=user_id,
+        )
+
+        pdu_dict = ret["event"]
+
+        logger.debug("Got response to make_join: %s", pdu_dict)
+
+        defer.returnValue(self.event_from_pdu_json(pdu_dict))
 
     @defer.inlineCallbacks
+    def send_join(self, destination, pdu):
+        _, content = yield self.transport_layer.send_join(
+            destination,
+            pdu.room_id,
+            pdu.event_id,
+            pdu.get_pdu_json(),
+        )
+
+        logger.debug("Got content: %s", content)
+
+        state = [
+            self.event_from_pdu_json(p, outlier=True)
+            for p in content.get("state", [])
+        ]
+
+        # FIXME: We probably want to do something with the auth_chain given
+        # to us
+
+        # auth_chain = [
+        #    Pdu(outlier=True, **p) for p in content.get("auth_chain", [])
+        # ]
+
+        defer.returnValue(state)
+
+    @defer.inlineCallbacks
+    def send_invite(self, destination, context, event_id, pdu):
+        code, content = yield self.transport_layer.send_invite(
+            destination=destination,
+            context=context,
+            event_id=event_id,
+            content=pdu.get_pdu_json(),
+        )
+
+        pdu_dict = content["event"]
+
+        logger.debug("Got response to send_invite: %s", pdu_dict)
+
+        defer.returnValue(self.event_from_pdu_json(pdu_dict))
+
     @log_function
-    def _get_persisted_pdu(self, pdu_id, pdu_origin):
+    def _get_persisted_pdu(self, origin, event_id):
         """ Get a PDU from the database with given origin and id.
 
         Returns:
             Deferred: Results in a `Pdu`.
         """
-        pdu_tuple = yield self.store.get_pdu(pdu_id, pdu_origin)
-
-        defer.returnValue(Pdu.from_pdu_tuple(pdu_tuple))
+        return self.handler.get_persisted_pdu(origin, event_id)
 
     def _transaction_from_pdus(self, pdu_list):
         """Returns a new Transaction containing the given PDUs suitable for
         transmission.
         """
-        pdus = [p.get_dict() for p in pdu_list]
+        pdus = [p.get_pdu_json() for p in pdu_list]
+        time_now = self._clock.time_msec()
         for p in pdus:
-            if "age_ts" in pdus:
-                p["age"] = int(self.clock.time_msec()) - p["age_ts"]
-
+            if "age_ts" in p:
+                age = time_now - p["age_ts"]
+                p.setdefault("unsigned", {})["age"] = int(age)
+                del p["age_ts"]
         return Transaction(
             origin=self.server_name,
             pdus=pdus,
-            origin_server_ts=int(self._clock.time_msec()),
+            origin_server_ts=int(time_now),
             destination=None,
         )
 
     @defer.inlineCallbacks
     @log_function
-    def _handle_new_pdu(self, pdu, backfilled=False):
+    def _handle_new_pdu(self, origin, pdu, backfilled=False):
         # We reprocess pdus when we have seen them only as outliers
-        existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin)
+        existing = yield self._get_persisted_pdu(origin, pdu.event_id)
 
         if existing and (not existing.outlier or pdu.outlier):
-            logger.debug("Already seen pdu %s %s", pdu.pdu_id, pdu.origin)
+            logger.debug("Already seen pdu %s", pdu.event_id)
             defer.returnValue({})
             return
 
+        state = None
+
         # Get missing pdus if necessary.
-        is_new = yield self.pdu_actions.is_new(pdu)
-        if is_new and not pdu.outlier:
+        if not pdu.outlier:
             # We only backfill backwards to the min depth.
-            min_depth = yield self.store.get_min_depth_for_context(pdu.context)
+            min_depth = yield self.handler.get_min_depth_for_context(
+                pdu.room_id
+            )
 
             if min_depth and pdu.depth > min_depth:
-                for pdu_id, origin in pdu.prev_pdus:
-                    exists = yield self._get_persisted_pdu(pdu_id, origin)
+                for event_id, hashes in pdu.prev_events:
+                    exists = yield self._get_persisted_pdu(origin, event_id)
 
                     if not exists:
-                        logger.debug("Requesting pdu %s %s", pdu_id, origin)
+                        logger.debug("Requesting pdu %s", event_id)
 
                         try:
                             yield self.get_pdu(
                                 pdu.origin,
-                                pdu_id=pdu_id,
-                                pdu_origin=origin
+                                event_id=event_id,
                             )
-                            logger.debug("Processed pdu %s %s", pdu_id, origin)
+                            logger.debug("Processed pdu %s", event_id)
                         except:
                             # TODO(erikj): Do some more intelligent retries.
                             logger.exception("Failed to get PDU")
-
-        # Persist the Pdu, but don't mark it as processed yet.
-        yield self.store.persist_event(pdu=pdu)
+            else:
+                # We need to get the state at this event, since we have reached
+                # a backward extremity edge.
+                state = yield self.get_state_for_context(
+                    origin, pdu.room_id, pdu.event_id,
+                )
 
         if not backfilled:
-            ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled)
+            ret = yield self.handler.on_receive_pdu(
+                pdu,
+                backfilled=backfilled,
+                state=state,
+            )
         else:
             ret = None
 
-        yield self.pdu_actions.mark_as_processed(pdu)
+        # yield self.pdu_actions.mark_as_processed(pdu)
 
         defer.returnValue(ret)
 
     def __str__(self):
         return "<ReplicationLayer(%s)>" % self.server_name
 
-
-class ReplicationHandler(object):
-    """This defines the methods that the :py:class:`.ReplicationLayer` will
-    use to communicate with the rest of the home server.
-    """
-    def on_receive_pdu(self, pdu):
-        raise NotImplementedError("on_receive_pdu")
+    def event_from_pdu_json(self, pdu_json, outlier=False):
+        #TODO: Check we have all the PDU keys here
+        pdu_json.setdefault("hashes", {})
+        pdu_json.setdefault("signatures", {})
+        return self.event_factory.create_event(
+            pdu_json["type"], outlier=outlier, **pdu_json
+        )
 
 
 class _TransactionQueue(object):
@@ -509,6 +612,9 @@ class _TransactionQueue(object):
         # destination -> list of tuple(edu, deferred)
         self.pending_edus_by_dest = {}
 
+        # destination -> list of tuple(failure, deferred)
+        self.pending_failures_by_dest = {}
+
         # HACK to get unique tx id
         self._next_txn_id = int(self._clock.time_msec())
 
@@ -562,6 +668,18 @@ class _TransactionQueue(object):
         return deferred
 
     @defer.inlineCallbacks
+    def enqueue_failure(self, failure, destination):
+        deferred = defer.Deferred()
+
+        self.pending_failures_by_dest.setdefault(
+            destination, []
+        ).append(
+            (failure, deferred)
+        )
+
+        yield deferred
+
+    @defer.inlineCallbacks
     @log_function
     def _attempt_new_transaction(self, destination):
         if destination in self.pending_transactions:
@@ -570,8 +688,9 @@ class _TransactionQueue(object):
         #  list of (pending_pdu, deferred, order)
         pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
         pending_edus = self.pending_edus_by_dest.pop(destination, [])
+        pending_failures = self.pending_failures_by_dest.pop(destination, [])
 
-        if not pending_pdus and not pending_edus:
+        if not pending_pdus and not pending_edus and not pending_failures:
             return
 
         logger.debug("TX [%s] Attempting new transaction", destination)
@@ -581,7 +700,11 @@ class _TransactionQueue(object):
 
         pdus = [x[0] for x in pending_pdus]
         edus = [x[0] for x in pending_edus]
-        deferreds = [x[1] for x in pending_pdus + pending_edus]
+        failures = [x[0].get_dict() for x in pending_failures]
+        deferreds = [
+            x[1]
+            for x in pending_pdus + pending_edus + pending_failures
+        ]
 
         try:
             self.pending_transactions[destination] = 1
@@ -589,12 +712,13 @@ class _TransactionQueue(object):
             logger.debug("TX [%s] Persisting transaction...", destination)
 
             transaction = Transaction.create_new(
-                origin_server_ts=self._clock.time_msec(),
+                origin_server_ts=int(self._clock.time_msec()),
                 transaction_id=str(self._next_txn_id),
                 origin=self.server_name,
                 destination=destination,
                 pdus=pdus,
                 edus=edus,
+                pdu_failures=failures,
             )
 
             self._next_txn_id += 1
@@ -614,7 +738,9 @@ class _TransactionQueue(object):
                 if "pdus" in data:
                     for p in data["pdus"]:
                         if "age_ts" in p:
-                            p["age"] = now - int(p["age_ts"])
+                            unsigned = p.setdefault("unsigned", {})
+                            unsigned["age"] = now - int(p["age_ts"])
+                            del p["age_ts"]
                 return data
 
             code, response = yield self.transport_layer.send_transaction(
diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py
index e7517cac4d..95c40c6c1b 100644
--- a/synapse/federation/transport.py
+++ b/synapse/federation/transport.py
@@ -72,7 +72,7 @@ class TransportLayer(object):
         self.received_handler = None
 
     @log_function
-    def get_context_state(self, destination, context):
+    def get_context_state(self, destination, context, event_id=None):
         """ Requests all state for a given context (i.e. room) from the
         given server.
 
@@ -89,54 +89,62 @@ class TransportLayer(object):
 
         subpath = "/state/%s/" % context
 
-        return self._do_request_for_transaction(destination, subpath)
+        args = {}
+        if event_id:
+            args["event_id"] = event_id
+
+        return self._do_request_for_transaction(
+            destination, subpath, args=args
+        )
 
     @log_function
-    def get_pdu(self, destination, pdu_origin, pdu_id):
+    def get_event(self, destination, event_id):
         """ Requests the pdu with give id and origin from the given server.
 
         Args:
             destination (str): The host name of the remote home server we want
                 to get the state from.
-            pdu_origin (str): The home server which created the PDU.
-            pdu_id (str): The id of the PDU being requested.
+            event_id (str): The id of the event being requested.
 
         Returns:
             Deferred: Results in a dict received from the remote homeserver.
         """
-        logger.debug("get_pdu dest=%s, pdu_origin=%s, pdu_id=%s",
-                     destination, pdu_origin, pdu_id)
+        logger.debug("get_pdu dest=%s, event_id=%s",
+                     destination, event_id)
 
-        subpath = "/pdu/%s/%s/" % (pdu_origin, pdu_id)
+        subpath = "/event/%s/" % (event_id, )
 
         return self._do_request_for_transaction(destination, subpath)
 
     @log_function
-    def backfill(self, dest, context, pdu_tuples, limit):
+    def backfill(self, dest, context, event_tuples, limit):
         """ Requests `limit` previous PDUs in a given context before list of
         PDUs.
 
         Args:
             dest (str)
             context (str)
-            pdu_tuples (list)
+            event_tuples (list)
             limt (int)
 
         Returns:
             Deferred: Results in a dict received from the remote homeserver.
         """
         logger.debug(
-            "backfill dest=%s, context=%s, pdu_tuples=%s, limit=%s",
-            dest, context, repr(pdu_tuples), str(limit)
+            "backfill dest=%s, context=%s, event_tuples=%s, limit=%s",
+            dest, context, repr(event_tuples), str(limit)
         )
 
-        if not pdu_tuples:
+        if not event_tuples:
+            # TODO: raise?
             return
 
-        subpath = "/backfill/%s/" % context
+        subpath = "/backfill/%s/" % (context,)
 
-        args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]}
-        args["limit"] = limit
+        args = {
+            "v": event_tuples,
+            "limit": limit,
+        }
 
         return self._do_request_for_transaction(
             dest,
@@ -198,6 +206,72 @@ class TransportLayer(object):
         defer.returnValue(response)
 
     @defer.inlineCallbacks
+    @log_function
+    def make_join(self, destination, context, user_id, retry_on_dns_fail=True):
+        path = PREFIX + "/make_join/%s/%s" % (context, user_id,)
+
+        response = yield self.client.get_json(
+            destination=destination,
+            path=path,
+            retry_on_dns_fail=retry_on_dns_fail,
+        )
+
+        defer.returnValue(response)
+
+    @defer.inlineCallbacks
+    @log_function
+    def send_join(self, destination, context, event_id, content):
+        path = PREFIX + "/send_join/%s/%s" % (
+            context,
+            event_id,
+        )
+
+        code, content = yield self.client.put_json(
+            destination=destination,
+            path=path,
+            data=content,
+        )
+
+        if not 200 <= code < 300:
+            raise RuntimeError("Got %d from send_join", code)
+
+        defer.returnValue(json.loads(content))
+
+    @defer.inlineCallbacks
+    @log_function
+    def send_invite(self, destination, context, event_id, content):
+        path = PREFIX + "/invite/%s/%s" % (
+            context,
+            event_id,
+        )
+
+        code, content = yield self.client.put_json(
+            destination=destination,
+            path=path,
+            data=content,
+        )
+
+        if not 200 <= code < 300:
+            raise RuntimeError("Got %d from send_invite", code)
+
+        defer.returnValue(json.loads(content))
+
+    @defer.inlineCallbacks
+    @log_function
+    def get_event_auth(self, destination, context, event_id):
+        path = PREFIX + "/event_auth/%s/%s" % (
+            context,
+            event_id,
+        )
+
+        response = yield self.client.get_json(
+            destination=destination,
+            path=path,
+        )
+
+        defer.returnValue(response)
+
+    @defer.inlineCallbacks
     def _authenticate_request(self, request):
         json_request = {
             "method": request.method,
@@ -210,7 +284,7 @@ class TransportLayer(object):
         origin = None
 
         if request.method == "PUT":
-            #TODO: Handle other method types? other content types?
+            # TODO: Handle other method types? other content types?
             try:
                 content_bytes = request.content.read()
                 content = json.loads(content_bytes)
@@ -222,11 +296,13 @@ class TransportLayer(object):
             try:
                 params = auth.split(" ")[1].split(",")
                 param_dict = dict(kv.split("=") for kv in params)
+
                 def strip_quotes(value):
                     if value.startswith("\""):
                         return value[1:-1]
                     else:
                         return value
+
                 origin = strip_quotes(param_dict["origin"])
                 key = strip_quotes(param_dict["key"])
                 sig = strip_quotes(param_dict["sig"])
@@ -247,7 +323,7 @@ class TransportLayer(object):
             if auth.startswith("X-Matrix"):
                 (origin, key, sig) = parse_auth_header(auth)
                 json_request["origin"] = origin
-                json_request["signatures"].setdefault(origin,{})[key] = sig
+                json_request["signatures"].setdefault(origin, {})[key] = sig
 
         if not json_request["signatures"]:
             raise SynapseError(
@@ -313,10 +389,10 @@ class TransportLayer(object):
         # data_id pair.
         self.server.register_path(
             "GET",
-            re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"),
+            re.compile("^" + PREFIX + "/event/([^/]*)/$"),
             self._with_authentication(
-                lambda origin, content, query, pdu_origin, pdu_id:
-                handler.on_pdu_request(pdu_origin, pdu_id)
+                lambda origin, content, query, event_id:
+                handler.on_pdu_request(origin, event_id)
             )
         )
 
@@ -326,7 +402,11 @@ class TransportLayer(object):
             re.compile("^" + PREFIX + "/state/([^/]*)/$"),
             self._with_authentication(
                 lambda origin, content, query, context:
-                handler.on_context_state_request(context)
+                handler.on_context_state_request(
+                    origin,
+                    context,
+                    query.get("event_id", [None])[0],
+                )
             )
         )
 
@@ -336,28 +416,63 @@ class TransportLayer(object):
             self._with_authentication(
                 lambda origin, content, query, context:
                 self._on_backfill_request(
-                    context, query["v"], query["limit"]
+                    origin, context, query["v"], query["limit"]
                 )
             )
         )
 
+        # This is when we receive a server-server Query
         self.server.register_path(
             "GET",
-            re.compile("^" + PREFIX + "/context/([^/]*)/$"),
+            re.compile("^" + PREFIX + "/query/([^/]*)$"),
             self._with_authentication(
-                lambda origin, content, query, context:
-                handler.on_context_pdus_request(context)
+                lambda origin, content, query, query_type:
+                handler.on_query_request(
+                    query_type, {k: v[0] for k, v in query.items()}
+                )
             )
         )
 
-        # This is when we receive a server-server Query
         self.server.register_path(
             "GET",
-            re.compile("^" + PREFIX + "/query/([^/]*)$"),
+            re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"),
             self._with_authentication(
-                lambda origin, content, query, query_type:
-                handler.on_query_request(
-                    query_type, {k: v[0] for k, v in query.items()}
+                lambda origin, content, query, context, user_id:
+                self._on_make_join_request(
+                    origin, content, query, context, user_id
+                )
+            )
+        )
+
+        self.server.register_path(
+            "GET",
+            re.compile("^" + PREFIX + "/event_auth/([^/]*)/([^/]*)$"),
+            self._with_authentication(
+                lambda origin, content, query, context, event_id:
+                handler.on_event_auth(
+                    origin, context, event_id,
+                )
+            )
+        )
+
+        self.server.register_path(
+            "PUT",
+            re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"),
+            self._with_authentication(
+                lambda origin, content, query, context, event_id:
+                self._on_send_join_request(
+                    origin, content, query,
+                )
+            )
+        )
+
+        self.server.register_path(
+            "PUT",
+            re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)$"),
+            self._with_authentication(
+                lambda origin, content, query, context, event_id:
+                self._on_invite_request(
+                    origin, content, query,
                 )
             )
         )
@@ -402,7 +517,8 @@ class TransportLayer(object):
             return
 
         try:
-            code, response = yield self.received_handler.on_incoming_transaction(
+            handler = self.received_handler
+            code, response = yield handler.on_incoming_transaction(
                 transaction_data
             )
         except:
@@ -440,7 +556,7 @@ class TransportLayer(object):
         defer.returnValue(data)
 
     @log_function
-    def _on_backfill_request(self, context, v_list, limits):
+    def _on_backfill_request(self, origin, context, v_list, limits):
         if not limits:
             return defer.succeed(
                 (400, {"error": "Did not include limit param"})
@@ -448,124 +564,34 @@ class TransportLayer(object):
 
         limit = int(limits[-1])
 
-        versions = [v.split(",", 1) for v in v_list]
+        versions = v_list
 
         return self.request_handler.on_backfill_request(
-            context, versions, limit)
-
-
-class TransportReceivedHandler(object):
-    """ Callbacks used when we receive a transaction
-    """
-    def on_incoming_transaction(self, transaction):
-        """ Called on PUT /send/<transaction_id>, or on response to a request
-        that we sent (e.g. a backfill request)
-
-        Args:
-            transaction (synapse.transaction.Transaction): The transaction that
-                was sent to us.
-
-        Returns:
-            twisted.internet.defer.Deferred: A deferred that gets fired when
-            the transaction has finished being processed.
-
-            The result should be a tuple in the form of
-            `(response_code, respond_body)`, where `response_body` is a python
-            dict that will get serialized to JSON.
-
-            On errors, the dict should have an `error` key with a brief message
-            of what went wrong.
-        """
-        pass
-
-
-class TransportRequestHandler(object):
-    """ Handlers used when someone want's data from us
-    """
-    def on_pull_request(self, versions):
-        """ Called on GET /pull/?v=...
-
-        This is hit when a remote home server wants to get all data
-        after a given transaction. Mainly used when a home server comes back
-        online and wants to get everything it has missed.
-
-        Args:
-            versions (list): A list of transaction_ids that should be used to
-                determine what PDUs the remote side have not yet seen.
-
-        Returns:
-            Deferred: Resultsin a tuple in the form of
-            `(response_code, respond_body)`, where `response_body` is a python
-            dict that will get serialized to JSON.
-
-            On errors, the dict should have an `error` key with a brief message
-            of what went wrong.
-        """
-        pass
-
-    def on_pdu_request(self, pdu_origin, pdu_id):
-        """ Called on GET /pdu/<pdu_origin>/<pdu_id>/
-
-        Someone wants a particular PDU. This PDU may or may not have originated
-        from us.
-
-        Args:
-            pdu_origin (str)
-            pdu_id (str)
-
-        Returns:
-            Deferred: Resultsin a tuple in the form of
-            `(response_code, respond_body)`, where `response_body` is a python
-            dict that will get serialized to JSON.
-
-            On errors, the dict should have an `error` key with a brief message
-            of what went wrong.
-        """
-        pass
-
-    def on_context_state_request(self, context):
-        """ Called on GET /state/<context>/
-
-        Gets hit when someone wants all the *current* state for a given
-        contexts.
-
-        Args:
-            context (str): The name of the context that we're interested in.
-
-        Returns:
-            twisted.internet.defer.Deferred: A deferred that gets fired when
-            the transaction has finished being processed.
-
-            The result should be a tuple in the form of
-            `(response_code, respond_body)`, where `response_body` is a python
-            dict that will get serialized to JSON.
-
-            On errors, the dict should have an `error` key with a brief message
-            of what went wrong.
-        """
-        pass
-
-    def on_backfill_request(self, context, versions, limit):
-        """ Called on GET /backfill/<context>/?v=...&limit=...
+            origin, context, versions, limit
+        )
 
-        Gets hit when we want to backfill backwards on a given context from
-        the given point.
+    @defer.inlineCallbacks
+    @log_function
+    def _on_make_join_request(self, origin, content, query, context, user_id):
+        content = yield self.request_handler.on_make_join_request(
+            context, user_id,
+        )
+        defer.returnValue((200, content))
 
-        Args:
-            context (str): The context to backfill
-            versions (list): A list of 2-tuples representing where to backfill
-                from, in the form `(pdu_id, origin)`
-            limit (int): How many pdus to return.
+    @defer.inlineCallbacks
+    @log_function
+    def _on_send_join_request(self, origin, content, query):
+        content = yield self.request_handler.on_send_join_request(
+            origin, content,
+        )
 
-        Returns:
-            Deferred: Results in a tuple in the form of
-            `(response_code, respond_body)`, where `response_body` is a python
-            dict that will get serialized to JSON.
+        defer.returnValue((200, content))
 
-            On errors, the dict should have an `error` key with a brief message
-            of what went wrong.
-        """
-        pass
+    @defer.inlineCallbacks
+    @log_function
+    def _on_invite_request(self, origin, content, query):
+        content = yield self.request_handler.on_invite_request(
+            origin, content,
+        )
 
-    def on_query_request(self):
-        """ Called on a GET /query/<query_type> request. """
+        defer.returnValue((200, content))
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index b2fb964180..6e708edb8c 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -20,126 +20,11 @@ server protocol.
 from synapse.util.jsonobject import JsonEncodedObject
 
 import logging
-import json
-import copy
 
 
 logger = logging.getLogger(__name__)
 
 
-class Pdu(JsonEncodedObject):
-    """ A Pdu represents a piece of data sent from a server and is associated
-    with a context.
-
-    A Pdu can be classified as "state". For a given context, we can efficiently
-    retrieve all state pdu's that haven't been clobbered. Clobbering is done
-    via a unique constraint on the tuple (context, pdu_type, state_key). A pdu
-    is a state pdu if `is_state` is True.
-
-    Example pdu::
-
-        {
-            "pdu_id": "78c",
-            "origin_server_ts": 1404835423000,
-            "origin": "bar",
-            "prev_ids": [
-                ["23b", "foo"],
-                ["56a", "bar"],
-            ],
-            "content": { ... },
-        }
-
-    """
-
-    valid_keys = [
-        "pdu_id",
-        "context",
-        "origin",
-        "origin_server_ts",
-        "pdu_type",
-        "destinations",
-        "transaction_id",
-        "prev_pdus",
-        "depth",
-        "content",
-        "outlier",
-        "is_state",  # Below this are keys valid only for State Pdus.
-        "state_key",
-        "power_level",
-        "prev_state_id",
-        "prev_state_origin",
-        "required_power_level",
-        "user_id",
-    ]
-
-    internal_keys = [
-        "destinations",
-        "transaction_id",
-        "outlier",
-    ]
-
-    required_keys = [
-        "pdu_id",
-        "context",
-        "origin",
-        "origin_server_ts",
-        "pdu_type",
-        "content",
-    ]
-
-    # TODO: We need to make this properly load content rather than
-    # just leaving it as a dict. (OR DO WE?!)
-
-    def __init__(self, destinations=[], is_state=False, prev_pdus=[],
-                 outlier=False, **kwargs):
-        if is_state:
-            for required_key in ["state_key"]:
-                if required_key not in kwargs:
-                    raise RuntimeError("Key %s is required" % required_key)
-
-        super(Pdu, self).__init__(
-            destinations=destinations,
-            is_state=is_state,
-            prev_pdus=prev_pdus,
-            outlier=outlier,
-            **kwargs
-        )
-
-    @classmethod
-    def from_pdu_tuple(cls, pdu_tuple):
-        """ Converts a PduTuple to a Pdu
-
-        Args:
-            pdu_tuple (synapse.persistence.transactions.PduTuple): The tuple to
-                convert
-
-        Returns:
-            Pdu
-        """
-        if pdu_tuple:
-            d = copy.copy(pdu_tuple.pdu_entry._asdict())
-            d["origin_server_ts"] = d.pop("ts")
-
-            d["content"] = json.loads(d["content_json"])
-            del d["content_json"]
-
-            args = {f: d[f] for f in cls.valid_keys if f in d}
-            if "unrecognized_keys" in d and d["unrecognized_keys"]:
-                args.update(json.loads(d["unrecognized_keys"]))
-
-            return Pdu(
-                prev_pdus=pdu_tuple.prev_pdu_list,
-                **args
-            )
-        else:
-            return None
-
-    def __str__(self):
-        return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
-
-    def __repr__(self):
-        return "<%s, %s>" % (self.__class__.__name__, repr(self.__dict__))
-
 
 class Edu(JsonEncodedObject):
     """ An Edu represents a piece of data sent from one homeserver to another.
@@ -160,11 +45,10 @@ class Edu(JsonEncodedObject):
         "edu_type",
     ]
 
-#    TODO: SYN-103: Remove "origin" and "destination" keys.
-#    internal_keys = [
-#        "origin",
-#        "destination",
-#    ]
+    internal_keys = [
+        "origin",
+        "destination",
+    ]
 
 
 class Transaction(JsonEncodedObject):
@@ -193,6 +77,7 @@ class Transaction(JsonEncodedObject):
         "edus",
         "transaction_id",
         "destination",
+        "pdu_failures",
     ]
 
     internal_keys = [
@@ -229,7 +114,9 @@ class Transaction(JsonEncodedObject):
         transaction_id and origin_server_ts keys.
         """
         if "origin_server_ts" not in kwargs:
-            raise KeyError("Require 'origin_server_ts' to construct a Transaction")
+            raise KeyError(
+                "Require 'origin_server_ts' to construct a Transaction"
+            )
         if "transaction_id" not in kwargs:
             raise KeyError(
                 "Require 'transaction_id' to construct a Transaction"
@@ -238,9 +125,6 @@ class Transaction(JsonEncodedObject):
         for p in pdus:
             p.transaction_id = kwargs["transaction_id"]
 
-        kwargs["pdus"] = [p.get_dict() for p in pdus]
+        kwargs["pdus"] = [p.get_pdu_json() for p in pdus]
 
         return Transaction(**kwargs)
-
-
-