summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/__init__.py8
-rw-r--r--synapse/federation/federation_base.py185
-rw-r--r--synapse/federation/federation_client.py167
-rw-r--r--synapse/federation/federation_server.py544
-rw-r--r--synapse/federation/replication.py73
-rw-r--r--synapse/federation/send_queue.py384
-rw-r--r--synapse/federation/transaction_queue.py473
-rw-r--r--synapse/federation/transport/client.py556
-rw-r--r--synapse/federation/transport/server.py661
9 files changed, 2278 insertions, 773 deletions
diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py
index 2e32d245ba..f5f0bdfca3 100644
--- a/synapse/federation/__init__.py
+++ b/synapse/federation/__init__.py
@@ -15,11 +15,3 @@
 
 """ This package includes all the federation specific logic.
 """
-
-from .replication import ReplicationLayer
-
-
-def initialize_http_replication(hs):
-    transport = hs.get_federation_transport_client()
-
-    return ReplicationLayer(hs, transport)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 2339cc9034..4cc98a3fe8 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -12,28 +12,31 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
 
+import six
 
-from twisted.internet import defer
-
-from synapse.events.utils import prune_event
-
+from synapse.api.constants import MAX_DEPTH
+from synapse.api.errors import SynapseError, Codes
 from synapse.crypto.event_signing import check_event_content_hash
-
-from synapse.api.errors import SynapseError
-
-from synapse.util import unwrapFirstError
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-
-import logging
-
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
+from synapse.http.servlet import assert_params_in_request
+from synapse.util import unwrapFirstError, logcontext
+from twisted.internet import defer
 
 logger = logging.getLogger(__name__)
 
 
 class FederationBase(object):
     def __init__(self, hs):
-        pass
+        self.hs = hs
+
+        self.server_name = hs.hostname
+        self.keyring = hs.get_keyring()
+        self.spam_checker = hs.get_spam_checker()
+        self.store = hs.get_datastore()
+        self._clock = hs.get_clock()
 
     @defer.inlineCallbacks
     def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
@@ -57,56 +60,52 @@ class FederationBase(object):
         """
         deferreds = self._check_sigs_and_hashes(pdus)
 
-        def callback(pdu):
-            return pdu
-
-        def errback(failure, pdu):
-            failure.trap(SynapseError)
-            return None
+        @defer.inlineCallbacks
+        def handle_check_result(pdu, deferred):
+            try:
+                res = yield logcontext.make_deferred_yieldable(deferred)
+            except SynapseError:
+                res = None
 
-        def try_local_db(res, pdu):
             if not res:
                 # Check local db.
-                return self.store.get_event(
+                res = yield self.store.get_event(
                     pdu.event_id,
                     allow_rejected=True,
                     allow_none=True,
                 )
-            return res
 
-        def try_remote(res, pdu):
             if not res and pdu.origin != origin:
-                return self.get_pdu(
-                    destinations=[pdu.origin],
-                    event_id=pdu.event_id,
-                    outlier=outlier,
-                    timeout=10000,
-                ).addErrback(lambda e: None)
-            return res
-
-        def warn(res, pdu):
+                try:
+                    res = yield self.get_pdu(
+                        destinations=[pdu.origin],
+                        event_id=pdu.event_id,
+                        outlier=outlier,
+                        timeout=10000,
+                    )
+                except SynapseError:
+                    pass
+
             if not res:
                 logger.warn(
                     "Failed to find copy of %s with valid signature",
                     pdu.event_id,
                 )
-            return res
 
-        for pdu, deferred in zip(pdus, deferreds):
-            deferred.addCallbacks(
-                callback, errback, errbackArgs=[pdu]
-            ).addCallback(
-                try_local_db, pdu
-            ).addCallback(
-                try_remote, pdu
-            ).addCallback(
-                warn, pdu
-            )
+            defer.returnValue(res)
 
-        valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
-            deferreds,
-            consumeErrors=True
-        )).addErrback(unwrapFirstError)
+        handle = logcontext.preserve_fn(handle_check_result)
+        deferreds2 = [
+            handle(pdu, deferred)
+            for pdu, deferred in zip(pdus, deferreds)
+        ]
+
+        valid_pdus = yield logcontext.make_deferred_yieldable(
+            defer.gatherResults(
+                deferreds2,
+                consumeErrors=True,
+            )
+        ).addErrback(unwrapFirstError)
 
         if include_none:
             defer.returnValue(valid_pdus)
@@ -114,15 +113,24 @@ class FederationBase(object):
             defer.returnValue([p for p in valid_pdus if p])
 
     def _check_sigs_and_hash(self, pdu):
-        return self._check_sigs_and_hashes([pdu])[0]
+        return logcontext.make_deferred_yieldable(
+            self._check_sigs_and_hashes([pdu])[0],
+        )
 
     def _check_sigs_and_hashes(self, pdus):
-        """Throws a SynapseError if a PDU does not have the correct
-        signatures.
+        """Checks that each of the received events is correctly signed by the
+        sending server.
+
+        Args:
+            pdus (list[FrozenEvent]): the events to be checked
 
         Returns:
-            FrozenEvent: Either the given event or it redacted if it failed the
-            content hash check.
+            list[Deferred]: for each input event, a deferred which:
+              * returns the original event if the checks pass
+              * returns a redacted version of the event (if the signature
+                matched but the hash did not)
+              * throws a SynapseError if the signature check failed.
+            The deferreds run their callbacks in the sentinel logcontext.
         """
 
         redacted_pdus = [
@@ -130,26 +138,38 @@ class FederationBase(object):
             for pdu in pdus
         ]
 
-        deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([
+        deferreds = self.keyring.verify_json_objects_for_server([
             (p.origin, p.get_pdu_json())
             for p in redacted_pdus
         ])
 
+        ctx = logcontext.LoggingContext.current_context()
+
         def callback(_, pdu, redacted):
-            if not check_event_content_hash(pdu):
-                logger.warn(
-                    "Event content has been tampered, redacting %s: %s",
-                    pdu.event_id, pdu.get_pdu_json()
-                )
-                return redacted
-            return pdu
+            with logcontext.PreserveLoggingContext(ctx):
+                if not check_event_content_hash(pdu):
+                    logger.warn(
+                        "Event content has been tampered, redacting %s: %s",
+                        pdu.event_id, pdu.get_pdu_json()
+                    )
+                    return redacted
+
+                if self.spam_checker.check_event_for_spam(pdu):
+                    logger.warn(
+                        "Event contains spam, redacting %s: %s",
+                        pdu.event_id, pdu.get_pdu_json()
+                    )
+                    return redacted
+
+                return pdu
 
         def errback(failure, pdu):
             failure.trap(SynapseError)
-            logger.warn(
-                "Signature check failed for %s",
-                pdu.event_id,
-            )
+            with logcontext.PreserveLoggingContext(ctx):
+                logger.warn(
+                    "Signature check failed for %s",
+                    pdu.event_id,
+                )
             return failure
 
         for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
@@ -160,3 +180,40 @@ class FederationBase(object):
             )
 
         return deferreds
+
+
+def event_from_pdu_json(pdu_json, outlier=False):
+    """Construct a FrozenEvent from an event json received over federation
+
+    Args:
+        pdu_json (object): pdu as received over federation
+        outlier (bool): True to mark this event as an outlier
+
+    Returns:
+        FrozenEvent
+
+    Raises:
+        SynapseError: if the pdu is missing required fields or is otherwise
+            not a valid matrix event
+    """
+    # we could probably enforce a bunch of other fields here (room_id, sender,
+    # origin, etc etc)
+    assert_params_in_request(pdu_json, ('event_id', 'type', 'depth'))
+
+    depth = pdu_json['depth']
+    if not isinstance(depth, six.integer_types):
+        raise SynapseError(400, "Depth %r not an intger" % (depth, ),
+                           Codes.BAD_JSON)
+
+    if depth < 0:
+        raise SynapseError(400, "Depth too small", Codes.BAD_JSON)
+    elif depth > MAX_DEPTH:
+        raise SynapseError(400, "Depth too large", Codes.BAD_JSON)
+
+    event = FrozenEvent(
+        pdu_json
+    )
+
+    event.internal_metadata.outlier = outlier
+
+    return event
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index b5bcfd705a..6163f7c466 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -14,28 +14,30 @@
 # limitations under the License.
 
 
+import copy
+import itertools
+import logging
+import random
+
+from six.moves import range
+
 from twisted.internet import defer
 
-from .federation_base import FederationBase
 from synapse.api.constants import Membership
-
 from synapse.api.errors import (
-    CodeMessageException, HttpResponseException, SynapseError,
+    CodeMessageException, HttpResponseException, SynapseError, FederationDeniedError
 )
-from synapse.util import unwrapFirstError
+from synapse.events import builder
+from synapse.federation.federation_base import (
+    FederationBase,
+    event_from_pdu_json,
+)
+import synapse.metrics
+from synapse.util import logcontext, unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
 from synapse.util.logutils import log_function
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-from synapse.events import FrozenEvent, builder
-import synapse.metrics
-
-from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
-
-import copy
-import itertools
-import logging
-import random
-
+from synapse.util.retryutils import NotRetryingDestination
 
 logger = logging.getLogger(__name__)
 
@@ -58,6 +60,7 @@ class FederationClient(FederationBase):
             self._clear_tried_cache, 60 * 1000,
         )
         self.state = hs.get_state_handler()
+        self.transport_layer = hs.get_federation_transport_client()
 
     def _clear_tried_cache(self):
         """Clear pdu_destination_tried cache"""
@@ -88,7 +91,7 @@ class FederationClient(FederationBase):
 
     @log_function
     def make_query(self, destination, query_type, args,
-                   retry_on_dns_fail=False):
+                   retry_on_dns_fail=False, ignore_backoff=False):
         """Sends a federation Query to a remote homeserver of the given type
         and arguments.
 
@@ -98,6 +101,8 @@ class FederationClient(FederationBase):
                 handler name used in register_query_handler().
             args (dict): Mapping of strings to strings containing the details
                 of the query request.
+            ignore_backoff (bool): true to ignore the historical backoff data
+                and try the request anyway.
 
         Returns:
             a Deferred which will eventually yield a JSON object from the
@@ -106,7 +111,8 @@ class FederationClient(FederationBase):
         sent_queries_counter.inc(query_type)
 
         return self.transport_layer.make_query(
-            destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
+            destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
+            ignore_backoff=ignore_backoff,
         )
 
     @log_function
@@ -181,15 +187,15 @@ class FederationClient(FederationBase):
         logger.debug("backfill transaction_data=%s", repr(transaction_data))
 
         pdus = [
-            self.event_from_pdu_json(p, outlier=False)
+            event_from_pdu_json(p, outlier=False)
             for p in transaction_data["pdus"]
         ]
 
         # FIXME: We should handle signature failures more gracefully.
-        pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
+        pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
             self._check_sigs_and_hashes(pdus),
             consumeErrors=True,
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
 
         defer.returnValue(pdus)
 
@@ -206,8 +212,7 @@ class FederationClient(FederationBase):
 
         Args:
             destinations (list): Which home servers to query
-            pdu_origin (str): The home server that originally sent the pdu.
-            event_id (str)
+            event_id (str): event to fetch
             outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
                 it's from an arbitary point in the context as opposed to part
                 of the current block of PDUs. Defaults to `False`
@@ -235,31 +240,24 @@ class FederationClient(FederationBase):
                 continue
 
             try:
-                limiter = yield get_retry_limiter(
-                    destination,
-                    self._clock,
-                    self.store,
+                transaction_data = yield self.transport_layer.get_event(
+                    destination, event_id, timeout=timeout,
                 )
 
-                with limiter:
-                    transaction_data = yield self.transport_layer.get_event(
-                        destination, event_id, timeout=timeout,
-                    )
-
-                    logger.debug("transaction_data %r", transaction_data)
+                logger.debug("transaction_data %r", transaction_data)
 
-                    pdu_list = [
-                        self.event_from_pdu_json(p, outlier=outlier)
-                        for p in transaction_data["pdus"]
-                    ]
+                pdu_list = [
+                    event_from_pdu_json(p, outlier=outlier)
+                    for p in transaction_data["pdus"]
+                ]
 
-                    if pdu_list and pdu_list[0]:
-                        pdu = pdu_list[0]
+                if pdu_list and pdu_list[0]:
+                    pdu = pdu_list[0]
 
-                        # Check signatures are correct.
-                        signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
+                    # Check signatures are correct.
+                    signed_pdu = yield self._check_sigs_and_hash(pdu)
 
-                        break
+                    break
 
                 pdu_attempts[destination] = now
 
@@ -271,6 +269,9 @@ class FederationClient(FederationBase):
             except NotRetryingDestination as e:
                 logger.info(e.message)
                 continue
+            except FederationDeniedError as e:
+                logger.info(e.message)
+                continue
             except Exception as e:
                 pdu_attempts[destination] = now
 
@@ -341,11 +342,11 @@ class FederationClient(FederationBase):
         )
 
         pdus = [
-            self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
+            event_from_pdu_json(p, outlier=True) for p in result["pdus"]
         ]
 
         auth_chain = [
-            self.event_from_pdu_json(p, outlier=True)
+            event_from_pdu_json(p, outlier=True)
             for p in result.get("auth_chain", [])
         ]
 
@@ -395,7 +396,7 @@ class FederationClient(FederationBase):
             seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
             signed_events = seen_events.values()
         else:
-            seen_events = yield self.store.have_events(event_ids)
+            seen_events = yield self.store.have_seen_events(event_ids)
             signed_events = []
 
         failed_to_fetch = set()
@@ -414,18 +415,19 @@ class FederationClient(FederationBase):
 
         batch_size = 20
         missing_events = list(missing_events)
-        for i in xrange(0, len(missing_events), batch_size):
+        for i in range(0, len(missing_events), batch_size):
             batch = set(missing_events[i:i + batch_size])
 
             deferreds = [
-                preserve_fn(self.get_pdu)(
+                run_in_background(
+                    self.get_pdu,
                     destinations=random_server_list(),
                     event_id=e_id,
                 )
                 for e_id in batch
             ]
 
-            res = yield preserve_context_over_deferred(
+            res = yield make_deferred_yieldable(
                 defer.DeferredList(deferreds, consumeErrors=True)
             )
             for success, result in res:
@@ -446,7 +448,7 @@ class FederationClient(FederationBase):
         )
 
         auth_chain = [
-            self.event_from_pdu_json(p, outlier=True)
+            event_from_pdu_json(p, outlier=True)
             for p in res["auth_chain"]
         ]
 
@@ -479,8 +481,13 @@ class FederationClient(FederationBase):
             content (object): Any additional data to put into the content field
                 of the event.
         Return:
-            A tuple of (origin (str), event (object)) where origin is the remote
-            homeserver which generated the event.
+            Deferred: resolves to a tuple of (origin (str), event (object))
+            where origin is the remote homeserver which generated the event.
+
+            Fails with a ``CodeMessageException`` if the chosen remote server
+            returns a 300/400 code.
+
+            Fails with a ``RuntimeError`` if no servers were reachable.
         """
         valid_memberships = {Membership.JOIN, Membership.LEAVE}
         if membership not in valid_memberships:
@@ -533,6 +540,27 @@ class FederationClient(FederationBase):
 
     @defer.inlineCallbacks
     def send_join(self, destinations, pdu):
+        """Sends a join event to one of a list of homeservers.
+
+        Doing so will cause the remote server to add the event to the graph,
+        and send the event out to the rest of the federation.
+
+        Args:
+            destinations (str): Candidate homeservers which are probably
+                participating in the room.
+            pdu (BaseEvent): event to be sent
+
+        Return:
+            Deferred: resolves to a dict with members ``origin`` (a string
+            giving the serer the event was sent to, ``state`` (?) and
+            ``auth_chain``.
+
+            Fails with a ``CodeMessageException`` if the chosen remote server
+            returns a 300/400 code.
+
+            Fails with a ``RuntimeError`` if no servers were reachable.
+        """
+
         for destination in destinations:
             if destination == self.server_name:
                 continue
@@ -549,12 +577,12 @@ class FederationClient(FederationBase):
                 logger.debug("Got content: %s", content)
 
                 state = [
-                    self.event_from_pdu_json(p, outlier=True)
+                    event_from_pdu_json(p, outlier=True)
                     for p in content.get("state", [])
                 ]
 
                 auth_chain = [
-                    self.event_from_pdu_json(p, outlier=True)
+                    event_from_pdu_json(p, outlier=True)
                     for p in content.get("auth_chain", [])
                 ]
 
@@ -629,7 +657,7 @@ class FederationClient(FederationBase):
 
         logger.debug("Got response to send_invite: %s", pdu_dict)
 
-        pdu = self.event_from_pdu_json(pdu_dict)
+        pdu = event_from_pdu_json(pdu_dict)
 
         # Check signatures are correct.
         pdu = yield self._check_sigs_and_hash(pdu)
@@ -640,6 +668,26 @@ class FederationClient(FederationBase):
 
     @defer.inlineCallbacks
     def send_leave(self, destinations, pdu):
+        """Sends a leave event to one of a list of homeservers.
+
+        Doing so will cause the remote server to add the event to the graph,
+        and send the event out to the rest of the federation.
+
+        This is mostly useful to reject received invites.
+
+        Args:
+            destinations (str): Candidate homeservers which are probably
+                participating in the room.
+            pdu (BaseEvent): event to be sent
+
+        Return:
+            Deferred: resolves to None.
+
+            Fails with a ``CodeMessageException`` if the chosen remote server
+            returns a non-200 code.
+
+            Fails with a ``RuntimeError`` if no servers were reachable.
+        """
         for destination in destinations:
             if destination == self.server_name:
                 continue
@@ -699,7 +747,7 @@ class FederationClient(FederationBase):
         )
 
         auth_chain = [
-            self.event_from_pdu_json(e)
+            event_from_pdu_json(e)
             for e in content["auth_chain"]
         ]
 
@@ -747,7 +795,7 @@ class FederationClient(FederationBase):
             )
 
             events = [
-                self.event_from_pdu_json(e)
+                event_from_pdu_json(e)
                 for e in content.get("events", [])
             ]
 
@@ -764,15 +812,6 @@ class FederationClient(FederationBase):
 
         defer.returnValue(signed_events)
 
-    def event_from_pdu_json(self, pdu_json, outlier=False):
-        event = FrozenEvent(
-            pdu_json
-        )
-
-        event.internal_metadata.outlier = outlier
-
-        return event
-
     @defer.inlineCallbacks
     def forward_third_party_invite(self, destinations, room_id, event_dict):
         for destination in destinations:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index e922b7ff4a..247ddc89d5 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,27 +13,31 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
 
-
+import simplejson as json
 from twisted.internet import defer
 
-from .federation_base import FederationBase
-from .units import Transaction, Edu
+from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError
+from synapse.crypto.event_signing import compute_event_signature
+from synapse.federation.federation_base import (
+    FederationBase,
+    event_from_pdu_json,
+)
 
-from synapse.util.async import Linearizer
-from synapse.util.logutils import log_function
-from synapse.util.caches.response_cache import ResponseCache
-from synapse.events import FrozenEvent
-from synapse.types import get_domain_from_id
+from synapse.federation.persistence import TransactionActions
+from synapse.federation.units import Edu, Transaction
 import synapse.metrics
+from synapse.types import get_domain_from_id
+from synapse.util import async
+from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.logutils import log_function
 
-from synapse.api.errors import AuthError, FederationError, SynapseError
-
-from synapse.crypto.event_signing import compute_event_signature
-
-import simplejson as json
-import logging
+from six import iteritems
 
+# when processing incoming transactions, we try to handle multiple rooms in
+# parallel, up to this limit.
+TRANSACTION_CONCURRENCY_LIMIT = 10
 
 logger = logging.getLogger(__name__)
 
@@ -51,49 +56,18 @@ class FederationServer(FederationBase):
         super(FederationServer, self).__init__(hs)
 
         self.auth = hs.get_auth()
+        self.handler = hs.get_handlers().federation_handler
 
-        self._room_pdu_linearizer = Linearizer("fed_room_pdu")
-        self._server_linearizer = Linearizer("fed_server")
-
-        # We cache responses to state queries, as they take a while and often
-        # come in waves.
-        self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
-
-    def set_handler(self, handler):
-        """Sets the handler that the replication layer will use to communicate
-        receipt of new PDUs from other home servers. The required methods are
-        documented on :py:class:`.ReplicationHandler`.
-        """
-        self.handler = handler
-
-    def register_edu_handler(self, edu_type, handler):
-        if edu_type in self.edu_handlers:
-            raise KeyError("Already have an EDU handler for %s" % (edu_type,))
-
-        self.edu_handlers[edu_type] = handler
-
-    def register_query_handler(self, query_type, handler):
-        """Sets the handler callable that will be used to handle an incoming
-        federation Query of the given type.
-
-        Args:
-            query_type (str): Category name of the query, which should match
-                the string used by make_query.
-            handler (callable): Invoked to handle incoming queries of this type
+        self._server_linearizer = async.Linearizer("fed_server")
+        self._transaction_linearizer = async.Linearizer("fed_txn_handler")
 
-        handler is invoked as:
-            result = handler(args)
+        self.transaction_actions = TransactionActions(self.store)
 
-        where 'args' is a dict mapping strings to strings of the query
-          arguments. It should return a Deferred that will eventually yield an
-          object to encode as JSON.
-        """
-        if query_type in self.query_handlers:
-            raise KeyError(
-                "Already have a Query handler for %s" % (query_type,)
-            )
+        self.registry = hs.get_federation_registry()
 
-        self.query_handlers[query_type] = handler
+        # We cache responses to state queries, as they take a while and often
+        # come in waves.
+        self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
 
     @defer.inlineCallbacks
     @log_function
@@ -110,25 +84,41 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     @log_function
     def on_incoming_transaction(self, transaction_data):
+        # keep this as early as possible to make the calculated origin ts as
+        # accurate as possible.
+        request_time = self._clock.time_msec()
+
         transaction = Transaction(**transaction_data)
 
-        received_pdus_counter.inc_by(len(transaction.pdus))
+        if not transaction.transaction_id:
+            raise Exception("Transaction missing transaction_id")
+        if not transaction.origin:
+            raise Exception("Transaction missing origin")
 
-        for p in transaction.pdus:
-            if "unsigned" in p:
-                unsigned = p["unsigned"]
-                if "age" in unsigned:
-                    p["age"] = unsigned["age"]
-            if "age" in p:
-                p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
-                del p["age"]
+        logger.debug("[%s] Got transaction", transaction.transaction_id)
 
-        pdu_list = [
-            self.event_from_pdu_json(p) for p in transaction.pdus
-        ]
+        # use a linearizer to ensure that we don't process the same transaction
+        # multiple times in parallel.
+        with (yield self._transaction_linearizer.queue(
+                (transaction.origin, transaction.transaction_id),
+        )):
+            result = yield self._handle_incoming_transaction(
+                transaction, request_time,
+            )
 
-        logger.debug("[%s] Got transaction", transaction.transaction_id)
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def _handle_incoming_transaction(self, transaction, request_time):
+        """ Process an incoming transaction and return the HTTP response
 
+        Args:
+            transaction (Transaction): incoming transaction
+            request_time (int): timestamp that the HTTP request arrived at
+
+        Returns:
+            Deferred[(int, object)]: http response code and body
+        """
         response = yield self.transaction_actions.have_responded(transaction)
 
         if response:
@@ -141,38 +131,49 @@ class FederationServer(FederationBase):
 
         logger.debug("[%s] Transaction is new", transaction.transaction_id)
 
-        results = []
-
-        for pdu in pdu_list:
-            # check that it's actually being sent from a valid destination to
-            # workaround bug #1753 in 0.18.5 and 0.18.6
-            if transaction.origin != get_domain_from_id(pdu.event_id):
-                if not (
-                    pdu.type == 'm.room.member' and
-                    pdu.content and
-                    pdu.content.get("membership", None) == 'join' and
-                    self.hs.is_mine_id(pdu.state_key)
-                ):
-                    logger.info(
-                        "Discarding PDU %s from invalid origin %s",
-                        pdu.event_id, transaction.origin
-                    )
-                    continue
-                else:
-                    logger.info(
-                        "Accepting join PDU %s from %s",
-                        pdu.event_id, transaction.origin
-                    )
+        received_pdus_counter.inc_by(len(transaction.pdus))
+
+        pdus_by_room = {}
+
+        for p in transaction.pdus:
+            if "unsigned" in p:
+                unsigned = p["unsigned"]
+                if "age" in unsigned:
+                    p["age"] = unsigned["age"]
+            if "age" in p:
+                p["age_ts"] = request_time - int(p["age"])
+                del p["age"]
 
-            try:
-                yield self._handle_new_pdu(transaction.origin, pdu)
-                results.append({})
-            except FederationError as e:
-                self.send_failure(e, transaction.origin)
-                results.append({"error": str(e)})
-            except Exception as e:
-                results.append({"error": str(e)})
-                logger.exception("Failed to handle PDU")
+            event = event_from_pdu_json(p)
+            room_id = event.room_id
+            pdus_by_room.setdefault(room_id, []).append(event)
+
+        pdu_results = {}
+
+        # we can process different rooms in parallel (which is useful if they
+        # require callouts to other servers to fetch missing events), but
+        # impose a limit to avoid going too crazy with ram/cpu.
+        @defer.inlineCallbacks
+        def process_pdus_for_room(room_id):
+            logger.debug("Processing PDUs for %s", room_id)
+            for pdu in pdus_by_room[room_id]:
+                event_id = pdu.event_id
+                try:
+                    yield self._handle_received_pdu(
+                        transaction.origin, pdu
+                    )
+                    pdu_results[event_id] = {}
+                except FederationError as e:
+                    logger.warn("Error handling PDU %s: %s", event_id, e)
+                    pdu_results[event_id] = {"error": str(e)}
+                except Exception as e:
+                    pdu_results[event_id] = {"error": str(e)}
+                    logger.exception("Failed to handle PDU %s", event_id)
+
+        yield async.concurrently_execute(
+            process_pdus_for_room, pdus_by_room.keys(),
+            TRANSACTION_CONCURRENCY_LIMIT,
+        )
 
         if hasattr(transaction, "edus"):
             for edu in (Edu(**x) for x in transaction.edus):
@@ -182,17 +183,16 @@ class FederationServer(FederationBase):
                     edu.content
                 )
 
-            for failure in getattr(transaction, "pdu_failures", []):
-                logger.info("Got failure %r", failure)
-
-        logger.debug("Returning: %s", str(results))
+        pdu_failures = getattr(transaction, "pdu_failures", [])
+        for failure in pdu_failures:
+            logger.info("Got failure %r", failure)
 
         response = {
-            "pdus": dict(zip(
-                (p.event_id for p in pdu_list), results
-            )),
+            "pdus": pdu_results,
         }
 
+        logger.debug("Returning: %s", str(response))
+
         yield self.transaction_actions.set_response(
             transaction,
             200, response
@@ -202,16 +202,7 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     def received_edu(self, origin, edu_type, content):
         received_edus_counter.inc()
-
-        if edu_type in self.edu_handlers:
-            try:
-                yield self.edu_handlers[edu_type](origin, content)
-            except SynapseError as e:
-                logger.info("Failed to handle edu %r: %r", edu_type, e)
-            except Exception as e:
-                logger.exception("Failed to handle edu %r", edu_type)
-        else:
-            logger.warn("Received EDU of type %s with no handler", edu_type)
+        yield self.registry.on_edu(edu_type, origin, content)
 
     @defer.inlineCallbacks
     @log_function
@@ -223,15 +214,17 @@ class FederationServer(FederationBase):
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
-        result = self._state_resp_cache.get((room_id, event_id))
-        if not result:
-            with (yield self._server_linearizer.queue((origin, room_id))):
-                resp = yield self._state_resp_cache.set(
-                    (room_id, event_id),
-                    self._on_context_state_request_compute(room_id, event_id)
-                )
-        else:
-            resp = yield result
+        # we grab the linearizer to protect ourselves from servers which hammer
+        # us. In theory we might already have the response to this query
+        # in the cache so we could return it without waiting for the linearizer
+        # - but that's non-trivial to get right, and anyway somewhat defeats
+        # the point of the linearizer.
+        with (yield self._server_linearizer.queue((origin, room_id))):
+            resp = yield self._state_resp_cache.wrap(
+                (room_id, event_id),
+                self._on_context_state_request_compute,
+                room_id, event_id,
+            )
 
         defer.returnValue((200, resp))
 
@@ -300,14 +293,8 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     def on_query_request(self, query_type, args):
         received_queries_counter.inc(query_type)
-
-        if query_type in self.query_handlers:
-            response = yield self.query_handlers[query_type](args)
-            defer.returnValue((200, response))
-        else:
-            defer.returnValue(
-                (404, "No handler for Query type '%s'" % (query_type,))
-            )
+        resp = yield self.registry.on_query(query_type, args)
+        defer.returnValue((200, resp))
 
     @defer.inlineCallbacks
     def on_make_join_request(self, room_id, user_id):
@@ -317,7 +304,7 @@ class FederationServer(FederationBase):
 
     @defer.inlineCallbacks
     def on_invite_request(self, origin, content):
-        pdu = self.event_from_pdu_json(content)
+        pdu = event_from_pdu_json(content)
         ret_pdu = yield self.handler.on_invite_request(origin, pdu)
         time_now = self._clock.time_msec()
         defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@@ -325,7 +312,7 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     def on_send_join_request(self, origin, content):
         logger.debug("on_send_join_request: content: %s", content)
-        pdu = self.event_from_pdu_json(content)
+        pdu = event_from_pdu_json(content)
         logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
         res_pdus = yield self.handler.on_send_join_request(origin, pdu)
         time_now = self._clock.time_msec()
@@ -345,7 +332,7 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     def on_send_leave_request(self, origin, content):
         logger.debug("on_send_leave_request: content: %s", content)
-        pdu = self.event_from_pdu_json(content)
+        pdu = event_from_pdu_json(content)
         logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
         yield self.handler.on_send_leave_request(origin, pdu)
         defer.returnValue((200, {}))
@@ -382,7 +369,7 @@ class FederationServer(FederationBase):
         """
         with (yield self._server_linearizer.queue((origin, room_id))):
             auth_chain = [
-                self.event_from_pdu_json(e)
+                event_from_pdu_json(e)
                 for e in content["auth_chain"]
             ]
 
@@ -437,6 +424,16 @@ class FederationServer(FederationBase):
                         key_id: json.loads(json_bytes)
                     }
 
+        logger.info(
+            "Claimed one-time-keys: %s",
+            ",".join((
+                "%s for %s:%s" % (key_id, user_id, device_id)
+                for user_id, user_keys in iteritems(json_result)
+                for device_id, device_keys in iteritems(user_keys)
+                for key_id, _ in iteritems(device_keys)
+            )),
+        )
+
         defer.returnValue({"one_time_keys": json_result})
 
     @defer.inlineCallbacks
@@ -497,26 +494,59 @@ class FederationServer(FederationBase):
         )
 
     @defer.inlineCallbacks
-    @log_function
-    def _handle_new_pdu(self, origin, pdu, get_missing=True):
+    def _handle_received_pdu(self, origin, pdu):
+        """ Process a PDU received in a federation /send/ transaction.
+
+        If the event is invalid, then this method throws a FederationError.
+        (The error will then be logged and sent back to the sender (which
+        probably won't do anything with it), and other events in the
+        transaction will be processed as normal).
+
+        It is likely that we'll then receive other events which refer to
+        this rejected_event in their prev_events, etc.  When that happens,
+        we'll attempt to fetch the rejected event again, which will presumably
+        fail, so those second-generation events will also get rejected.
+
+        Eventually, we get to the point where there are more than 10 events
+        between any new events and the original rejected event. Since we
+        only try to backfill 10 events deep on received pdu, we then accept the
+        new event, possibly introducing a discontinuity in the DAG, with new
+        forward extremities, so normal service is approximately returned,
+        until we try to backfill across the discontinuity.
 
-        # We reprocess pdus when we have seen them only as outliers
-        existing = yield self._get_persisted_pdu(
-            origin, pdu.event_id, do_auth=False
-        )
+        Args:
+            origin (str): server which sent the pdu
+            pdu (FrozenEvent): received pdu
 
-        # FIXME: Currently we fetch an event again when we already have it
-        # if it has been marked as an outlier.
+        Returns (Deferred): completes with None
 
-        already_seen = (
-            existing and (
-                not existing.internal_metadata.is_outlier()
-                or pdu.internal_metadata.is_outlier()
-            )
-        )
-        if already_seen:
-            logger.debug("Already seen pdu %s", pdu.event_id)
-            return
+        Raises: FederationError if the signatures / hash do not match, or
+            if the event was unacceptable for any other reason (eg, too large,
+            too many prev_events, couldn't find the prev_events)
+        """
+        # check that it's actually being sent from a valid destination to
+        # workaround bug #1753 in 0.18.5 and 0.18.6
+        if origin != get_domain_from_id(pdu.event_id):
+            # We continue to accept join events from any server; this is
+            # necessary for the federation join dance to work correctly.
+            # (When we join over federation, the "helper" server is
+            # responsible for sending out the join event, rather than the
+            # origin. See bug #1893).
+            if not (
+                pdu.type == 'm.room.member' and
+                pdu.content and
+                pdu.content.get("membership", None) == 'join'
+            ):
+                logger.info(
+                    "Discarding PDU %s from invalid origin %s",
+                    pdu.event_id, origin
+                )
+                return
+            else:
+                logger.info(
+                    "Accepting join PDU %s from %s",
+                    pdu.event_id, origin
+                )
 
         # Check signature.
         try:
@@ -529,156 +559,11 @@ class FederationServer(FederationBase):
                 affected=pdu.event_id,
             )
 
-        state = None
-
-        auth_chain = []
-
-        have_seen = yield self.store.have_events(
-            [ev for ev, _ in pdu.prev_events]
-        )
-
-        fetch_state = False
-
-        # Get missing pdus if necessary.
-        if not pdu.internal_metadata.is_outlier():
-            # We only backfill backwards to the min depth.
-            min_depth = yield self.handler.get_min_depth_for_context(
-                pdu.room_id
-            )
-
-            logger.debug(
-                "_handle_new_pdu min_depth for %s: %d",
-                pdu.room_id, min_depth
-            )
-
-            prevs = {e_id for e_id, _ in pdu.prev_events}
-            seen = set(have_seen.keys())
-
-            if min_depth and pdu.depth < min_depth:
-                # This is so that we don't notify the user about this
-                # message, to work around the fact that some events will
-                # reference really really old events we really don't want to
-                # send to the clients.
-                pdu.internal_metadata.outlier = True
-            elif min_depth and pdu.depth > min_depth:
-                if get_missing and prevs - seen:
-                    # If we're missing stuff, ensure we only fetch stuff one
-                    # at a time.
-                    logger.info(
-                        "Acquiring lock for room %r to fetch %d missing events: %r...",
-                        pdu.room_id, len(prevs - seen), list(prevs - seen)[:5],
-                    )
-                    with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
-                        logger.info(
-                            "Acquired lock for room %r to fetch %d missing events",
-                            pdu.room_id, len(prevs - seen),
-                        )
-
-                        # We recalculate seen, since it may have changed.
-                        have_seen = yield self.store.have_events(prevs)
-                        seen = set(have_seen.keys())
-
-                        if prevs - seen:
-                            latest = yield self.store.get_latest_event_ids_in_room(
-                                pdu.room_id
-                            )
-
-                            # We add the prev events that we have seen to the latest
-                            # list to ensure the remote server doesn't give them to us
-                            latest = set(latest)
-                            latest |= seen
-
-                            logger.info(
-                                "Missing %d events for room %r: %r...",
-                                len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
-                            )
-
-                            # XXX: we set timeout to 10s to help workaround
-                            # https://github.com/matrix-org/synapse/issues/1733.
-                            # The reason is to avoid holding the linearizer lock
-                            # whilst processing inbound /send transactions, causing
-                            # FDs to stack up and block other inbound transactions
-                            # which empirically can currently take up to 30 minutes.
-                            #
-                            # N.B. this explicitly disables retry attempts.
-                            #
-                            # N.B. this also increases our chances of falling back to
-                            # fetching fresh state for the room if the missing event
-                            # can't be found, which slightly reduces our security.
-                            # it may also increase our DAG extremity count for the room,
-                            # causing additional state resolution?  See #1760.
-                            # However, fetching state doesn't hold the linearizer lock
-                            # apparently.
-                            #
-                            # see https://github.com/matrix-org/synapse/pull/1744
-
-                            missing_events = yield self.get_missing_events(
-                                origin,
-                                pdu.room_id,
-                                earliest_events_ids=list(latest),
-                                latest_events=[pdu],
-                                limit=10,
-                                min_depth=min_depth,
-                                timeout=10000,
-                            )
-
-                            # We want to sort these by depth so we process them and
-                            # tell clients about them in order.
-                            missing_events.sort(key=lambda x: x.depth)
-
-                            for e in missing_events:
-                                yield self._handle_new_pdu(
-                                    origin,
-                                    e,
-                                    get_missing=False
-                                )
-
-                            have_seen = yield self.store.have_events(
-                                [ev for ev, _ in pdu.prev_events]
-                            )
-
-            prevs = {e_id for e_id, _ in pdu.prev_events}
-            seen = set(have_seen.keys())
-            if prevs - seen:
-                logger.info(
-                    "Still missing %d events for room %r: %r...",
-                    len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
-                )
-                fetch_state = True
-
-        if fetch_state:
-            # We need to get the state at this event, since we haven't
-            # processed all the prev events.
-            logger.debug(
-                "_handle_new_pdu getting state for %s",
-                pdu.room_id
-            )
-            try:
-                state, auth_chain = yield self.get_state_for_room(
-                    origin, pdu.room_id, pdu.event_id,
-                )
-            except:
-                logger.exception("Failed to get state for event: %s", pdu.event_id)
-
-        yield self.handler.on_receive_pdu(
-            origin,
-            pdu,
-            state=state,
-            auth_chain=auth_chain,
-        )
+        yield self.handler.on_receive_pdu(origin, pdu, get_missing=True)
 
     def __str__(self):
         return "<ReplicationLayer(%s)>" % self.server_name
 
-    def event_from_pdu_json(self, pdu_json, outlier=False):
-        event = FrozenEvent(
-            pdu_json
-        )
-
-        event.internal_metadata.outlier = outlier
-
-        return event
-
     @defer.inlineCallbacks
     def exchange_third_party_invite(
             self,
@@ -701,3 +586,66 @@ class FederationServer(FederationBase):
             origin, room_id, event_dict
         )
         defer.returnValue(ret)
+
+
+class FederationHandlerRegistry(object):
+    """Allows classes to register themselves as handlers for a given EDU or
+    query type for incoming federation traffic.
+    """
+    def __init__(self):
+        self.edu_handlers = {}
+        self.query_handlers = {}
+
+    def register_edu_handler(self, edu_type, handler):
+        """Sets the handler callable that will be used to handle an incoming
+        federation EDU of the given type.
+
+        Args:
+            edu_type (str): The type of the incoming EDU to register handler for
+            handler (Callable[[str, dict]]): A callable invoked on incoming EDU
+                of the given type. The arguments are the origin server name and
+                the EDU contents.
+        """
+        if edu_type in self.edu_handlers:
+            raise KeyError("Already have an EDU handler for %s" % (edu_type,))
+
+        self.edu_handlers[edu_type] = handler
+
+    def register_query_handler(self, query_type, handler):
+        """Sets the handler callable that will be used to handle an incoming
+        federation query of the given type.
+
+        Args:
+            query_type (str): Category name of the query, which should match
+                the string used by make_query.
+            handler (Callable[[dict], Deferred[dict]]): Invoked to handle
+                incoming queries of this type. The return will be yielded
+                on and the result used as the response to the query request.
+        """
+        if query_type in self.query_handlers:
+            raise KeyError(
+                "Already have a Query handler for %s" % (query_type,)
+            )
+
+        self.query_handlers[query_type] = handler
+
+    @defer.inlineCallbacks
+    def on_edu(self, edu_type, origin, content):
+        handler = self.edu_handlers.get(edu_type)
+        if not handler:
+            logger.warn("No handler registered for EDU type %s", edu_type)
+
+        try:
+            yield handler(origin, content)
+        except SynapseError as e:
+            logger.info("Failed to handle edu %r: %r", edu_type, e)
+        except Exception as e:
+            logger.exception("Failed to handle edu %r", edu_type)
+
+    def on_query(self, query_type, args):
+        handler = self.query_handlers.get(query_type)
+        if not handler:
+            logger.warn("No handler registered for query type %s", query_type)
+            raise NotFoundError("No handler for Query type '%s'" % (query_type,))
+
+        return handler(args)
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
deleted file mode 100644
index 62d865ec4b..0000000000
--- a/synapse/federation/replication.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""This layer is responsible for replicating with remote home servers using
-a given transport.
-"""
-
-from .federation_client import FederationClient
-from .federation_server import FederationServer
-
-from .persistence import TransactionActions
-
-import logging
-
-
-logger = logging.getLogger(__name__)
-
-
-class ReplicationLayer(FederationClient, FederationServer):
-    """This layer is responsible for replicating with remote home servers over
-    the given transport. I.e., does the sending and receiving of PDUs to
-    remote home servers.
-
-    The layer communicates with the rest of the server via a registered
-    ReplicationHandler.
-
-    In more detail, the layer:
-        * Receives incoming data and processes it into transactions and pdus.
-        * Fetches any PDUs it thinks it might have missed.
-        * Keeps the current state for contexts up to date by applying the
-          suitable conflict resolution.
-        * Sends outgoing pdus wrapped in transactions.
-        * Fills out the references to previous pdus/transactions appropriately
-          for outgoing data.
-    """
-
-    def __init__(self, hs, transport_layer):
-        self.server_name = hs.hostname
-
-        self.keyring = hs.get_keyring()
-
-        self.transport_layer = transport_layer
-
-        self.federation_client = self
-
-        self.store = hs.get_datastore()
-
-        self.handler = None
-        self.edu_handlers = {}
-        self.query_handlers = {}
-
-        self._clock = hs.get_clock()
-
-        self.transaction_actions = TransactionActions(self.store)
-
-        self.hs = hs
-
-        super(ReplicationLayer, self).__init__(hs)
-
-    def __str__(self):
-        return "<ReplicationLayer(%s)>" % self.server_name
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 5c9f7a86f0..0f0c687b37 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -31,21 +31,21 @@ Events are replicated via a separate events stream.
 
 from .units import Edu
 
+from synapse.storage.presence import UserPresenceState
 from synapse.util.metrics import Measure
 import synapse.metrics
 
 from blist import sorteddict
-import ujson
+from collections import namedtuple
 
+import logging
 
-metrics = synapse.metrics.get_metrics_for(__name__)
+from six import itervalues, iteritems
+
+logger = logging.getLogger(__name__)
 
 
-PRESENCE_TYPE = "p"
-KEYED_EDU_TYPE = "k"
-EDU_TYPE = "e"
-FAILURE_TYPE = "f"
-DEVICE_MESSAGE_TYPE = "d"
+metrics = synapse.metrics.get_metrics_for(__name__)
 
 
 class FederationRemoteSendQueue(object):
@@ -54,18 +54,20 @@ class FederationRemoteSendQueue(object):
     def __init__(self, hs):
         self.server_name = hs.hostname
         self.clock = hs.get_clock()
+        self.notifier = hs.get_notifier()
+        self.is_mine_id = hs.is_mine_id
 
-        self.presence_map = {}
-        self.presence_changed = sorteddict()
+        self.presence_map = {}  # Pending presence map user_id -> UserPresenceState
+        self.presence_changed = sorteddict()  # Stream position -> user_id
 
-        self.keyed_edu = {}
-        self.keyed_edu_changed = sorteddict()
+        self.keyed_edu = {}  # (destination, key) -> EDU
+        self.keyed_edu_changed = sorteddict()  # stream position -> (destination, key)
 
-        self.edus = sorteddict()
+        self.edus = sorteddict()  # stream position -> Edu
 
-        self.failures = sorteddict()
+        self.failures = sorteddict()  # stream position -> (destination, Failure)
 
-        self.device_messages = sorteddict()
+        self.device_messages = sorteddict()  # stream position -> destination
 
         self.pos = 1
         self.pos_time = sorteddict()
@@ -121,7 +123,9 @@ class FederationRemoteSendQueue(object):
                 del self.presence_changed[key]
 
             user_ids = set(
-                user_id for uids in self.presence_changed.values() for _, user_id in uids
+                user_id
+                for uids in itervalues(self.presence_changed)
+                for user_id in uids
             )
 
             to_del = [
@@ -186,37 +190,50 @@ class FederationRemoteSendQueue(object):
         else:
             self.edus[pos] = edu
 
-    def send_presence(self, destination, states):
-        """As per TransactionQueue"""
+        self.notifier.on_new_replication_data()
+
+    def send_presence(self, states):
+        """As per TransactionQueue
+
+        Args:
+            states (list(UserPresenceState))
+        """
         pos = self._next_pos()
 
-        self.presence_map.update({
-            state.user_id: state
-            for state in states
-        })
+        # We only want to send presence for our own users, so lets always just
+        # filter here just in case.
+        local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
 
-        self.presence_changed[pos] = [
-            (destination, state.user_id) for state in states
-        ]
+        self.presence_map.update({state.user_id: state for state in local_states})
+        self.presence_changed[pos] = [state.user_id for state in local_states]
+
+        self.notifier.on_new_replication_data()
 
     def send_failure(self, failure, destination):
         """As per TransactionQueue"""
         pos = self._next_pos()
 
         self.failures[pos] = (destination, str(failure))
+        self.notifier.on_new_replication_data()
 
     def send_device_messages(self, destination):
         """As per TransactionQueue"""
         pos = self._next_pos()
         self.device_messages[pos] = destination
+        self.notifier.on_new_replication_data()
 
     def get_current_token(self):
         return self.pos - 1
 
-    def get_replication_rows(self, token, limit, federation_ack=None):
-        """
+    def federation_ack(self, token):
+        self._clear_queue_before_pos(token)
+
+    def get_replication_rows(self, from_token, to_token, limit, federation_ack=None):
+        """Get rows to be sent over federation between the two tokens
+
         Args:
-            token (int)
+            from_token (int)
+            to_token(int)
             limit (int)
             federation_ack (int): Optional. The position where the worker is
                 explicitly acknowledged it has handled. Allows us to drop
@@ -225,9 +242,11 @@ class FederationRemoteSendQueue(object):
         # TODO: Handle limit.
 
         # To handle restarts where we wrap around
-        if token > self.pos:
-            token = -1
+        if from_token > self.pos:
+            from_token = -1
 
+        # list of tuple(int, BaseFederationRow), where the first is the position
+        # of the federation stream.
         rows = []
 
         # There should be only one reader, so lets delete everything its
@@ -237,62 +256,295 @@ class FederationRemoteSendQueue(object):
 
         # Fetch changed presence
         keys = self.presence_changed.keys()
-        i = keys.bisect_right(token)
-        dest_user_ids = set(
-            (pos, dest_user_id)
-            for pos in keys[i:]
-            for dest_user_id in self.presence_changed[pos]
-        )
+        i = keys.bisect_right(from_token)
+        j = keys.bisect_right(to_token) + 1
+        dest_user_ids = [
+            (pos, user_id)
+            for pos in keys[i:j]
+            for user_id in self.presence_changed[pos]
+        ]
 
-        for (key, (dest, user_id)) in dest_user_ids:
-            rows.append((key, PRESENCE_TYPE, ujson.dumps({
-                "destination": dest,
-                "state": self.presence_map[user_id].as_dict(),
-            })))
+        for (key, user_id) in dest_user_ids:
+            rows.append((key, PresenceRow(
+                state=self.presence_map[user_id],
+            )))
 
         # Fetch changes keyed edus
         keys = self.keyed_edu_changed.keys()
-        i = keys.bisect_right(token)
-        keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:])
-
-        for (pos, (destination, edu_key)) in keyed_edus:
-            rows.append(
-                (pos, KEYED_EDU_TYPE, ujson.dumps({
-                    "key": edu_key,
-                    "edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(),
-                }))
-            )
+        i = keys.bisect_right(from_token)
+        j = keys.bisect_right(to_token) + 1
+        # We purposefully clobber based on the key here, python dict comprehensions
+        # always use the last value, so this will correctly point to the last
+        # stream position.
+        keyed_edus = {self.keyed_edu_changed[k]: k for k in keys[i:j]}
+
+        for ((destination, edu_key), pos) in iteritems(keyed_edus):
+            rows.append((pos, KeyedEduRow(
+                key=edu_key,
+                edu=self.keyed_edu[(destination, edu_key)],
+            )))
 
         # Fetch changed edus
         keys = self.edus.keys()
-        i = keys.bisect_right(token)
-        edus = set((k, self.edus[k]) for k in keys[i:])
+        i = keys.bisect_right(from_token)
+        j = keys.bisect_right(to_token) + 1
+        edus = ((k, self.edus[k]) for k in keys[i:j])
 
         for (pos, edu) in edus:
-            rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict())))
+            rows.append((pos, EduRow(edu)))
 
         # Fetch changed failures
         keys = self.failures.keys()
-        i = keys.bisect_right(token)
-        failures = set((k, self.failures[k]) for k in keys[i:])
+        i = keys.bisect_right(from_token)
+        j = keys.bisect_right(to_token) + 1
+        failures = ((k, self.failures[k]) for k in keys[i:j])
 
         for (pos, (destination, failure)) in failures:
-            rows.append((pos, FAILURE_TYPE, ujson.dumps({
-                "destination": destination,
-                "failure": failure,
-            })))
+            rows.append((pos, FailureRow(
+                destination=destination,
+                failure=failure,
+            )))
 
         # Fetch changed device messages
         keys = self.device_messages.keys()
-        i = keys.bisect_right(token)
-        device_messages = set((k, self.device_messages[k]) for k in keys[i:])
+        i = keys.bisect_right(from_token)
+        j = keys.bisect_right(to_token) + 1
+        device_messages = {self.device_messages[k]: k for k in keys[i:j]}
 
-        for (pos, destination) in device_messages:
-            rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({
-                "destination": destination,
-            })))
+        for (destination, pos) in iteritems(device_messages):
+            rows.append((pos, DeviceRow(
+                destination=destination,
+            )))
 
         # Sort rows based on pos
         rows.sort()
 
-        return rows
+        return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
+
+
+class BaseFederationRow(object):
+    """Base class for rows to be sent in the federation stream.
+
+    Specifies how to identify, serialize and deserialize the different types.
+    """
+
+    TypeId = None  # Unique string that ids the type. Must be overriden in sub classes.
+
+    @staticmethod
+    def from_data(data):
+        """Parse the data from the federation stream into a row.
+
+        Args:
+            data: The value of ``data`` from FederationStreamRow.data, type
+                depends on the type of stream
+        """
+        raise NotImplementedError()
+
+    def to_data(self):
+        """Serialize this row to be sent over the federation stream.
+
+        Returns:
+            The value to be sent in FederationStreamRow.data. The type depends
+            on the type of stream.
+        """
+        raise NotImplementedError()
+
+    def add_to_buffer(self, buff):
+        """Add this row to the appropriate field in the buffer ready for this
+        to be sent over federation.
+
+        We use a buffer so that we can batch up events that have come in at
+        the same time and send them all at once.
+
+        Args:
+            buff (BufferedToSend)
+        """
+        raise NotImplementedError()
+
+
+class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", (
+    "state",  # UserPresenceState
+))):
+    TypeId = "p"
+
+    @staticmethod
+    def from_data(data):
+        return PresenceRow(
+            state=UserPresenceState.from_dict(data)
+        )
+
+    def to_data(self):
+        return self.state.as_dict()
+
+    def add_to_buffer(self, buff):
+        buff.presence.append(self.state)
+
+
+class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", (
+    "key",  # tuple(str) - the edu key passed to send_edu
+    "edu",  # Edu
+))):
+    """Streams EDUs that have an associated key that is ued to clobber. For example,
+    typing EDUs clobber based on room_id.
+    """
+
+    TypeId = "k"
+
+    @staticmethod
+    def from_data(data):
+        return KeyedEduRow(
+            key=tuple(data["key"]),
+            edu=Edu(**data["edu"]),
+        )
+
+    def to_data(self):
+        return {
+            "key": self.key,
+            "edu": self.edu.get_internal_dict(),
+        }
+
+    def add_to_buffer(self, buff):
+        buff.keyed_edus.setdefault(
+            self.edu.destination, {}
+        )[self.key] = self.edu
+
+
+class EduRow(BaseFederationRow, namedtuple("EduRow", (
+    "edu",  # Edu
+))):
+    """Streams EDUs that don't have keys. See KeyedEduRow
+    """
+    TypeId = "e"
+
+    @staticmethod
+    def from_data(data):
+        return EduRow(Edu(**data))
+
+    def to_data(self):
+        return self.edu.get_internal_dict()
+
+    def add_to_buffer(self, buff):
+        buff.edus.setdefault(self.edu.destination, []).append(self.edu)
+
+
+class FailureRow(BaseFederationRow, namedtuple("FailureRow", (
+    "destination",  # str
+    "failure",
+))):
+    """Streams failures to a remote server. Failures are issued when there was
+    something wrong with a transaction the remote sent us, e.g. it included
+    an event that was invalid.
+    """
+
+    TypeId = "f"
+
+    @staticmethod
+    def from_data(data):
+        return FailureRow(
+            destination=data["destination"],
+            failure=data["failure"],
+        )
+
+    def to_data(self):
+        return {
+            "destination": self.destination,
+            "failure": self.failure,
+        }
+
+    def add_to_buffer(self, buff):
+        buff.failures.setdefault(self.destination, []).append(self.failure)
+
+
+class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
+    "destination",  # str
+))):
+    """Streams the fact that either a) there is pending to device messages for
+    users on the remote, or b) a local users device has changed and needs to
+    be sent to the remote.
+    """
+    TypeId = "d"
+
+    @staticmethod
+    def from_data(data):
+        return DeviceRow(destination=data["destination"])
+
+    def to_data(self):
+        return {"destination": self.destination}
+
+    def add_to_buffer(self, buff):
+        buff.device_destinations.add(self.destination)
+
+
+TypeToRow = {
+    Row.TypeId: Row
+    for Row in (
+        PresenceRow,
+        KeyedEduRow,
+        EduRow,
+        FailureRow,
+        DeviceRow,
+    )
+}
+
+
+ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
+    "presence",  # list(UserPresenceState)
+    "keyed_edus",  # dict of destination -> { key -> Edu }
+    "edus",  # dict of destination -> [Edu]
+    "failures",  # dict of destination -> [failures]
+    "device_destinations",  # set of destinations
+))
+
+
+def process_rows_for_federation(transaction_queue, rows):
+    """Parse a list of rows from the federation stream and put them in the
+    transaction queue ready for sending to the relevant homeservers.
+
+    Args:
+        transaction_queue (TransactionQueue)
+        rows (list(synapse.replication.tcp.streams.FederationStreamRow))
+    """
+
+    # The federation stream contains a bunch of different types of
+    # rows that need to be handled differently. We parse the rows, put
+    # them into the appropriate collection and then send them off.
+
+    buff = ParsedFederationStreamData(
+        presence=[],
+        keyed_edus={},
+        edus={},
+        failures={},
+        device_destinations=set(),
+    )
+
+    # Parse the rows in the stream and add to the buffer
+    for row in rows:
+        if row.type not in TypeToRow:
+            logger.error("Unrecognized federation row type %r", row.type)
+            continue
+
+        RowType = TypeToRow[row.type]
+        parsed_row = RowType.from_data(row.data)
+        parsed_row.add_to_buffer(buff)
+
+    if buff.presence:
+        transaction_queue.send_presence(buff.presence)
+
+    for destination, edu_map in iteritems(buff.keyed_edus):
+        for key, edu in edu_map.items():
+            transaction_queue.send_edu(
+                edu.destination, edu.edu_type, edu.content, key=key,
+            )
+
+    for destination, edu_list in iteritems(buff.edus):
+        for edu in edu_list:
+            transaction_queue.send_edu(
+                edu.destination, edu.edu_type, edu.content, key=None,
+            )
+
+    for destination, failure_list in iteritems(buff.failures):
+        for failure in failure_list:
+            transaction_queue.send_failure(destination, failure)
+
+    for destination in buff.device_destinations:
+        transaction_queue.send_device_messages(destination)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index bb3d9258a6..ded2b1871a 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -12,22 +12,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import datetime
 
 from twisted.internet import defer
 
 from .persistence import TransactionActions
 from .units import Transaction, Edu
 
-from synapse.api.errors import HttpResponseException
+from synapse.api.errors import HttpResponseException, FederationDeniedError
+from synapse.util import logcontext, PreserveLoggingContext
 from synapse.util.async import run_on_reactor
-from synapse.util.logcontext import preserve_context_over_fn
-from synapse.util.retryutils import (
-    get_retry_limiter, NotRetryingDestination,
-)
+from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
 from synapse.util.metrics import measure_func
-from synapse.types import get_domain_from_id
-from synapse.handlers.presence import format_user_presence_state
+from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
 import synapse.metrics
 
 import logging
@@ -43,6 +40,10 @@ sent_pdus_destination_dist = client_metrics.register_distribution(
 )
 sent_edus_counter = client_metrics.register_counter("sent_edus")
 
+sent_transactions_counter = client_metrics.register_counter("sent_transactions")
+
+events_processed_counter = client_metrics.register_counter("events_processed")
+
 
 class TransactionQueue(object):
     """This class makes sure we only have one transaction in flight at
@@ -79,8 +80,18 @@ class TransactionQueue(object):
         # destination -> list of tuple(edu, deferred)
         self.pending_edus_by_dest = edus = {}
 
-        # Presence needs to be separate as we send single aggragate EDUs
+        # Map of user_id -> UserPresenceState for all the pending presence
+        # to be sent out by user_id. Entries here get processed and put in
+        # pending_presence_by_dest
+        self.pending_presence = {}
+
+        # Map of destination -> user_id -> UserPresenceState of pending presence
+        # to be sent to each destinations
         self.pending_presence_by_dest = presence = {}
+
+        # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
+        # based on their key (e.g. typing events by room_id)
+        # Map of destination -> (edu_type, key) -> Edu
         self.pending_edus_keyed_by_dest = edus_keyed = {}
 
         metrics.register_callback(
@@ -99,7 +110,12 @@ class TransactionQueue(object):
         # destination -> list of tuple(failure, deferred)
         self.pending_failures_by_dest = {}
 
+        # destination -> stream_id of last successfully sent to-device message.
+        # NB: may be a long or an int.
         self.last_device_stream_id_by_dest = {}
+
+        # destination -> stream_id of last successfully sent device list
+        # update.
         self.last_device_list_stream_id_by_dest = {}
 
         # HACK to get unique tx id
@@ -110,6 +126,8 @@ class TransactionQueue(object):
         self._is_processing = False
         self._last_poked_id = -1
 
+        self._processing_pending_presence = False
+
     def can_send_to(self, destination):
         """Can we send messages to the given server?
 
@@ -130,7 +148,6 @@ class TransactionQueue(object):
         else:
             return not destination.startswith("localhost")
 
-    @defer.inlineCallbacks
     def notify_new_events(self, current_id):
         """This gets called when we have some new events we might want to
         send out to other servers.
@@ -140,12 +157,19 @@ class TransactionQueue(object):
         if self._is_processing:
             return
 
+        # fire off a processing loop in the background. It's likely it will
+        # outlast the current request, so run it in the sentinel logcontext.
+        with PreserveLoggingContext():
+            self._process_event_queue_loop()
+
+    @defer.inlineCallbacks
+    def _process_event_queue_loop(self):
         try:
             self._is_processing = True
             while True:
                 last_token = yield self.store.get_federation_out_pos("events")
                 next_token, events = yield self.store.get_all_new_events_stream(
-                    last_token, self._last_poked_id, limit=20,
+                    last_token, self._last_poked_id, limit=100,
                 )
 
                 logger.debug("Handling %s -> %s", last_token, next_token)
@@ -153,28 +177,35 @@ class TransactionQueue(object):
                 if not events and next_token >= self._last_poked_id:
                     break
 
-                for event in events:
+                @defer.inlineCallbacks
+                def handle_event(event):
                     # Only send events for this server.
                     send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of()
                     is_mine = self.is_mine_id(event.event_id)
                     if not is_mine and send_on_behalf_of is None:
-                        continue
-
-                    # Get the state from before the event.
-                    # We need to make sure that this is the state from before
-                    # the event and not from after it.
-                    # Otherwise if the last member on a server in a room is
-                    # banned then it won't receive the event because it won't
-                    # be in the room after the ban.
-                    users_in_room = yield self.state.get_current_user_in_room(
-                        event.room_id, latest_event_ids=[
-                            prev_id for prev_id, _ in event.prev_events
-                        ],
-                    )
+                        return
+
+                    try:
+                        # Get the state from before the event.
+                        # We need to make sure that this is the state from before
+                        # the event and not from after it.
+                        # Otherwise if the last member on a server in a room is
+                        # banned then it won't receive the event because it won't
+                        # be in the room after the ban.
+                        destinations = yield self.state.get_current_hosts_in_room(
+                            event.room_id, latest_event_ids=[
+                                prev_id for prev_id, _ in event.prev_events
+                            ],
+                        )
+                    except Exception:
+                        logger.exception(
+                            "Failed to calculate hosts in room for event: %s",
+                            event.event_id,
+                        )
+                        return
+
+                    destinations = set(destinations)
 
-                    destinations = set(
-                        get_domain_from_id(user_id) for user_id in users_in_room
-                    )
                     if send_on_behalf_of is not None:
                         # If we are sending the event on behalf of another server
                         # then it already has the event and there is no reason to
@@ -185,10 +216,44 @@ class TransactionQueue(object):
 
                     self._send_pdu(event, destinations)
 
+                @defer.inlineCallbacks
+                def handle_room_events(events):
+                    for event in events:
+                        yield handle_event(event)
+
+                events_by_room = {}
+                for event in events:
+                    events_by_room.setdefault(event.room_id, []).append(event)
+
+                yield logcontext.make_deferred_yieldable(defer.gatherResults(
+                    [
+                        logcontext.run_in_background(handle_room_events, evs)
+                        for evs in events_by_room.itervalues()
+                    ],
+                    consumeErrors=True
+                ))
+
                 yield self.store.update_federation_out_pos(
                     "events", next_token
                 )
 
+                if events:
+                    now = self.clock.time_msec()
+                    ts = yield self.store.get_received_ts(events[-1].event_id)
+
+                    synapse.metrics.event_processing_lag.set(
+                        now - ts, "federation_sender",
+                    )
+                    synapse.metrics.event_processing_last_ts.set(
+                        ts, "federation_sender",
+                    )
+
+                events_processed_counter.inc_by(len(events))
+
+                synapse.metrics.event_processing_positions.set(
+                    next_token, "federation_sender",
+                )
+
         finally:
             self._is_processing = False
 
@@ -217,21 +282,75 @@ class TransactionQueue(object):
                 (pdu, order)
             )
 
-            preserve_context_over_fn(
-                self._attempt_new_transaction, destination
-            )
+            self._attempt_new_transaction(destination)
 
-    def send_presence(self, destination, states):
-        if not self.can_send_to(destination):
-            return
+    @logcontext.preserve_fn  # the caller should not yield on this
+    @defer.inlineCallbacks
+    def send_presence(self, states):
+        """Send the new presence states to the appropriate destinations.
+
+        This actually queues up the presence states ready for sending and
+        triggers a background task to process them and send out the transactions.
+
+        Args:
+            states (list(UserPresenceState))
+        """
 
-        self.pending_presence_by_dest.setdefault(destination, {}).update({
+        # First we queue up the new presence by user ID, so multiple presence
+        # updates in quick successtion are correctly handled
+        # We only want to send presence for our own users, so lets always just
+        # filter here just in case.
+        self.pending_presence.update({
             state.user_id: state for state in states
+            if self.is_mine_id(state.user_id)
         })
 
-        preserve_context_over_fn(
-            self._attempt_new_transaction, destination
-        )
+        # We then handle the new pending presence in batches, first figuring
+        # out the destinations we need to send each state to and then poking it
+        # to attempt a new transaction. We linearize this so that we don't
+        # accidentally mess up the ordering and send multiple presence updates
+        # in the wrong order
+        if self._processing_pending_presence:
+            return
+
+        self._processing_pending_presence = True
+        try:
+            while True:
+                states_map = self.pending_presence
+                self.pending_presence = {}
+
+                if not states_map:
+                    break
+
+                yield self._process_presence_inner(states_map.values())
+        except Exception:
+            logger.exception("Error sending presence states to servers")
+        finally:
+            self._processing_pending_presence = False
+
+    @measure_func("txnqueue._process_presence")
+    @defer.inlineCallbacks
+    def _process_presence_inner(self, states):
+        """Given a list of states populate self.pending_presence_by_dest and
+        poke to send a new transaction to each destination
+
+        Args:
+            states (list(UserPresenceState))
+        """
+        hosts_and_states = yield get_interested_remotes(self.store, states, self.state)
+
+        for destinations, states in hosts_and_states:
+            for destination in destinations:
+                if not self.can_send_to(destination):
+                    continue
+
+                self.pending_presence_by_dest.setdefault(
+                    destination, {}
+                ).update({
+                    state.user_id: state for state in states
+                })
+
+                self._attempt_new_transaction(destination)
 
     def send_edu(self, destination, edu_type, content, key=None):
         edu = Edu(
@@ -253,9 +372,7 @@ class TransactionQueue(object):
         else:
             self.pending_edus_by_dest.setdefault(destination, []).append(edu)
 
-        preserve_context_over_fn(
-            self._attempt_new_transaction, destination
-        )
+        self._attempt_new_transaction(destination)
 
     def send_failure(self, failure, destination):
         if destination == self.server_name or destination == "localhost":
@@ -268,9 +385,7 @@ class TransactionQueue(object):
             destination, []
         ).append(failure)
 
-        preserve_context_over_fn(
-            self._attempt_new_transaction, destination
-        )
+        self._attempt_new_transaction(destination)
 
     def send_device_messages(self, destination):
         if destination == self.server_name or destination == "localhost":
@@ -279,15 +394,24 @@ class TransactionQueue(object):
         if not self.can_send_to(destination):
             return
 
-        preserve_context_over_fn(
-            self._attempt_new_transaction, destination
-        )
+        self._attempt_new_transaction(destination)
 
     def get_current_token(self):
         return 0
 
-    @defer.inlineCallbacks
     def _attempt_new_transaction(self, destination):
+        """Try to start a new transaction to this destination
+
+        If there is already a transaction in progress to this destination,
+        returns immediately. Otherwise kicks off the process of sending a
+        transaction in the background.
+
+        Args:
+            destination (str):
+
+        Returns:
+            None
+        """
         # list of (pending_pdu, deferred, order)
         if destination in self.pending_transactions:
             # XXX: pending_transactions can get stuck on by a never-ending
@@ -300,12 +424,46 @@ class TransactionQueue(object):
             )
             return
 
+        logger.debug("TX [%s] Starting transaction loop", destination)
+
+        # Drop the logcontext before starting the transaction. It doesn't
+        # really make sense to log all the outbound transactions against
+        # whatever path led us to this point: that's pretty arbitrary really.
+        #
+        # (this also means we can fire off _perform_transaction without
+        # yielding)
+        with logcontext.PreserveLoggingContext():
+            self._transaction_transmission_loop(destination)
+
+    @defer.inlineCallbacks
+    def _transaction_transmission_loop(self, destination):
+        pending_pdus = []
         try:
             self.pending_transactions[destination] = 1
 
+            # This will throw if we wouldn't retry. We do this here so we fail
+            # quickly, but we will later check this again in the http client,
+            # hence why we throw the result away.
+            yield get_retry_limiter(destination, self.clock, self.store)
+
+            # XXX: what's this for?
             yield run_on_reactor()
 
+            pending_pdus = []
             while True:
+                device_message_edus, device_stream_id, dev_list_id = (
+                    yield self._get_new_device_messages(destination)
+                )
+
+                # BEGIN CRITICAL SECTION
+                #
+                # In order to avoid a race condition, we need to make sure that
+                # the following code (from popping the queues up to the point
+                # where we decide if we actually have any pending messages) is
+                # atomic - otherwise new PDUs or EDUs might arrive in the
+                # meantime, but not get sent because we hold the
+                # pending_transactions flag.
+
                 pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
                 pending_edus = self.pending_edus_by_dest.pop(destination, [])
                 pending_presence = self.pending_presence_by_dest.pop(destination, {})
@@ -315,17 +473,6 @@ class TransactionQueue(object):
                     self.pending_edus_keyed_by_dest.pop(destination, {}).values()
                 )
 
-                limiter = yield get_retry_limiter(
-                    destination,
-                    self.clock,
-                    self.store,
-                    backoff_on_404=True,  # If we get a 404 the other side has gone
-                )
-
-                device_message_edus, device_stream_id, dev_list_id = (
-                    yield self._get_new_device_messages(destination)
-                )
-
                 pending_edus.extend(device_message_edus)
                 if pending_presence:
                     pending_edus.append(
@@ -355,11 +502,13 @@ class TransactionQueue(object):
                     )
                     return
 
+                # END CRITICAL SECTION
+
                 success = yield self._send_new_transaction(
                     destination, pending_pdus, pending_edus, pending_failures,
-                    limiter=limiter,
                 )
                 if success:
+                    sent_transactions_counter.inc()
                     # Remove the acknowledged device messages from the database
                     # Only bother if we actually sent some device messages
                     if device_message_edus:
@@ -375,12 +524,26 @@ class TransactionQueue(object):
                     self.last_device_list_stream_id_by_dest[destination] = dev_list_id
                 else:
                     break
-        except NotRetryingDestination:
+        except NotRetryingDestination as e:
             logger.debug(
-                "TX [%s] not ready for retry yet - "
+                "TX [%s] not ready for retry yet (next retry at %s) - "
                 "dropping transaction for now",
                 destination,
+                datetime.datetime.fromtimestamp(
+                    (e.retry_last_ts + e.retry_interval) / 1000.0
+                ),
             )
+        except FederationDeniedError as e:
+            logger.info(e)
+        except Exception as e:
+            logger.warn(
+                "TX [%s] Failed to send transaction: %s",
+                destination,
+                e,
+            )
+            for p, _ in pending_pdus:
+                logger.info("Failed to send event %s to %s", p.event_id,
+                            destination)
         finally:
             # We want to be *very* sure we delete this after we stop processing
             self.pending_transactions.pop(destination, None)
@@ -420,7 +583,7 @@ class TransactionQueue(object):
     @measure_func("_send_new_transaction")
     @defer.inlineCallbacks
     def _send_new_transaction(self, destination, pending_pdus, pending_edus,
-                              pending_failures, limiter):
+                              pending_failures):
 
         # Sort based on the order field
         pending_pdus.sort(key=lambda t: t[1])
@@ -430,132 +593,104 @@ class TransactionQueue(object):
 
         success = True
 
-        try:
-            logger.debug("TX [%s] _attempt_new_transaction", destination)
+        logger.debug("TX [%s] _attempt_new_transaction", destination)
 
-            txn_id = str(self._next_txn_id)
+        txn_id = str(self._next_txn_id)
 
-            logger.debug(
-                "TX [%s] {%s} Attempting new transaction"
-                " (pdus: %d, edus: %d, failures: %d)",
-                destination, txn_id,
-                len(pdus),
-                len(edus),
-                len(failures)
-            )
+        logger.debug(
+            "TX [%s] {%s} Attempting new transaction"
+            " (pdus: %d, edus: %d, failures: %d)",
+            destination, txn_id,
+            len(pdus),
+            len(edus),
+            len(failures)
+        )
 
-            logger.debug("TX [%s] Persisting transaction...", destination)
+        logger.debug("TX [%s] Persisting transaction...", destination)
 
-            transaction = Transaction.create_new(
-                origin_server_ts=int(self.clock.time_msec()),
-                transaction_id=txn_id,
-                origin=self.server_name,
-                destination=destination,
-                pdus=pdus,
-                edus=edus,
-                pdu_failures=failures,
-            )
+        transaction = Transaction.create_new(
+            origin_server_ts=int(self.clock.time_msec()),
+            transaction_id=txn_id,
+            origin=self.server_name,
+            destination=destination,
+            pdus=pdus,
+            edus=edus,
+            pdu_failures=failures,
+        )
 
-            self._next_txn_id += 1
+        self._next_txn_id += 1
 
-            yield self.transaction_actions.prepare_to_send(transaction)
+        yield self.transaction_actions.prepare_to_send(transaction)
 
-            logger.debug("TX [%s] Persisted transaction", destination)
-            logger.info(
-                "TX [%s] {%s} Sending transaction [%s],"
-                " (PDUs: %d, EDUs: %d, failures: %d)",
-                destination, txn_id,
-                transaction.transaction_id,
-                len(pdus),
-                len(edus),
-                len(failures),
-            )
+        logger.debug("TX [%s] Persisted transaction", destination)
+        logger.info(
+            "TX [%s] {%s} Sending transaction [%s],"
+            " (PDUs: %d, EDUs: %d, failures: %d)",
+            destination, txn_id,
+            transaction.transaction_id,
+            len(pdus),
+            len(edus),
+            len(failures),
+        )
 
-            with limiter:
-                # Actually send the transaction
-
-                # FIXME (erikj): This is a bit of a hack to make the Pdu age
-                # keys work
-                def json_data_cb():
-                    data = transaction.get_dict()
-                    now = int(self.clock.time_msec())
-                    if "pdus" in data:
-                        for p in data["pdus"]:
-                            if "age_ts" in p:
-                                unsigned = p.setdefault("unsigned", {})
-                                unsigned["age"] = now - int(p["age_ts"])
-                                del p["age_ts"]
-                    return data
-
-                try:
-                    response = yield self.transport_layer.send_transaction(
-                        transaction, json_data_cb
-                    )
-                    code = 200
-
-                    if response:
-                        for e_id, r in response.get("pdus", {}).items():
-                            if "error" in r:
-                                logger.warn(
-                                    "Transaction returned error for %s: %s",
-                                    e_id, r,
-                                )
-                except HttpResponseException as e:
-                    code = e.code
-                    response = e.response
-
-                    if e.code in (401, 404, 429) or 500 <= e.code:
-                        logger.info(
-                            "TX [%s] {%s} got %d response",
-                            destination, txn_id, code
+        # Actually send the transaction
+
+        # FIXME (erikj): This is a bit of a hack to make the Pdu age
+        # keys work
+        def json_data_cb():
+            data = transaction.get_dict()
+            now = int(self.clock.time_msec())
+            if "pdus" in data:
+                for p in data["pdus"]:
+                    if "age_ts" in p:
+                        unsigned = p.setdefault("unsigned", {})
+                        unsigned["age"] = now - int(p["age_ts"])
+                        del p["age_ts"]
+            return data
+
+        try:
+            response = yield self.transport_layer.send_transaction(
+                transaction, json_data_cb
+            )
+            code = 200
+
+            if response:
+                for e_id, r in response.get("pdus", {}).items():
+                    if "error" in r:
+                        logger.warn(
+                            "Transaction returned error for %s: %s",
+                            e_id, r,
                         )
-                        raise e
+        except HttpResponseException as e:
+            code = e.code
+            response = e.response
 
+            if e.code in (401, 404, 429) or 500 <= e.code:
                 logger.info(
                     "TX [%s] {%s} got %d response",
                     destination, txn_id, code
                 )
+                raise e
 
-                logger.debug("TX [%s] Sent transaction", destination)
-                logger.debug("TX [%s] Marking as delivered...", destination)
-
-            yield self.transaction_actions.delivered(
-                transaction, code, response
-            )
+        logger.info(
+            "TX [%s] {%s} got %d response",
+            destination, txn_id, code
+        )
 
-            logger.debug("TX [%s] Marked as delivered", destination)
+        logger.debug("TX [%s] Sent transaction", destination)
+        logger.debug("TX [%s] Marking as delivered...", destination)
 
-            if code != 200:
-                for p in pdus:
-                    logger.info(
-                        "Failed to send event %s to %s", p.event_id, destination
-                    )
-                success = False
-        except RuntimeError as e:
-            # We capture this here as there as nothing actually listens
-            # for this finishing functions deferred.
-            logger.warn(
-                "TX [%s] Problem in _attempt_transaction: %s",
-                destination,
-                e,
-            )
+        yield self.transaction_actions.delivered(
+            transaction, code, response
+        )
 
-            success = False
+        logger.debug("TX [%s] Marked as delivered", destination)
 
+        if code != 200:
             for p in pdus:
-                logger.info("Failed to send event %s to %s", p.event_id, destination)
-        except Exception as e:
-            # We capture this here as there as nothing actually listens
-            # for this finishing functions deferred.
-            logger.warn(
-                "TX [%s] Problem in _attempt_transaction: %s",
-                destination,
-                e,
-            )
-
+                logger.info(
+                    "Failed to send event %s to %s", p.event_id, destination
+                )
             success = False
 
-            for p in pdus:
-                logger.info("Failed to send event %s to %s", p.event_id, destination)
-
         defer.returnValue(success)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index f49e8a2cc4..6db8efa6dd 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -20,6 +21,7 @@ from synapse.api.urls import FEDERATION_PREFIX as PREFIX
 from synapse.util.logutils import log_function
 
 import logging
+import urllib
 
 
 logger = logging.getLogger(__name__)
@@ -49,7 +51,7 @@ class TransportLayerClient(object):
         logger.debug("get_room_state dest=%s, room=%s",
                      destination, room_id)
 
-        path = PREFIX + "/state/%s/" % room_id
+        path = _create_path(PREFIX, "/state/%s/", room_id)
         return self.client.get_json(
             destination, path=path, args={"event_id": event_id},
         )
@@ -71,7 +73,7 @@ class TransportLayerClient(object):
         logger.debug("get_room_state_ids dest=%s, room=%s",
                      destination, room_id)
 
-        path = PREFIX + "/state_ids/%s/" % room_id
+        path = _create_path(PREFIX, "/state_ids/%s/", room_id)
         return self.client.get_json(
             destination, path=path, args={"event_id": event_id},
         )
@@ -93,7 +95,7 @@ class TransportLayerClient(object):
         logger.debug("get_pdu dest=%s, event_id=%s",
                      destination, event_id)
 
-        path = PREFIX + "/event/%s/" % (event_id, )
+        path = _create_path(PREFIX, "/event/%s/", event_id)
         return self.client.get_json(destination, path=path, timeout=timeout)
 
     @log_function
@@ -119,7 +121,7 @@ class TransportLayerClient(object):
             # TODO: raise?
             return
 
-        path = PREFIX + "/backfill/%s/" % (room_id,)
+        path = _create_path(PREFIX, "/backfill/%s/", room_id)
 
         args = {
             "v": event_tuples,
@@ -157,12 +159,15 @@ class TransportLayerClient(object):
         # generated by the json_data_callback.
         json_data = transaction.get_dict()
 
+        path = _create_path(PREFIX, "/send/%s/", transaction.transaction_id)
+
         response = yield self.client.put_json(
             transaction.destination,
-            path=PREFIX + "/send/%s/" % transaction.transaction_id,
+            path=path,
             data=json_data,
             json_data_callback=json_data_callback,
             long_retries=True,
+            backoff_on_404=True,  # If we get a 404 the other side has gone
         )
 
         logger.debug(
@@ -174,8 +179,9 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
-    def make_query(self, destination, query_type, args, retry_on_dns_fail):
-        path = PREFIX + "/query/%s" % query_type
+    def make_query(self, destination, query_type, args, retry_on_dns_fail,
+                   ignore_backoff=False):
+        path = _create_path(PREFIX, "/query/%s", query_type)
 
         content = yield self.client.get_json(
             destination=destination,
@@ -183,6 +189,7 @@ class TransportLayerClient(object):
             args=args,
             retry_on_dns_fail=retry_on_dns_fail,
             timeout=10000,
+            ignore_backoff=ignore_backoff,
         )
 
         defer.returnValue(content)
@@ -190,19 +197,54 @@ class TransportLayerClient(object):
     @defer.inlineCallbacks
     @log_function
     def make_membership_event(self, destination, room_id, user_id, membership):
+        """Asks a remote server to build and sign us a membership event
+
+        Note that this does not append any events to any graphs.
+
+        Args:
+            destination (str): address of remote homeserver
+            room_id (str): room to join/leave
+            user_id (str): user to be joined/left
+            membership (str): one of join/leave
+
+        Returns:
+            Deferred: Succeeds when we get a 2xx HTTP response. The result
+            will be the decoded JSON body (ie, the new event).
+
+            Fails with ``HTTPRequestException`` if we get an HTTP response
+            code >= 300.
+
+            Fails with ``NotRetryingDestination`` if we are not yet ready
+            to retry this server.
+
+            Fails with ``FederationDeniedError`` if the remote destination
+            is not in our federation whitelist
+        """
         valid_memberships = {Membership.JOIN, Membership.LEAVE}
         if membership not in valid_memberships:
             raise RuntimeError(
                 "make_membership_event called with membership='%s', must be one of %s" %
                 (membership, ",".join(valid_memberships))
             )
-        path = PREFIX + "/make_%s/%s/%s" % (membership, room_id, user_id)
+        path = _create_path(PREFIX, "/make_%s/%s/%s", membership, room_id, user_id)
+
+        ignore_backoff = False
+        retry_on_dns_fail = False
+
+        if membership == Membership.LEAVE:
+            # we particularly want to do our best to send leave events. The
+            # problem is that if it fails, we won't retry it later, so if the
+            # remote server was just having a momentary blip, the room will be
+            # out of sync.
+            ignore_backoff = True
+            retry_on_dns_fail = True
 
         content = yield self.client.get_json(
             destination=destination,
             path=path,
-            retry_on_dns_fail=False,
+            retry_on_dns_fail=retry_on_dns_fail,
             timeout=20000,
+            ignore_backoff=ignore_backoff,
         )
 
         defer.returnValue(content)
@@ -210,7 +252,7 @@ class TransportLayerClient(object):
     @defer.inlineCallbacks
     @log_function
     def send_join(self, destination, room_id, event_id, content):
-        path = PREFIX + "/send_join/%s/%s" % (room_id, event_id)
+        path = _create_path(PREFIX, "/send_join/%s/%s", room_id, event_id)
 
         response = yield self.client.put_json(
             destination=destination,
@@ -223,12 +265,18 @@ class TransportLayerClient(object):
     @defer.inlineCallbacks
     @log_function
     def send_leave(self, destination, room_id, event_id, content):
-        path = PREFIX + "/send_leave/%s/%s" % (room_id, event_id)
+        path = _create_path(PREFIX, "/send_leave/%s/%s", room_id, event_id)
 
         response = yield self.client.put_json(
             destination=destination,
             path=path,
             data=content,
+
+            # we want to do our best to send this through. The problem is
+            # that if it fails, we won't retry it later, so if the remote
+            # server was just having a momentary blip, the room will be out of
+            # sync.
+            ignore_backoff=True,
         )
 
         defer.returnValue(response)
@@ -236,12 +284,13 @@ class TransportLayerClient(object):
     @defer.inlineCallbacks
     @log_function
     def send_invite(self, destination, room_id, event_id, content):
-        path = PREFIX + "/invite/%s/%s" % (room_id, event_id)
+        path = _create_path(PREFIX, "/invite/%s/%s", room_id, event_id)
 
         response = yield self.client.put_json(
             destination=destination,
             path=path,
             data=content,
+            ignore_backoff=True,
         )
 
         defer.returnValue(response)
@@ -269,6 +318,7 @@ class TransportLayerClient(object):
             destination=remote_server,
             path=path,
             args=args,
+            ignore_backoff=True,
         )
 
         defer.returnValue(response)
@@ -276,7 +326,7 @@ class TransportLayerClient(object):
     @defer.inlineCallbacks
     @log_function
     def exchange_third_party_invite(self, destination, room_id, event_dict):
-        path = PREFIX + "/exchange_third_party_invite/%s" % (room_id,)
+        path = _create_path(PREFIX, "/exchange_third_party_invite/%s", room_id,)
 
         response = yield self.client.put_json(
             destination=destination,
@@ -289,7 +339,7 @@ class TransportLayerClient(object):
     @defer.inlineCallbacks
     @log_function
     def get_event_auth(self, destination, room_id, event_id):
-        path = PREFIX + "/event_auth/%s/%s" % (room_id, event_id)
+        path = _create_path(PREFIX, "/event_auth/%s/%s", room_id, event_id)
 
         content = yield self.client.get_json(
             destination=destination,
@@ -301,7 +351,7 @@ class TransportLayerClient(object):
     @defer.inlineCallbacks
     @log_function
     def send_query_auth(self, destination, room_id, event_id, content):
-        path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id)
+        path = _create_path(PREFIX, "/query_auth/%s/%s", room_id, event_id)
 
         content = yield self.client.post_json(
             destination=destination,
@@ -363,7 +413,7 @@ class TransportLayerClient(object):
         Returns:
             A dict containg the device keys.
         """
-        path = PREFIX + "/user/devices/" + user_id
+        path = _create_path(PREFIX, "/user/devices/%s", user_id)
 
         content = yield self.client.get_json(
             destination=destination,
@@ -413,7 +463,7 @@ class TransportLayerClient(object):
     @log_function
     def get_missing_events(self, destination, room_id, earliest_events,
                            latest_events, limit, min_depth, timeout):
-        path = PREFIX + "/get_missing_events/%s" % (room_id,)
+        path = _create_path(PREFIX, "/get_missing_events/%s", room_id,)
 
         content = yield self.client.post_json(
             destination=destination,
@@ -428,3 +478,475 @@ class TransportLayerClient(object):
         )
 
         defer.returnValue(content)
+
+    @log_function
+    def get_group_profile(self, destination, group_id, requester_user_id):
+        """Get a group profile
+        """
+        path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
+
+        return self.client.get_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def update_group_profile(self, destination, group_id, requester_user_id, content):
+        """Update a remote group profile
+
+        Args:
+            destination (str)
+            group_id (str)
+            requester_user_id (str)
+            content (dict): The new profile of the group
+        """
+        path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def get_group_summary(self, destination, group_id, requester_user_id):
+        """Get a group summary
+        """
+        path = _create_path(PREFIX, "/groups/%s/summary", group_id,)
+
+        return self.client.get_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def get_rooms_in_group(self, destination, group_id, requester_user_id):
+        """Get all rooms in a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/rooms", group_id,)
+
+        return self.client.get_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    def add_room_to_group(self, destination, group_id, requester_user_id, room_id,
+                          content):
+        """Add a room to a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            data=content,
+            ignore_backoff=True,
+        )
+
+    def update_room_in_group(self, destination, group_id, requester_user_id, room_id,
+                             config_key, content):
+        """Update room in group
+        """
+        path = _create_path(
+            PREFIX, "/groups/%s/room/%s/config/%s",
+            group_id, room_id, config_key,
+        )
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            data=content,
+            ignore_backoff=True,
+        )
+
+    def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
+        """Remove a room from a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
+
+        return self.client.delete_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def get_users_in_group(self, destination, group_id, requester_user_id):
+        """Get users in a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/users", group_id,)
+
+        return self.client.get_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def get_invited_users_in_group(self, destination, group_id, requester_user_id):
+        """Get users that have been invited to a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/invited_users", group_id,)
+
+        return self.client.get_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def accept_group_invite(self, destination, group_id, user_id, content):
+        """Accept a group invite
+        """
+        path = _create_path(
+            PREFIX, "/groups/%s/users/%s/accept_invite",
+            group_id, user_id,
+        )
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def join_group(self, destination, group_id, user_id, content):
+        """Attempts to join a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/users/%s/join", group_id, user_id)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
+        """Invite a user to a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/users/%s/invite", group_id, user_id)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def invite_to_group_notification(self, destination, group_id, user_id, content):
+        """Sent by group server to inform a user's server that they have been
+        invited.
+        """
+
+        path = _create_path(PREFIX, "/groups/local/%s/users/%s/invite", group_id, user_id)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def remove_user_from_group(self, destination, group_id, requester_user_id,
+                               user_id, content):
+        """Remove a user fron a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/users/%s/remove", group_id, user_id)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def remove_user_from_group_notification(self, destination, group_id, user_id,
+                                            content):
+        """Sent by group server to inform a user's server that they have been
+        kicked from the group.
+        """
+
+        path = _create_path(PREFIX, "/groups/local/%s/users/%s/remove", group_id, user_id)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def renew_group_attestation(self, destination, group_id, user_id, content):
+        """Sent by either a group server or a user's server to periodically update
+        the attestations
+        """
+
+        path = _create_path(PREFIX, "/groups/%s/renew_attestation/%s", group_id, user_id)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def update_group_summary_room(self, destination, group_id, user_id, room_id,
+                                  category_id, content):
+        """Update a room entry in a group summary
+        """
+        if category_id:
+            path = _create_path(
+                PREFIX, "/groups/%s/summary/categories/%s/rooms/%s",
+                group_id, category_id, room_id,
+            )
+        else:
+            path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": user_id},
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def delete_group_summary_room(self, destination, group_id, user_id, room_id,
+                                  category_id):
+        """Delete a room entry in a group summary
+        """
+        if category_id:
+            path = _create_path(
+                PREFIX + "/groups/%s/summary/categories/%s/rooms/%s",
+                group_id, category_id, room_id,
+            )
+        else:
+            path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
+
+        return self.client.delete_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def get_group_categories(self, destination, group_id, requester_user_id):
+        """Get all categories in a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/categories", group_id,)
+
+        return self.client.get_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def get_group_category(self, destination, group_id, requester_user_id, category_id):
+        """Get category info in a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
+
+        return self.client.get_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def update_group_category(self, destination, group_id, requester_user_id, category_id,
+                              content):
+        """Update a category in a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def delete_group_category(self, destination, group_id, requester_user_id,
+                              category_id):
+        """Delete a category in a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
+
+        return self.client.delete_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def get_group_roles(self, destination, group_id, requester_user_id):
+        """Get all roles in a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/roles", group_id,)
+
+        return self.client.get_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def get_group_role(self, destination, group_id, requester_user_id, role_id):
+        """Get a roles info
+        """
+        path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
+
+        return self.client.get_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def update_group_role(self, destination, group_id, requester_user_id, role_id,
+                          content):
+        """Update a role in a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def delete_group_role(self, destination, group_id, requester_user_id, role_id):
+        """Delete a role in a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
+
+        return self.client.delete_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def update_group_summary_user(self, destination, group_id, requester_user_id,
+                                  user_id, role_id, content):
+        """Update a users entry in a group
+        """
+        if role_id:
+            path = _create_path(
+                PREFIX, "/groups/%s/summary/roles/%s/users/%s",
+                group_id, role_id, user_id,
+            )
+        else:
+            path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def set_group_join_policy(self, destination, group_id, requester_user_id,
+                              content):
+        """Sets the join policy for a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/settings/m.join_policy", group_id,)
+
+        return self.client.put_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def delete_group_summary_user(self, destination, group_id, requester_user_id,
+                                  user_id, role_id):
+        """Delete a users entry in a group
+        """
+        if role_id:
+            path = _create_path(
+                PREFIX, "/groups/%s/summary/roles/%s/users/%s",
+                group_id, role_id, user_id,
+            )
+        else:
+            path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
+
+        return self.client.delete_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    def bulk_get_publicised_groups(self, destination, user_ids):
+        """Get the groups a list of users are publicising
+        """
+
+        path = PREFIX + "/get_groups_publicised"
+
+        content = {"user_ids": user_ids}
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            data=content,
+            ignore_backoff=True,
+        )
+
+
+def _create_path(prefix, path, *args):
+    """Creates a path from the prefix, path template and args. Ensures that
+    all args are url encoded.
+
+    Example:
+
+        _create_path(PREFIX, "/event/%s/", event_id)
+
+    Args:
+        prefix (str)
+        path (str): String template for the path
+        args: ([str]): Args to insert into path. Each arg will be url encoded
+
+    Returns:
+        str
+    """
+    return prefix + path % tuple(urllib.quote(arg, "") for arg in args)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index c840da834c..19d09f5422 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,7 +17,7 @@
 from twisted.internet import defer
 
 from synapse.api.urls import FEDERATION_PREFIX as PREFIX
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import Codes, SynapseError, FederationDeniedError
 from synapse.http.server import JsonResource
 from synapse.http.servlet import (
     parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
@@ -24,7 +25,8 @@ from synapse.http.servlet import (
 )
 from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.versionstring import get_version_string
-from synapse.types import ThirdPartyInstanceID
+from synapse.util.logcontext import run_in_background
+from synapse.types import ThirdPartyInstanceID, get_domain_from_id
 
 import functools
 import logging
@@ -79,6 +81,8 @@ class Authenticator(object):
     def __init__(self, hs):
         self.keyring = hs.get_keyring()
         self.server_name = hs.hostname
+        self.store = hs.get_datastore()
+        self.federation_domain_whitelist = hs.config.federation_domain_whitelist
 
     # A method just so we can pass 'self' as the authenticator to the Servlets
     @defer.inlineCallbacks
@@ -110,7 +114,7 @@ class Authenticator(object):
                 key = strip_quotes(param_dict["key"])
                 sig = strip_quotes(param_dict["sig"])
                 return (origin, key, sig)
-            except:
+            except Exception:
                 raise AuthenticationError(
                     400, "Malformed Authorization header", Codes.UNAUTHORIZED
                 )
@@ -128,6 +132,12 @@ class Authenticator(object):
                 json_request["origin"] = origin
                 json_request["signatures"].setdefault(origin, {})[key] = sig
 
+        if (
+            self.federation_domain_whitelist is not None and
+            origin not in self.federation_domain_whitelist
+        ):
+            raise FederationDeniedError(origin)
+
         if not json_request["signatures"]:
             raise NoAuthenticationError(
                 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
@@ -138,18 +148,30 @@ class Authenticator(object):
         logger.info("Request from %s", origin)
         request.authenticated_entity = origin
 
+        # If we get a valid signed request from the other side, its probably
+        # alive
+        retry_timings = yield self.store.get_destination_retry_timings(origin)
+        if retry_timings and retry_timings["retry_last_ts"]:
+            run_in_background(self._reset_retry_timings, origin)
+
         defer.returnValue(origin)
 
+    @defer.inlineCallbacks
+    def _reset_retry_timings(self, origin):
+        try:
+            logger.info("Marking origin %r as up", origin)
+            yield self.store.set_destination_retry_timings(origin, 0, 0)
+        except Exception:
+            logger.exception("Error resetting retry timings on %s", origin)
+
 
 class BaseFederationServlet(object):
     REQUIRE_AUTH = True
 
-    def __init__(self, handler, authenticator, ratelimiter, server_name,
-                 room_list_handler):
+    def __init__(self, handler, authenticator, ratelimiter, server_name):
         self.handler = handler
         self.authenticator = authenticator
         self.ratelimiter = ratelimiter
-        self.room_list_handler = room_list_handler
 
     def _wrap(self, func):
         authenticator = self.authenticator
@@ -170,7 +192,7 @@ class BaseFederationServlet(object):
                 if self.REQUIRE_AUTH:
                     logger.exception("authenticate_request failed")
                     raise
-            except:
+            except Exception:
                 logger.exception("authenticate_request failed")
                 raise
 
@@ -263,7 +285,7 @@ class FederationSendServlet(BaseFederationServlet):
             code, response = yield self.handler.on_incoming_transaction(
                 transaction_data
             )
-        except:
+        except Exception:
             logger.exception("on_incoming_transaction failed")
             raise
 
@@ -581,7 +603,7 @@ class PublicRoomList(BaseFederationServlet):
         else:
             network_tuple = ThirdPartyInstanceID(None, None)
 
-        data = yield self.room_list_handler.get_local_public_room_list(
+        data = yield self.handler.get_local_public_room_list(
             limit, since_token,
             network_tuple=network_tuple
         )
@@ -602,7 +624,550 @@ class FederationVersionServlet(BaseFederationServlet):
         }))
 
 
-SERVLET_CLASSES = (
+class FederationGroupsProfileServlet(BaseFederationServlet):
+    """Get/set the basic profile of a group on behalf of a user
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/profile$"
+
+    @defer.inlineCallbacks
+    def on_GET(self, origin, content, query, group_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.get_group_profile(
+            group_id, requester_user_id
+        )
+
+        defer.returnValue((200, new_content))
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.update_group_profile(
+            group_id, requester_user_id, content
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsSummaryServlet(BaseFederationServlet):
+    PATH = "/groups/(?P<group_id>[^/]*)/summary$"
+
+    @defer.inlineCallbacks
+    def on_GET(self, origin, content, query, group_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.get_group_summary(
+            group_id, requester_user_id
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsRoomsServlet(BaseFederationServlet):
+    """Get the rooms in a group on behalf of a user
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/rooms$"
+
+    @defer.inlineCallbacks
+    def on_GET(self, origin, content, query, group_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.get_rooms_in_group(
+            group_id, requester_user_id
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsAddRoomsServlet(BaseFederationServlet):
+    """Add/remove room from group
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)$"
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, room_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.add_room_to_group(
+            group_id, requester_user_id, room_id, content
+        )
+
+        defer.returnValue((200, new_content))
+
+    @defer.inlineCallbacks
+    def on_DELETE(self, origin, content, query, group_id, room_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.remove_room_from_group(
+            group_id, requester_user_id, room_id,
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
+    """Update room config in group
+    """
+    PATH = (
+        "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
+        "/config/(?P<config_key>[^/]*)$"
+    )
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, room_id, config_key):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        result = yield self.groups_handler.update_room_in_group(
+            group_id, requester_user_id, room_id, config_key, content,
+        )
+
+        defer.returnValue((200, result))
+
+
+class FederationGroupsUsersServlet(BaseFederationServlet):
+    """Get the users in a group on behalf of a user
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/users$"
+
+    @defer.inlineCallbacks
+    def on_GET(self, origin, content, query, group_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.get_users_in_group(
+            group_id, requester_user_id
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
+    """Get the users that have been invited to a group
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/invited_users$"
+
+    @defer.inlineCallbacks
+    def on_GET(self, origin, content, query, group_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.get_invited_users_in_group(
+            group_id, requester_user_id
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsInviteServlet(BaseFederationServlet):
+    """Ask a group server to invite someone to the group
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, user_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.invite_to_group(
+            group_id, user_id, requester_user_id, content,
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
+    """Accept an invitation from the group server
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite$"
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, user_id):
+        if get_domain_from_id(user_id) != origin:
+            raise SynapseError(403, "user_id doesn't match origin")
+
+        new_content = yield self.handler.accept_invite(
+            group_id, user_id, content,
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsJoinServlet(BaseFederationServlet):
+    """Attempt to join a group
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join$"
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, user_id):
+        if get_domain_from_id(user_id) != origin:
+            raise SynapseError(403, "user_id doesn't match origin")
+
+        new_content = yield self.handler.join_group(
+            group_id, user_id, content,
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsRemoveUserServlet(BaseFederationServlet):
+    """Leave or kick a user from the group
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, user_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.remove_user_from_group(
+            group_id, user_id, requester_user_id, content,
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsLocalInviteServlet(BaseFederationServlet):
+    """A group server has invited a local user
+    """
+    PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, user_id):
+        if get_domain_from_id(group_id) != origin:
+            raise SynapseError(403, "group_id doesn't match origin")
+
+        new_content = yield self.handler.on_invite(
+            group_id, user_id, content,
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
+    """A group server has removed a local user
+    """
+    PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, user_id):
+        if get_domain_from_id(group_id) != origin:
+            raise SynapseError(403, "user_id doesn't match origin")
+
+        new_content = yield self.handler.user_removed_from_group(
+            group_id, user_id, content,
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
+    """A group or user's server renews their attestation
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)$"
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, user_id):
+        # We don't need to check auth here as we check the attestation signatures
+
+        new_content = yield self.handler.on_renew_attestation(
+            group_id, user_id, content
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
+    """Add/remove a room from the group summary, with optional category.
+
+    Matches both:
+        - /groups/:group/summary/rooms/:room_id
+        - /groups/:group/summary/categories/:category/rooms/:room_id
+    """
+    PATH = (
+        "/groups/(?P<group_id>[^/]*)/summary"
+        "(/categories/(?P<category_id>[^/]+))?"
+        "/rooms/(?P<room_id>[^/]*)$"
+    )
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, category_id, room_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        if category_id == "":
+            raise SynapseError(400, "category_id cannot be empty string")
+
+        resp = yield self.handler.update_group_summary_room(
+            group_id, requester_user_id,
+            room_id=room_id,
+            category_id=category_id,
+            content=content,
+        )
+
+        defer.returnValue((200, resp))
+
+    @defer.inlineCallbacks
+    def on_DELETE(self, origin, content, query, group_id, category_id, room_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        if category_id == "":
+            raise SynapseError(400, "category_id cannot be empty string")
+
+        resp = yield self.handler.delete_group_summary_room(
+            group_id, requester_user_id,
+            room_id=room_id,
+            category_id=category_id,
+        )
+
+        defer.returnValue((200, resp))
+
+
+class FederationGroupsCategoriesServlet(BaseFederationServlet):
+    """Get all categories for a group
+    """
+    PATH = (
+        "/groups/(?P<group_id>[^/]*)/categories/$"
+    )
+
+    @defer.inlineCallbacks
+    def on_GET(self, origin, content, query, group_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        resp = yield self.handler.get_group_categories(
+            group_id, requester_user_id,
+        )
+
+        defer.returnValue((200, resp))
+
+
+class FederationGroupsCategoryServlet(BaseFederationServlet):
+    """Add/remove/get a category in a group
+    """
+    PATH = (
+        "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
+    )
+
+    @defer.inlineCallbacks
+    def on_GET(self, origin, content, query, group_id, category_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        resp = yield self.handler.get_group_category(
+            group_id, requester_user_id, category_id
+        )
+
+        defer.returnValue((200, resp))
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, category_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        if category_id == "":
+            raise SynapseError(400, "category_id cannot be empty string")
+
+        resp = yield self.handler.upsert_group_category(
+            group_id, requester_user_id, category_id, content,
+        )
+
+        defer.returnValue((200, resp))
+
+    @defer.inlineCallbacks
+    def on_DELETE(self, origin, content, query, group_id, category_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        if category_id == "":
+            raise SynapseError(400, "category_id cannot be empty string")
+
+        resp = yield self.handler.delete_group_category(
+            group_id, requester_user_id, category_id,
+        )
+
+        defer.returnValue((200, resp))
+
+
+class FederationGroupsRolesServlet(BaseFederationServlet):
+    """Get roles in a group
+    """
+    PATH = (
+        "/groups/(?P<group_id>[^/]*)/roles/$"
+    )
+
+    @defer.inlineCallbacks
+    def on_GET(self, origin, content, query, group_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        resp = yield self.handler.get_group_roles(
+            group_id, requester_user_id,
+        )
+
+        defer.returnValue((200, resp))
+
+
+class FederationGroupsRoleServlet(BaseFederationServlet):
+    """Add/remove/get a role in a group
+    """
+    PATH = (
+        "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
+    )
+
+    @defer.inlineCallbacks
+    def on_GET(self, origin, content, query, group_id, role_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        resp = yield self.handler.get_group_role(
+            group_id, requester_user_id, role_id
+        )
+
+        defer.returnValue((200, resp))
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, role_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        if role_id == "":
+            raise SynapseError(400, "role_id cannot be empty string")
+
+        resp = yield self.handler.update_group_role(
+            group_id, requester_user_id, role_id, content,
+        )
+
+        defer.returnValue((200, resp))
+
+    @defer.inlineCallbacks
+    def on_DELETE(self, origin, content, query, group_id, role_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        if role_id == "":
+            raise SynapseError(400, "role_id cannot be empty string")
+
+        resp = yield self.handler.delete_group_role(
+            group_id, requester_user_id, role_id,
+        )
+
+        defer.returnValue((200, resp))
+
+
+class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
+    """Add/remove a user from the group summary, with optional role.
+
+    Matches both:
+        - /groups/:group/summary/users/:user_id
+        - /groups/:group/summary/roles/:role/users/:user_id
+    """
+    PATH = (
+        "/groups/(?P<group_id>[^/]*)/summary"
+        "(/roles/(?P<role_id>[^/]+))?"
+        "/users/(?P<user_id>[^/]*)$"
+    )
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, role_id, user_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        if role_id == "":
+            raise SynapseError(400, "role_id cannot be empty string")
+
+        resp = yield self.handler.update_group_summary_user(
+            group_id, requester_user_id,
+            user_id=user_id,
+            role_id=role_id,
+            content=content,
+        )
+
+        defer.returnValue((200, resp))
+
+    @defer.inlineCallbacks
+    def on_DELETE(self, origin, content, query, group_id, role_id, user_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        if role_id == "":
+            raise SynapseError(400, "role_id cannot be empty string")
+
+        resp = yield self.handler.delete_group_summary_user(
+            group_id, requester_user_id,
+            user_id=user_id,
+            role_id=role_id,
+        )
+
+        defer.returnValue((200, resp))
+
+
+class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
+    """Get roles in a group
+    """
+    PATH = (
+        "/get_groups_publicised$"
+    )
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query):
+        resp = yield self.handler.bulk_get_publicised_groups(
+            content["user_ids"], proxy=False,
+        )
+
+        defer.returnValue((200, resp))
+
+
+class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
+    """Sets whether a group is joinable without an invite or knock
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy$"
+
+    @defer.inlineCallbacks
+    def on_PUT(self, origin, content, query, group_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.set_group_join_policy(
+            group_id, requester_user_id, content
+        )
+
+        defer.returnValue((200, new_content))
+
+
+FEDERATION_SERVLET_CLASSES = (
     FederationSendServlet,
     FederationPullServlet,
     FederationEventServlet,
@@ -625,17 +1190,85 @@ SERVLET_CLASSES = (
     FederationThirdPartyInviteExchangeServlet,
     On3pidBindServlet,
     OpenIdUserInfo,
-    PublicRoomList,
     FederationVersionServlet,
 )
 
 
+ROOM_LIST_CLASSES = (
+    PublicRoomList,
+)
+
+GROUP_SERVER_SERVLET_CLASSES = (
+    FederationGroupsProfileServlet,
+    FederationGroupsSummaryServlet,
+    FederationGroupsRoomsServlet,
+    FederationGroupsUsersServlet,
+    FederationGroupsInvitedUsersServlet,
+    FederationGroupsInviteServlet,
+    FederationGroupsAcceptInviteServlet,
+    FederationGroupsJoinServlet,
+    FederationGroupsRemoveUserServlet,
+    FederationGroupsSummaryRoomsServlet,
+    FederationGroupsCategoriesServlet,
+    FederationGroupsCategoryServlet,
+    FederationGroupsRolesServlet,
+    FederationGroupsRoleServlet,
+    FederationGroupsSummaryUsersServlet,
+    FederationGroupsAddRoomsServlet,
+    FederationGroupsAddRoomsConfigServlet,
+    FederationGroupsSettingJoinPolicyServlet,
+)
+
+
+GROUP_LOCAL_SERVLET_CLASSES = (
+    FederationGroupsLocalInviteServlet,
+    FederationGroupsRemoveLocalUserServlet,
+    FederationGroupsBulkPublicisedServlet,
+)
+
+
+GROUP_ATTESTATION_SERVLET_CLASSES = (
+    FederationGroupsRenewAttestaionServlet,
+)
+
+
 def register_servlets(hs, resource, authenticator, ratelimiter):
-    for servletclass in SERVLET_CLASSES:
+    for servletclass in FEDERATION_SERVLET_CLASSES:
+        servletclass(
+            handler=hs.get_federation_server(),
+            authenticator=authenticator,
+            ratelimiter=ratelimiter,
+            server_name=hs.hostname,
+        ).register(resource)
+
+    for servletclass in ROOM_LIST_CLASSES:
+        servletclass(
+            handler=hs.get_room_list_handler(),
+            authenticator=authenticator,
+            ratelimiter=ratelimiter,
+            server_name=hs.hostname,
+        ).register(resource)
+
+    for servletclass in GROUP_SERVER_SERVLET_CLASSES:
+        servletclass(
+            handler=hs.get_groups_server_handler(),
+            authenticator=authenticator,
+            ratelimiter=ratelimiter,
+            server_name=hs.hostname,
+        ).register(resource)
+
+    for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
+        servletclass(
+            handler=hs.get_groups_local_handler(),
+            authenticator=authenticator,
+            ratelimiter=ratelimiter,
+            server_name=hs.hostname,
+        ).register(resource)
+
+    for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
         servletclass(
-            handler=hs.get_replication_layer(),
+            handler=hs.get_groups_attestation_renewer(),
             authenticator=authenticator,
             ratelimiter=ratelimiter,
             server_name=hs.hostname,
-            room_list_handler=hs.get_room_list_handler(),
         ).register(resource)