summary refs log tree commit diff
path: root/synapse/handlers/federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r--synapse/handlers/federation.py156
1 files changed, 106 insertions, 50 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 01a761715b..a7ea8fb98f 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -29,6 +29,7 @@ from synapse.util import unwrapFirstError
 from synapse.util.logcontext import (
     PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
 )
+from synapse.util.metrics import measure_func
 from synapse.util.logutils import log_function
 from synapse.util.async import run_on_reactor
 from synapse.util.frozenutils import unfreeze
@@ -217,17 +218,28 @@ class FederationHandler(BaseHandler):
 
         if event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
-                prev_state = context.current_state.get((event.type, event.state_key))
-                if not prev_state or prev_state.membership != Membership.JOIN:
-                    # Only fire user_joined_room if the user has acutally
-                    # joined the room. Don't bother if the user is just
-                    # changing their profile info.
+                # Only fire user_joined_room if the user has acutally
+                # joined the room. Don't bother if the user is just
+                # changing their profile info.
+                newly_joined = True
+                prev_state_id = context.current_state_ids.get(
+                    (event.type, event.state_key)
+                )
+                if prev_state_id:
+                    prev_state = yield self.store.get_event(
+                        prev_state_id, allow_none=True,
+                    )
+                    if prev_state and prev_state.membership == Membership.JOIN:
+                        newly_joined = False
+
+                if newly_joined:
                     user = UserID.from_string(event.state_key)
                     yield user_joined_room(self.distributor, user, event.room_id)
 
+    @measure_func("_filter_events_for_server")
     @defer.inlineCallbacks
     def _filter_events_for_server(self, server_name, room_id, events):
-        event_to_state = yield self.store.get_state_for_events(
+        event_to_state_ids = yield self.store.get_state_ids_for_events(
             frozenset(e.event_id for e in events),
             types=(
                 (EventTypes.RoomHistoryVisibility, ""),
@@ -235,6 +247,30 @@ class FederationHandler(BaseHandler):
             )
         )
 
+        # We only want to pull out member events that correspond to the
+        # server's domain.
+
+        def check_match(id):
+            try:
+                return server_name == get_domain_from_id(id)
+            except:
+                return False
+
+        event_map = yield self.store.get_events([
+            e_id for key_to_eid in event_to_state_ids.values()
+            for key, e_id in key_to_eid
+            if key[0] != EventTypes.Member or check_match(key[1])
+        ])
+
+        event_to_state = {
+            e_id: {
+                key: event_map[inner_e_id]
+                for key, inner_e_id in key_to_eid.items()
+                if inner_e_id in event_map
+            }
+            for e_id, key_to_eid in event_to_state_ids.items()
+        }
+
         def redact_disallowed(event, state):
             if not state:
                 return event
@@ -377,7 +413,9 @@ class FederationHandler(BaseHandler):
                 )).addErrback(unwrapFirstError)
                 auth_events.update({a.event_id: a for a in results if a})
                 required_auth.update(
-                    a_id for event in results for a_id, _ in event.auth_events if event
+                    a_id
+                    for event in results if event
+                    for a_id, _ in event.auth_events
                 )
                 missing_auth = required_auth - set(auth_events)
 
@@ -560,6 +598,18 @@ class FederationHandler(BaseHandler):
         ]))
         states = dict(zip(event_ids, [s[1] for s in states]))
 
+        state_map = yield self.store.get_events(
+            [e_id for ids in states.values() for e_id in ids],
+            get_prev_content=False
+        )
+        states = {
+            key: {
+                k: state_map[e_id]
+                for k, e_id in state_dict.items()
+                if e_id in state_map
+            } for key, state_dict in states.items()
+        }
+
         for e_id, _ in sorted_extremeties_tuple:
             likely_domains = get_domains_from_state(states[e_id])
 
@@ -722,7 +772,7 @@ class FederationHandler(BaseHandler):
 
         # The remote hasn't signed it yet, obviously. We'll do the full checks
         # when we get the event back in `on_send_join_request`
-        self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
+        yield self.auth.check_from_context(event, context, do_sig_check=False)
 
         defer.returnValue(event)
 
@@ -770,18 +820,11 @@ class FederationHandler(BaseHandler):
 
         new_pdu = event
 
-        destinations = set()
-
-        for k, s in context.current_state.items():
-            try:
-                if k[0] == EventTypes.Member:
-                    if s.content["membership"] == Membership.JOIN:
-                        destinations.add(get_domain_from_id(s.state_key))
-            except:
-                logger.warn(
-                    "Failed to get destination from event %s", s.event_id
-                )
-
+        message_handler = self.hs.get_handlers().message_handler
+        destinations = yield message_handler.get_joined_hosts_for_room_from_state(
+            context
+        )
+        destinations = set(destinations)
         destinations.discard(origin)
 
         logger.debug(
@@ -792,13 +835,15 @@ class FederationHandler(BaseHandler):
 
         self.replication_layer.send_pdu(new_pdu, destinations)
 
-        state_ids = [e.event_id for e in context.current_state.values()]
+        state_ids = context.current_state_ids.values()
         auth_chain = yield self.store.get_auth_chain(set(
             [event.event_id] + state_ids
         ))
 
+        state = yield self.store.get_events(context.current_state_ids.values())
+
         defer.returnValue({
-            "state": context.current_state.values(),
+            "state": state.values(),
             "auth_chain": auth_chain,
         })
 
@@ -954,7 +999,7 @@ class FederationHandler(BaseHandler):
         try:
             # The remote hasn't signed it yet, obviously. We'll do the full checks
             # when we get the event back in `on_send_leave_request`
-            self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
+            yield self.auth.check_from_context(event, context, do_sig_check=False)
         except AuthError as e:
             logger.warn("Failed to create new leave %r because %s", event, e)
             raise e
@@ -998,18 +1043,11 @@ class FederationHandler(BaseHandler):
 
         new_pdu = event
 
-        destinations = set()
-
-        for k, s in context.current_state.items():
-            try:
-                if k[0] == EventTypes.Member:
-                    if s.content["membership"] == Membership.LEAVE:
-                        destinations.add(get_domain_from_id(s.state_key))
-            except:
-                logger.warn(
-                    "Failed to get destination from event %s", s.event_id
-                )
-
+        message_handler = self.hs.get_handlers().message_handler
+        destinations = yield message_handler.get_joined_hosts_for_room_from_state(
+            context
+        )
+        destinations = set(destinations)
         destinations.discard(origin)
 
         logger.debug(
@@ -1294,7 +1332,13 @@ class FederationHandler(BaseHandler):
         )
 
         if not auth_events:
-            auth_events = context.current_state
+            auth_events_ids = yield self.auth.compute_auth_events(
+                event, context.current_state_ids, for_verification=True,
+            )
+            auth_events = yield self.store.get_events(auth_events_ids)
+            auth_events = {
+                (e.type, e.state_key): e for e in auth_events.values()
+            }
 
         # This is a hack to fix some old rooms where the initial join event
         # didn't reference the create event in its auth events.
@@ -1320,8 +1364,7 @@ class FederationHandler(BaseHandler):
             context.rejected = RejectedReason.AUTH_ERROR
 
         if event.type == EventTypes.GuestAccess:
-            full_context = yield self.store.get_current_state(room_id=event.room_id)
-            yield self.maybe_kick_guest_users(event, full_context)
+            yield self.maybe_kick_guest_users(event)
 
         defer.returnValue(context)
 
@@ -1492,7 +1535,9 @@ class FederationHandler(BaseHandler):
                 current_state = set(e.event_id for e in auth_events.values())
                 different_auth = event_auth_events - current_state
 
-                context.current_state.update(auth_events)
+                context.current_state_ids.update({
+                    k: a.event_id for k, a in auth_events.items()
+                })
                 context.state_group = None
 
         if different_auth and not event.internal_metadata.is_outlier():
@@ -1514,8 +1559,8 @@ class FederationHandler(BaseHandler):
 
             if do_resolution:
                 # 1. Get what we think is the auth chain.
-                auth_ids = self.auth.compute_auth_events(
-                    event, context.current_state
+                auth_ids = yield self.auth.compute_auth_events(
+                    event, context.current_state_ids
                 )
                 local_auth_chain = yield self.store.get_auth_chain(auth_ids)
 
@@ -1571,7 +1616,9 @@ class FederationHandler(BaseHandler):
                 # 4. Look at rejects and their proofs.
                 # TODO.
 
-                context.current_state.update(auth_events)
+                context.current_state_ids.update({
+                    k: a.event_id for k, a in auth_events.items()
+                })
                 context.state_group = None
 
         try:
@@ -1758,12 +1805,12 @@ class FederationHandler(BaseHandler):
             )
 
             try:
-                self.auth.check(event, context.current_state)
+                yield self.auth.check_from_context(event, context)
             except AuthError as e:
                 logger.warn("Denying new third party invite %r because %s", event, e)
                 raise e
 
-            yield self._check_signature(event, auth_events=context.current_state)
+            yield self._check_signature(event, context)
             member_handler = self.hs.get_handlers().room_member_handler
             yield member_handler.send_membership_event(None, event, context)
         else:
@@ -1789,11 +1836,11 @@ class FederationHandler(BaseHandler):
         )
 
         try:
-            self.auth.check(event, auth_events=context.current_state)
+            self.auth.check_from_context(event, context)
         except AuthError as e:
             logger.warn("Denying third party invite %r because %s", event, e)
             raise e
-        yield self._check_signature(event, auth_events=context.current_state)
+        yield self._check_signature(event, context)
 
         returned_invite = yield self.send_invite(origin, event)
         # TODO: Make sure the signatures actually are correct.
@@ -1807,7 +1854,12 @@ class FederationHandler(BaseHandler):
             EventTypes.ThirdPartyInvite,
             event.content["third_party_invite"]["signed"]["token"]
         )
-        original_invite = context.current_state.get(key)
+        original_invite = None
+        original_invite_id = context.current_state_ids.get(key)
+        if original_invite_id:
+            original_invite = yield self.store.get_event(
+                original_invite_id, allow_none=True
+            )
         if not original_invite:
             logger.info(
                 "Could not find invite event for third_party_invite - "
@@ -1824,13 +1876,13 @@ class FederationHandler(BaseHandler):
         defer.returnValue((event, context))
 
     @defer.inlineCallbacks
-    def _check_signature(self, event, auth_events):
+    def _check_signature(self, event, context):
         """
         Checks that the signature in the event is consistent with its invite.
 
         Args:
             event (Event): The m.room.member event to check
-            auth_events (dict<(event type, state_key), event>):
+            context (EventContext):
 
         Raises:
             AuthError: if signature didn't match any keys, or key has been
@@ -1841,10 +1893,14 @@ class FederationHandler(BaseHandler):
         signed = event.content["third_party_invite"]["signed"]
         token = signed["token"]
 
-        invite_event = auth_events.get(
+        invite_event_id = context.current_state_ids.get(
             (EventTypes.ThirdPartyInvite, token,)
         )
 
+        invite_event = None
+        if invite_event_id:
+            invite_event = yield self.store.get_event(invite_event_id, allow_none=True)
+
         if not invite_event:
             raise AuthError(403, "Could not find invite")