summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_base.py96
-rw-r--r--synapse/federation/federation_client.py184
-rw-r--r--synapse/federation/federation_server.py102
-rw-r--r--synapse/federation/transaction_queue.py52
-rw-r--r--synapse/federation/transport/client.py169
-rw-r--r--synapse/federation/transport/server.py196
-rw-r--r--synapse/federation/units.py3
7 files changed, 561 insertions, 241 deletions
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index b7ad729c63..a7a2ec4523 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -20,10 +20,10 @@ import six
 from twisted.internet import defer
 from twisted.internet.defer import DeferredList
 
-from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
+from synapse.api.constants import MAX_DEPTH, EventTypes, Membership, RoomVersions
 from synapse.api.errors import Codes, SynapseError
 from synapse.crypto.event_signing import check_event_content_hash
-from synapse.events import FrozenEvent
+from synapse.events import event_type_from_format_version
 from synapse.events.utils import prune_event
 from synapse.http.servlet import assert_params_in_dict
 from synapse.types import get_domain_from_id
@@ -43,8 +43,8 @@ class FederationBase(object):
         self._clock = hs.get_clock()
 
     @defer.inlineCallbacks
-    def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
-                                       include_none=False):
+    def _check_sigs_and_hash_and_fetch(self, origin, pdus, room_version,
+                                       outlier=False, include_none=False):
         """Takes a list of PDUs and checks the signatures and hashs of each
         one. If a PDU fails its signature check then we check if we have it in
         the database and if not then request if from the originating server of
@@ -56,13 +56,17 @@ class FederationBase(object):
         a new list.
 
         Args:
+            origin (str)
             pdu (list)
-            outlier (bool)
+            room_version (str)
+            outlier (bool): Whether the events are outliers or not
+            include_none (str): Whether to include None in the returned list
+                for events that have failed their checks
 
         Returns:
             Deferred : A list of PDUs that have valid signatures and hashes.
         """
-        deferreds = self._check_sigs_and_hashes(pdus)
+        deferreds = self._check_sigs_and_hashes(room_version, pdus)
 
         @defer.inlineCallbacks
         def handle_check_result(pdu, deferred):
@@ -84,6 +88,7 @@ class FederationBase(object):
                     res = yield self.get_pdu(
                         destinations=[pdu.origin],
                         event_id=pdu.event_id,
+                        room_version=room_version,
                         outlier=outlier,
                         timeout=10000,
                     )
@@ -116,16 +121,17 @@ class FederationBase(object):
         else:
             defer.returnValue([p for p in valid_pdus if p])
 
-    def _check_sigs_and_hash(self, pdu):
+    def _check_sigs_and_hash(self, room_version, pdu):
         return logcontext.make_deferred_yieldable(
-            self._check_sigs_and_hashes([pdu])[0],
+            self._check_sigs_and_hashes(room_version, [pdu])[0],
         )
 
-    def _check_sigs_and_hashes(self, pdus):
+    def _check_sigs_and_hashes(self, room_version, pdus):
         """Checks that each of the received events is correctly signed by the
         sending server.
 
         Args:
+            room_version (str): The room version of the PDUs
             pdus (list[FrozenEvent]): the events to be checked
 
         Returns:
@@ -136,7 +142,7 @@ class FederationBase(object):
               * throws a SynapseError if the signature check failed.
             The deferreds run their callbacks in the sentinel logcontext.
         """
-        deferreds = _check_sigs_on_pdus(self.keyring, pdus)
+        deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
 
         ctx = logcontext.LoggingContext.current_context()
 
@@ -198,16 +204,17 @@ class FederationBase(object):
 
 
 class PduToCheckSig(namedtuple("PduToCheckSig", [
-    "pdu", "redacted_pdu_json", "event_id_domain", "sender_domain", "deferreds",
+    "pdu", "redacted_pdu_json", "sender_domain", "deferreds",
 ])):
     pass
 
 
-def _check_sigs_on_pdus(keyring, pdus):
+def _check_sigs_on_pdus(keyring, room_version, pdus):
     """Check that the given events are correctly signed
 
     Args:
         keyring (synapse.crypto.Keyring): keyring object to do the checks
+        room_version (str): the room version of the PDUs
         pdus (Collection[EventBase]): the events to be checked
 
     Returns:
@@ -220,9 +227,7 @@ def _check_sigs_on_pdus(keyring, pdus):
 
     # we want to check that the event is signed by:
     #
-    # (a) the server which created the event_id
-    #
-    # (b) the sender's server.
+    # (a) the sender's server
     #
     #     - except in the case of invites created from a 3pid invite, which are exempt
     #     from this check, because the sender has to match that of the original 3pid
@@ -236,34 +241,26 @@ def _check_sigs_on_pdus(keyring, pdus):
     #     and signatures are *supposed* to be valid whether or not an event has been
     #     redacted. But this isn't the worst of the ways that 3pid invites are broken.
     #
+    # (b) for V1 and V2 rooms, the server which created the event_id
+    #
     # let's start by getting the domain for each pdu, and flattening the event back
     # to JSON.
+
     pdus_to_check = [
         PduToCheckSig(
             pdu=p,
             redacted_pdu_json=prune_event(p).get_pdu_json(),
-            event_id_domain=get_domain_from_id(p.event_id),
             sender_domain=get_domain_from_id(p.sender),
             deferreds=[],
         )
         for p in pdus
     ]
 
-    # first make sure that the event is signed by the event_id's domain
-    deferreds = keyring.verify_json_objects_for_server([
-        (p.event_id_domain, p.redacted_pdu_json)
-        for p in pdus_to_check
-    ])
-
-    for p, d in zip(pdus_to_check, deferreds):
-        p.deferreds.append(d)
-
-    # now let's look for events where the sender's domain is different to the
-    # event id's domain (normally only the case for joins/leaves), and add additional
-    # checks.
+    # First we check that the sender event is signed by the sender's domain
+    # (except if its a 3pid invite, in which case it may be sent by any server)
     pdus_to_check_sender = [
         p for p in pdus_to_check
-        if p.sender_domain != p.event_id_domain and not _is_invite_via_3pid(p.pdu)
+        if not _is_invite_via_3pid(p.pdu)
     ]
 
     more_deferreds = keyring.verify_json_objects_for_server([
@@ -274,19 +271,43 @@ def _check_sigs_on_pdus(keyring, pdus):
     for p, d in zip(pdus_to_check_sender, more_deferreds):
         p.deferreds.append(d)
 
+    # now let's look for events where the sender's domain is different to the
+    # event id's domain (normally only the case for joins/leaves), and add additional
+    # checks. Only do this if the room version has a concept of event ID domain
+    if room_version in (
+        RoomVersions.V1, RoomVersions.V2, RoomVersions.STATE_V2_TEST,
+    ):
+        pdus_to_check_event_id = [
+            p for p in pdus_to_check
+            if p.sender_domain != get_domain_from_id(p.pdu.event_id)
+        ]
+
+        more_deferreds = keyring.verify_json_objects_for_server([
+            (get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json)
+            for p in pdus_to_check_event_id
+        ])
+
+        for p, d in zip(pdus_to_check_event_id, more_deferreds):
+            p.deferreds.append(d)
+    elif room_version in (RoomVersions.V3,):
+        pass  # No further checks needed, as event IDs are hashes here
+    else:
+        raise RuntimeError("Unrecognized room version %s" % (room_version,))
+
     # replace lists of deferreds with single Deferreds
     return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
 
 
 def _flatten_deferred_list(deferreds):
-    """Given a list of one or more deferreds, either return the single deferred, or
-    combine into a DeferredList.
+    """Given a list of deferreds, either return the single deferred,
+    combine into a DeferredList, or return an already resolved deferred.
     """
     if len(deferreds) > 1:
         return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True)
-    else:
-        assert len(deferreds) == 1
+    elif len(deferreds) == 1:
         return deferreds[0]
+    else:
+        return defer.succeed(None)
 
 
 def _is_invite_via_3pid(event):
@@ -297,11 +318,12 @@ def _is_invite_via_3pid(event):
     )
 
 
-def event_from_pdu_json(pdu_json, outlier=False):
+def event_from_pdu_json(pdu_json, event_format_version, outlier=False):
     """Construct a FrozenEvent from an event json received over federation
 
     Args:
         pdu_json (object): pdu as received over federation
+        event_format_version (int): The event format version
         outlier (bool): True to mark this event as an outlier
 
     Returns:
@@ -313,7 +335,7 @@ def event_from_pdu_json(pdu_json, outlier=False):
     """
     # we could probably enforce a bunch of other fields here (room_id, sender,
     # origin, etc etc)
-    assert_params_in_dict(pdu_json, ('event_id', 'type', 'depth'))
+    assert_params_in_dict(pdu_json, ('type', 'depth'))
 
     depth = pdu_json['depth']
     if not isinstance(depth, six.integer_types):
@@ -325,8 +347,8 @@ def event_from_pdu_json(pdu_json, outlier=False):
     elif depth > MAX_DEPTH:
         raise SynapseError(400, "Depth too large", Codes.BAD_JSON)
 
-    event = FrozenEvent(
-        pdu_json
+    event = event_type_from_format_version(event_format_version)(
+        pdu_json,
     )
 
     event.internal_metadata.outlier = outlier
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index d05ed91d64..4e4f58b418 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -25,14 +25,19 @@ from prometheus_client import Counter
 
 from twisted.internet import defer
 
-from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, Membership
+from synapse.api.constants import (
+    KNOWN_ROOM_VERSIONS,
+    EventTypes,
+    Membership,
+    RoomVersions,
+)
 from synapse.api.errors import (
     CodeMessageException,
     FederationDeniedError,
     HttpResponseException,
     SynapseError,
 )
-from synapse.events import builder
+from synapse.events import builder, room_version_to_event_format
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
 from synapse.util import logcontext, unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -66,6 +71,9 @@ class FederationClient(FederationBase):
         self.state = hs.get_state_handler()
         self.transport_layer = hs.get_federation_transport_client()
 
+        self.hostname = hs.hostname
+        self.signing_key = hs.config.signing_key[0]
+
         self._get_pdu_cache = ExpiringCache(
             cache_name="get_pdu_cache",
             clock=self._clock,
@@ -162,13 +170,13 @@ class FederationClient(FederationBase):
 
     @defer.inlineCallbacks
     @log_function
-    def backfill(self, dest, context, limit, extremities):
+    def backfill(self, dest, room_id, limit, extremities):
         """Requests some more historic PDUs for the given context from the
         given destination server.
 
         Args:
             dest (str): The remote home server to ask.
-            context (str): The context to backfill.
+            room_id (str): The room_id to backfill.
             limit (int): The maximum number of PDUs to return.
             extremities (list): List of PDU id and origins of the first pdus
                 we have seen from the context
@@ -183,18 +191,21 @@ class FederationClient(FederationBase):
             return
 
         transaction_data = yield self.transport_layer.backfill(
-            dest, context, extremities, limit)
+            dest, room_id, extremities, limit)
 
         logger.debug("backfill transaction_data=%s", repr(transaction_data))
 
+        room_version = yield self.store.get_room_version(room_id)
+        format_ver = room_version_to_event_format(room_version)
+
         pdus = [
-            event_from_pdu_json(p, outlier=False)
+            event_from_pdu_json(p, format_ver, outlier=False)
             for p in transaction_data["pdus"]
         ]
 
         # FIXME: We should handle signature failures more gracefully.
         pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
-            self._check_sigs_and_hashes(pdus),
+            self._check_sigs_and_hashes(room_version, pdus),
             consumeErrors=True,
         ).addErrback(unwrapFirstError))
 
@@ -202,7 +213,8 @@ class FederationClient(FederationBase):
 
     @defer.inlineCallbacks
     @log_function
-    def get_pdu(self, destinations, event_id, outlier=False, timeout=None):
+    def get_pdu(self, destinations, event_id, room_version, outlier=False,
+                timeout=None):
         """Requests the PDU with given origin and ID from the remote home
         servers.
 
@@ -212,6 +224,7 @@ class FederationClient(FederationBase):
         Args:
             destinations (list): Which home servers to query
             event_id (str): event to fetch
+            room_version (str): version of the room
             outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
                 it's from an arbitary point in the context as opposed to part
                 of the current block of PDUs. Defaults to `False`
@@ -230,6 +243,8 @@ class FederationClient(FederationBase):
 
         pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
 
+        format_ver = room_version_to_event_format(room_version)
+
         signed_pdu = None
         for destination in destinations:
             now = self._clock.time_msec()
@@ -245,7 +260,7 @@ class FederationClient(FederationBase):
                 logger.debug("transaction_data %r", transaction_data)
 
                 pdu_list = [
-                    event_from_pdu_json(p, outlier=outlier)
+                    event_from_pdu_json(p, format_ver, outlier=outlier)
                     for p in transaction_data["pdus"]
                 ]
 
@@ -253,7 +268,7 @@ class FederationClient(FederationBase):
                     pdu = pdu_list[0]
 
                     # Check signatures are correct.
-                    signed_pdu = yield self._check_sigs_and_hash(pdu)
+                    signed_pdu = yield self._check_sigs_and_hash(room_version, pdu)
 
                     break
 
@@ -339,12 +354,16 @@ class FederationClient(FederationBase):
             destination, room_id, event_id=event_id,
         )
 
+        room_version = yield self.store.get_room_version(room_id)
+        format_ver = room_version_to_event_format(room_version)
+
         pdus = [
-            event_from_pdu_json(p, outlier=True) for p in result["pdus"]
+            event_from_pdu_json(p, format_ver, outlier=True)
+            for p in result["pdus"]
         ]
 
         auth_chain = [
-            event_from_pdu_json(p, outlier=True)
+            event_from_pdu_json(p, format_ver, outlier=True)
             for p in result.get("auth_chain", [])
         ]
 
@@ -355,7 +374,8 @@ class FederationClient(FederationBase):
         signed_pdus = yield self._check_sigs_and_hash_and_fetch(
             destination,
             [p for p in pdus if p.event_id not in seen_events],
-            outlier=True
+            outlier=True,
+            room_version=room_version,
         )
         signed_pdus.extend(
             seen_events[p.event_id] for p in pdus if p.event_id in seen_events
@@ -364,7 +384,8 @@ class FederationClient(FederationBase):
         signed_auth = yield self._check_sigs_and_hash_and_fetch(
             destination,
             [p for p in auth_chain if p.event_id not in seen_events],
-            outlier=True
+            outlier=True,
+            room_version=room_version,
         )
         signed_auth.extend(
             seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
@@ -411,6 +432,8 @@ class FederationClient(FederationBase):
             random.shuffle(srvs)
             return srvs
 
+        room_version = yield self.store.get_room_version(room_id)
+
         batch_size = 20
         missing_events = list(missing_events)
         for i in range(0, len(missing_events), batch_size):
@@ -421,6 +444,7 @@ class FederationClient(FederationBase):
                     self.get_pdu,
                     destinations=random_server_list(),
                     event_id=e_id,
+                    room_version=room_version,
                 )
                 for e_id in batch
             ]
@@ -445,13 +469,17 @@ class FederationClient(FederationBase):
             destination, room_id, event_id,
         )
 
+        room_version = yield self.store.get_room_version(room_id)
+        format_ver = room_version_to_event_format(room_version)
+
         auth_chain = [
-            event_from_pdu_json(p, outlier=True)
+            event_from_pdu_json(p, format_ver, outlier=True)
             for p in res["auth_chain"]
         ]
 
         signed_auth = yield self._check_sigs_and_hash_and_fetch(
-            destination, auth_chain, outlier=True
+            destination, auth_chain,
+            outlier=True, room_version=room_version,
         )
 
         signed_auth.sort(key=lambda e: e.depth)
@@ -522,6 +550,8 @@ class FederationClient(FederationBase):
         Does so by asking one of the already participating servers to create an
         event with proper context.
 
+        Returns a fully signed and hashed event.
+
         Note that this does not append any events to any graphs.
 
         Args:
@@ -536,8 +566,10 @@ class FederationClient(FederationBase):
             params (dict[str, str|Iterable[str]]): Query parameters to include in the
                 request.
         Return:
-            Deferred: resolves to a tuple of (origin (str), event (object))
-            where origin is the remote homeserver which generated the event.
+            Deferred[tuple[str, FrozenEvent, int]]: resolves to a tuple of
+            `(origin, event, event_format)` where origin is the remote
+            homeserver which generated the event, and event_format is one of
+            `synapse.api.constants.EventFormatVersions`.
 
             Fails with a ``SynapseError`` if the chosen remote server
             returns a 300/400 code.
@@ -557,6 +589,11 @@ class FederationClient(FederationBase):
                 destination, room_id, user_id, membership, params,
             )
 
+            # Note: If not supplied, the room version may be either v1 or v2,
+            # however either way the event format version will be v1.
+            room_version = ret.get("room_version", RoomVersions.V1)
+            event_format = room_version_to_event_format(room_version)
+
             pdu_dict = ret.get("event", None)
             if not isinstance(pdu_dict, dict):
                 raise InvalidResponseError("Bad 'event' field in response")
@@ -571,17 +608,20 @@ class FederationClient(FederationBase):
             if "prev_state" not in pdu_dict:
                 pdu_dict["prev_state"] = []
 
-            ev = builder.EventBuilder(pdu_dict)
+            ev = builder.create_local_event_from_event_dict(
+                self._clock, self.hostname, self.signing_key,
+                format_version=event_format, event_dict=pdu_dict,
+            )
 
             defer.returnValue(
-                (destination, ev)
+                (destination, ev, event_format)
             )
 
         return self._try_destination_list(
             "make_" + membership, destinations, send_request,
         )
 
-    def send_join(self, destinations, pdu):
+    def send_join(self, destinations, pdu, event_format_version):
         """Sends a join event to one of a list of homeservers.
 
         Doing so will cause the remote server to add the event to the graph,
@@ -591,6 +631,7 @@ class FederationClient(FederationBase):
             destinations (str): Candidate homeservers which are probably
                 participating in the room.
             pdu (BaseEvent): event to be sent
+            event_format_version (int): The event format version
 
         Return:
             Deferred: resolves to a dict with members ``origin`` (a string
@@ -636,12 +677,12 @@ class FederationClient(FederationBase):
             logger.debug("Got content: %s", content)
 
             state = [
-                event_from_pdu_json(p, outlier=True)
+                event_from_pdu_json(p, event_format_version, outlier=True)
                 for p in content.get("state", [])
             ]
 
             auth_chain = [
-                event_from_pdu_json(p, outlier=True)
+                event_from_pdu_json(p, event_format_version, outlier=True)
                 for p in content.get("auth_chain", [])
             ]
 
@@ -650,9 +691,21 @@ class FederationClient(FederationBase):
                 for p in itertools.chain(state, auth_chain)
             }
 
+            room_version = None
+            for e in state:
+                if (e.type, e.state_key) == (EventTypes.Create, ""):
+                    room_version = e.content.get("room_version", RoomVersions.V1)
+                    break
+
+            if room_version is None:
+                # If the state doesn't have a create event then the room is
+                # invalid, and it would fail auth checks anyway.
+                raise SynapseError(400, "No create event in state")
+
             valid_pdus = yield self._check_sigs_and_hash_and_fetch(
                 destination, list(pdus.values()),
                 outlier=True,
+                room_version=room_version,
             )
 
             valid_pdus_map = {
@@ -690,32 +743,75 @@ class FederationClient(FederationBase):
 
     @defer.inlineCallbacks
     def send_invite(self, destination, room_id, event_id, pdu):
-        time_now = self._clock.time_msec()
-        try:
-            code, content = yield self.transport_layer.send_invite(
-                destination=destination,
-                room_id=room_id,
-                event_id=event_id,
-                content=pdu.get_pdu_json(time_now),
-            )
-        except HttpResponseException as e:
-            if e.code == 403:
-                raise e.to_synapse_error()
-            raise
+        room_version = yield self.store.get_room_version(room_id)
+
+        content = yield self._do_send_invite(destination, pdu, room_version)
 
         pdu_dict = content["event"]
 
         logger.debug("Got response to send_invite: %s", pdu_dict)
 
-        pdu = event_from_pdu_json(pdu_dict)
+        room_version = yield self.store.get_room_version(room_id)
+        format_ver = room_version_to_event_format(room_version)
+
+        pdu = event_from_pdu_json(pdu_dict, format_ver)
 
         # Check signatures are correct.
-        pdu = yield self._check_sigs_and_hash(pdu)
+        pdu = yield self._check_sigs_and_hash(room_version, pdu)
 
         # FIXME: We should handle signature failures more gracefully.
 
         defer.returnValue(pdu)
 
+    @defer.inlineCallbacks
+    def _do_send_invite(self, destination, pdu, room_version):
+        """Actually sends the invite, first trying v2 API and falling back to
+        v1 API if necessary.
+
+        Args:
+            destination (str): Target server
+            pdu (FrozenEvent)
+            room_version (str)
+
+        Returns:
+            dict: The event as a dict as returned by the remote server
+        """
+        time_now = self._clock.time_msec()
+
+        try:
+            content = yield self.transport_layer.send_invite_v2(
+                destination=destination,
+                room_id=pdu.room_id,
+                event_id=pdu.event_id,
+                content={
+                    "event": pdu.get_pdu_json(time_now),
+                    "room_version": room_version,
+                    "invite_room_state": pdu.unsigned.get("invite_room_state", []),
+                },
+            )
+            defer.returnValue(content)
+        except HttpResponseException as e:
+            if e.code in [400, 404]:
+                if room_version in (RoomVersions.V1, RoomVersions.V2):
+                    pass  # We'll fall through
+                else:
+                    raise Exception("Remote server is too old")
+            elif e.code == 403:
+                raise e.to_synapse_error()
+            else:
+                raise
+
+        # Didn't work, try v1 API.
+        # Note the v1 API returns a tuple of `(200, content)`
+
+        _, content = yield self.transport_layer.send_invite_v1(
+            destination=destination,
+            room_id=pdu.room_id,
+            event_id=pdu.event_id,
+            content=pdu.get_pdu_json(time_now),
+        )
+        defer.returnValue(content)
+
     def send_leave(self, destinations, pdu):
         """Sends a leave event to one of a list of homeservers.
 
@@ -785,13 +881,16 @@ class FederationClient(FederationBase):
             content=send_content,
         )
 
+        room_version = yield self.store.get_room_version(room_id)
+        format_ver = room_version_to_event_format(room_version)
+
         auth_chain = [
-            event_from_pdu_json(e)
+            event_from_pdu_json(e, format_ver)
             for e in content["auth_chain"]
         ]
 
         signed_auth = yield self._check_sigs_and_hash_and_fetch(
-            destination, auth_chain, outlier=True
+            destination, auth_chain, outlier=True, room_version=room_version,
         )
 
         signed_auth.sort(key=lambda e: e.depth)
@@ -833,13 +932,16 @@ class FederationClient(FederationBase):
                 timeout=timeout,
             )
 
+            room_version = yield self.store.get_room_version(room_id)
+            format_ver = room_version_to_event_format(room_version)
+
             events = [
-                event_from_pdu_json(e)
+                event_from_pdu_json(e, format_ver)
                 for e in content.get("events", [])
             ]
 
             signed_events = yield self._check_sigs_and_hash_and_fetch(
-                destination, events, outlier=False
+                destination, events, outlier=False, room_version=room_version,
             )
         except HttpResponseException as e:
             if not e.code == 400:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 0f9302a6a8..3da86d4ba6 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -25,7 +25,7 @@ from twisted.internet import defer
 from twisted.internet.abstract import isIPAddress
 from twisted.python import failure
 
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import (
     AuthError,
     FederationError,
@@ -34,6 +34,7 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.crypto.event_signing import compute_event_signature
+from synapse.events import room_version_to_event_format
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
 from synapse.federation.persistence import TransactionActions
 from synapse.federation.units import Edu, Transaction
@@ -147,6 +148,22 @@ class FederationServer(FederationBase):
 
         logger.debug("[%s] Transaction is new", transaction.transaction_id)
 
+        # Reject if PDU count > 50 and EDU count > 100
+        if (len(transaction.pdus) > 50
+                or (hasattr(transaction, "edus") and len(transaction.edus) > 100)):
+
+            logger.info(
+                "Transaction PDU or EDU count too large. Returning 400",
+            )
+
+            response = {}
+            yield self.transaction_actions.set_response(
+                origin,
+                transaction,
+                400, response
+            )
+            defer.returnValue((400, response))
+
         received_pdus_counter.inc(len(transaction.pdus))
 
         origin_host, _ = parse_server_name(origin)
@@ -162,8 +179,29 @@ class FederationServer(FederationBase):
                 p["age_ts"] = request_time - int(p["age"])
                 del p["age"]
 
-            event = event_from_pdu_json(p)
-            room_id = event.room_id
+            # We try and pull out an event ID so that if later checks fail we
+            # can log something sensible. We don't mandate an event ID here in
+            # case future event formats get rid of the key.
+            possible_event_id = p.get("event_id", "<Unknown>")
+
+            # Now we get the room ID so that we can check that we know the
+            # version of the room.
+            room_id = p.get("room_id")
+            if not room_id:
+                logger.info(
+                    "Ignoring PDU as does not have a room_id. Event ID: %s",
+                    possible_event_id,
+                )
+                continue
+
+            try:
+                room_version = yield self.store.get_room_version(room_id)
+                format_ver = room_version_to_event_format(room_version)
+            except NotFoundError:
+                logger.info("Ignoring PDU for unknown room_id: %s", room_id)
+                continue
+
+            event = event_from_pdu_json(p, format_ver)
             pdus_by_room.setdefault(room_id, []).append(event)
 
         pdu_results = {}
@@ -300,7 +338,7 @@ class FederationServer(FederationBase):
             if self.hs.is_mine_id(event.event_id):
                 event.signatures.update(
                     compute_event_signature(
-                        event,
+                        event.get_pdu_json(),
                         self.hs.hostname,
                         self.hs.config.signing_key[0]
                     )
@@ -324,11 +362,6 @@ class FederationServer(FederationBase):
             defer.returnValue((404, ""))
 
     @defer.inlineCallbacks
-    @log_function
-    def on_pull_request(self, origin, versions):
-        raise NotImplementedError("Pull transactions not implemented")
-
-    @defer.inlineCallbacks
     def on_query_request(self, query_type, args):
         received_queries_counter.labels(query_type).inc()
         resp = yield self.registry.on_query(query_type, args)
@@ -352,18 +385,23 @@ class FederationServer(FederationBase):
         })
 
     @defer.inlineCallbacks
-    def on_invite_request(self, origin, content):
-        pdu = event_from_pdu_json(content)
+    def on_invite_request(self, origin, content, room_version):
+        format_ver = room_version_to_event_format(room_version)
+
+        pdu = event_from_pdu_json(content, format_ver)
         origin_host, _ = parse_server_name(origin)
         yield self.check_server_matches_acl(origin_host, pdu.room_id)
         ret_pdu = yield self.handler.on_invite_request(origin, pdu)
         time_now = self._clock.time_msec()
-        defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
+        defer.returnValue({"event": ret_pdu.get_pdu_json(time_now)})
 
     @defer.inlineCallbacks
-    def on_send_join_request(self, origin, content):
+    def on_send_join_request(self, origin, content, room_id):
         logger.debug("on_send_join_request: content: %s", content)
-        pdu = event_from_pdu_json(content)
+
+        room_version = yield self.store.get_room_version(room_id)
+        format_ver = room_version_to_event_format(room_version)
+        pdu = event_from_pdu_json(content, format_ver)
 
         origin_host, _ = parse_server_name(origin)
         yield self.check_server_matches_acl(origin_host, pdu.room_id)
@@ -383,13 +421,22 @@ class FederationServer(FederationBase):
         origin_host, _ = parse_server_name(origin)
         yield self.check_server_matches_acl(origin_host, room_id)
         pdu = yield self.handler.on_make_leave_request(room_id, user_id)
+
+        room_version = yield self.store.get_room_version(room_id)
+
         time_now = self._clock.time_msec()
-        defer.returnValue({"event": pdu.get_pdu_json(time_now)})
+        defer.returnValue({
+            "event": pdu.get_pdu_json(time_now),
+            "room_version": room_version,
+        })
 
     @defer.inlineCallbacks
-    def on_send_leave_request(self, origin, content):
+    def on_send_leave_request(self, origin, content, room_id):
         logger.debug("on_send_leave_request: content: %s", content)
-        pdu = event_from_pdu_json(content)
+
+        room_version = yield self.store.get_room_version(room_id)
+        format_ver = room_version_to_event_format(room_version)
+        pdu = event_from_pdu_json(content, format_ver)
 
         origin_host, _ = parse_server_name(origin)
         yield self.check_server_matches_acl(origin_host, pdu.room_id)
@@ -435,13 +482,16 @@ class FederationServer(FederationBase):
             origin_host, _ = parse_server_name(origin)
             yield self.check_server_matches_acl(origin_host, room_id)
 
+            room_version = yield self.store.get_room_version(room_id)
+            format_ver = room_version_to_event_format(room_version)
+
             auth_chain = [
-                event_from_pdu_json(e)
+                event_from_pdu_json(e, format_ver)
                 for e in content["auth_chain"]
             ]
 
             signed_auth = yield self._check_sigs_and_hash_and_fetch(
-                origin, auth_chain, outlier=True
+                origin, auth_chain, outlier=True, room_version=room_version,
             )
 
             ret = yield self.handler.on_query_auth(
@@ -586,16 +636,19 @@ class FederationServer(FederationBase):
         """
         # check that it's actually being sent from a valid destination to
         # workaround bug #1753 in 0.18.5 and 0.18.6
-        if origin != get_domain_from_id(pdu.event_id):
+        if origin != get_domain_from_id(pdu.sender):
             # We continue to accept join events from any server; this is
             # necessary for the federation join dance to work correctly.
             # (When we join over federation, the "helper" server is
             # responsible for sending out the join event, rather than the
-            # origin. See bug #1893).
+            # origin. See bug #1893. This is also true for some third party
+            # invites).
             if not (
                 pdu.type == 'm.room.member' and
                 pdu.content and
-                pdu.content.get("membership", None) == 'join'
+                pdu.content.get("membership", None) in (
+                    Membership.JOIN, Membership.INVITE,
+                )
             ):
                 logger.info(
                     "Discarding PDU %s from invalid origin %s",
@@ -608,9 +661,12 @@ class FederationServer(FederationBase):
                     pdu.event_id, origin
                 )
 
+        # We've already checked that we know the room version by this point
+        room_version = yield self.store.get_room_version(pdu.room_id)
+
         # Check signature.
         try:
-            pdu = yield self._check_sigs_and_hash(pdu)
+            pdu = yield self._check_sigs_and_hash(room_version, pdu)
         except SynapseError as e:
             raise FederationError(
                 "ERROR",
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 3fdd63be95..30941f5ad6 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -22,14 +22,17 @@ from prometheus_client import Counter
 from twisted.internet import defer
 
 import synapse.metrics
-from synapse.api.errors import FederationDeniedError, HttpResponseException
+from synapse.api.errors import (
+    FederationDeniedError,
+    HttpResponseException,
+    RequestSendFailed,
+)
 from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
 from synapse.metrics import (
     LaterGauge,
     event_processing_loop_counter,
     event_processing_loop_room_count,
     events_processed_counter,
-    sent_edus_counter,
     sent_transactions_counter,
 )
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -43,10 +46,24 @@ from .units import Edu, Transaction
 logger = logging.getLogger(__name__)
 
 sent_pdus_destination_dist_count = Counter(
-    "synapse_federation_client_sent_pdu_destinations:count", ""
+    "synapse_federation_client_sent_pdu_destinations:count",
+    "Number of PDUs queued for sending to one or more destinations",
 )
+
 sent_pdus_destination_dist_total = Counter(
     "synapse_federation_client_sent_pdu_destinations:total", ""
+    "Total number of PDUs queued for sending across all destinations",
+)
+
+sent_edus_counter = Counter(
+    "synapse_federation_client_sent_edus",
+    "Total number of EDUs successfully sent",
+)
+
+sent_edus_by_type = Counter(
+    "synapse_federation_client_sent_edus_by_type",
+    "Number of sent EDUs successfully sent, by event type",
+    ["type"],
 )
 
 
@@ -171,7 +188,7 @@ class TransactionQueue(object):
                 def handle_event(event):
                     # Only send events for this server.
                     send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of()
-                    is_mine = self.is_mine_id(event.event_id)
+                    is_mine = self.is_mine_id(event.sender)
                     if not is_mine and send_on_behalf_of is None:
                         return
 
@@ -183,9 +200,7 @@ class TransactionQueue(object):
                         # banned then it won't receive the event because it won't
                         # be in the room after the ban.
                         destinations = yield self.state.get_current_hosts_in_room(
-                            event.room_id, latest_event_ids=[
-                                prev_id for prev_id, _ in event.prev_events
-                            ],
+                            event.room_id, latest_event_ids=event.prev_event_ids(),
                         )
                     except Exception:
                         logger.exception(
@@ -358,8 +373,6 @@ class TransactionQueue(object):
             logger.info("Not sending EDU to ourselves")
             return
 
-        sent_edus_counter.inc()
-
         if key:
             self.pending_edus_keyed_by_dest.setdefault(
                 destination, {}
@@ -494,6 +507,9 @@ class TransactionQueue(object):
                 )
                 if success:
                     sent_transactions_counter.inc()
+                    sent_edus_counter.inc(len(pending_edus))
+                    for edu in pending_edus:
+                        sent_edus_by_type.labels(edu.edu_type).inc()
                     # Remove the acknowledged device messages from the database
                     # Only bother if we actually sent some device messages
                     if device_message_edus:
@@ -520,11 +536,21 @@ class TransactionQueue(object):
             )
         except FederationDeniedError as e:
             logger.info(e)
-        except Exception as e:
-            logger.warn(
-                "TX [%s] Failed to send transaction: %s",
+        except HttpResponseException as e:
+            logger.warning(
+                "TX [%s] Received %d response to transaction: %s",
+                destination, e.code, e,
+            )
+        except RequestSendFailed as e:
+            logger.warning("TX [%s] Failed to send transaction: %s", destination, e)
+
+            for p, _ in pending_pdus:
+                logger.info("Failed to send event %s to %s", p.event_id,
+                            destination)
+        except Exception:
+            logger.exception(
+                "TX [%s] Failed to send transaction",
                 destination,
-                e,
             )
             for p, _ in pending_pdus:
                 logger.info("Failed to send event %s to %s", p.event_id,
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index edba5a9808..8e2be218e2 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -21,7 +21,7 @@ from six.moves import urllib
 from twisted.internet import defer
 
 from synapse.api.constants import Membership
-from synapse.api.urls import FEDERATION_PREFIX as PREFIX
+from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX
 from synapse.util.logutils import log_function
 
 logger = logging.getLogger(__name__)
@@ -51,7 +51,7 @@ class TransportLayerClient(object):
         logger.debug("get_room_state dest=%s, room=%s",
                      destination, room_id)
 
-        path = _create_path(PREFIX, "/state/%s/", room_id)
+        path = _create_v1_path("/state/%s/", room_id)
         return self.client.get_json(
             destination, path=path, args={"event_id": event_id},
         )
@@ -73,7 +73,7 @@ class TransportLayerClient(object):
         logger.debug("get_room_state_ids dest=%s, room=%s",
                      destination, room_id)
 
-        path = _create_path(PREFIX, "/state_ids/%s/", room_id)
+        path = _create_v1_path("/state_ids/%s/", room_id)
         return self.client.get_json(
             destination, path=path, args={"event_id": event_id},
         )
@@ -95,7 +95,7 @@ class TransportLayerClient(object):
         logger.debug("get_pdu dest=%s, event_id=%s",
                      destination, event_id)
 
-        path = _create_path(PREFIX, "/event/%s/", event_id)
+        path = _create_v1_path("/event/%s/", event_id)
         return self.client.get_json(destination, path=path, timeout=timeout)
 
     @log_function
@@ -121,7 +121,7 @@ class TransportLayerClient(object):
             # TODO: raise?
             return
 
-        path = _create_path(PREFIX, "/backfill/%s/", room_id)
+        path = _create_v1_path("/backfill/%s/", room_id)
 
         args = {
             "v": event_tuples,
@@ -167,7 +167,7 @@ class TransportLayerClient(object):
         # generated by the json_data_callback.
         json_data = transaction.get_dict()
 
-        path = _create_path(PREFIX, "/send/%s/", transaction.transaction_id)
+        path = _create_v1_path("/send/%s/", transaction.transaction_id)
 
         response = yield self.client.put_json(
             transaction.destination,
@@ -184,7 +184,7 @@ class TransportLayerClient(object):
     @log_function
     def make_query(self, destination, query_type, args, retry_on_dns_fail,
                    ignore_backoff=False):
-        path = _create_path(PREFIX, "/query/%s", query_type)
+        path = _create_v1_path("/query/%s", query_type)
 
         content = yield self.client.get_json(
             destination=destination,
@@ -231,7 +231,7 @@ class TransportLayerClient(object):
                 "make_membership_event called with membership='%s', must be one of %s" %
                 (membership, ",".join(valid_memberships))
             )
-        path = _create_path(PREFIX, "/make_%s/%s/%s", membership, room_id, user_id)
+        path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id)
 
         ignore_backoff = False
         retry_on_dns_fail = False
@@ -258,7 +258,7 @@ class TransportLayerClient(object):
     @defer.inlineCallbacks
     @log_function
     def send_join(self, destination, room_id, event_id, content):
-        path = _create_path(PREFIX, "/send_join/%s/%s", room_id, event_id)
+        path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
 
         response = yield self.client.put_json(
             destination=destination,
@@ -271,7 +271,7 @@ class TransportLayerClient(object):
     @defer.inlineCallbacks
     @log_function
     def send_leave(self, destination, room_id, event_id, content):
-        path = _create_path(PREFIX, "/send_leave/%s/%s", room_id, event_id)
+        path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
 
         response = yield self.client.put_json(
             destination=destination,
@@ -289,8 +289,22 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
-    def send_invite(self, destination, room_id, event_id, content):
-        path = _create_path(PREFIX, "/invite/%s/%s", room_id, event_id)
+    def send_invite_v1(self, destination, room_id, event_id, content):
+        path = _create_v1_path("/invite/%s/%s", room_id, event_id)
+
+        response = yield self.client.put_json(
+            destination=destination,
+            path=path,
+            data=content,
+            ignore_backoff=True,
+        )
+
+        defer.returnValue(response)
+
+    @defer.inlineCallbacks
+    @log_function
+    def send_invite_v2(self, destination, room_id, event_id, content):
+        path = _create_v2_path("/invite/%s/%s", room_id, event_id)
 
         response = yield self.client.put_json(
             destination=destination,
@@ -306,7 +320,7 @@ class TransportLayerClient(object):
     def get_public_rooms(self, remote_server, limit, since_token,
                          search_filter=None, include_all_networks=False,
                          third_party_instance_id=None):
-        path = PREFIX + "/publicRooms"
+        path = _create_v1_path("/publicRooms")
 
         args = {
             "include_all_networks": "true" if include_all_networks else "false",
@@ -332,7 +346,7 @@ class TransportLayerClient(object):
     @defer.inlineCallbacks
     @log_function
     def exchange_third_party_invite(self, destination, room_id, event_dict):
-        path = _create_path(PREFIX, "/exchange_third_party_invite/%s", room_id,)
+        path = _create_v1_path("/exchange_third_party_invite/%s", room_id,)
 
         response = yield self.client.put_json(
             destination=destination,
@@ -345,7 +359,7 @@ class TransportLayerClient(object):
     @defer.inlineCallbacks
     @log_function
     def get_event_auth(self, destination, room_id, event_id):
-        path = _create_path(PREFIX, "/event_auth/%s/%s", room_id, event_id)
+        path = _create_v1_path("/event_auth/%s/%s", room_id, event_id)
 
         content = yield self.client.get_json(
             destination=destination,
@@ -357,7 +371,7 @@ class TransportLayerClient(object):
     @defer.inlineCallbacks
     @log_function
     def send_query_auth(self, destination, room_id, event_id, content):
-        path = _create_path(PREFIX, "/query_auth/%s/%s", room_id, event_id)
+        path = _create_v1_path("/query_auth/%s/%s", room_id, event_id)
 
         content = yield self.client.post_json(
             destination=destination,
@@ -392,7 +406,7 @@ class TransportLayerClient(object):
         Returns:
             A dict containg the device keys.
         """
-        path = PREFIX + "/user/keys/query"
+        path = _create_v1_path("/user/keys/query")
 
         content = yield self.client.post_json(
             destination=destination,
@@ -419,7 +433,7 @@ class TransportLayerClient(object):
         Returns:
             A dict containg the device keys.
         """
-        path = _create_path(PREFIX, "/user/devices/%s", user_id)
+        path = _create_v1_path("/user/devices/%s", user_id)
 
         content = yield self.client.get_json(
             destination=destination,
@@ -455,7 +469,7 @@ class TransportLayerClient(object):
             A dict containg the one-time keys.
         """
 
-        path = PREFIX + "/user/keys/claim"
+        path = _create_v1_path("/user/keys/claim")
 
         content = yield self.client.post_json(
             destination=destination,
@@ -469,7 +483,7 @@ class TransportLayerClient(object):
     @log_function
     def get_missing_events(self, destination, room_id, earliest_events,
                            latest_events, limit, min_depth, timeout):
-        path = _create_path(PREFIX, "/get_missing_events/%s", room_id,)
+        path = _create_v1_path("/get_missing_events/%s", room_id,)
 
         content = yield self.client.post_json(
             destination=destination,
@@ -489,7 +503,7 @@ class TransportLayerClient(object):
     def get_group_profile(self, destination, group_id, requester_user_id):
         """Get a group profile
         """
-        path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
+        path = _create_v1_path("/groups/%s/profile", group_id,)
 
         return self.client.get_json(
             destination=destination,
@@ -508,7 +522,7 @@ class TransportLayerClient(object):
             requester_user_id (str)
             content (dict): The new profile of the group
         """
-        path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
+        path = _create_v1_path("/groups/%s/profile", group_id,)
 
         return self.client.post_json(
             destination=destination,
@@ -522,7 +536,7 @@ class TransportLayerClient(object):
     def get_group_summary(self, destination, group_id, requester_user_id):
         """Get a group summary
         """
-        path = _create_path(PREFIX, "/groups/%s/summary", group_id,)
+        path = _create_v1_path("/groups/%s/summary", group_id,)
 
         return self.client.get_json(
             destination=destination,
@@ -535,7 +549,7 @@ class TransportLayerClient(object):
     def get_rooms_in_group(self, destination, group_id, requester_user_id):
         """Get all rooms in a group
         """
-        path = _create_path(PREFIX, "/groups/%s/rooms", group_id,)
+        path = _create_v1_path("/groups/%s/rooms", group_id,)
 
         return self.client.get_json(
             destination=destination,
@@ -548,7 +562,7 @@ class TransportLayerClient(object):
                           content):
         """Add a room to a group
         """
-        path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
+        path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,)
 
         return self.client.post_json(
             destination=destination,
@@ -562,8 +576,8 @@ class TransportLayerClient(object):
                              config_key, content):
         """Update room in group
         """
-        path = _create_path(
-            PREFIX, "/groups/%s/room/%s/config/%s",
+        path = _create_v1_path(
+            "/groups/%s/room/%s/config/%s",
             group_id, room_id, config_key,
         )
 
@@ -578,7 +592,7 @@ class TransportLayerClient(object):
     def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
         """Remove a room from a group
         """
-        path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
+        path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,)
 
         return self.client.delete_json(
             destination=destination,
@@ -591,7 +605,7 @@ class TransportLayerClient(object):
     def get_users_in_group(self, destination, group_id, requester_user_id):
         """Get users in a group
         """
-        path = _create_path(PREFIX, "/groups/%s/users", group_id,)
+        path = _create_v1_path("/groups/%s/users", group_id,)
 
         return self.client.get_json(
             destination=destination,
@@ -604,7 +618,7 @@ class TransportLayerClient(object):
     def get_invited_users_in_group(self, destination, group_id, requester_user_id):
         """Get users that have been invited to a group
         """
-        path = _create_path(PREFIX, "/groups/%s/invited_users", group_id,)
+        path = _create_v1_path("/groups/%s/invited_users", group_id,)
 
         return self.client.get_json(
             destination=destination,
@@ -617,8 +631,8 @@ class TransportLayerClient(object):
     def accept_group_invite(self, destination, group_id, user_id, content):
         """Accept a group invite
         """
-        path = _create_path(
-            PREFIX, "/groups/%s/users/%s/accept_invite",
+        path = _create_v1_path(
+            "/groups/%s/users/%s/accept_invite",
             group_id, user_id,
         )
 
@@ -633,7 +647,7 @@ class TransportLayerClient(object):
     def join_group(self, destination, group_id, user_id, content):
         """Attempts to join a group
         """
-        path = _create_path(PREFIX, "/groups/%s/users/%s/join", group_id, user_id)
+        path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
 
         return self.client.post_json(
             destination=destination,
@@ -646,7 +660,7 @@ class TransportLayerClient(object):
     def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
         """Invite a user to a group
         """
-        path = _create_path(PREFIX, "/groups/%s/users/%s/invite", group_id, user_id)
+        path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id)
 
         return self.client.post_json(
             destination=destination,
@@ -662,7 +676,7 @@ class TransportLayerClient(object):
         invited.
         """
 
-        path = _create_path(PREFIX, "/groups/local/%s/users/%s/invite", group_id, user_id)
+        path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id)
 
         return self.client.post_json(
             destination=destination,
@@ -676,7 +690,7 @@ class TransportLayerClient(object):
                                user_id, content):
         """Remove a user fron a group
         """
-        path = _create_path(PREFIX, "/groups/%s/users/%s/remove", group_id, user_id)
+        path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
 
         return self.client.post_json(
             destination=destination,
@@ -693,7 +707,7 @@ class TransportLayerClient(object):
         kicked from the group.
         """
 
-        path = _create_path(PREFIX, "/groups/local/%s/users/%s/remove", group_id, user_id)
+        path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id)
 
         return self.client.post_json(
             destination=destination,
@@ -708,7 +722,7 @@ class TransportLayerClient(object):
         the attestations
         """
 
-        path = _create_path(PREFIX, "/groups/%s/renew_attestation/%s", group_id, user_id)
+        path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id)
 
         return self.client.post_json(
             destination=destination,
@@ -723,12 +737,12 @@ class TransportLayerClient(object):
         """Update a room entry in a group summary
         """
         if category_id:
-            path = _create_path(
-                PREFIX, "/groups/%s/summary/categories/%s/rooms/%s",
+            path = _create_v1_path(
+                "/groups/%s/summary/categories/%s/rooms/%s",
                 group_id, category_id, room_id,
             )
         else:
-            path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
+            path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,)
 
         return self.client.post_json(
             destination=destination,
@@ -744,12 +758,12 @@ class TransportLayerClient(object):
         """Delete a room entry in a group summary
         """
         if category_id:
-            path = _create_path(
-                PREFIX + "/groups/%s/summary/categories/%s/rooms/%s",
+            path = _create_v1_path(
+                "/groups/%s/summary/categories/%s/rooms/%s",
                 group_id, category_id, room_id,
             )
         else:
-            path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
+            path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,)
 
         return self.client.delete_json(
             destination=destination,
@@ -762,7 +776,7 @@ class TransportLayerClient(object):
     def get_group_categories(self, destination, group_id, requester_user_id):
         """Get all categories in a group
         """
-        path = _create_path(PREFIX, "/groups/%s/categories", group_id,)
+        path = _create_v1_path("/groups/%s/categories", group_id,)
 
         return self.client.get_json(
             destination=destination,
@@ -775,7 +789,7 @@ class TransportLayerClient(object):
     def get_group_category(self, destination, group_id, requester_user_id, category_id):
         """Get category info in a group
         """
-        path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
+        path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
 
         return self.client.get_json(
             destination=destination,
@@ -789,7 +803,7 @@ class TransportLayerClient(object):
                               content):
         """Update a category in a group
         """
-        path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
+        path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
 
         return self.client.post_json(
             destination=destination,
@@ -804,7 +818,7 @@ class TransportLayerClient(object):
                               category_id):
         """Delete a category in a group
         """
-        path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
+        path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
 
         return self.client.delete_json(
             destination=destination,
@@ -817,7 +831,7 @@ class TransportLayerClient(object):
     def get_group_roles(self, destination, group_id, requester_user_id):
         """Get all roles in a group
         """
-        path = _create_path(PREFIX, "/groups/%s/roles", group_id,)
+        path = _create_v1_path("/groups/%s/roles", group_id,)
 
         return self.client.get_json(
             destination=destination,
@@ -830,7 +844,7 @@ class TransportLayerClient(object):
     def get_group_role(self, destination, group_id, requester_user_id, role_id):
         """Get a roles info
         """
-        path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
+        path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
 
         return self.client.get_json(
             destination=destination,
@@ -844,7 +858,7 @@ class TransportLayerClient(object):
                           content):
         """Update a role in a group
         """
-        path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
+        path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
 
         return self.client.post_json(
             destination=destination,
@@ -858,7 +872,7 @@ class TransportLayerClient(object):
     def delete_group_role(self, destination, group_id, requester_user_id, role_id):
         """Delete a role in a group
         """
-        path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
+        path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
 
         return self.client.delete_json(
             destination=destination,
@@ -873,12 +887,12 @@ class TransportLayerClient(object):
         """Update a users entry in a group
         """
         if role_id:
-            path = _create_path(
-                PREFIX, "/groups/%s/summary/roles/%s/users/%s",
+            path = _create_v1_path(
+                "/groups/%s/summary/roles/%s/users/%s",
                 group_id, role_id, user_id,
             )
         else:
-            path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
+            path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,)
 
         return self.client.post_json(
             destination=destination,
@@ -893,7 +907,7 @@ class TransportLayerClient(object):
                               content):
         """Sets the join policy for a group
         """
-        path = _create_path(PREFIX, "/groups/%s/settings/m.join_policy", group_id,)
+        path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id,)
 
         return self.client.put_json(
             destination=destination,
@@ -909,12 +923,12 @@ class TransportLayerClient(object):
         """Delete a users entry in a group
         """
         if role_id:
-            path = _create_path(
-                PREFIX, "/groups/%s/summary/roles/%s/users/%s",
+            path = _create_v1_path(
+                "/groups/%s/summary/roles/%s/users/%s",
                 group_id, role_id, user_id,
             )
         else:
-            path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
+            path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,)
 
         return self.client.delete_json(
             destination=destination,
@@ -927,7 +941,7 @@ class TransportLayerClient(object):
         """Get the groups a list of users are publicising
         """
 
-        path = PREFIX + "/get_groups_publicised"
+        path = _create_v1_path("/get_groups_publicised")
 
         content = {"user_ids": user_ids}
 
@@ -939,20 +953,43 @@ class TransportLayerClient(object):
         )
 
 
-def _create_path(prefix, path, *args):
-    """Creates a path from the prefix, path template and args. Ensures that
-    all args are url encoded.
+def _create_v1_path(path, *args):
+    """Creates a path against V1 federation API from the path template and
+    args. Ensures that all args are url encoded.
+
+    Example:
+
+        _create_v1_path("/event/%s/", event_id)
+
+    Args:
+        path (str): String template for the path
+        args: ([str]): Args to insert into path. Each arg will be url encoded
+
+    Returns:
+        str
+    """
+    return (
+        FEDERATION_V1_PREFIX
+        + path % tuple(urllib.parse.quote(arg, "") for arg in args)
+    )
+
+
+def _create_v2_path(path, *args):
+    """Creates a path against V2 federation API from the path template and
+    args. Ensures that all args are url encoded.
 
     Example:
 
-        _create_path(PREFIX, "/event/%s/", event_id)
+        _create_v2_path("/event/%s/", event_id)
 
     Args:
-        prefix (str)
         path (str): String template for the path
         args: ([str]): Args to insert into path. Each arg will be url encoded
 
     Returns:
         str
     """
-    return prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args)
+    return (
+        FEDERATION_V2_PREFIX
+        + path % tuple(urllib.parse.quote(arg, "") for arg in args)
+    )
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 6d4a26f595..5ba94be2ec 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -21,8 +21,9 @@ import re
 from twisted.internet import defer
 
 import synapse
+from synapse.api.constants import RoomVersions
 from synapse.api.errors import Codes, FederationDeniedError, SynapseError
-from synapse.api.urls import FEDERATION_PREFIX as PREFIX
+from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX
 from synapse.http.endpoint import parse_and_validate_server_name
 from synapse.http.server import JsonResource
 from synapse.http.servlet import (
@@ -42,9 +43,20 @@ logger = logging.getLogger(__name__)
 class TransportLayerServer(JsonResource):
     """Handles incoming federation HTTP requests"""
 
-    def __init__(self, hs):
+    def __init__(self, hs, servlet_groups=None):
+        """Initialize the TransportLayerServer
+
+        Will by default register all servlets. For custom behaviour, pass in
+        a list of servlet_groups to register.
+
+        Args:
+            hs (synapse.server.HomeServer): homeserver
+            servlet_groups (list[str], optional): List of servlet groups to register.
+                Defaults to ``DEFAULT_SERVLET_GROUPS``.
+        """
         self.hs = hs
         self.clock = hs.get_clock()
+        self.servlet_groups = servlet_groups
 
         super(TransportLayerServer, self).__init__(hs, canonical_json=False)
 
@@ -66,6 +78,7 @@ class TransportLayerServer(JsonResource):
             resource=self,
             ratelimiter=self.ratelimiter,
             authenticator=self.authenticator,
+            servlet_groups=self.servlet_groups,
         )
 
 
@@ -227,6 +240,8 @@ class BaseFederationServlet(object):
     """
     REQUIRE_AUTH = True
 
+    PREFIX = FEDERATION_V1_PREFIX  # Allows specifying the API version
+
     def __init__(self, handler, authenticator, ratelimiter, server_name):
         self.handler = handler
         self.authenticator = authenticator
@@ -286,7 +301,7 @@ class BaseFederationServlet(object):
         return new_func
 
     def register(self, server):
-        pattern = re.compile("^" + PREFIX + self.PATH + "$")
+        pattern = re.compile("^" + self.PREFIX + self.PATH + "$")
 
         for method in ("GET", "PUT", "POST"):
             code = getattr(self, "on_%s" % (method), None)
@@ -362,14 +377,6 @@ class FederationSendServlet(BaseFederationServlet):
         defer.returnValue((code, response))
 
 
-class FederationPullServlet(BaseFederationServlet):
-    PATH = "/pull/"
-
-    # This is for when someone asks us for everything since version X
-    def on_GET(self, origin, content, query):
-        return self.handler.on_pull_request(query["origin"][0], query["v"])
-
-
 class FederationEventServlet(BaseFederationServlet):
     PATH = "/event/(?P<event_id>[^/]*)/"
 
@@ -474,7 +481,7 @@ class FederationSendLeaveServlet(BaseFederationServlet):
 
     @defer.inlineCallbacks
     def on_PUT(self, origin, content, query, room_id, event_id):
-        content = yield self.handler.on_send_leave_request(origin, content)
+        content = yield self.handler.on_send_leave_request(origin, content, room_id)
         defer.returnValue((200, content))
 
 
@@ -492,18 +499,50 @@ class FederationSendJoinServlet(BaseFederationServlet):
     def on_PUT(self, origin, content, query, context, event_id):
         # TODO(paul): assert that context/event_id parsed from path actually
         #   match those given in content
-        content = yield self.handler.on_send_join_request(origin, content)
+        content = yield self.handler.on_send_join_request(origin, content, context)
         defer.returnValue((200, content))
 
 
-class FederationInviteServlet(BaseFederationServlet):
+class FederationV1InviteServlet(BaseFederationServlet):
     PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
 
     @defer.inlineCallbacks
     def on_PUT(self, origin, content, query, context, event_id):
+        # We don't get a room version, so we have to assume its EITHER v1 or
+        # v2. This is "fine" as the only difference between V1 and V2 is the
+        # state resolution algorithm, and we don't use that for processing
+        # invites
+        content = yield self.handler.on_invite_request(
+            origin, content, room_version=RoomVersions.V1,
+        )
+
+        # V1 federation API is defined to return a content of `[200, {...}]`
+        # due to a historical bug.
+        defer.returnValue((200, (200, content)))
+
+
+class FederationV2InviteServlet(BaseFederationServlet):
+    PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+
+    PREFIX = FEDERATION_V2_PREFIX
+
+    @defer.inlineCallbacks
+    def on_PUT(self, origin, content, query, context, event_id):
         # TODO(paul): assert that context/event_id parsed from path actually
         #   match those given in content
-        content = yield self.handler.on_invite_request(origin, content)
+
+        room_version = content["room_version"]
+        event = content["event"]
+        invite_room_state = content["invite_room_state"]
+
+        # Synapse expects invite_room_state to be in unsigned, as it is in v1
+        # API
+
+        event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state
+
+        content = yield self.handler.on_invite_request(
+            origin, event, room_version=room_version,
+        )
         defer.returnValue((200, content))
 
 
@@ -1262,7 +1301,6 @@ class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
 
 FEDERATION_SERVLET_CLASSES = (
     FederationSendServlet,
-    FederationPullServlet,
     FederationEventServlet,
     FederationStateServlet,
     FederationStateIdsServlet,
@@ -1273,7 +1311,8 @@ FEDERATION_SERVLET_CLASSES = (
     FederationEventServlet,
     FederationSendJoinServlet,
     FederationSendLeaveServlet,
-    FederationInviteServlet,
+    FederationV1InviteServlet,
+    FederationV2InviteServlet,
     FederationQueryAuthServlet,
     FederationGetMissingEventsServlet,
     FederationEventAuthServlet,
@@ -1282,10 +1321,12 @@ FEDERATION_SERVLET_CLASSES = (
     FederationClientKeysClaimServlet,
     FederationThirdPartyInviteExchangeServlet,
     On3pidBindServlet,
-    OpenIdUserInfo,
     FederationVersionServlet,
 )
 
+OPENID_SERVLET_CLASSES = (
+    OpenIdUserInfo,
+)
 
 ROOM_LIST_CLASSES = (
     PublicRoomList,
@@ -1324,44 +1365,83 @@ GROUP_ATTESTATION_SERVLET_CLASSES = (
     FederationGroupsRenewAttestaionServlet,
 )
 
+DEFAULT_SERVLET_GROUPS = (
+    "federation",
+    "room_list",
+    "group_server",
+    "group_local",
+    "group_attestation",
+    "openid",
+)
+
 
-def register_servlets(hs, resource, authenticator, ratelimiter):
-    for servletclass in FEDERATION_SERVLET_CLASSES:
-        servletclass(
-            handler=hs.get_federation_server(),
-            authenticator=authenticator,
-            ratelimiter=ratelimiter,
-            server_name=hs.hostname,
-        ).register(resource)
-
-    for servletclass in ROOM_LIST_CLASSES:
-        servletclass(
-            handler=hs.get_room_list_handler(),
-            authenticator=authenticator,
-            ratelimiter=ratelimiter,
-            server_name=hs.hostname,
-        ).register(resource)
-
-    for servletclass in GROUP_SERVER_SERVLET_CLASSES:
-        servletclass(
-            handler=hs.get_groups_server_handler(),
-            authenticator=authenticator,
-            ratelimiter=ratelimiter,
-            server_name=hs.hostname,
-        ).register(resource)
-
-    for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
-        servletclass(
-            handler=hs.get_groups_local_handler(),
-            authenticator=authenticator,
-            ratelimiter=ratelimiter,
-            server_name=hs.hostname,
-        ).register(resource)
-
-    for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
-        servletclass(
-            handler=hs.get_groups_attestation_renewer(),
-            authenticator=authenticator,
-            ratelimiter=ratelimiter,
-            server_name=hs.hostname,
-        ).register(resource)
+def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=None):
+    """Initialize and register servlet classes.
+
+    Will by default register all servlets. For custom behaviour, pass in
+    a list of servlet_groups to register.
+
+    Args:
+        hs (synapse.server.HomeServer): homeserver
+        resource (TransportLayerServer): resource class to register to
+        authenticator (Authenticator): authenticator to use
+        ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
+        servlet_groups (list[str], optional): List of servlet groups to register.
+            Defaults to ``DEFAULT_SERVLET_GROUPS``.
+    """
+    if not servlet_groups:
+        servlet_groups = DEFAULT_SERVLET_GROUPS
+
+    if "federation" in servlet_groups:
+        for servletclass in FEDERATION_SERVLET_CLASSES:
+            servletclass(
+                handler=hs.get_federation_server(),
+                authenticator=authenticator,
+                ratelimiter=ratelimiter,
+                server_name=hs.hostname,
+            ).register(resource)
+
+    if "openid" in servlet_groups:
+        for servletclass in OPENID_SERVLET_CLASSES:
+            servletclass(
+                handler=hs.get_federation_server(),
+                authenticator=authenticator,
+                ratelimiter=ratelimiter,
+                server_name=hs.hostname,
+            ).register(resource)
+
+    if "room_list" in servlet_groups:
+        for servletclass in ROOM_LIST_CLASSES:
+            servletclass(
+                handler=hs.get_room_list_handler(),
+                authenticator=authenticator,
+                ratelimiter=ratelimiter,
+                server_name=hs.hostname,
+            ).register(resource)
+
+    if "group_server" in servlet_groups:
+        for servletclass in GROUP_SERVER_SERVLET_CLASSES:
+            servletclass(
+                handler=hs.get_groups_server_handler(),
+                authenticator=authenticator,
+                ratelimiter=ratelimiter,
+                server_name=hs.hostname,
+            ).register(resource)
+
+    if "group_local" in servlet_groups:
+        for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
+            servletclass(
+                handler=hs.get_groups_local_handler(),
+                authenticator=authenticator,
+                ratelimiter=ratelimiter,
+                server_name=hs.hostname,
+            ).register(resource)
+
+    if "group_attestation" in servlet_groups:
+        for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
+            servletclass(
+                handler=hs.get_groups_attestation_renewer(),
+                authenticator=authenticator,
+                ratelimiter=ratelimiter,
+                server_name=hs.hostname,
+            ).register(resource)
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index c5ab14314e..025a79c022 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -117,9 +117,6 @@ class Transaction(JsonEncodedObject):
                 "Require 'transaction_id' to construct a Transaction"
             )
 
-        for p in pdus:
-            p.transaction_id = kwargs["transaction_id"]
-
         kwargs["pdus"] = [p.get_pdu_json() for p in pdus]
 
         return Transaction(**kwargs)