summary refs log tree commit diff
path: root/synapse/handlers/message.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/message.py')
-rw-r--r--synapse/handlers/message.py93
1 files changed, 70 insertions, 23 deletions
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 4c3cd9d12e..3577db0595 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -30,6 +30,7 @@ from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLo
 from synapse.util.caches.snapshot_cache import SnapshotCache
 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 from synapse.util.metrics import measure_func
+from synapse.util.caches.descriptors import cachedInlineCallbacks
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
@@ -248,7 +249,7 @@ class MessageHandler(BaseHandler):
         assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
 
         if event.is_state():
-            prev_state = self.deduplicate_state_event(event, context)
+            prev_state = yield self.deduplicate_state_event(event, context)
             if prev_state is not None:
                 defer.returnValue(prev_state)
 
@@ -263,6 +264,7 @@ class MessageHandler(BaseHandler):
             presence = self.hs.get_presence_handler()
             yield presence.bump_presence_active_time(user)
 
+    @defer.inlineCallbacks
     def deduplicate_state_event(self, event, context):
         """
         Checks whether event is in the latest resolved state in context.
@@ -270,13 +272,17 @@ class MessageHandler(BaseHandler):
         If so, returns the version of the event in context.
         Otherwise, returns None.
         """
-        prev_event = context.current_state.get((event.type, event.state_key))
+        prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
+        prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+        if not prev_event:
+            return
+
         if prev_event and event.user_id == prev_event.user_id:
             prev_content = encode_canonical_json(prev_event.content)
             next_content = encode_canonical_json(event.content)
             if prev_content == next_content:
-                return prev_event
-        return None
+                defer.returnValue(prev_event)
+        return
 
     @defer.inlineCallbacks
     def create_and_send_nonmember_event(
@@ -802,8 +808,8 @@ class MessageHandler(BaseHandler):
         event = builder.build()
 
         logger.debug(
-            "Created event %s with current state: %s",
-            event.event_id, context.current_state,
+            "Created event %s with state: %s",
+            event.event_id, context.prev_state_ids,
         )
 
         defer.returnValue(
@@ -826,12 +832,12 @@ class MessageHandler(BaseHandler):
             self.ratelimit(requester)
 
         try:
-            self.auth.check(event, auth_events=context.current_state)
+            yield self.auth.check_from_context(event, context)
         except AuthError as err:
             logger.warn("Denying new event %r because %s", event, err)
             raise err
 
-        yield self.maybe_kick_guest_users(event, context.current_state.values())
+        yield self.maybe_kick_guest_users(event, context)
 
         if event.type == EventTypes.CanonicalAlias:
             # Check the alias is acually valid (at this time at least)
@@ -859,6 +865,15 @@ class MessageHandler(BaseHandler):
                         e.sender == event.sender
                     )
 
+                state_to_include_ids = [
+                    e_id
+                    for k, e_id in context.current_state_ids.items()
+                    if k[0] in self.hs.config.room_invite_state_types
+                    or k[0] == EventTypes.Member and k[1] == event.sender
+                ]
+
+                state_to_include = yield self.store.get_events(state_to_include_ids)
+
                 event.unsigned["invite_room_state"] = [
                     {
                         "type": e.type,
@@ -866,9 +881,7 @@ class MessageHandler(BaseHandler):
                         "content": e.content,
                         "sender": e.sender,
                     }
-                    for k, e in context.current_state.items()
-                    if e.type in self.hs.config.room_invite_state_types
-                    or is_inviter_member_event(e)
+                    for e in state_to_include.values()
                 ]
 
                 invitee = UserID.from_string(event.state_key)
@@ -890,7 +903,14 @@ class MessageHandler(BaseHandler):
                     )
 
         if event.type == EventTypes.Redaction:
-            if self.auth.check_redaction(event, auth_events=context.current_state):
+            auth_events_ids = yield self.auth.compute_auth_events(
+                event, context.prev_state_ids, for_verification=True,
+            )
+            auth_events = yield self.store.get_events(auth_events_ids)
+            auth_events = {
+                (e.type, e.state_key): e for e in auth_events.values()
+            }
+            if self.auth.check_redaction(event, auth_events=auth_events):
                 original_event = yield self.store.get_event(
                     event.redacts,
                     check_redacted=False,
@@ -904,7 +924,7 @@ class MessageHandler(BaseHandler):
                         "You don't have permission to redact events"
                     )
 
-        if event.type == EventTypes.Create and context.current_state:
+        if event.type == EventTypes.Create and context.prev_state_ids:
             raise AuthError(
                 403,
                 "Changing the room create event is forbidden",
@@ -925,16 +945,7 @@ class MessageHandler(BaseHandler):
             event_stream_id, max_stream_id
         )
 
-        destinations = set()
-        for k, s in context.current_state.items():
-            try:
-                if k[0] == EventTypes.Member:
-                    if s.content["membership"] == Membership.JOIN:
-                        destinations.add(get_domain_from_id(s.state_key))
-            except SynapseError:
-                logger.warn(
-                    "Failed to get destination from event %s", s.event_id
-                )
+        destinations = yield self.get_joined_hosts_for_room_from_state(context)
 
         @defer.inlineCallbacks
         def _notify():
@@ -952,3 +963,39 @@ class MessageHandler(BaseHandler):
         preserve_fn(federation_handler.handle_new_event)(
             event, destinations=destinations,
         )
+
+    def get_joined_hosts_for_room_from_state(self, context):
+        state_group = context.state_group
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        return self._get_joined_hosts_for_room_from_state(
+            state_group, context.current_state_ids
+        )
+
+    @cachedInlineCallbacks(num_args=1, cache_context=True)
+    def _get_joined_hosts_for_room_from_state(self, state_group, current_state_ids,
+                                              cache_context):
+
+        # Don't bother getting state for people on the same HS
+        current_state = yield self.store.get_events([
+            e_id for key, e_id in current_state_ids.items()
+            if key[0] == EventTypes.Member and not self.hs.is_mine_id(key[1])
+        ])
+
+        destinations = set()
+        for e in current_state.itervalues():
+            try:
+                if e.type == EventTypes.Member:
+                    if e.content["membership"] == Membership.JOIN:
+                        destinations.add(get_domain_from_id(e.state_key))
+            except SynapseError:
+                logger.warn(
+                    "Failed to get destination from event %s", e.event_id
+                )
+
+        defer.returnValue(destinations)