summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/_base.py4
-rw-r--r--synapse/storage/databases/main/events.py99
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py10
-rw-r--r--synapse/storage/databases/main/events_worker.py167
-rw-r--r--synapse/storage/databases/main/purge_events.py27
-rw-r--r--synapse/storage/databases/main/rejections.py23
6 files changed, 237 insertions, 93 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index abfc56b061..b15c37679b 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -36,6 +36,10 @@ class SQLBaseStore(metaclass=ABCMeta):
     per data store (and not one per physical database).
     """
 
+    # if set to False, we will query the `state_events` and `rejections` tables when
+    # fetching event data. When True, we rely on it all being in the `events` table.
+    STATE_KEY_IN_EVENTS = False
+
     def __init__(
         self,
         database: DatabasePool,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 2ff3d21305..6456a951bf 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -257,17 +257,30 @@ class PersistEventsStore:
         def _get_events_which_are_prevs_txn(
             txn: LoggingTransaction, batch: Collection[str]
         ) -> None:
-            sql = """
-            SELECT prev_event_id, internal_metadata
-            FROM event_edges
-                INNER JOIN events USING (event_id)
-                LEFT JOIN rejections USING (event_id)
-                LEFT JOIN event_json USING (event_id)
-            WHERE
-                NOT events.outlier
-                AND rejections.event_id IS NULL
-                AND
-            """
+            if self.store.STATE_KEY_IN_EVENTS:
+                sql = """
+                SELECT prev_event_id, internal_metadata
+                FROM event_edges
+                    INNER JOIN events USING (event_id)
+                    LEFT JOIN event_json USING (event_id)
+                WHERE
+                    NOT events.outlier
+                    AND events.rejection_reason IS NULL
+                    AND
+                """
+
+            else:
+                sql = """
+                SELECT prev_event_id, internal_metadata
+                FROM event_edges
+                    INNER JOIN events USING (event_id)
+                    LEFT JOIN rejections USING (event_id)
+                    LEFT JOIN event_json USING (event_id)
+                WHERE
+                    NOT events.outlier
+                    AND rejections.event_id IS NULL
+                    AND
+                """
 
             clause, args = make_in_list_sql_clause(
                 self.database_engine, "prev_event_id", batch
@@ -311,7 +324,19 @@ class PersistEventsStore:
         ) -> None:
             to_recursively_check = batch
 
-            while to_recursively_check:
+            if self.store.STATE_KEY_IN_EVENTS:
+                sql = """
+                SELECT
+                    event_id, prev_event_id, internal_metadata,
+                    events.rejection_reason IS NOT NULL
+                FROM event_edges
+                    INNER JOIN events USING (event_id)
+                    LEFT JOIN event_json USING (event_id)
+                WHERE
+                    NOT events.outlier
+                    AND
+                """
+            else:
                 sql = """
                 SELECT
                     event_id, prev_event_id, internal_metadata,
@@ -325,6 +350,7 @@ class PersistEventsStore:
                     AND
                 """
 
+            while to_recursively_check:
                 clause, args = make_in_list_sql_clause(
                     self.database_engine, "event_id", to_recursively_check
                 )
@@ -530,6 +556,7 @@ class PersistEventsStore:
         event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
 
         self._add_chain_cover_index(
+            self.store.STATE_KEY_IN_EVENTS,
             txn,
             self.db_pool,
             self.store.event_chain_id_gen,
@@ -541,6 +568,7 @@ class PersistEventsStore:
     @classmethod
     def _add_chain_cover_index(
         cls,
+        state_key_in_events: bool,
         txn: LoggingTransaction,
         db_pool: DatabasePool,
         event_chain_id_gen: SequenceGenerator,
@@ -551,6 +579,8 @@ class PersistEventsStore:
         """Calculate the chain cover index for the given events.
 
         Args:
+            state_key_in_events: whether to use the `state_key` column in the `events`
+                table in preference to the `state_events` table
             event_to_room_id: Event ID to the room ID of the event
             event_to_types: Event ID to type and state_key of the event
             event_to_auth_chain: Event ID to list of auth event IDs of the
@@ -610,7 +640,15 @@ class PersistEventsStore:
 
         # We loop here in case we find an out of band membership and need to
         # fetch their auth event info.
-        while missing_auth_chains:
+        if state_key_in_events:
+            sql = """
+                SELECT event_id, events.type, events.state_key, chain_id, sequence_number
+                FROM events
+                LEFT JOIN event_auth_chains USING (event_id)
+                WHERE
+                    events.state_key IS NOT NULL AND
+            """
+        else:
             sql = """
                 SELECT event_id, events.type, se.state_key, chain_id, sequence_number
                 FROM events
@@ -618,6 +656,8 @@ class PersistEventsStore:
                 LEFT JOIN event_auth_chains USING (event_id)
                 WHERE
             """
+
+        while missing_auth_chains:
             clause, args = make_in_list_sql_clause(
                 txn.database_engine,
                 "event_id",
@@ -1641,22 +1681,31 @@ class PersistEventsStore:
     ) -> None:
         to_prefill = []
 
-        rows = []
-
         ev_map = {e.event_id: e for e, _ in events_and_contexts}
         if not ev_map:
             return
 
-        sql = (
-            "SELECT "
-            " e.event_id as event_id, "
-            " r.redacts as redacts,"
-            " rej.event_id as rejects "
-            " FROM events as e"
-            " LEFT JOIN rejections as rej USING (event_id)"
-            " LEFT JOIN redactions as r ON e.event_id = r.redacts"
-            " WHERE "
-        )
+        if self.store.STATE_KEY_IN_EVENTS:
+            sql = (
+                "SELECT "
+                " e.event_id as event_id, "
+                " r.redacts as redacts,"
+                " e.rejection_reason as rejects "
+                " FROM events as e"
+                " LEFT JOIN redactions as r ON e.event_id = r.redacts"
+                " WHERE "
+            )
+        else:
+            sql = (
+                "SELECT "
+                " e.event_id as event_id, "
+                " r.redacts as redacts,"
+                " rej.event_id as rejects "
+                " FROM events as e"
+                " LEFT JOIN rejections as rej USING (event_id)"
+                " LEFT JOIN redactions as r ON e.event_id = r.redacts"
+                " WHERE "
+            )
 
         clause, args = make_in_list_sql_clause(
             self.database_engine, "e.event_id", list(ev_map)
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 7d951d85d6..8746bc493e 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -448,6 +448,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             # First, we get `batch_size` events from the table, pulling out
             # their successor events, if any, and the successor events'
             # rejection status.
+
+            # this should happen before the bg update which drops 'rejections'
+            assert not self.STATE_KEY_IN_EVENTS
+
             txn.execute(
                 """SELECT prev_event_id, event_id, internal_metadata,
                     rejections.event_id IS NOT NULL, events.outlier
@@ -973,6 +977,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             extra_clause = "AND events.room_id = ?"
             tuple_args.append(last_room_id)
 
+        # this should happen before the bg update which drops 'state_events'
+        assert not self.STATE_KEY_IN_EVENTS
+
         sql = """
             SELECT
                 event_id, state_events.type, state_events.state_key,
@@ -1041,9 +1048,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         # Calculate and persist the chain cover index for this set of events.
         #
-        # Annoyingly we need to gut wrench into the persit event store so that
+        # Annoyingly we need to gut wrench into the persist event store so that
         # we can reuse the function to calculate the chain cover for rooms.
         PersistEventsStore._add_chain_cover_index(
+            False,
             txn,
             self.db_pool,
             self.event_chain_id_gen,  # type: ignore[attr-defined]
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index b99b107784..7cd94f3964 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1482,20 +1482,35 @@ class EventsWorkerStore(SQLBaseStore):
         def get_all_new_forward_event_rows(
             txn: LoggingTransaction,
         ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
-            sql = (
-                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
-                " FROM events AS e"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events AS se USING (event_id)"
-                " LEFT JOIN event_relations USING (event_id)"
-                " LEFT JOIN room_memberships USING (event_id)"
-                " LEFT JOIN rejections USING (event_id)"
-                " WHERE ? < stream_ordering AND stream_ordering <= ?"
-                " AND instance_name = ?"
-                " ORDER BY stream_ordering ASC"
-                " LIMIT ?"
-            )
+            if self.STATE_KEY_IN_EVENTS:
+                sql = (
+                    "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+                    " e.state_key, redacts, relates_to_id, membership, e.rejection_reason IS NOT NULL"
+                    " FROM events AS e"
+                    " LEFT JOIN redactions USING (event_id)"
+                    " LEFT JOIN event_relations USING (event_id)"
+                    " LEFT JOIN room_memberships USING (event_id)"
+                    " WHERE ? < stream_ordering AND stream_ordering <= ?"
+                    " AND instance_name = ?"
+                    " ORDER BY stream_ordering ASC"
+                    " LIMIT ?"
+                )
+            else:
+                sql = (
+                    "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+                    " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+                    " FROM events AS e"
+                    " LEFT JOIN redactions USING (event_id)"
+                    " LEFT JOIN state_events AS se USING (event_id)"
+                    " LEFT JOIN event_relations USING (event_id)"
+                    " LEFT JOIN room_memberships USING (event_id)"
+                    " LEFT JOIN rejections USING (event_id)"
+                    " WHERE ? < stream_ordering AND stream_ordering <= ?"
+                    " AND instance_name = ?"
+                    " ORDER BY stream_ordering ASC"
+                    " LIMIT ?"
+                )
+
             txn.execute(sql, (last_id, current_id, instance_name, limit))
             return cast(
                 List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
@@ -1523,21 +1538,36 @@ class EventsWorkerStore(SQLBaseStore):
         def get_ex_outlier_stream_rows_txn(
             txn: LoggingTransaction,
         ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
-            sql = (
-                "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
-                " FROM events AS e"
-                " INNER JOIN ex_outlier_stream AS out USING (event_id)"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events AS se USING (event_id)"
-                " LEFT JOIN event_relations USING (event_id)"
-                " LEFT JOIN room_memberships USING (event_id)"
-                " LEFT JOIN rejections USING (event_id)"
-                " WHERE ? < event_stream_ordering"
-                " AND event_stream_ordering <= ?"
-                " AND out.instance_name = ?"
-                " ORDER BY event_stream_ordering ASC"
-            )
+            if self.STATE_KEY_IN_EVENTS:
+                sql = (
+                    "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
+                    " e.state_key, redacts, relates_to_id, membership, e.rejection_reason IS NOT NULL"
+                    " FROM events AS e"
+                    " INNER JOIN ex_outlier_stream AS out USING (event_id)"
+                    " LEFT JOIN redactions USING (event_id)"
+                    " LEFT JOIN event_relations USING (event_id)"
+                    " LEFT JOIN room_memberships USING (event_id)"
+                    " WHERE ? < event_stream_ordering"
+                    " AND event_stream_ordering <= ?"
+                    " AND out.instance_name = ?"
+                    " ORDER BY event_stream_ordering ASC"
+                )
+            else:
+                sql = (
+                    "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
+                    " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+                    " FROM events AS e"
+                    " INNER JOIN ex_outlier_stream AS out USING (event_id)"
+                    " LEFT JOIN redactions USING (event_id)"
+                    " LEFT JOIN state_events AS se USING (event_id)"
+                    " LEFT JOIN event_relations USING (event_id)"
+                    " LEFT JOIN room_memberships USING (event_id)"
+                    " LEFT JOIN rejections USING (event_id)"
+                    " WHERE ? < event_stream_ordering"
+                    " AND event_stream_ordering <= ?"
+                    " AND out.instance_name = ?"
+                    " ORDER BY event_stream_ordering ASC"
+                )
 
             txn.execute(sql, (last_id, current_id, instance_name))
             return cast(
@@ -1581,18 +1611,32 @@ class EventsWorkerStore(SQLBaseStore):
         def get_all_new_backfill_event_rows(
             txn: LoggingTransaction,
         ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
-            sql = (
-                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " se.state_key, redacts, relates_to_id"
-                " FROM events AS e"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events AS se USING (event_id)"
-                " LEFT JOIN event_relations USING (event_id)"
-                " WHERE ? > stream_ordering AND stream_ordering >= ?"
-                "  AND instance_name = ?"
-                " ORDER BY stream_ordering ASC"
-                " LIMIT ?"
-            )
+            if self.STATE_KEY_IN_EVENTS:
+                sql = (
+                    "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+                    " e.state_key, redacts, relates_to_id"
+                    " FROM events AS e"
+                    " LEFT JOIN redactions USING (event_id)"
+                    " LEFT JOIN event_relations USING (event_id)"
+                    " WHERE ? > stream_ordering AND stream_ordering >= ?"
+                    "  AND instance_name = ?"
+                    " ORDER BY stream_ordering ASC"
+                    " LIMIT ?"
+                )
+            else:
+                sql = (
+                    "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+                    " se.state_key, redacts, relates_to_id"
+                    " FROM events AS e"
+                    " LEFT JOIN redactions USING (event_id)"
+                    " LEFT JOIN state_events AS se USING (event_id)"
+                    " LEFT JOIN event_relations USING (event_id)"
+                    " WHERE ? > stream_ordering AND stream_ordering >= ?"
+                    "  AND instance_name = ?"
+                    " ORDER BY stream_ordering ASC"
+                    " LIMIT ?"
+                )
+
             txn.execute(sql, (-last_id, -current_id, instance_name, limit))
             new_event_updates: List[
                 Tuple[int, Tuple[str, str, str, str, str, str]]
@@ -1611,19 +1655,34 @@ class EventsWorkerStore(SQLBaseStore):
             else:
                 upper_bound = current_id
 
-            sql = (
-                "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " se.state_key, redacts, relates_to_id"
-                " FROM events AS e"
-                " INNER JOIN ex_outlier_stream AS out USING (event_id)"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events AS se USING (event_id)"
-                " LEFT JOIN event_relations USING (event_id)"
-                " WHERE ? > event_stream_ordering"
-                " AND event_stream_ordering >= ?"
-                " AND out.instance_name = ?"
-                " ORDER BY event_stream_ordering DESC"
-            )
+            if self.STATE_KEY_IN_EVENTS:
+                sql = (
+                    "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
+                    " e.state_key, redacts, relates_to_id"
+                    " FROM events AS e"
+                    " INNER JOIN ex_outlier_stream AS out USING (event_id)"
+                    " LEFT JOIN redactions USING (event_id)"
+                    " LEFT JOIN event_relations USING (event_id)"
+                    " WHERE ? > event_stream_ordering"
+                    " AND event_stream_ordering >= ?"
+                    " AND out.instance_name = ?"
+                    " ORDER BY event_stream_ordering DESC"
+                )
+            else:
+                sql = (
+                    "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
+                    " se.state_key, redacts, relates_to_id"
+                    " FROM events AS e"
+                    " INNER JOIN ex_outlier_stream AS out USING (event_id)"
+                    " LEFT JOIN redactions USING (event_id)"
+                    " LEFT JOIN state_events AS se USING (event_id)"
+                    " LEFT JOIN event_relations USING (event_id)"
+                    " WHERE ? > event_stream_ordering"
+                    " AND event_stream_ordering >= ?"
+                    " AND out.instance_name = ?"
+                    " ORDER BY event_stream_ordering DESC"
+                )
+
             txn.execute(sql, (-last_id, -upper_bound, instance_name))
             # Type safety: iterating over `txn` yields `Tuple`, i.e.
             # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 87b0d09039..37b53de971 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -122,7 +122,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
 
         logger.info("[purge] looking for events to delete")
 
-        should_delete_expr = "state_events.state_key IS NULL"
+        should_delete_expr = (
+            "e.state_key IS NULL"
+            if self.STATE_KEY_IN_EVENTS
+            else "state_events.state_key IS NULL"
+        )
         should_delete_params: Tuple[Any, ...] = ()
         if not delete_local_events:
             should_delete_expr += " AND event_id NOT LIKE ?"
@@ -134,12 +138,23 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
 
         # Note that we insert events that are outliers and aren't going to be
         # deleted, as nothing will happen to them.
+        if self.STATE_KEY_IN_EVENTS:
+            sqlf = """
+            INSERT INTO events_to_purge
+                SELECT event_id, %s
+                FROM events AS e
+                WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?
+            """
+        else:
+            sqlf = """
+            INSERT INTO events_to_purge
+                SELECT event_id, %s
+                FROM events AS e LEFT JOIN state_events USING (event_id)
+                WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?
+            """
+
         txn.execute(
-            "INSERT INTO events_to_purge"
-            " SELECT event_id, %s"
-            " FROM events AS e LEFT JOIN state_events USING (event_id)"
-            " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?"
-            % (should_delete_expr, should_delete_expr),
+            sqlf % (should_delete_expr, should_delete_expr),
             should_delete_params,
         )
 
diff --git a/synapse/storage/databases/main/rejections.py b/synapse/storage/databases/main/rejections.py
index 167318b314..f57262c51a 100644
--- a/synapse/storage/databases/main/rejections.py
+++ b/synapse/storage/databases/main/rejections.py
@@ -22,10 +22,19 @@ logger = logging.getLogger(__name__)
 
 class RejectionsStore(SQLBaseStore):
     async def get_rejection_reason(self, event_id: str) -> Optional[str]:
-        return await self.db_pool.simple_select_one_onecol(
-            table="rejections",
-            retcol="reason",
-            keyvalues={"event_id": event_id},
-            allow_none=True,
-            desc="get_rejection_reason",
-        )
+        if self.STATE_KEY_IN_EVENTS:
+            return await self.db_pool.simple_select_one_onecol(
+                table="events",
+                retcol="rejection_reason",
+                keyvalues={"event_id": event_id},
+                allow_none=True,
+                desc="get_rejection_reason",
+            )
+        else:
+            return await self.db_pool.simple_select_one_onecol(
+                table="rejections",
+                retcol="reason",
+                keyvalues={"event_id": event_id},
+                allow_none=True,
+                desc="get_rejection_reason",
+            )