summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2019-12-11 16:37:51 +0000
committerRichard van der Hoff <richard@matrix.org>2019-12-16 13:26:12 +0000
commit20d5ba16e626aa4217492c83dda9fabd36bd5d2b (patch)
tree67048fd39134b72a93773a6d407ec5a0004e3adb /synapse/handlers
parentMove get_state methods into FederationHandler (#6503) (diff)
downloadsynapse-20d5ba16e626aa4217492c83dda9fabd36bd5d2b.tar.xz
Add `include_event_in_state` to _get_state_for_room (#6521)
Make it return the state *after* the requested event, rather than the one
before it. This is a bit easier and requires fewer calls to
get_events_from_store_or_dest.
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/federation.py39
1 files changed, 21 insertions, 18 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c0dcf9abf8..31c9132ae9 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -379,22 +379,10 @@ class FederationHandler(BaseHandler):
                             (
                                 remote_state,
                                 got_auth_chain,
-                            ) = yield self._get_state_for_room(origin, room_id, p)
-
-                            # we want the state *after* p; _get_state_for_room returns the
-                            # state *before* p.
-                            remote_event = yield self.federation_client.get_pdu(
-                                [origin], p, room_version, outlier=True
+                            ) = yield self._get_state_for_room(
+                                origin, room_id, p, include_event_in_state=True
                             )
 
-                            if remote_event is None:
-                                raise Exception(
-                                    "Unable to get missing prev_event %s" % (p,)
-                                )
-
-                            if remote_event.is_state():
-                                remote_state.append(remote_event)
-
                             # XXX hrm I'm not convinced that duplicate events will compare
                             # for equality, so I'm not sure this does what the author
                             # hoped.
@@ -583,13 +571,15 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     @log_function
-    def _get_state_for_room(self, destination, room_id, event_id):
+    def _get_state_for_room(self, destination, room_id, event_id, include_event_in_state):
         """Requests all of the room state at a given event from a remote homeserver.
 
         Args:
             destination (str): The remote homeserver to query for the state.
             room_id (str): The id of the room we're interested in.
             event_id (str): The id of the event we want the state at.
+            include_event_in_state: if true, the event itself will be included in the
+                returned state event list.
 
         Returns:
             Deferred[Tuple[List[EventBase], List[EventBase]]]:
@@ -604,6 +594,10 @@ class FederationHandler(BaseHandler):
         )
 
         desired_events = set(state_event_ids + auth_event_ids)
+
+        if include_event_in_state:
+            desired_events.add(event_id)
+
         event_map = yield self._get_events_from_store_or_dest(
             destination, room_id, desired_events
         )
@@ -616,12 +610,21 @@ class FederationHandler(BaseHandler):
                 failed_to_fetch,
             )
 
-        pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
-        auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
+        remote_state = [
+            event_map[e_id] for e_id in state_event_ids if e_id in event_map
+        ]
 
+        if include_event_in_state:
+            remote_event = event_map.get(event_id)
+            if not remote_event:
+                raise Exception("Unable to get missing prev_event %s" % (event_id,))
+            if remote_event.is_state():
+                remote_state.append(remote_event)
+
+        auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
         auth_chain.sort(key=lambda e: e.depth)
 
-        return pdus, auth_chain
+        return remote_state, auth_chain
 
     @defer.inlineCallbacks
     def _get_events_from_store_or_dest(self, destination, room_id, event_ids):