summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/events_worker.py376
1 files changed, 234 insertions, 142 deletions
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 874d0a56bc..858fc755a1 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -37,6 +37,7 @@ from synapse.logging.context import (
 )
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import get_domain_from_id
+from synapse.util import batch_iter
 from synapse.util.metrics import Measure
 
 from ._base import SQLBaseStore
@@ -218,9 +219,108 @@ class EventsWorkerStore(SQLBaseStore):
         if not event_ids:
             defer.returnValue([])
 
-        event_id_list = event_ids
-        event_ids = set(event_ids)
+        # there may be duplicates so we cast the list to a set
+        event_entry_map = yield self._get_events_from_cache_or_db(
+            set(event_ids), allow_rejected=allow_rejected
+        )
+
+        events = []
+        for event_id in event_ids:
+            entry = event_entry_map.get(event_id, None)
+            if not entry:
+                continue
+
+            if not allow_rejected:
+                assert not entry.event.rejected_reason, (
+                    "rejected event returned from _get_events_from_cache_or_db despite "
+                    "allow_rejected=False"
+                )
+
+            # We may not have had the original event when we received a redaction, so
+            # we have to recheck auth now.
+
+            if not allow_rejected and entry.event.type == EventTypes.Redaction:
+                redacted_event_id = entry.event.redacts
+                event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
+                original_event_entry = event_map.get(redacted_event_id)
+                if not original_event_entry:
+                    # we don't have the redacted event (or it was rejected).
+                    #
+                    # We assume that the redaction isn't authorized for now; if the
+                    # redacted event later turns up, the redaction will be re-checked,
+                    # and if it is found valid, the original will get redacted before it
+                    # is served to the client.
+                    logger.debug(
+                        "Withholding redaction event %s since we don't (yet) have the "
+                        "original %s",
+                        event_id,
+                        redacted_event_id,
+                    )
+                    continue
+
+                original_event = original_event_entry.event
+                if original_event.type == EventTypes.Create:
+                    # we never serve redactions of Creates to clients.
+                    logger.info(
+                        "Withholding redaction %s of create event %s",
+                        event_id,
+                        redacted_event_id,
+                    )
+                    continue
+
+                if entry.event.internal_metadata.need_to_check_redaction():
+                    original_domain = get_domain_from_id(original_event.sender)
+                    redaction_domain = get_domain_from_id(entry.event.sender)
+                    if original_domain != redaction_domain:
+                        # the senders don't match, so this is forbidden
+                        logger.info(
+                            "Withholding redaction %s whose sender domain %s doesn't "
+                            "match that of redacted event %s %s",
+                            event_id,
+                            redaction_domain,
+                            redacted_event_id,
+                            original_domain,
+                        )
+                        continue
+
+                    # Update the cache to save doing the checks again.
+                    entry.event.internal_metadata.recheck_redaction = False
+
+            if check_redacted and entry.redacted_event:
+                event = entry.redacted_event
+            else:
+                event = entry.event
+
+            events.append(event)
+
+            if get_prev_content:
+                if "replaces_state" in event.unsigned:
+                    prev = yield self.get_event(
+                        event.unsigned["replaces_state"],
+                        get_prev_content=False,
+                        allow_none=True,
+                    )
+                    if prev:
+                        event.unsigned = dict(event.unsigned)
+                        event.unsigned["prev_content"] = prev.content
+                        event.unsigned["prev_sender"] = prev.sender
+
+        defer.returnValue(events)
+
+    @defer.inlineCallbacks
+    def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
+        """Fetch a bunch of events from the cache or the database.
+
+        If events are pulled from the database, they will be cached for future lookups.
 
+        Args:
+            event_ids (Iterable[str]): The event_ids of the events to fetch
+            allow_rejected (bool): Whether to include rejected events
+
+        Returns:
+            Deferred[Dict[str, _EventCacheEntry]]:
+                map from event id to result
+        """
         event_entry_map = self._get_events_from_cache(
             event_ids, allow_rejected=allow_rejected
         )
@@ -243,81 +343,7 @@ class EventsWorkerStore(SQLBaseStore):
 
             event_entry_map.update(missing_events)
 
-        events = []
-        for event_id in event_id_list:
-            entry = event_entry_map.get(event_id, None)
-            if not entry:
-                continue
-
-            # Starting in room version v3, some redactions need to be rechecked if we
-            # didn't have the redacted event at the time, so we recheck on read
-            # instead.
-            if not allow_rejected and entry.event.type == EventTypes.Redaction:
-                if entry.event.internal_metadata.need_to_check_redaction():
-                    # XXX: we need to avoid calling get_event here.
-                    #
-                    # The problem is that we end up at this point when an event
-                    # which has been redacted is pulled out of the database by
-                    # _enqueue_events, because _enqueue_events needs to check
-                    # the redaction before it can cache the redacted event. So
-                    # obviously, calling get_event to get the redacted event out
-                    # of the database gives us an infinite loop.
-                    #
-                    # For now (quick hack to fix during 0.99 release cycle), we
-                    # just go and fetch the relevant row from the db, but it
-                    # would be nice to think about how we can cache this rather
-                    # than hit the db every time we access a redaction event.
-                    #
-                    # One thought on how to do this:
-                    #  1. split get_events_as_list up so that it is divided into
-                    #     (a) get the rawish event from the db/cache, (b) do the
-                    #     redaction/rejection filtering
-                    #  2. have _get_event_from_row just call the first half of
-                    #     that
-
-                    orig_sender = yield self._simple_select_one_onecol(
-                        table="events",
-                        keyvalues={"event_id": entry.event.redacts},
-                        retcol="sender",
-                        allow_none=True,
-                    )
-
-                    expected_domain = get_domain_from_id(entry.event.sender)
-                    if (
-                        orig_sender
-                        and get_domain_from_id(orig_sender) == expected_domain
-                    ):
-                        # This redaction event is allowed. Mark as not needing a
-                        # recheck.
-                        entry.event.internal_metadata.recheck_redaction = False
-                    else:
-                        # We don't have the event that is being redacted, so we
-                        # assume that the event isn't authorized for now. (If we
-                        # later receive the event, then we will always redact
-                        # it anyway, since we have this redaction)
-                        continue
-
-            if allow_rejected or not entry.event.rejected_reason:
-                if check_redacted and entry.redacted_event:
-                    event = entry.redacted_event
-                else:
-                    event = entry.event
-
-                events.append(event)
-
-                if get_prev_content:
-                    if "replaces_state" in event.unsigned:
-                        prev = yield self.get_event(
-                            event.unsigned["replaces_state"],
-                            get_prev_content=False,
-                            allow_none=True,
-                        )
-                        if prev:
-                            event.unsigned = dict(event.unsigned)
-                            event.unsigned["prev_content"] = prev.content
-                            event.unsigned["prev_sender"] = prev.sender
-
-        defer.returnValue(events)
+        return event_entry_map
 
     def _invalidate_get_event_cache(self, event_id):
         self._get_event_cache.invalidate((event_id,))
@@ -326,7 +352,7 @@ class EventsWorkerStore(SQLBaseStore):
         """Fetch events from the caches
 
         Args:
-            events (list(str)): list of event_ids to fetch
+            events (Iterable[str]): list of event_ids to fetch
             allow_rejected (bool): Whether to return events that were rejected
             update_metrics (bool): Whether to update the cache hit ratio metrics
 
@@ -384,19 +410,16 @@ class EventsWorkerStore(SQLBaseStore):
                 The fetch requests. Each entry consists of a list of event
                 ids to be fetched, and a deferred to be completed once the
                 events have been fetched.
-
         """
         with Measure(self._clock, "_fetch_event_list"):
             try:
                 event_id_lists = list(zip(*event_list))[0]
                 event_ids = [item for sublist in event_id_lists for item in sublist]
 
-                rows = self._new_transaction(
+                row_dict = self._new_transaction(
                     conn, "do_fetch", [], [], self._fetch_event_rows, event_ids
                 )
 
-                row_dict = {r["event_id"]: r for r in rows}
-
                 # We only want to resolve deferreds from the main thread
                 def fire(lst, res):
                     for ids, d in lst:
@@ -454,7 +477,7 @@ class EventsWorkerStore(SQLBaseStore):
         logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
 
         if not allow_rejected:
-            rows[:] = [r for r in rows if not r["rejects"]]
+            rows[:] = [r for r in rows if r["rejected_reason"] is None]
 
         res = yield make_deferred_yieldable(
             defer.gatherResults(
@@ -463,8 +486,8 @@ class EventsWorkerStore(SQLBaseStore):
                         self._get_event_from_row,
                         row["internal_metadata"],
                         row["json"],
-                        row["redacts"],
-                        rejected_reason=row["rejects"],
+                        row["redactions"],
+                        rejected_reason=row["rejected_reason"],
                         format_version=row["format_version"],
                     )
                     for row in rows
@@ -475,49 +498,98 @@ class EventsWorkerStore(SQLBaseStore):
 
         defer.returnValue({e.event.event_id: e for e in res if e})
 
-    def _fetch_event_rows(self, txn, events):
-        rows = []
-        N = 200
-        for i in range(1 + len(events) // N):
-            evs = events[i * N : (i + 1) * N]
-            if not evs:
-                break
+    def _fetch_event_rows(self, txn, event_ids):
+        """Fetch event rows from the database
+
+        Events which are not found are omitted from the result.
+
+        The returned per-event dicts contain the following keys:
+
+         * event_id (str)
+
+         * json (str): json-encoded event structure
+
+         * internal_metadata (str): json-encoded internal metadata dict
+
+         * format_version (int|None): The format of the event. Hopefully one
+           of EventFormatVersions. 'None' means the event predates
+           EventFormatVersions (so the event is format V1).
+
+         * rejected_reason (str|None): if the event was rejected, the reason
+           why.
 
+         * redactions (List[str]): a list of event-ids which (claim to) redact
+           this event.
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection):
+            event_ids (Iterable[str]): event IDs to fetch
+
+        Returns:
+            Dict[str, Dict]: a map from event id to event info.
+        """
+        event_dict = {}
+        for evs in batch_iter(event_ids, 200):
             sql = (
                 "SELECT "
-                " e.event_id as event_id, "
+                " e.event_id, "
                 " e.internal_metadata,"
                 " e.json,"
                 " e.format_version, "
-                " r.redacts as redacts,"
-                " rej.event_id as rejects "
+                " rej.reason "
                 " FROM event_json as e"
                 " LEFT JOIN rejections as rej USING (event_id)"
-                " LEFT JOIN redactions as r ON e.event_id = r.redacts"
                 " WHERE e.event_id IN (%s)"
             ) % (",".join(["?"] * len(evs)),)
 
             txn.execute(sql, evs)
-            rows.extend(self.cursor_to_dict(txn))
 
-        return rows
+            for row in txn:
+                event_id = row[0]
+                event_dict[event_id] = {
+                    "event_id": event_id,
+                    "internal_metadata": row[1],
+                    "json": row[2],
+                    "format_version": row[3],
+                    "rejected_reason": row[4],
+                    "redactions": [],
+                }
+
+            # check for redactions
+            redactions_sql = (
+                "SELECT event_id, redacts FROM redactions WHERE redacts IN (%s)"
+            ) % (",".join(["?"] * len(evs)),)
+
+            txn.execute(redactions_sql, evs)
+
+            for (redacter, redacted) in txn:
+                d = event_dict.get(redacted)
+                if d:
+                    d["redactions"].append(redacter)
+
+        return event_dict
 
     @defer.inlineCallbacks
     def _get_event_from_row(
-        self, internal_metadata, js, redacted, format_version, rejected_reason=None
+        self, internal_metadata, js, redactions, format_version, rejected_reason=None
     ):
+        """Parse an event row which has been read from the database
+
+        Args:
+            internal_metadata (str): json-encoded internal_metadata column
+            js (str): json-encoded event body from event_json
+            redactions (list[str]): a list of the events which claim to have redacted
+                this event, from the redactions table
+            format_version: (str): the 'format_version' column
+            rejected_reason (str|None): the reason this event was rejected, if any
+
+        Returns:
+            _EventCacheEntry
+        """
         with Measure(self._clock, "_get_event_from_row"):
             d = json.loads(js)
             internal_metadata = json.loads(internal_metadata)
 
-            if rejected_reason:
-                rejected_reason = yield self._simple_select_one_onecol(
-                    table="rejections",
-                    keyvalues={"event_id": rejected_reason},
-                    retcol="reason",
-                    desc="_get_event_from_row_rejected_reason",
-                )
-
             if format_version is None:
                 # This means that we stored the event before we had the concept
                 # of a event format version, so it must be a V1 event.
@@ -529,41 +601,7 @@ class EventsWorkerStore(SQLBaseStore):
                 rejected_reason=rejected_reason,
             )
 
-            redacted_event = None
-            if redacted:
-                redacted_event = prune_event(original_ev)
-
-                redaction_id = yield self._simple_select_one_onecol(
-                    table="redactions",
-                    keyvalues={"redacts": redacted_event.event_id},
-                    retcol="event_id",
-                    desc="_get_event_from_row_redactions",
-                )
-
-                redacted_event.unsigned["redacted_by"] = redaction_id
-                # Get the redaction event.
-
-                because = yield self.get_event(
-                    redaction_id, check_redacted=False, allow_none=True
-                )
-
-                if because:
-                    # It's fine to do add the event directly, since get_pdu_json
-                    # will serialise this field correctly
-                    redacted_event.unsigned["redacted_because"] = because
-
-                    # Starting in room version v3, some redactions need to be
-                    # rechecked if we didn't have the redacted event at the
-                    # time, so we recheck on read instead.
-                    if because.internal_metadata.need_to_check_redaction():
-                        expected_domain = get_domain_from_id(original_ev.sender)
-                        if get_domain_from_id(because.sender) == expected_domain:
-                            # This redaction event is allowed. Mark as not needing a
-                            # recheck.
-                            because.internal_metadata.recheck_redaction = False
-                        else:
-                            # Senders don't match, so the event isn't actually redacted
-                            redacted_event = None
+            redacted_event = yield self._maybe_redact_event_row(original_ev, redactions)
 
             cache_entry = _EventCacheEntry(
                 event=original_ev, redacted_event=redacted_event
@@ -574,6 +612,60 @@ class EventsWorkerStore(SQLBaseStore):
         defer.returnValue(cache_entry)
 
     @defer.inlineCallbacks
+    def _maybe_redact_event_row(self, original_ev, redactions):
+        """Given an event object and a list of possible redacting event ids,
+        determine whether to honour any of those redactions and if so return a redacted
+        event.
+
+        Args:
+             original_ev (EventBase):
+             redactions (iterable[str]): list of event ids of potential redaction events
+
+        Returns:
+            Deferred[EventBase|None]: if the event should be redacted, a pruned
+                event object. Otherwise, None.
+        """
+        if original_ev.type == "m.room.create":
+            # we choose to ignore redactions of m.room.create events.
+            return None
+
+        redaction_map = yield self._get_events_from_cache_or_db(redactions)
+
+        for redaction_id in redactions:
+            redaction_entry = redaction_map.get(redaction_id)
+            if not redaction_entry:
+                # we don't have the redaction event, or the redaction event was not
+                # authorized.
+                continue
+
+            redaction_event = redaction_entry.event
+
+            # Starting in room version v3, some redactions need to be
+            # rechecked if we didn't have the redacted event at the
+            # time, so we recheck on read instead.
+            if redaction_event.internal_metadata.need_to_check_redaction():
+                expected_domain = get_domain_from_id(original_ev.sender)
+                if get_domain_from_id(redaction_event.sender) == expected_domain:
+                    # This redaction event is allowed. Mark as not needing a recheck.
+                    redaction_event.internal_metadata.recheck_redaction = False
+                else:
+                    # Senders don't match, so the event isn't actually redacted
+                    continue
+
+            # we found a good redaction event. Redact!
+            redacted_event = prune_event(original_ev)
+            redacted_event.unsigned["redacted_by"] = redaction_id
+
+            # It's fine to add the event directly, since get_pdu_json
+            # will serialise this field correctly
+            redacted_event.unsigned["redacted_because"] = redaction_event
+
+            return redacted_event
+
+        # no valid redaction found for this event
+        return None
+
+    @defer.inlineCallbacks
     def have_events_in_timeline(self, event_ids):
         """Given a list of event ids, check if we have already processed and
         stored them as non outliers.