diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 1b6ccd51c8..c128889bf9 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -25,6 +25,7 @@ from twisted.enterprise.adbapi import Connection
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
from synapse.types import JsonDict
from synapse.util import json_encoder
@@ -513,21 +514,35 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
for user_chunk in batch_iter(user_ids, 100):
clause, params = make_in_list_sql_clause(
- txn.database_engine, "k.user_id", user_chunk
- )
- sql = (
- """
- SELECT k.user_id, k.keytype, k.keydata, k.stream_id
- FROM e2e_cross_signing_keys k
- INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
- FROM e2e_cross_signing_keys
- GROUP BY user_id, keytype) s
- USING (user_id, stream_id, keytype)
- WHERE
- """
- + clause
+ txn.database_engine, "user_id", user_chunk
)
+ # Fetch the latest key for each type per user.
+ if isinstance(self.database_engine, PostgresEngine):
+ # The `DISTINCT ON` clause will pick the *first* row it
+ # encounters, so ordering by stream ID desc will ensure we get
+ # the latest key.
+ sql = """
+ SELECT DISTINCT ON (user_id, keytype) user_id, keytype, keydata, stream_id
+ FROM e2e_cross_signing_keys
+ WHERE %(clause)s
+ ORDER BY user_id, keytype, stream_id DESC
+ """ % {
+ "clause": clause
+ }
+ else:
+ # SQLite has special handling for bare columns when using
+ # MIN/MAX with a `GROUP BY` clause where it picks the value from
+ # a row that matches the MIN/MAX.
+ sql = """
+ SELECT user_id, keytype, keydata, MAX(stream_id)
+ FROM e2e_cross_signing_keys
+ WHERE %(clause)s
+ GROUP BY user_id, keytype
+ """ % {
+ "clause": clause
+ }
+
txn.execute(sql, params)
rows = self.db_pool.cursor_to_dict(txn)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 186f064036..3216b3f3c8 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -43,7 +43,6 @@ from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter, sorted_topologically
@@ -100,14 +99,6 @@ class PersistEventsStore:
self._clock = hs.get_clock()
self._instance_name = hs.get_instance_name()
- def get_chain_id_txn(txn):
- txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
- return txn.fetchone()[0]
-
- self._event_chain_id_gen = build_sequence_generator(
- db.engine, get_chain_id_txn, "event_auth_chain_id"
- )
-
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
@@ -466,9 +457,6 @@ class PersistEventsStore:
if not state_events:
return
- # Map from event ID to chain ID/sequence number.
- chain_map = {} # type: Dict[str, Tuple[int, int]]
-
# We need to know the type/state_key and auth events of the events we're
# calculating chain IDs for. We don't rely on having the full Event
# instances as we'll potentially be pulling more events from the DB and
@@ -479,19 +467,44 @@ class PersistEventsStore:
event_to_auth_chain = {
e.event_id: e.auth_event_ids() for e in state_events.values()
}
+ event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
+
+ self._add_chain_cover_index(
+ txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
+ )
+
+ @staticmethod
+ def _add_chain_cover_index(
+ txn,
+ db_pool: DatabasePool,
+ event_to_room_id: Dict[str, str],
+ event_to_types: Dict[str, Tuple[str, str]],
+ event_to_auth_chain: Dict[str, List[str]],
+ ) -> None:
+ """Calculate the chain cover index for the given events.
+
+ Args:
+ event_to_room_id: Event ID to the room ID of the event
+ event_to_types: Event ID to type and state_key of the event
+ event_to_auth_chain: Event ID to list of auth event IDs of the
+ event (events with no auth events can be excluded).
+ """
+
+ # Map from event ID to chain ID/sequence number.
+ chain_map = {} # type: Dict[str, Tuple[int, int]]
# Set of event IDs to calculate chain ID/seq numbers for.
- events_to_calc_chain_id_for = set(state_events)
+ events_to_calc_chain_id_for = set(event_to_room_id)
# We check if there are any events that need to be handled in the rooms
# we're looking at. These should just be out of band memberships, where
# we didn't have the auth chain when we first persisted.
- rows = self.db_pool.simple_select_many_txn(
+ rows = db_pool.simple_select_many_txn(
txn,
table="event_auth_chain_to_calculate",
keyvalues={},
column="room_id",
- iterable={e.room_id for e in state_events.values()},
+ iterable=set(event_to_room_id.values()),
retcols=("event_id", "type", "state_key"),
)
for row in rows:
@@ -502,7 +515,7 @@ class PersistEventsStore:
# (We could pull out the auth events for all rows at once using
# simple_select_many, but this case happens rarely and almost always
# with a single row.)
- auth_events = self.db_pool.simple_select_onecol_txn(
+ auth_events = db_pool.simple_select_onecol_txn(
txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id",
)
@@ -551,9 +564,7 @@ class PersistEventsStore:
events_to_calc_chain_id_for.add(auth_id)
- event_to_auth_chain[
- auth_id
- ] = self.db_pool.simple_select_onecol_txn(
+ event_to_auth_chain[auth_id] = db_pool.simple_select_onecol_txn(
txn,
"event_auth",
keyvalues={"event_id": auth_id},
@@ -582,16 +593,17 @@ class PersistEventsStore:
# the list of events to calculate chain IDs for next time
# around. (Otherwise we will have already added it to the
# table).
- event = state_events.get(event_id)
- if event:
- self.db_pool.simple_insert_txn(
+ room_id = event_to_room_id.get(event_id)
+ if room_id:
+ e_type, state_key = event_to_types[event_id]
+ db_pool.simple_insert_txn(
txn,
table="event_auth_chain_to_calculate",
values={
- "event_id": event.event_id,
- "room_id": event.room_id,
- "type": event.type,
- "state_key": event.state_key,
+ "event_id": event_id,
+ "room_id": room_id,
+ "type": e_type,
+ "state_key": state_key,
},
)
@@ -617,7 +629,7 @@ class PersistEventsStore:
events_to_calc_chain_id_for, event_to_auth_chain
):
existing_chain_id = None
- for auth_id in event_to_auth_chain[event_id]:
+ for auth_id in event_to_auth_chain.get(event_id, []):
if event_to_types.get(event_id) == event_to_types.get(auth_id):
existing_chain_id = chain_map[auth_id]
break
@@ -629,7 +641,7 @@ class PersistEventsStore:
proposed_new_id = existing_chain_id[0]
proposed_new_seq = existing_chain_id[1] + 1
if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
- already_allocated = self.db_pool.simple_select_one_onecol_txn(
+ already_allocated = db_pool.simple_select_one_onecol_txn(
txn,
table="event_auth_chains",
keyvalues={
@@ -650,14 +662,14 @@ class PersistEventsStore:
)
if not new_chain_tuple:
- new_chain_tuple = (self._event_chain_id_gen.get_next_id_txn(txn), 1)
+ new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1)
chains_tuples_allocated.add(new_chain_tuple)
chain_map[event_id] = new_chain_tuple
new_chain_tuples[event_id] = new_chain_tuple
- self.db_pool.simple_insert_many_txn(
+ db_pool.simple_insert_many_txn(
txn,
table="event_auth_chains",
values=[
@@ -666,7 +678,7 @@ class PersistEventsStore:
],
)
- self.db_pool.simple_delete_many_txn(
+ db_pool.simple_delete_many_txn(
txn,
table="event_auth_chain_to_calculate",
keyvalues={},
@@ -699,7 +711,7 @@ class PersistEventsStore:
# Step 1, fetch all existing links from all the chains we've seen
# referenced.
chain_links = _LinkMap()
- rows = self.db_pool.simple_select_many_txn(
+ rows = db_pool.simple_select_many_txn(
txn,
table="event_auth_chain_links",
column="origin_chain_id",
@@ -730,11 +742,11 @@ class PersistEventsStore:
# auth events (A, B) to check if B is reachable from A.
reduction = {
a_id
- for a_id in event_to_auth_chain[event_id]
+ for a_id in event_to_auth_chain.get(event_id, [])
if chain_map[a_id][0] != chain_id
}
for start_auth_id, end_auth_id in itertools.permutations(
- event_to_auth_chain[event_id], r=2,
+ event_to_auth_chain.get(event_id, []), r=2,
):
if chain_links.exists_path_from(
chain_map[start_auth_id], chain_map[end_auth_id]
@@ -763,7 +775,7 @@ class PersistEventsStore:
(chain_id, sequence_number), (target_id, target_seq)
)
- self.db_pool.simple_insert_many_txn(
+ db_pool.simple_insert_many_txn(
txn,
table="event_auth_chain_links",
values=[
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 7e4b175d08..7128dc1742 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -14,13 +14,14 @@
# limitations under the License.
import logging
-from typing import List, Tuple
+from typing import Dict, List, Optional, Tuple
from synapse.api.constants import EventContentFields
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
+from synapse.storage.databases.main.events import PersistEventsStore
from synapse.storage.types import Cursor
from synapse.types import JsonDict
@@ -108,6 +109,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
"rejected_events_metadata", self._rejected_events_metadata,
)
+ self.db_pool.updates.register_background_update_handler(
+ "chain_cover", self._chain_cover_index,
+ )
+
async def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
@@ -706,3 +711,191 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
return len(results)
+
+ async def _chain_cover_index(self, progress: dict, batch_size: int) -> int:
+ """A background updates that iterates over all rooms and generates the
+ chain cover index for them.
+ """
+
+ current_room_id = progress.get("current_room_id", "")
+
+ # Have we finished processing the current room.
+ finished = progress.get("finished", True)
+
+ # Where we've processed up to in the room, defaults to the start of the
+ # room.
+ last_depth = progress.get("last_depth", -1)
+ last_stream = progress.get("last_stream", -1)
+
+ # Have we set the `has_auth_chain_index` for the room yet.
+ has_set_room_has_chain_index = progress.get(
+ "has_set_room_has_chain_index", False
+ )
+
+ if finished:
+ # If we've finished with the previous room (or its our first
+ # iteration) we move on to the next room.
+
+ def _get_next_room(txn: Cursor) -> Optional[str]:
+ sql = """
+ SELECT room_id FROM rooms
+ WHERE room_id > ?
+ AND (
+ NOT has_auth_chain_index
+ OR has_auth_chain_index IS NULL
+ )
+ ORDER BY room_id
+ LIMIT 1
+ """
+ txn.execute(sql, (current_room_id,))
+ row = txn.fetchone()
+ if row:
+ return row[0]
+
+ return None
+
+ current_room_id = await self.db_pool.runInteraction(
+ "_chain_cover_index", _get_next_room
+ )
+ if not current_room_id:
+ await self.db_pool.updates._end_background_update("chain_cover")
+ return 0
+
+ logger.debug("Adding chain cover to %s", current_room_id)
+
+ def _calculate_auth_chain(
+ txn: Cursor, last_depth: int, last_stream: int
+ ) -> Tuple[int, int, int]:
+ # Get the next set of events in the room (that we haven't already
+ # computed chain cover for). We do this in topological order.
+
+ # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
+ # comparison, but that is not supported on older SQLite versions
+ tuple_clause, tuple_args = make_tuple_comparison_clause(
+ self.database_engine,
+ [
+ ("topological_ordering", last_depth),
+ ("stream_ordering", last_stream),
+ ],
+ )
+
+ sql = """
+ SELECT
+ event_id, state_events.type, state_events.state_key,
+ topological_ordering, stream_ordering
+ FROM events
+ INNER JOIN state_events USING (event_id)
+ LEFT JOIN event_auth_chains USING (event_id)
+ LEFT JOIN event_auth_chain_to_calculate USING (event_id)
+ WHERE events.room_id = ?
+ AND event_auth_chains.event_id IS NULL
+ AND event_auth_chain_to_calculate.event_id IS NULL
+ AND %(tuple_cmp)s
+ ORDER BY topological_ordering, stream_ordering
+ LIMIT ?
+ """ % {
+ "tuple_cmp": tuple_clause,
+ }
+
+ args = [current_room_id]
+ args.extend(tuple_args)
+ args.append(batch_size)
+
+ txn.execute(sql, args)
+ rows = txn.fetchall()
+
+ # Put the results in the necessary format for
+ # `_add_chain_cover_index`
+ event_to_room_id = {row[0]: current_room_id for row in rows}
+ event_to_types = {row[0]: (row[1], row[2]) for row in rows}
+
+ new_last_depth = rows[-1][3] if rows else last_depth # type: int
+ new_last_stream = rows[-1][4] if rows else last_stream # type: int
+
+ count = len(rows)
+
+ # We also need to fetch the auth events for them.
+ auth_events = self.db_pool.simple_select_many_txn(
+ txn,
+ table="event_auth",
+ column="event_id",
+ iterable=event_to_room_id,
+ keyvalues={},
+ retcols=("event_id", "auth_id"),
+ )
+
+ event_to_auth_chain = {} # type: Dict[str, List[str]]
+ for row in auth_events:
+ event_to_auth_chain.setdefault(row["event_id"], []).append(
+ row["auth_id"]
+ )
+
+ # Calculate and persist the chain cover index for this set of events.
+ #
+ # Annoyingly we need to gut wrench into the persit event store so that
+ # we can reuse the function to calculate the chain cover for rooms.
+ PersistEventsStore._add_chain_cover_index(
+ txn,
+ self.db_pool,
+ event_to_room_id,
+ event_to_types,
+ event_to_auth_chain,
+ )
+
+ return new_last_depth, new_last_stream, count
+
+ last_depth, last_stream, count = await self.db_pool.runInteraction(
+ "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
+ )
+
+ total_rows_processed = count
+
+ if count < batch_size and not has_set_room_has_chain_index:
+ # If we've done all the events in the room we flip the
+ # `has_auth_chain_index` in the DB. Note that its possible for
+ # further events to be persisted between the above and setting the
+ # flag without having the chain cover calculated for them. This is
+ # fine as a) the code gracefully handles these cases and b) we'll
+ # calculate them below.
+
+ await self.db_pool.simple_update(
+ table="rooms",
+ keyvalues={"room_id": current_room_id},
+ updatevalues={"has_auth_chain_index": True},
+ desc="_chain_cover_index",
+ )
+ has_set_room_has_chain_index = True
+
+ # Handle any events that might have raced with us flipping the
+ # bit above.
+ last_depth, last_stream, count = await self.db_pool.runInteraction(
+ "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
+ )
+
+ total_rows_processed += count
+
+ # Note that at this point its technically possible that more events
+ # than our `batch_size` have been persisted without their chain
+ # cover, so we need to continue processing this room if the last
+ # count returned was equal to the `batch_size`.
+
+ if count < batch_size:
+ # We've finished calculating the index for this room, move on to the
+ # next room.
+ await self.db_pool.updates._background_update_progress(
+ "chain_cover", {"current_room_id": current_room_id, "finished": True},
+ )
+ else:
+ # We still have outstanding events to calculate the index for.
+ await self.db_pool.updates._background_update_progress(
+ "chain_cover",
+ {
+ "current_room_id": current_room_id,
+ "last_depth": last_depth,
+ "last_stream": last_stream,
+ "has_auth_chain_index": has_set_room_has_chain_index,
+ "finished": False,
+ },
+ )
+
+ return total_rows_processed
diff --git a/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql b/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql
new file mode 100644
index 0000000000..fe3dca71dd
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+ (5906, 'chain_cover', '{}', 'rejected_events_metadata');
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 59207cadd4..cea595ff19 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -464,19 +464,17 @@ class TransactionStore(TransactionWorkerStore):
txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str]
) -> List[str]:
q = """
- SELECT destination FROM destinations
- WHERE destination IN (
- SELECT destination FROM destination_rooms
- WHERE destination_rooms.stream_ordering >
- destinations.last_successful_stream_ordering
- )
- AND destination > ?
- AND (
- retry_last_ts IS NULL OR
- retry_last_ts + retry_interval < ?
- )
- ORDER BY destination
- LIMIT 25
+ SELECT DISTINCT destination FROM destinations
+ INNER JOIN destination_rooms USING (destination)
+ WHERE
+ stream_ordering > last_successful_stream_ordering
+ AND destination > ?
+ AND (
+ retry_last_ts IS NULL OR
+ retry_last_ts + retry_interval < ?
+ )
+ ORDER BY destination
+ LIMIT 25
"""
txn.execute(
q,
|