summary refs log tree commit diff
path: root/synapse/handlers/message.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/message.py')
-rw-r--r--synapse/handlers/message.py153
1 files changed, 88 insertions, 65 deletions
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index ff800f8af1..5c50c611ba 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -16,12 +16,11 @@
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import SynapseError, AuthError, Codes
+from synapse.api.errors import AuthError, Codes, SynapseError
 from synapse.streams.config import PaginationConfig
 from synapse.events.utils import serialize_event
 from synapse.events.validator import EventValidator
 from synapse.util import unwrapFirstError
-from synapse.util.logcontext import PreserveLoggingContext
 from synapse.util.caches.snapshot_cache import SnapshotCache
 from synapse.types import UserID, RoomStreamToken, StreamToken
 
@@ -105,8 +104,6 @@ class MessageHandler(BaseHandler):
             room_token = pagin_config.from_token.room_key
 
         room_token = RoomStreamToken.parse(room_token)
-        if room_token.topological is None:
-            raise SynapseError(400, "Invalid token")
 
         pagin_config.from_token = pagin_config.from_token.copy_and_replace(
             "room_key", str(room_token)
@@ -117,27 +114,31 @@ class MessageHandler(BaseHandler):
         membership, member_event_id = yield self._check_in_room_or_world_readable(
             room_id, user_id
         )
-        if membership == Membership.LEAVE:
-            # If they have left the room then clamp the token to be before
-            # they left the room.
-            leave_token = yield self.store.get_topological_token_for_event(
-                member_event_id
+
+        if source_config.direction == 'b':
+            # if we're going backwards, we might need to backfill. This
+            # requires that we have a topo token.
+            if room_token.topological:
+                max_topo = room_token.topological
+            else:
+                max_topo = yield self.store.get_max_topological_token_for_stream_and_room(
+                    room_id, room_token.stream
+                )
+
+            if membership == Membership.LEAVE:
+                # If they have left the room then clamp the token to be before
+                # they left the room, to save the effort of loading from the
+                # database.
+                leave_token = yield self.store.get_topological_token_for_event(
+                    member_event_id
+                )
+                leave_token = RoomStreamToken.parse(leave_token)
+                if leave_token.topological < max_topo:
+                    source_config.from_key = str(leave_token)
+
+            yield self.hs.get_handlers().federation_handler.maybe_backfill(
+                room_id, max_topo
             )
-            leave_token = RoomStreamToken.parse(leave_token)
-            if leave_token.topological < room_token.topological:
-                source_config.from_key = str(leave_token)
-
-            if source_config.direction == "f":
-                if source_config.to_key is None:
-                    source_config.to_key = str(leave_token)
-                else:
-                    to_token = RoomStreamToken.parse(source_config.to_key)
-                    if leave_token.topological < to_token.topological:
-                        source_config.to_key = str(leave_token)
-
-        yield self.hs.get_handlers().federation_handler.maybe_backfill(
-            room_id, room_token.topological
-        )
 
         events, next_key = yield data_source.get_pagination_rows(
             requester.user, source_config, room_id
@@ -195,12 +196,25 @@ class MessageHandler(BaseHandler):
 
         if builder.type == EventTypes.Member:
             membership = builder.content.get("membership", None)
+            target = UserID.from_string(builder.state_key)
+
             if membership == Membership.JOIN:
-                joinee = UserID.from_string(builder.state_key)
                 # If event doesn't include a display name, add one.
                 yield collect_presencelike_data(
-                    self.distributor, joinee, builder.content
+                    self.distributor, target, builder.content
                 )
+            elif membership == Membership.INVITE:
+                profile = self.hs.get_handlers().profile_handler
+                content = builder.content
+
+                try:
+                    content["displayname"] = yield profile.get_displayname(target)
+                    content["avatar_url"] = yield profile.get_avatar_url(target)
+                except Exception as e:
+                    logger.info(
+                        "Failed to get profile information for %r: %s",
+                        target, e
+                    )
 
         if token_id is not None:
             builder.internal_metadata.token_id = token_id
@@ -214,7 +228,7 @@ class MessageHandler(BaseHandler):
         defer.returnValue((event, context))
 
     @defer.inlineCallbacks
-    def send_event(self, event, context, ratelimit=True, is_guest=False):
+    def send_nonmember_event(self, requester, event, context, ratelimit=True):
         """
         Persists and notifies local clients and federation of an event.
 
@@ -224,55 +238,70 @@ class MessageHandler(BaseHandler):
             ratelimit (bool): Whether to rate limit this send.
             is_guest (bool): Whether the sender is a guest.
         """
+        if event.type == EventTypes.Member:
+            raise SynapseError(
+                500,
+                "Tried to send member event through non-member codepath"
+            )
+
         user = UserID.from_string(event.sender)
 
         assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
 
-        if ratelimit:
-            self.ratelimit(event.sender)
-
         if event.is_state():
-            prev_state = context.current_state.get((event.type, event.state_key))
-            if prev_state and event.user_id == prev_state.user_id:
-                prev_content = encode_canonical_json(prev_state.content)
-                next_content = encode_canonical_json(event.content)
-                if prev_content == next_content:
-                    # Duplicate suppression for state updates with same sender
-                    # and content.
-                    defer.returnValue(prev_state)
-
-        if event.type == EventTypes.Member:
-            member_handler = self.hs.get_handlers().room_member_handler
-            yield member_handler.send_membership_event(event, context, is_guest=is_guest)
-        else:
-            yield self.handle_new_client_event(
-                event=event,
-                context=context,
-            )
+            prev_state = self.deduplicate_state_event(event, context)
+            if prev_state is not None:
+                defer.returnValue(prev_state)
+
+        yield self.handle_new_client_event(
+            requester=requester,
+            event=event,
+            context=context,
+            ratelimit=ratelimit,
+        )
 
         if event.type == EventTypes.Message:
             presence = self.hs.get_handlers().presence_handler
-            with PreserveLoggingContext():
-                presence.bump_presence_active_time(user)
+            yield presence.bump_presence_active_time(user)
+
+    def deduplicate_state_event(self, event, context):
+        """
+        Checks whether event is in the latest resolved state in context.
+
+        If so, returns the version of the event in context.
+        Otherwise, returns None.
+        """
+        prev_event = context.current_state.get((event.type, event.state_key))
+        if prev_event and event.user_id == prev_event.user_id:
+            prev_content = encode_canonical_json(prev_event.content)
+            next_content = encode_canonical_json(event.content)
+            if prev_content == next_content:
+                return prev_event
+        return None
 
     @defer.inlineCallbacks
-    def create_and_send_event(self, event_dict, ratelimit=True,
-                              token_id=None, txn_id=None, is_guest=False):
+    def create_and_send_nonmember_event(
+        self,
+        requester,
+        event_dict,
+        ratelimit=True,
+        txn_id=None
+    ):
         """
         Creates an event, then sends it.
 
-        See self.create_event and self.send_event.
+        See self.create_event and self.send_nonmember_event.
         """
         event, context = yield self.create_event(
             event_dict,
-            token_id=token_id,
+            token_id=requester.access_token_id,
             txn_id=txn_id
         )
-        yield self.send_event(
+        yield self.send_nonmember_event(
+            requester,
             event,
             context,
             ratelimit=ratelimit,
-            is_guest=is_guest
         )
         defer.returnValue(event)
 
@@ -633,8 +662,8 @@ class MessageHandler(BaseHandler):
             user_id, messages, is_peeking=is_peeking
         )
 
-        start_token = StreamToken(token[0], 0, 0, 0, 0)
-        end_token = StreamToken(token[1], 0, 0, 0, 0)
+        start_token = StreamToken.START.copy_and_replace("room_key", token[0])
+        end_token = StreamToken.START.copy_and_replace("room_key", token[1])
 
         time_now = self.clock.time_msec()
 
@@ -658,10 +687,6 @@ class MessageHandler(BaseHandler):
             room_id=room_id,
         )
 
-        # TODO(paul): I wish I was called with user objects not user_id
-        #   strings...
-        auth_user = UserID.from_string(user_id)
-
         # TODO: These concurrently
         time_now = self.clock.time_msec()
         state = [
@@ -686,13 +711,11 @@ class MessageHandler(BaseHandler):
         @defer.inlineCallbacks
         def get_presence():
             states = yield presence_handler.get_states(
-                target_users=[UserID.from_string(m.user_id) for m in room_members],
-                auth_user=auth_user,
+                [m.user_id for m in room_members],
                 as_event=True,
-                check_auth=False,
             )
 
-            defer.returnValue(states.values())
+            defer.returnValue(states)
 
         @defer.inlineCallbacks
         def get_receipts():