summary refs log tree commit diff
path: root/synapse/handlers/federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r--synapse/handlers/federation.py238
1 files changed, 144 insertions, 94 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 0e904f2da0..bc26921768 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,11 +19,13 @@
 
 import itertools
 import logging
+from typing import Dict, Iterable, Optional, Sequence, Tuple
 
 import six
 from six import iteritems, itervalues
 from six.moves import http_client, zip
 
+import attr
 from signedjson.key import decode_verify_key_bytes
 from signedjson.sign import verify_signed_json
 from unpaddedbase64 import decode_base64
@@ -45,6 +47,7 @@ from synapse.api.errors import (
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
 from synapse.crypto.event_signing import compute_event_signature
 from synapse.event_auth import auth_types_for_event
+from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.events.validator import EventValidator
 from synapse.logging.context import (
@@ -72,6 +75,23 @@ from ._base import BaseHandler
 logger = logging.getLogger(__name__)
 
 
+@attr.s
+class _NewEventInfo:
+    """Holds information about a received event, ready for passing to _handle_new_events
+
+    Attributes:
+        event: the received event
+
+        state: the state at that event
+
+        auth_events: the auth_event map for that event
+    """
+
+    event = attr.ib(type=EventBase)
+    state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
+    auth_events = attr.ib(type=Optional[Dict[Tuple[str, str], EventBase]], default=None)
+
+
 def shortstr(iterable, maxitems=5):
     """If iterable has maxitems or fewer, return the stringification of a list
     containing those items.
@@ -121,6 +141,7 @@ class FederationHandler(BaseHandler):
         self.pusher_pool = hs.get_pusherpool()
         self.spam_checker = hs.get_spam_checker()
         self.event_creation_handler = hs.get_event_creation_handler()
+        self._message_handler = hs.get_message_handler()
         self._server_notices_mxid = hs.config.server_notices_mxid
         self.config = hs.config
         self.http_client = hs.get_simple_http_client()
@@ -141,6 +162,8 @@ class FederationHandler(BaseHandler):
 
         self.third_party_event_rules = hs.get_third_party_event_rules()
 
+        self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+
     @defer.inlineCallbacks
     def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
         """ Process a PDU received via a federation /send/ transaction, or
@@ -594,14 +617,14 @@ class FederationHandler(BaseHandler):
                     for e in auth_chain
                     if e.event_id in auth_ids or e.type == EventTypes.Create
                 }
-                event_infos.append({"event": e, "auth_events": auth})
+                event_infos.append(_NewEventInfo(event=e, auth_events=auth))
                 seen_ids.add(e.event_id)
 
             logger.info(
                 "[%s %s] persisting newly-received auth/state events %s",
                 room_id,
                 event_id,
-                [e["event"].event_id for e in event_infos],
+                [e.event.event_id for e in event_infos],
             )
             yield self._handle_new_events(origin, event_infos)
 
@@ -792,9 +815,9 @@ class FederationHandler(BaseHandler):
 
             a.internal_metadata.outlier = True
             ev_infos.append(
-                {
-                    "event": a,
-                    "auth_events": {
+                _NewEventInfo(
+                    event=a,
+                    auth_events={
                         (
                             auth_events[a_id].type,
                             auth_events[a_id].state_key,
@@ -802,7 +825,7 @@ class FederationHandler(BaseHandler):
                         for a_id in a.auth_event_ids()
                         if a_id in auth_events
                     },
-                }
+                )
             )
 
         # Step 1b: persist the events in the chunk we fetched state for (i.e.
@@ -814,10 +837,10 @@ class FederationHandler(BaseHandler):
             assert not ev.internal_metadata.is_outlier()
 
             ev_infos.append(
-                {
-                    "event": ev,
-                    "state": events_to_state[e_id],
-                    "auth_events": {
+                _NewEventInfo(
+                    event=ev,
+                    state=events_to_state[e_id],
+                    auth_events={
                         (
                             auth_events[a_id].type,
                             auth_events[a_id].state_key,
@@ -825,7 +848,7 @@ class FederationHandler(BaseHandler):
                         for a_id in ev.auth_event_ids()
                         if a_id in auth_events
                     },
-                }
+                )
             )
 
         yield self._handle_new_events(dest, ev_infos, backfilled=True)
@@ -1428,9 +1451,9 @@ class FederationHandler(BaseHandler):
         return event
 
     @defer.inlineCallbacks
-    def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
+    def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content):
         origin, event, event_format_version = yield self._make_and_verify_event(
-            target_hosts, room_id, user_id, "leave"
+            target_hosts, room_id, user_id, "leave", content=content,
         )
         # Mark as outlier as we don't have any state for this event; we're not
         # even in the room.
@@ -1710,7 +1733,12 @@ class FederationHandler(BaseHandler):
         return context
 
     @defer.inlineCallbacks
-    def _handle_new_events(self, origin, event_infos, backfilled=False):
+    def _handle_new_events(
+        self,
+        origin: str,
+        event_infos: Iterable[_NewEventInfo],
+        backfilled: bool = False,
+    ):
         """Creates the appropriate contexts and persists events. The events
         should not depend on one another, e.g. this should be used to persist
         a bunch of outliers, but not a chunk of individual events that depend
@@ -1720,14 +1748,14 @@ class FederationHandler(BaseHandler):
         """
 
         @defer.inlineCallbacks
-        def prep(ev_info):
-            event = ev_info["event"]
+        def prep(ev_info: _NewEventInfo):
+            event = ev_info.event
             with nested_logging_context(suffix=event.event_id):
                 res = yield self._prep_event(
                     origin,
                     event,
-                    state=ev_info.get("state"),
-                    auth_events=ev_info.get("auth_events"),
+                    state=ev_info.state,
+                    auth_events=ev_info.auth_events,
                     backfilled=backfilled,
                 )
             return res
@@ -1741,7 +1769,7 @@ class FederationHandler(BaseHandler):
 
         yield self.persist_events_and_notify(
             [
-                (ev_info["event"], context)
+                (ev_info.event, context)
                 for ev_info, context in zip(event_infos, contexts)
             ],
             backfilled=backfilled,
@@ -1843,7 +1871,14 @@ class FederationHandler(BaseHandler):
         yield self.persist_events_and_notify([(event, new_event_context)])
 
     @defer.inlineCallbacks
-    def _prep_event(self, origin, event, state, auth_events, backfilled):
+    def _prep_event(
+        self,
+        origin: str,
+        event: EventBase,
+        state: Optional[Iterable[EventBase]],
+        auth_events: Optional[Dict[Tuple[str, str], EventBase]],
+        backfilled: bool,
+    ):
         """
 
         Args:
@@ -1851,7 +1886,7 @@ class FederationHandler(BaseHandler):
             event:
             state:
             auth_events:
-            backfilled (bool)
+            backfilled:
 
         Returns:
             Deferred, which resolves to synapse.events.snapshot.EventContext
@@ -1887,15 +1922,16 @@ class FederationHandler(BaseHandler):
         return context
 
     @defer.inlineCallbacks
-    def _check_for_soft_fail(self, event, state, backfilled):
+    def _check_for_soft_fail(
+        self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool
+    ):
         """Checks if we should soft fail the event, if so marks the event as
         such.
 
         Args:
-            event (FrozenEvent)
-            state (dict|None): The state at the event if we don't have all the
-                event's prev events
-            backfilled (bool): Whether the event is from backfill
+            event
+            state: The state at the event if we don't have all the event's prev events
+            backfilled: Whether the event is from backfill
 
         Returns:
             Deferred
@@ -2040,8 +2076,10 @@ class FederationHandler(BaseHandler):
             auth_events (dict[(str, str)->synapse.events.EventBase]):
                 Map from (event_type, state_key) to event
 
-                What we expect the event's auth_events to be, based on the event's
-                position in the dag. I think? maybe??
+                Normally, our calculated auth_events based on the state of the room
+                at the event's position in the DAG, though occasionally (eg if the
+                event is an outlier), may be the auth events claimed by the remote
+                server.
 
                 Also NB that this function adds entries to it.
         Returns:
@@ -2091,35 +2129,35 @@ class FederationHandler(BaseHandler):
             origin (str):
             event (synapse.events.EventBase):
             context (synapse.events.snapshot.EventContext):
+
             auth_events (dict[(str, str)->synapse.events.EventBase]):
+                Map from (event_type, state_key) to event
+
+                Normally, our calculated auth_events based on the state of the room
+                at the event's position in the DAG, though occasionally (eg if the
+                event is an outlier), may be the auth events claimed by the remote
+                server.
+
+                Also NB that this function adds entries to it.
 
         Returns:
             defer.Deferred[EventContext]: updated context
         """
         event_auth_events = set(event.auth_event_ids())
 
-        if event.is_state():
-            event_key = (event.type, event.state_key)
-        else:
-            event_key = None
-
-        # if the event's auth_events refers to events which are not in our
-        # calculated auth_events, we need to fetch those events from somewhere.
-        #
-        # we start by fetching them from the store, and then try calling /event_auth/.
+        # missing_auth is the set of the event's auth_events which we don't yet have
+        # in auth_events.
         missing_auth = event_auth_events.difference(
             e.event_id for e in auth_events.values()
         )
 
+        # if we have missing events, we need to fetch those events from somewhere.
+        #
+        # we start by checking if they are in the store, and then try calling /event_auth/.
         if missing_auth:
-            # TODO: can we use store.have_seen_events here instead?
-            have_events = yield self.store.get_seen_events_with_rejections(missing_auth)
-            logger.debug("Got events %s from store", have_events)
-            missing_auth.difference_update(have_events.keys())
-        else:
-            have_events = {}
-
-        have_events.update({e.event_id: "" for e in auth_events.values()})
+            have_events = yield self.store.have_seen_events(missing_auth)
+            logger.debug("Events %s are in the store", have_events)
+            missing_auth.difference_update(have_events)
 
         if missing_auth:
             # If we don't have all the auth events, we need to get them.
@@ -2165,19 +2203,18 @@ class FederationHandler(BaseHandler):
                     except AuthError:
                         pass
 
-                have_events = yield self.store.get_seen_events_with_rejections(
-                    event.auth_event_ids()
-                )
             except Exception:
-                # FIXME:
                 logger.exception("Failed to get auth chain")
 
         if event.internal_metadata.is_outlier():
+            # XXX: given that, for an outlier, we'll be working with the
+            # event's *claimed* auth events rather than those we calculated:
+            # (a) is there any point in this test, since different_auth below will
+            # obviously be empty
+            # (b) alternatively, why don't we do it earlier?
             logger.info("Skipping auth_event fetch for outlier")
             return context
 
-        # FIXME: Assumes we have and stored all the state for all the
-        # prev_events
         different_auth = event_auth_events.difference(
             e.event_id for e in auth_events.values()
         )
@@ -2191,53 +2228,58 @@ class FederationHandler(BaseHandler):
             different_auth,
         )
 
-        room_version = yield self.store.get_room_version(event.room_id)
+        # XXX: currently this checks for redactions but I'm not convinced that is
+        # necessary?
+        different_events = yield self.store.get_events_as_list(different_auth)
 
-        different_events = yield make_deferred_yieldable(
-            defer.gatherResults(
-                [
-                    run_in_background(
-                        self.store.get_event, d, allow_none=True, allow_rejected=False
-                    )
-                    for d in different_auth
-                    if d in have_events and not have_events[d]
-                ],
-                consumeErrors=True,
-            )
-        ).addErrback(unwrapFirstError)
+        for d in different_events:
+            if d.room_id != event.room_id:
+                logger.warning(
+                    "Event %s refers to auth_event %s which is in a different room",
+                    event.event_id,
+                    d.event_id,
+                )
 
-        if different_events:
-            local_view = dict(auth_events)
-            remote_view = dict(auth_events)
-            remote_view.update(
-                {(d.type, d.state_key): d for d in different_events if d}
-            )
+                # don't attempt to resolve the claimed auth events against our own
+                # in this case: just use our own auth events.
+                #
+                # XXX: should we reject the event in this case? It feels like we should,
+                # but then shouldn't we also do so if we've failed to fetch any of the
+                # auth events?
+                return context
 
-            new_state = yield self.state_handler.resolve_events(
-                room_version,
-                [list(local_view.values()), list(remote_view.values())],
-                event,
-            )
+        # now we state-resolve between our own idea of the auth events, and the remote's
+        # idea of them.
 
-            logger.info(
-                "After state res: updating auth_events with new state %s",
-                {
-                    (d.type, d.state_key): d.event_id
-                    for d in new_state.values()
-                    if auth_events.get((d.type, d.state_key)) != d
-                },
-            )
+        local_state = auth_events.values()
+        remote_auth_events = dict(auth_events)
+        remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
+        remote_state = remote_auth_events.values()
+
+        room_version = yield self.store.get_room_version(event.room_id)
+        new_state = yield self.state_handler.resolve_events(
+            room_version, (local_state, remote_state), event
+        )
+
+        logger.info(
+            "After state res: updating auth_events with new state %s",
+            {
+                (d.type, d.state_key): d.event_id
+                for d in new_state.values()
+                if auth_events.get((d.type, d.state_key)) != d
+            },
+        )
 
-            auth_events.update(new_state)
+        auth_events.update(new_state)
 
-            context = yield self._update_context_for_auth_events(
-                event, context, auth_events, event_key
-            )
+        context = yield self._update_context_for_auth_events(
+            event, context, auth_events
+        )
 
         return context
 
     @defer.inlineCallbacks
-    def _update_context_for_auth_events(self, event, context, auth_events, event_key):
+    def _update_context_for_auth_events(self, event, context, auth_events):
         """Update the state_ids in an event context after auth event resolution,
         storing the changes as a new state group.
 
@@ -2246,18 +2288,21 @@ class FederationHandler(BaseHandler):
 
             context (synapse.events.snapshot.EventContext): initial event context
 
-            auth_events (dict[(str, str)->str]): Events to update in the event
+            auth_events (dict[(str, str)->EventBase]): Events to update in the event
                 context.
 
-            event_key ((str, str)): (type, state_key) for the current event.
-                this will not be included in the current_state in the context.
-
         Returns:
             Deferred[EventContext]: new event context
         """
+        # exclude the state key of the new event from the current_state in the context.
+        if event.is_state():
+            event_key = (event.type, event.state_key)
+        else:
+            event_key = None
         state_updates = {
             k: a.event_id for k, a in iteritems(auth_events) if k != event_key
         }
+
         current_state_ids = yield context.get_current_state_ids(self.store)
         current_state_ids = dict(current_state_ids)
 
@@ -2459,7 +2504,7 @@ class FederationHandler(BaseHandler):
                 room_version, event_dict, event, context
             )
 
-            EventValidator().validate_new(event)
+            EventValidator().validate_new(event, self.config)
 
             # We need to tell the transaction queue to send this out, even
             # though the sender isn't a local user.
@@ -2574,7 +2619,7 @@ class FederationHandler(BaseHandler):
         event, context = yield self.event_creation_handler.create_new_client_event(
             builder=builder
         )
-        EventValidator().validate_new(event)
+        EventValidator().validate_new(event, self.config)
         return (event, context)
 
     @defer.inlineCallbacks
@@ -2708,6 +2753,11 @@ class FederationHandler(BaseHandler):
                 event_and_contexts, backfilled=backfilled
             )
 
+            if self._ephemeral_messages_enabled:
+                for (event, context) in event_and_contexts:
+                    # If there's an expiry timestamp on the event, schedule its expiry.
+                    self._message_handler.maybe_schedule_expiry(event)
+
             if not backfilled:  # Never notify for backfilled events
                 for event, _ in event_and_contexts:
                     yield self._notify_persisted_event(event, max_stream_id)