summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/events/utils.py2
-rw-r--r--synapse/federation/pdu_codec.py48
-rw-r--r--synapse/federation/replication.py72
-rw-r--r--synapse/federation/transport.py184
-rw-r--r--synapse/federation/units.py78
-rw-r--r--synapse/handlers/federation.py12
6 files changed, 80 insertions, 316 deletions
diff --git a/synapse/api/events/utils.py b/synapse/api/events/utils.py
index 7fdf45a264..31601fd3a9 100644
--- a/synapse/api/events/utils.py
+++ b/synapse/api/events/utils.py
@@ -32,7 +32,7 @@ def prune_event(event):
 def prune_pdu(pdu):
     """Removes keys that contain unrestricted and non-essential data from a PDU
     """
-    return _prune_event_or_pdu(pdu.pdu_type, pdu)
+    return _prune_event_or_pdu(pdu.type, pdu)
 
 def _prune_event_or_pdu(event_type, event):
     # Remove all extraneous fields.
diff --git a/synapse/federation/pdu_codec.py b/synapse/federation/pdu_codec.py
index d4c896e163..5ec97a698e 100644
--- a/synapse/federation/pdu_codec.py
+++ b/synapse/federation/pdu_codec.py
@@ -31,39 +31,16 @@ class PduCodec(object):
         self.clock = hs.get_clock()
         self.hs = hs
 
-    def encode_event_id(self, local, domain):
-        return local
-
-    def decode_event_id(self, event_id):
-        e_id = self.hs.parse_eventid(event_id)
-        return event_id, e_id.domain
-
     def event_from_pdu(self, pdu):
         kwargs = {}
 
-        kwargs["event_id"] = self.encode_event_id(pdu.pdu_id, pdu.origin)
-        kwargs["room_id"] = pdu.context
-        kwargs["etype"] = pdu.pdu_type
-        kwargs["prev_events"] = [
-            (self.encode_event_id(i, o), s)
-            for i, o, s in pdu.prev_pdus
-        ]
-
-        if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"):
-            kwargs["prev_state"] = self.encode_event_id(
-                pdu.prev_state_id, pdu.prev_state_origin
-            )
+        kwargs["etype"] = pdu.type
 
         kwargs.update({
             k: v
             for k, v in pdu.get_full_dict().items()
             if k not in [
-                "pdu_id",
-                "context",
-                "pdu_type",
-                "prev_pdus",
-                "prev_state_id",
-                "prev_state_origin",
+                "type",
             ]
         })
 
@@ -72,33 +49,12 @@ class PduCodec(object):
     def pdu_from_event(self, event):
         d = event.get_full_dict()
 
-        d["pdu_id"], d["origin"] = self.decode_event_id(
-            event.event_id
-        )
-        d["context"] = event.room_id
-        d["pdu_type"] = event.type
-
-        if hasattr(event, "prev_events"):
-            def f(e, s):
-                i, o = self.decode_event_id(e)
-                return i, o, s
-            d["prev_pdus"] = [
-                f(e, s)
-                for e, s in event.prev_events
-            ]
-
-        if hasattr(event, "prev_state"):
-            d["prev_state_id"], d["prev_state_origin"] = (
-                self.decode_event_id(event.prev_state)
-            )
-
         if hasattr(event, "state_key"):
             d["is_state"] = True
 
         kwargs = copy.deepcopy(event.unrecognized_keys)
         kwargs.update({
             k: v for k, v in d.items()
-            if k not in ["event_id", "room_id", "type", "prev_events"]
         })
 
         if "origin_server_ts" not in kwargs:
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 159af4eed7..838e660a46 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -111,14 +111,6 @@ class ReplicationLayer(object):
         """Informs the replication layer about a new PDU generated within the
         home server that should be transmitted to others.
 
-        This will fill out various attributes on the PDU object, e.g. the
-        `prev_pdus` key.
-
-        *Note:* The home server should always call `send_pdu` even if it knows
-        that it does not need to be replicated to other home servers. This is
-        in case e.g. someone else joins via a remote home server and then
-        backfills.
-
         TODO: Figure out when we should actually resolve the deferred.
 
         Args:
@@ -131,18 +123,12 @@ class ReplicationLayer(object):
         order = self._order
         self._order += 1
 
-        logger.debug("[%s] Persisting PDU", pdu.pdu_id)
-
-        # Save *before* trying to send
-        # yield self.store.persist_event(pdu=pdu)
-
-        logger.debug("[%s] Persisted PDU", pdu.pdu_id)
-        logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id)
+        logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
 
         # TODO, add errback, etc.
         self._transaction_queue.enqueue_pdu(pdu, order)
 
-        logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.pdu_id)
+        logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.event_id)
 
     @log_function
     def send_edu(self, destination, edu_type, content):
@@ -215,7 +201,7 @@ class ReplicationLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False):
+    def get_pdu(self, destination, event_id, outlier=False):
         """Requests the PDU with given origin and ID from the remote home
         server.
 
@@ -224,7 +210,7 @@ class ReplicationLayer(object):
         Args:
             destination (str): Which home server to query
             pdu_origin (str): The home server that originally sent the pdu.
-            pdu_id (str)
+            event_id (str)
             outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
                 it's from an arbitary point in the context as opposed to part
                 of the current block of PDUs. Defaults to `False`
@@ -233,8 +219,9 @@ class ReplicationLayer(object):
             Deferred: Results in the requested PDU.
         """
 
-        transaction_data = yield self.transport_layer.get_pdu(
-            destination, pdu_origin, pdu_id)
+        transaction_data = yield self.transport_layer.get_event(
+            destination, event_id
+        )
 
         transaction = Transaction(**transaction_data)
 
@@ -249,8 +236,7 @@ class ReplicationLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def get_state_for_context(self, destination, context, pdu_id=None,
-                              pdu_origin=None):
+    def get_state_for_context(self, destination, context, event_id=None):
         """Requests all of the `current` state PDUs for a given context from
         a remote home server.
 
@@ -263,7 +249,9 @@ class ReplicationLayer(object):
         """
 
         transaction_data = yield self.transport_layer.get_context_state(
-            destination, context, pdu_id=pdu_id, pdu_origin=pdu_origin,
+            destination,
+            context,
+            event_id=event_id,
         )
 
         transaction = Transaction(**transaction_data)
@@ -352,10 +340,10 @@ class ReplicationLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def on_context_state_request(self, context, pdu_id, pdu_origin):
-        if pdu_id and pdu_origin:
+    def on_context_state_request(self, context, event_id):
+        if event_id:
             pdus = yield self.handler.get_state_for_pdu(
-                pdu_id, pdu_origin
+                event_id
             )
         else:
             raise NotImplementedError("Specify an event")
@@ -370,8 +358,8 @@ class ReplicationLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def on_pdu_request(self, pdu_origin, pdu_id):
-        pdu = yield self._get_persisted_pdu(pdu_id, pdu_origin)
+    def on_pdu_request(self, event_id):
+        pdu = yield self._get_persisted_pdu(event_id)
 
         if pdu:
             defer.returnValue(
@@ -443,9 +431,8 @@ class ReplicationLayer(object):
     def send_join(self, destination, pdu):
         _, content = yield self.transport_layer.send_join(
             destination,
-            pdu.context,
-            pdu.pdu_id,
-            pdu.origin,
+            pdu.room_id,
+            pdu.event_id,
             pdu.get_dict(),
         )
 
@@ -457,13 +444,13 @@ class ReplicationLayer(object):
         defer.returnValue(pdus)
 
     @log_function
-    def _get_persisted_pdu(self, pdu_id, pdu_origin):
+    def _get_persisted_pdu(self, event_id):
         """ Get a PDU from the database with given origin and id.
 
         Returns:
             Deferred: Results in a `Pdu`.
         """
-        return self.handler.get_persisted_pdu(pdu_id, pdu_origin)
+        return self.handler.get_persisted_pdu(event_id)
 
     def _transaction_from_pdus(self, pdu_list):
         """Returns a new Transaction containing the given PDUs suitable for
@@ -487,10 +474,10 @@ class ReplicationLayer(object):
     @log_function
     def _handle_new_pdu(self, origin, pdu, backfilled=False):
         # We reprocess pdus when we have seen them only as outliers
-        existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin)
+        existing = yield self._get_persisted_pdu(pdu.event_id)
 
         if existing and (not existing.outlier or pdu.outlier):
-            logger.debug("Already seen pdu %s %s", pdu.pdu_id, pdu.origin)
+            logger.debug("Already seen pdu %s", pdu.event_id)
             defer.returnValue({})
             return
 
@@ -500,23 +487,22 @@ class ReplicationLayer(object):
         if not pdu.outlier:
             # We only backfill backwards to the min depth.
             min_depth = yield self.handler.get_min_depth_for_context(
-                pdu.context
+                pdu.room_id
             )
 
             if min_depth and pdu.depth > min_depth:
-                for pdu_id, origin, hashes in pdu.prev_pdus:
-                    exists = yield self._get_persisted_pdu(pdu_id, origin)
+                for event_id, hashes in pdu.prev_events:
+                    exists = yield self._get_persisted_pdu(event_id)
 
                     if not exists:
-                        logger.debug("Requesting pdu %s %s", pdu_id, origin)
+                        logger.debug("Requesting pdu %s", event_id)
 
                         try:
                             yield self.get_pdu(
                                 pdu.origin,
-                                pdu_id=pdu_id,
-                                pdu_origin=origin
+                                event_id=event_id,
                             )
-                            logger.debug("Processed pdu %s %s", pdu_id, origin)
+                            logger.debug("Processed pdu %s", event_id)
                         except:
                             # TODO(erikj): Do some more intelligent retries.
                             logger.exception("Failed to get PDU")
@@ -524,7 +510,7 @@ class ReplicationLayer(object):
                 # We need to get the state at this event, since we have reached
                 # a backward extremity edge.
                 state = yield self.get_state_for_context(
-                    origin, pdu.context, pdu.pdu_id, pdu.origin,
+                    origin, pdu.room_id, pdu.event_id,
                 )
 
         # Persist the Pdu, but don't mark it as processed yet.
diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py
index 7f01b4faaf..04ad7e63ae 100644
--- a/synapse/federation/transport.py
+++ b/synapse/federation/transport.py
@@ -72,8 +72,7 @@ class TransportLayer(object):
         self.received_handler = None
 
     @log_function
-    def get_context_state(self, destination, context, pdu_id=None,
-                          pdu_origin=None):
+    def get_context_state(self, destination, context, event_id=None):
         """ Requests all state for a given context (i.e. room) from the
         given server.
 
@@ -91,60 +90,59 @@ class TransportLayer(object):
         subpath = "/state/%s/" % context
 
         args = {}
-        if pdu_id and pdu_origin:
-            args["pdu_id"] = pdu_id
-            args["pdu_origin"] = pdu_origin
+        if event_id:
+            args["event_id"] = event_id
 
         return self._do_request_for_transaction(
             destination, subpath, args=args
         )
 
     @log_function
-    def get_pdu(self, destination, pdu_origin, pdu_id):
+    def get_event(self, destination, event_id):
         """ Requests the pdu with give id and origin from the given server.
 
         Args:
             destination (str): The host name of the remote home server we want
                 to get the state from.
-            pdu_origin (str): The home server which created the PDU.
-            pdu_id (str): The id of the PDU being requested.
+            event_id (str): The id of the event being requested.
 
         Returns:
             Deferred: Results in a dict received from the remote homeserver.
         """
-        logger.debug("get_pdu dest=%s, pdu_origin=%s, pdu_id=%s",
-                     destination, pdu_origin, pdu_id)
+        logger.debug("get_pdu dest=%s, event_id=%s",
+                     destination, event_id)
 
-        subpath = "/pdu/%s/%s/" % (pdu_origin, pdu_id)
+        subpath = "/event/%s/" % (event_id, )
 
         return self._do_request_for_transaction(destination, subpath)
 
     @log_function
-    def backfill(self, dest, context, pdu_tuples, limit):
+    def backfill(self, dest, context, event_tuples, limit):
         """ Requests `limit` previous PDUs in a given context before list of
         PDUs.
 
         Args:
             dest (str)
             context (str)
-            pdu_tuples (list)
+            event_tuples (list)
             limt (int)
 
         Returns:
             Deferred: Results in a dict received from the remote homeserver.
         """
         logger.debug(
-            "backfill dest=%s, context=%s, pdu_tuples=%s, limit=%s",
-            dest, context, repr(pdu_tuples), str(limit)
+            "backfill dest=%s, context=%s, event_tuples=%s, limit=%s",
+            dest, context, repr(event_tuples), str(limit)
         )
 
-        if not pdu_tuples:
+        if not event_tuples:
+            # TODO: raise?
             return
 
-        subpath = "/backfill/%s/" % context
+        subpath = "/backfill/%s/" % (context,)
 
         args = {
-            "v": ["%s,%s" % (i, o) for i, o in pdu_tuples],
+            "v": event_tuples,
             "limit": limit,
         }
 
@@ -222,11 +220,10 @@ class TransportLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def send_join(self, destination, context, pdu_id, origin, content):
-        path = PREFIX + "/send_join/%s/%s/%s" % (
+    def send_join(self, destination, context, event_id, content):
+        path = PREFIX + "/send_join/%s/%s" % (
             context,
-            origin,
-            pdu_id,
+            event_id,
         )
 
         code, content = yield self.client.put_json(
@@ -242,11 +239,10 @@ class TransportLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def send_invite(self, destination, context, pdu_id, origin, content):
-        path = PREFIX + "/invite/%s/%s/%s" % (
+    def send_invite(self, destination, context, event_id, content):
+        path = PREFIX + "/invite/%s/%s" % (
             context,
-            origin,
-            pdu_id,
+            event_id,
         )
 
         code, content = yield self.client.put_json(
@@ -376,10 +372,10 @@ class TransportLayer(object):
         # data_id pair.
         self.server.register_path(
             "GET",
-            re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"),
+            re.compile("^" + PREFIX + "/event/([^/]*)/$"),
             self._with_authentication(
-                lambda origin, content, query, pdu_origin, pdu_id:
-                handler.on_pdu_request(pdu_origin, pdu_id)
+                lambda origin, content, query, event_id:
+                handler.on_pdu_request(event_id)
             )
         )
 
@@ -391,8 +387,7 @@ class TransportLayer(object):
                 lambda origin, content, query, context:
                 handler.on_context_state_request(
                     context,
-                    query.get("pdu_id", [None])[0],
-                    query.get("pdu_origin", [None])[0]
+                    query.get("event_id", [None])[0],
                 )
             )
         )
@@ -442,9 +437,9 @@ class TransportLayer(object):
 
         self.server.register_path(
             "PUT",
-            re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)/([^/]*)$"),
+            re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"),
             self._with_authentication(
-                lambda origin, content, query, context, pdu_origin, pdu_id:
+                lambda origin, content, query, context, event_id:
                 self._on_send_join_request(
                     origin, content, query,
                 )
@@ -453,9 +448,9 @@ class TransportLayer(object):
 
         self.server.register_path(
             "PUT",
-            re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)/([^/]*)$"),
+            re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)$"),
             self._with_authentication(
-                lambda origin, content, query, context, pdu_origin, pdu_id:
+                lambda origin, content, query, context, event_id:
                 self._on_invite_request(
                     origin, content, query,
                 )
@@ -548,7 +543,7 @@ class TransportLayer(object):
 
         limit = int(limits[-1])
 
-        versions = [v.split(",", 1) for v in v_list]
+        versions = v_list
 
         return self.request_handler.on_backfill_request(
             context, versions, limit
@@ -579,120 +574,3 @@ class TransportLayer(object):
         )
 
         defer.returnValue((200, content))
-
-
-class TransportReceivedHandler(object):
-    """ Callbacks used when we receive a transaction
-    """
-    def on_incoming_transaction(self, transaction):
-        """ Called on PUT /send/<transaction_id>, or on response to a request
-        that we sent (e.g. a backfill request)
-
-        Args:
-            transaction (synapse.transaction.Transaction): The transaction that
-                was sent to us.
-
-        Returns:
-            twisted.internet.defer.Deferred: A deferred that gets fired when
-            the transaction has finished being processed.
-
-            The result should be a tuple in the form of
-            `(response_code, respond_body)`, where `response_body` is a python
-            dict that will get serialized to JSON.
-
-            On errors, the dict should have an `error` key with a brief message
-            of what went wrong.
-        """
-        pass
-
-
-class TransportRequestHandler(object):
-    """ Handlers used when someone want's data from us
-    """
-    def on_pull_request(self, versions):
-        """ Called on GET /pull/?v=...
-
-        This is hit when a remote home server wants to get all data
-        after a given transaction. Mainly used when a home server comes back
-        online and wants to get everything it has missed.
-
-        Args:
-            versions (list): A list of transaction_ids that should be used to
-                determine what PDUs the remote side have not yet seen.
-
-        Returns:
-            Deferred: Resultsin a tuple in the form of
-            `(response_code, respond_body)`, where `response_body` is a python
-            dict that will get serialized to JSON.
-
-            On errors, the dict should have an `error` key with a brief message
-            of what went wrong.
-        """
-        pass
-
-    def on_pdu_request(self, pdu_origin, pdu_id):
-        """ Called on GET /pdu/<pdu_origin>/<pdu_id>/
-
-        Someone wants a particular PDU. This PDU may or may not have originated
-        from us.
-
-        Args:
-            pdu_origin (str)
-            pdu_id (str)
-
-        Returns:
-            Deferred: Resultsin a tuple in the form of
-            `(response_code, respond_body)`, where `response_body` is a python
-            dict that will get serialized to JSON.
-
-            On errors, the dict should have an `error` key with a brief message
-            of what went wrong.
-        """
-        pass
-
-    def on_context_state_request(self, context):
-        """ Called on GET /state/<context>/
-
-        Gets hit when someone wants all the *current* state for a given
-        contexts.
-
-        Args:
-            context (str): The name of the context that we're interested in.
-
-        Returns:
-            twisted.internet.defer.Deferred: A deferred that gets fired when
-            the transaction has finished being processed.
-
-            The result should be a tuple in the form of
-            `(response_code, respond_body)`, where `response_body` is a python
-            dict that will get serialized to JSON.
-
-            On errors, the dict should have an `error` key with a brief message
-            of what went wrong.
-        """
-        pass
-
-    def on_backfill_request(self, context, versions, limit):
-        """ Called on GET /backfill/<context>/?v=...&limit=...
-
-        Gets hit when we want to backfill backwards on a given context from
-        the given point.
-
-        Args:
-            context (str): The context to backfill
-            versions (list): A list of 2-tuples representing where to backfill
-                from, in the form `(pdu_id, origin)`
-            limit (int): How many pdus to return.
-
-        Returns:
-            Deferred: Results in a tuple in the form of
-            `(response_code, respond_body)`, where `response_body` is a python
-            dict that will get serialized to JSON.
-
-            On errors, the dict should have an `error` key with a brief message
-            of what went wrong.
-        """
-        pass
-
-    def on_query_request(self):
-        """ Called on a GET /query/<query_type> request. """
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index adc3385644..c94dcf64cf 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -34,13 +34,13 @@ class Pdu(JsonEncodedObject):
 
     A Pdu can be classified as "state". For a given context, we can efficiently
     retrieve all state pdu's that haven't been clobbered. Clobbering is done
-    via a unique constraint on the tuple (context, pdu_type, state_key). A pdu
+    via a unique constraint on the tuple (context, type, state_key). A pdu
     is a state pdu if `is_state` is True.
 
     Example pdu::
 
         {
-            "pdu_id": "78c",
+            "event_id": "$78c:example.com",
             "origin_server_ts": 1404835423000,
             "origin": "bar",
             "prev_ids": [
@@ -53,14 +53,14 @@ class Pdu(JsonEncodedObject):
     """
 
     valid_keys = [
-        "pdu_id",
-        "context",
+        "event_id",
+        "room_id",
         "origin",
         "origin_server_ts",
-        "pdu_type",
+        "type",
         "destinations",
         "transaction_id",
-        "prev_pdus",
+        "prev_events",
         "depth",
         "content",
         "outlier",
@@ -68,8 +68,7 @@ class Pdu(JsonEncodedObject):
         "signatures",
         "is_state",  # Below this are keys valid only for State Pdus.
         "state_key",
-        "prev_state_id",
-        "prev_state_origin",
+        "prev_state",
         "required_power_level",
         "user_id",
     ]
@@ -81,18 +80,18 @@ class Pdu(JsonEncodedObject):
     ]
 
     required_keys = [
-        "pdu_id",
-        "context",
+        "event_id",
+        "room_id",
         "origin",
         "origin_server_ts",
-        "pdu_type",
+        "type",
         "content",
     ]
 
     # TODO: We need to make this properly load content rather than
     # just leaving it as a dict. (OR DO WE?!)
 
-    def __init__(self, destinations=[], is_state=False, prev_pdus=[],
+    def __init__(self, destinations=[], is_state=False, prev_events=[],
                  outlier=False, hashes={}, signatures={}, **kwargs):
         if is_state:
             for required_key in ["state_key"]:
@@ -102,66 +101,13 @@ class Pdu(JsonEncodedObject):
         super(Pdu, self).__init__(
             destinations=destinations,
             is_state=bool(is_state),
-            prev_pdus=prev_pdus,
+            prev_events=prev_events,
             outlier=outlier,
             hashes=hashes,
             signatures=signatures,
             **kwargs
         )
 
-    @classmethod
-    def from_pdu_tuple(cls, pdu_tuple):
-        """ Converts a PduTuple to a Pdu
-
-        Args:
-            pdu_tuple (synapse.persistence.transactions.PduTuple): The tuple to
-                convert
-
-        Returns:
-            Pdu
-        """
-        if pdu_tuple:
-            d = copy.copy(pdu_tuple.pdu_entry._asdict())
-            d["origin_server_ts"] = d.pop("ts")
-
-            for k in d.keys():
-                if d[k] is None:
-                    del d[k]
-
-            d["content"] = json.loads(d["content_json"])
-            del d["content_json"]
-
-            args = {f: d[f] for f in cls.valid_keys if f in d}
-            if "unrecognized_keys" in d and d["unrecognized_keys"]:
-                args.update(json.loads(d["unrecognized_keys"]))
-
-            hashes = {
-                alg: encode_base64(hsh)
-                for alg, hsh in pdu_tuple.hashes.items()
-            }
-
-            signatures = {
-                kid: encode_base64(sig)
-                for kid, sig in pdu_tuple.signatures.items()
-            }
-
-            prev_pdus = []
-            for prev_pdu in pdu_tuple.prev_pdu_list:
-                prev_hashes = pdu_tuple.edge_hashes.get(prev_pdu, {})
-                prev_hashes = {
-                    alg: encode_base64(hsh) for alg, hsh in prev_hashes.items()
-                }
-                prev_pdus.append((prev_pdu[0], prev_pdu[1], prev_hashes))
-
-            return Pdu(
-                prev_pdus=prev_pdus,
-                hashes=hashes,
-                signatures=signatures,
-                **args
-            )
-        else:
-            return None
-
     def __str__(self):
         return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 18cb1d4e97..bdd28f04bb 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -139,7 +139,7 @@ class FederationHandler(BaseHandler):
             # Huh, let's try and get the current state
             try:
                 yield self.replication_layer.get_state_for_context(
-                    event.origin, event.room_id, pdu.pdu_id, pdu.origin,
+                    event.origin, event.room_id, event.event_id,
                 )
 
                 hosts = yield self.store.get_joined_hosts_for_room(
@@ -368,11 +368,9 @@ class FederationHandler(BaseHandler):
         ])
 
     @defer.inlineCallbacks
-    def get_state_for_pdu(self, pdu_id, pdu_origin):
+    def get_state_for_pdu(self, event_id):
         yield run_on_reactor()
 
-        event_id = EventID.create(pdu_id, pdu_origin, self.hs).to_string()
-
         state_groups = yield self.store.get_state_groups(
             [event_id]
         )
@@ -406,7 +404,7 @@ class FederationHandler(BaseHandler):
 
         events = yield self.store.get_backfill_events(
             context,
-            [self.pdu_codec.encode_event_id(i, o) for i, o in pdu_list],
+            pdu_list,
             limit
         )
 
@@ -417,14 +415,14 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     @log_function
-    def get_persisted_pdu(self, pdu_id, origin):
+    def get_persisted_pdu(self, event_id):
         """ Get a PDU from the database with given origin and id.
 
         Returns:
             Deferred: Results in a `Pdu`.
         """
         event = yield self.store.get_event(
-            self.pdu_codec.encode_event_id(pdu_id, origin),
+            event_id,
             allow_none=True,
         )