diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 9c1e506da6..e2e6eb479f 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -127,6 +127,8 @@ class PersistEventsStore:
self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
+ self._msc3970_enabled = hs.config.experimental.msc3970_enabled
+
@trace
async def _persist_events_and_state_updates(
self,
@@ -977,23 +979,43 @@ class PersistEventsStore:
) -> None:
"""Persist the mapping from transaction IDs to event IDs (if defined)."""
- to_insert = []
+ inserted_ts = self._clock.time_msec()
+ to_insert_token_id: List[Tuple[str, str, str, int, str, int]] = []
+ to_insert_device_id: List[Tuple[str, str, str, str, str, int]] = []
for event, _ in events_and_contexts:
- token_id = getattr(event.internal_metadata, "token_id", None)
txn_id = getattr(event.internal_metadata, "txn_id", None)
- if token_id and txn_id:
- to_insert.append(
- (
- event.event_id,
- event.room_id,
- event.sender,
- token_id,
- txn_id,
- self._clock.time_msec(),
+ token_id = getattr(event.internal_metadata, "token_id", None)
+ device_id = getattr(event.internal_metadata, "device_id", None)
+
+ if txn_id is not None:
+ if token_id is not None:
+ to_insert_token_id.append(
+ (
+ event.event_id,
+ event.room_id,
+ event.sender,
+ token_id,
+ txn_id,
+ inserted_ts,
+ )
+ )
+
+ if device_id is not None:
+ to_insert_device_id.append(
+ (
+ event.event_id,
+ event.room_id,
+ event.sender,
+ device_id,
+ txn_id,
+ inserted_ts,
+ )
)
- )
- if to_insert:
+ # Pre-MSC3970, we rely on the access_token_id to scope the txn_id for events.
+ # Since this is an experimental flag, we still store the mapping even if the
+ # flag is disabled.
+ if to_insert_token_id:
self.db_pool.simple_insert_many_txn(
txn,
table="event_txn_id",
@@ -1005,7 +1027,25 @@ class PersistEventsStore:
"txn_id",
"inserted_ts",
),
- values=to_insert,
+ values=to_insert_token_id,
+ )
+
+ # With MSC3970, we rely on the device_id instead to scope the txn_id for events.
+ # We're only inserting if MSC3970 is *enabled*, because else the pre-MSC3970
+ # behaviour would allow for a UNIQUE constraint violation on this table
+ if to_insert_device_id and self._msc3970_enabled:
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="event_txn_id_device_id",
+ keys=(
+ "event_id",
+ "room_id",
+ "user_id",
+ "device_id",
+ "txn_id",
+ "inserted_ts",
+ ),
+ values=to_insert_device_id,
)
async def update_current_state(
@@ -1127,11 +1167,15 @@ class PersistEventsStore:
# been inserted into room_memberships.
txn.execute_batch(
"""INSERT INTO current_state_events
- (room_id, type, state_key, event_id, membership)
- VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
+ (room_id, type, state_key, event_id, membership, event_stream_ordering)
+ VALUES (
+ ?, ?, ?, ?,
+ (SELECT membership FROM room_memberships WHERE event_id = ?),
+ (SELECT stream_ordering FROM events WHERE event_id = ?)
+ )
""",
[
- (room_id, key[0], key[1], ev_id, ev_id)
+ (room_id, key[0], key[1], ev_id, ev_id, ev_id)
for key, ev_id in to_insert.items()
],
)
@@ -1158,11 +1202,15 @@ class PersistEventsStore:
if to_insert:
txn.execute_batch(
"""INSERT INTO local_current_membership
- (room_id, user_id, event_id, membership)
- VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
+ (room_id, user_id, event_id, membership, event_stream_ordering)
+ VALUES (
+ ?, ?, ?,
+ (SELECT membership FROM room_memberships WHERE event_id = ?),
+ (SELECT stream_ordering FROM events WHERE event_id = ?)
+ )
""",
[
- (room_id, key[1], ev_id, ev_id)
+ (room_id, key[1], ev_id, ev_id, ev_id)
for key, ev_id in to_insert.items()
if key[0] == EventTypes.Member and self.is_mine_id(key[1])
],
@@ -1768,6 +1816,7 @@ class PersistEventsStore:
table="room_memberships",
keys=(
"event_id",
+ "event_stream_ordering",
"user_id",
"sender",
"room_id",
@@ -1778,6 +1827,7 @@ class PersistEventsStore:
values=[
(
event.event_id,
+ event.internal_metadata.stream_ordering,
event.state_key,
event.user_id,
event.room_id,
@@ -1810,6 +1860,7 @@ class PersistEventsStore:
keyvalues={"room_id": event.room_id, "user_id": event.state_key},
values={
"event_id": event.event_id,
+ "event_stream_ordering": event.internal_metadata.stream_ordering,
"membership": event.membership,
},
)
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 0cf46626d2..0ff3fc7369 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -2022,7 +2022,7 @@ class EventsWorkerStore(SQLBaseStore):
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
- async def get_event_id_from_transaction_id(
+ async def get_event_id_from_transaction_id_and_token_id(
self, room_id: str, user_id: str, token_id: int, txn_id: str
) -> Optional[str]:
"""Look up if we have already persisted an event for the transaction ID,
@@ -2038,7 +2038,26 @@ class EventsWorkerStore(SQLBaseStore):
},
retcol="event_id",
allow_none=True,
- desc="get_event_id_from_transaction_id",
+ desc="get_event_id_from_transaction_id_and_token_id",
+ )
+
+ async def get_event_id_from_transaction_id_and_device_id(
+ self, room_id: str, user_id: str, device_id: str, txn_id: str
+ ) -> Optional[str]:
+ """Look up if we have already persisted an event for the transaction ID,
+ returning the event ID if so.
+ """
+ return await self.db_pool.simple_select_one_onecol(
+ table="event_txn_id_device_id",
+ keyvalues={
+ "room_id": room_id,
+ "user_id": user_id,
+ "device_id": device_id,
+ "txn_id": txn_id,
+ },
+ retcol="event_id",
+ allow_none=True,
+ desc="get_event_id_from_transaction_id_and_device_id",
)
async def get_already_persisted_events(
@@ -2068,7 +2087,7 @@ class EventsWorkerStore(SQLBaseStore):
# Check if this is a duplicate of an event we've already
# persisted.
- existing = await self.get_event_id_from_transaction_id(
+ existing = await self.get_event_id_from_transaction_id_and_token_id(
event.room_id, event.sender, token_id, txn_id
)
if existing:
@@ -2084,11 +2103,17 @@ class EventsWorkerStore(SQLBaseStore):
"""Cleans out transaction id mappings older than 24hrs."""
def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
+ one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
sql = """
DELETE FROM event_txn_id
WHERE inserted_ts < ?
"""
- one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
+ txn.execute(sql, (one_day_ago,))
+
+ sql = """
+ DELETE FROM event_txn_id_device_id
+ WHERE inserted_ts < ?
+ """
txn.execute(sql, (one_day_ago,))
return await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 89c37a4eb5..1666e3c43b 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -14,10 +14,12 @@
# limitations under the License.
import itertools
+import json
import logging
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
from signedjson.key import decode_verify_key_bytes
+from unpaddedbase64 import decode_base64
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
@@ -36,15 +38,16 @@ class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys"""
@cached()
- def _get_server_verify_key(
+ def _get_server_signature_key(
self, server_name_and_key_id: Tuple[str, str]
) -> FetchKeyResult:
raise NotImplementedError()
@cachedList(
- cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
+ cached_method_name="_get_server_signature_key",
+ list_name="server_name_and_key_ids",
)
- async def get_server_verify_keys(
+ async def get_server_signature_keys(
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
) -> Dict[Tuple[str, str], FetchKeyResult]:
"""
@@ -62,10 +65,12 @@ class KeyStore(SQLBaseStore):
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
# batch_iter always returns tuples so it's safe to do len(batch)
- sql = (
- "SELECT server_name, key_id, verify_key, ts_valid_until_ms "
- "FROM server_signature_keys WHERE 1=0"
- ) + " OR (server_name=? AND key_id=?)" * len(batch)
+ sql = """
+ SELECT server_name, key_id, verify_key, ts_valid_until_ms
+ FROM server_signature_keys WHERE 1=0
+ """ + " OR (server_name=? AND key_id=?)" * len(
+ batch
+ )
txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
@@ -89,9 +94,9 @@ class KeyStore(SQLBaseStore):
_get_keys(txn, batch)
return keys
- return await self.db_pool.runInteraction("get_server_verify_keys", _txn)
+ return await self.db_pool.runInteraction("get_server_signature_keys", _txn)
- async def store_server_verify_keys(
+ async def store_server_signature_keys(
self,
from_server: str,
ts_added_ms: int,
@@ -119,7 +124,7 @@ class KeyStore(SQLBaseStore):
)
)
# invalidate takes a tuple corresponding to the params of
- # _get_server_verify_key. _get_server_verify_key only takes one
+ # _get_server_signature_key. _get_server_signature_key only takes one
# param, which is itself the 2-tuple (server_name, key_id).
invalidations.append((server_name, key_id))
@@ -134,10 +139,10 @@ class KeyStore(SQLBaseStore):
"verify_key",
),
value_values=value_values,
- desc="store_server_verify_keys",
+ desc="store_server_signature_keys",
)
- invalidate = self._get_server_verify_key.invalidate
+ invalidate = self._get_server_signature_key.invalidate
for i in invalidations:
invalidate((i,))
@@ -180,7 +185,75 @@ class KeyStore(SQLBaseStore):
desc="store_server_keys_json",
)
+ # invalidate takes a tuple corresponding to the params of
+ # _get_server_keys_json. _get_server_keys_json only takes one
+ # param, which is itself the 2-tuple (server_name, key_id).
+ self._get_server_keys_json.invalidate((((server_name, key_id),)))
+
+ @cached()
+ def _get_server_keys_json(
+ self, server_name_and_key_id: Tuple[str, str]
+ ) -> FetchKeyResult:
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="_get_server_keys_json", list_name="server_name_and_key_ids"
+ )
async def get_server_keys_json(
+ self, server_name_and_key_ids: Iterable[Tuple[str, str]]
+ ) -> Dict[Tuple[str, str], FetchKeyResult]:
+ """
+ Args:
+ server_name_and_key_ids:
+ iterable of (server_name, key-id) tuples to fetch keys for
+
+ Returns:
+ A map from (server_name, key_id) -> FetchKeyResult, or None if the
+ key is unknown
+ """
+ keys = {}
+
+ def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
+ """Processes a batch of keys to fetch, and adds the result to `keys`."""
+
+ # batch_iter always returns tuples so it's safe to do len(batch)
+ sql = """
+ SELECT server_name, key_id, key_json, ts_valid_until_ms
+ FROM server_keys_json WHERE 1=0
+ """ + " OR (server_name=? AND key_id=?)" * len(
+ batch
+ )
+
+ txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
+
+ for server_name, key_id, key_json_bytes, ts_valid_until_ms in txn:
+ if ts_valid_until_ms is None:
+ # Old keys may be stored with a ts_valid_until_ms of null,
+ # in which case we treat this as if it was set to `0`, i.e.
+ # it won't match key requests that define a minimum
+ # `ts_valid_until_ms`.
+ ts_valid_until_ms = 0
+
+ # The entire signed JSON response is stored in server_keys_json,
+ # fetch out the bits needed.
+ key_json = json.loads(bytes(key_json_bytes))
+ key_base64 = key_json["verify_keys"][key_id]["key"]
+
+ keys[(server_name, key_id)] = FetchKeyResult(
+ verify_key=decode_verify_key_bytes(
+ key_id, decode_base64(key_base64)
+ ),
+ valid_until_ts=ts_valid_until_ms,
+ )
+
+ def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
+ for batch in batch_iter(server_name_and_key_ids, 50):
+ _get_keys(txn, batch)
+ return keys
+
+ return await self.db_pool.runInteraction("get_server_keys_json", _txn)
+
+ async def get_server_keys_json_for_remote(
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
"""Retrieve the key json for a list of server_keys and key ids.
@@ -188,8 +261,10 @@ class KeyStore(SQLBaseStore):
that server, key_id, and source triplet entry will be an empty list.
The JSON is returned as a byte array so that it can be efficiently
used in an HTTP response.
+
Args:
server_keys: List of (server_name, key_id, source) triplets.
+
Returns:
A mapping from (server_name, key_id, source) triplets to a list of dicts
"""
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 7a7c0d9c75..efbd3e75d9 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -428,14 +428,16 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"partial_state_events",
"partial_state_rooms_servers",
"partial_state_rooms",
+ # Note: the _membership(s) tables have foreign keys to the `events` table
+ # so must be deleted first.
+ "local_current_membership",
+ "room_memberships",
"events",
"federation_inbound_events_staging",
- "local_current_membership",
"receipts_graph",
"receipts_linearized",
"room_aliases",
"room_depth",
- "room_memberships",
"room_stats_state",
"room_stats_current",
"room_stats_earliest_token",
|