diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 9af9f4f18e..c38b8a9e5a 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -650,9 +650,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
txn, self.get_account_data_for_room, (user_id,)
)
self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,))
- self._invalidate_cache_and_stream(
- txn, self.get_push_rules_enabled_for_user, (user_id,)
- )
# This user might be contained in the ignored_by cache for other users,
# so we have to invalidate it all.
self._invalidate_all_cache_and_stream(txn, self.ignored_by)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 7ceb7a202b..dfca34550d 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -53,6 +53,7 @@ from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -668,6 +669,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
...
@trace
+ @cancellable
async def get_user_devices_from_cache(
self, query_list: List[Tuple[str, Optional[str]]]
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
@@ -743,6 +745,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
return self._device_list_stream_cache.get_all_entities_changed(from_key)
+ @cancellable
async def get_users_whose_devices_changed(
self,
from_key: int,
@@ -1221,6 +1224,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
desc="get_min_device_lists_changes_in_room",
)
+ @cancellable
async def get_device_list_changes_in_rooms(
self, room_ids: Collection[str], from_id: int
) -> Optional[Set[str]]:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 2df8101390..210cfab073 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -50,6 +50,7 @@ from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -135,6 +136,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return now_stream_id, []
@trace
+ @cancellable
async def get_e2e_device_keys_for_cs_api(
self, query_list: List[Tuple[str, Optional[str]]]
) -> Dict[str, Dict[str, JsonDict]]:
@@ -197,6 +199,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
...
@trace
+ @cancellable
async def get_e2e_device_keys_and_signatures(
self,
query_list: Collection[Tuple[str, Optional[str]]],
@@ -887,6 +890,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return keys
+ @cancellable
async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None
) -> Dict[str, Optional[Dict[str, JsonDict]]]:
@@ -902,7 +906,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
keys were not found, either their user ID will not be in the dict,
or their user ID will map to None.
"""
-
result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
if from_user_id:
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 41b015dba1..0669d54822 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -48,6 +48,7 @@ from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
+from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -976,6 +977,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return int(min_depth) if min_depth is not None else None
+ @cancellable
async def get_forward_extremities_for_room_at_stream_ordering(
self, room_id: str, stream_ordering: int
) -> List[str]:
@@ -1292,6 +1294,51 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return event_id_results
+ @trace
+ async def record_event_failed_pull_attempt(
+ self, room_id: str, event_id: str, cause: str
+ ) -> None:
+ """
+ Record when we fail to pull an event over federation.
+
+ This information allows us to be more intelligent when we decide to
+ retry (we don't need to fail over and over) and we can process that
+ event in the background so we don't block on it each time.
+
+ Args:
+ room_id: The room where the event failed to pull from
+ event_id: The event that failed to be fetched or processed
+ cause: The error message or reason that we failed to pull the event
+ """
+ await self.db_pool.runInteraction(
+ "record_event_failed_pull_attempt",
+ self._record_event_failed_pull_attempt_upsert_txn,
+ room_id,
+ event_id,
+ cause,
+ db_autocommit=True, # Safe as it's a single upsert
+ )
+
+ def _record_event_failed_pull_attempt_upsert_txn(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ event_id: str,
+ cause: str,
+ ) -> None:
+ sql = """
+ INSERT INTO event_failed_pull_attempts (
+ room_id, event_id, num_attempts, last_attempt_ts, last_cause
+ )
+ VALUES (?, ?, ?, ?, ?)
+ ON CONFLICT (room_id, event_id) DO UPDATE SET
+ num_attempts=event_failed_pull_attempts.num_attempts + 1,
+ last_attempt_ts=EXCLUDED.last_attempt_ts,
+ last_cause=EXCLUDED.last_cause;
+ """
+
+ txn.execute(sql, (room_id, event_id, 1, self._clock.time_msec(), cause))
+
async def get_missing_events(
self,
room_id: str,
@@ -1606,7 +1653,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
logger.info("Invalid prev_events for %s", event_id)
continue
- if room_version.event_format == EventFormatVersions.V1:
+ if room_version.event_format == EventFormatVersions.ROOM_V1_V2:
for prev_event_tuple in prev_events:
if (
not isinstance(prev_event_tuple, list)
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index eabf9c9739..6b8668d2dc 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -98,6 +98,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.stream import StreamWorkerStore
+from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -232,6 +233,104 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
replaces_index="event_push_summary_user_rm",
)
+ self.db_pool.updates.register_background_index_update(
+ "event_push_summary_unique_index2",
+ index_name="event_push_summary_unique_index2",
+ table="event_push_summary",
+ columns=["user_id", "room_id", "thread_id"],
+ unique=True,
+ )
+
+ self.db_pool.updates.register_background_update_handler(
+ "event_push_backfill_thread_id",
+ self._background_backfill_thread_id,
+ )
+
+ async def _background_backfill_thread_id(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """
+ Fill in the thread_id field for event_push_actions and event_push_summary.
+
+ This is preparatory so that it can be made non-nullable in the future.
+
+ Because all current (null) data is done in an unthreaded manner this
+ simply assumes it is on the "main" timeline. Since event_push_actions
+ are periodically cleared it is not possible to correctly re-calculate
+ the thread_id.
+ """
+ event_push_actions_done = progress.get("event_push_actions_done", False)
+
+ def add_thread_id_txn(
+ txn: LoggingTransaction, table_name: str, start_stream_ordering: int
+ ) -> int:
+ sql = f"""
+ SELECT stream_ordering
+ FROM {table_name}
+ WHERE
+ thread_id IS NULL
+ AND stream_ordering > ?
+ ORDER BY stream_ordering
+ LIMIT ?
+ """
+ txn.execute(sql, (start_stream_ordering, batch_size))
+
+ # No more rows to process.
+ rows = txn.fetchall()
+ if not rows:
+ progress[f"{table_name}_done"] = True
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "event_push_backfill_thread_id", progress
+ )
+ return 0
+
+ # Update the thread ID for any of those rows.
+ max_stream_ordering = rows[-1][0]
+
+ sql = f"""
+ UPDATE {table_name}
+ SET thread_id = 'main'
+ WHERE stream_ordering <= ? AND thread_id IS NULL
+ """
+ txn.execute(sql, (max_stream_ordering,))
+
+ # Update progress.
+ processed_rows = txn.rowcount
+ progress[f"max_{table_name}_stream_ordering"] = max_stream_ordering
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "event_push_backfill_thread_id", progress
+ )
+
+ return processed_rows
+
+ # First update the event_push_actions table, then the event_push_summary table.
+ #
+ # Note that the event_push_actions_staging table is ignored since it is
+ # assumed that items in that table will only exist for a short period of
+ # time.
+ if not event_push_actions_done:
+ result = await self.db_pool.runInteraction(
+ "event_push_backfill_thread_id",
+ add_thread_id_txn,
+ "event_push_actions",
+ progress.get("max_event_push_actions_stream_ordering", 0),
+ )
+ else:
+ result = await self.db_pool.runInteraction(
+ "event_push_backfill_thread_id",
+ add_thread_id_txn,
+ "event_push_summary",
+ progress.get("max_event_push_summary_stream_ordering", 0),
+ )
+
+ # Only done after the event_push_summary table is done.
+ if not result:
+ await self.db_pool.updates._end_background_update(
+ "event_push_backfill_thread_id"
+ )
+
+ return result
+
@cached(tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user(
self,
@@ -274,7 +373,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
receipt_types=(
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
- ReceiptTypes.UNSTABLE_READ_PRIVATE,
),
)
@@ -459,6 +557,31 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return await self.db_pool.runInteraction("get_push_action_users_in_range", f)
+ def _get_receipts_by_room_txn(
+ self, txn: LoggingTransaction, user_id: str
+ ) -> List[Tuple[str, int]]:
+ receipt_types_clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "receipt_type",
+ (
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ),
+ )
+
+ sql = f"""
+ SELECT room_id, MAX(stream_ordering)
+ FROM receipts_linearized
+ INNER JOIN events USING (room_id, event_id)
+ WHERE {receipt_types_clause}
+ AND user_id = ?
+ GROUP BY room_id
+ """
+
+ args.extend((user_id,))
+ txn.execute(sql, args)
+ return cast(List[Tuple[str, int]], txn.fetchall())
+
async def get_unread_push_actions_for_user_in_range_for_http(
self,
user_id: str,
@@ -482,106 +605,45 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
The list will have between 0~limit entries.
"""
- # find rooms that have a read receipt in them and return the next
- # push actions
- def get_after_receipt(
- txn: LoggingTransaction,
- ) -> List[Tuple[str, str, int, str, bool]]:
- # find rooms that have a read receipt in them and return the next
- # push actions
-
- receipt_types_clause, args = make_in_list_sql_clause(
- self.database_engine,
- "receipt_type",
- (
- ReceiptTypes.READ,
- ReceiptTypes.READ_PRIVATE,
- ReceiptTypes.UNSTABLE_READ_PRIVATE,
- ),
- )
-
- sql = f"""
- SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
- ep.highlight
- FROM (
- SELECT room_id,
- MAX(stream_ordering) as stream_ordering
- FROM events
- INNER JOIN receipts_linearized USING (room_id, event_id)
- WHERE {receipt_types_clause} AND user_id = ?
- GROUP BY room_id
- ) AS rl,
- event_push_actions AS ep
- WHERE
- ep.room_id = rl.room_id
- AND ep.stream_ordering > rl.stream_ordering
- AND ep.user_id = ?
- AND ep.stream_ordering > ?
- AND ep.stream_ordering <= ?
- AND ep.notif = 1
- ORDER BY ep.stream_ordering ASC LIMIT ?
- """
- args.extend(
- (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
- )
- txn.execute(sql, args)
- return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
-
- after_read_receipt = await self.db_pool.runInteraction(
- "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
+ receipts_by_room = dict(
+ await self.db_pool.runInteraction(
+ "get_unread_push_actions_for_user_in_range_http_receipts",
+ self._get_receipts_by_room_txn,
+ user_id=user_id,
+ ),
)
- # There are rooms with push actions in them but you don't have a read receipt in
- # them e.g. rooms you've been invited to, so get push actions for rooms which do
- # not have read receipts in them too.
- def get_no_receipt(
+ def get_push_actions_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool]]:
- receipt_types_clause, args = make_in_list_sql_clause(
- self.database_engine,
- "receipt_type",
- (
- ReceiptTypes.READ,
- ReceiptTypes.READ_PRIVATE,
- ReceiptTypes.UNSTABLE_READ_PRIVATE,
- ),
- )
-
- sql = f"""
- SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
- ep.highlight
+ sql = """
+ SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight
FROM event_push_actions AS ep
- INNER JOIN events AS e USING (room_id, event_id)
WHERE
- ep.room_id NOT IN (
- SELECT room_id FROM receipts_linearized
- WHERE {receipt_types_clause} AND user_id = ?
- GROUP BY room_id
- )
- AND ep.user_id = ?
+ ep.user_id = ?
AND ep.stream_ordering > ?
AND ep.stream_ordering <= ?
AND ep.notif = 1
ORDER BY ep.stream_ordering ASC LIMIT ?
"""
- args.extend(
- (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
- )
- txn.execute(sql, args)
+ txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
- no_read_receipt = await self.db_pool.runInteraction(
- "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
+ push_actions = await self.db_pool.runInteraction(
+ "get_unread_push_actions_for_user_in_range_http", get_push_actions_txn
)
notifs = [
HttpPushAction(
- event_id=row[0],
- room_id=row[1],
- stream_ordering=row[2],
- actions=_deserialize_action(row[3], row[4]),
+ event_id=event_id,
+ room_id=room_id,
+ stream_ordering=stream_ordering,
+ actions=_deserialize_action(actions, highlight),
)
- for row in after_read_receipt + no_read_receipt
+ for event_id, room_id, stream_ordering, actions, highlight in push_actions
+ # Only include push actions with a stream ordering after any receipt, or without any
+ # receipt present (invited to but never read rooms).
+ if stream_ordering > receipts_by_room.get(room_id, 0)
]
# Now sort it so it's ordered correctly, since currently it will
@@ -617,106 +679,49 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
The list will have between 0~limit entries.
"""
- # find rooms that have a read receipt in them and return the most recent
- # push actions
- def get_after_receipt(
- txn: LoggingTransaction,
- ) -> List[Tuple[str, str, int, str, bool, int]]:
- receipt_types_clause, args = make_in_list_sql_clause(
- self.database_engine,
- "receipt_type",
- (
- ReceiptTypes.READ,
- ReceiptTypes.READ_PRIVATE,
- ReceiptTypes.UNSTABLE_READ_PRIVATE,
- ),
- )
-
- sql = f"""
- SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
- ep.highlight, e.received_ts
- FROM (
- SELECT room_id,
- MAX(stream_ordering) as stream_ordering
- FROM events
- INNER JOIN receipts_linearized USING (room_id, event_id)
- WHERE {receipt_types_clause} AND user_id = ?
- GROUP BY room_id
- ) AS rl,
- event_push_actions AS ep
- INNER JOIN events AS e USING (room_id, event_id)
- WHERE
- ep.room_id = rl.room_id
- AND ep.stream_ordering > rl.stream_ordering
- AND ep.user_id = ?
- AND ep.stream_ordering > ?
- AND ep.stream_ordering <= ?
- AND ep.notif = 1
- ORDER BY ep.stream_ordering DESC LIMIT ?
- """
- args.extend(
- (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
- )
- txn.execute(sql, args)
- return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
-
- after_read_receipt = await self.db_pool.runInteraction(
- "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
+ receipts_by_room = dict(
+ await self.db_pool.runInteraction(
+ "get_unread_push_actions_for_user_in_range_email_receipts",
+ self._get_receipts_by_room_txn,
+ user_id=user_id,
+ ),
)
- # There are rooms with push actions in them but you don't have a read receipt in
- # them e.g. rooms you've been invited to, so get push actions for rooms which do
- # not have read receipts in them too.
- def get_no_receipt(
+ def get_push_actions_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool, int]]:
- receipt_types_clause, args = make_in_list_sql_clause(
- self.database_engine,
- "receipt_type",
- (
- ReceiptTypes.READ,
- ReceiptTypes.READ_PRIVATE,
- ReceiptTypes.UNSTABLE_READ_PRIVATE,
- ),
- )
-
- sql = f"""
+ sql = """
SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
ep.highlight, e.received_ts
FROM event_push_actions AS ep
INNER JOIN events AS e USING (room_id, event_id)
WHERE
- ep.room_id NOT IN (
- SELECT room_id FROM receipts_linearized
- WHERE {receipt_types_clause} AND user_id = ?
- GROUP BY room_id
- )
- AND ep.user_id = ?
+ ep.user_id = ?
AND ep.stream_ordering > ?
AND ep.stream_ordering <= ?
AND ep.notif = 1
ORDER BY ep.stream_ordering DESC LIMIT ?
"""
- args.extend(
- (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
- )
- txn.execute(sql, args)
+ txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
- no_read_receipt = await self.db_pool.runInteraction(
- "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
+ push_actions = await self.db_pool.runInteraction(
+ "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn
)
# Make a list of dicts from the two sets of results.
notifs = [
EmailPushAction(
- event_id=row[0],
- room_id=row[1],
- stream_ordering=row[2],
- actions=_deserialize_action(row[3], row[4]),
- received_ts=row[5],
+ event_id=event_id,
+ room_id=room_id,
+ stream_ordering=stream_ordering,
+ actions=_deserialize_action(actions, highlight),
+ received_ts=received_ts,
)
- for row in after_read_receipt + no_read_receipt
+ for event_id, room_id, stream_ordering, actions, highlight, received_ts in push_actions
+ # Only include push actions with a stream ordering after any receipt, or without any
+ # receipt present (invited to but never read rooms).
+ if stream_ordering > receipts_by_room.get(room_id, 0)
]
# Now sort it so it's ordered correctly, since currently it will
@@ -764,6 +769,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
event_id: str,
user_id_actions: Dict[str, Collection[Union[Mapping, str]]],
count_as_unread: bool,
+ thread_id: str,
) -> None:
"""Add the push actions for the event to the push action staging area.
@@ -772,6 +778,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id_actions: A mapping of user_id to list of push actions, where
an action can either be a string or dict.
count_as_unread: Whether this event should increment unread counts.
+ thread_id: The thread this event is parent of, if applicable.
"""
if not user_id_actions:
return
@@ -780,7 +787,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# can be used to insert into the `event_push_actions_staging` table.
def _gen_entry(
user_id: str, actions: Collection[Union[Mapping, str]]
- ) -> Tuple[str, str, str, int, int, int]:
+ ) -> Tuple[str, str, str, int, int, int, str]:
is_highlight = 1 if _action_has_highlight(actions) else 0
notif = 1 if "notify" in actions else 0
return (
@@ -790,28 +797,25 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
notif, # notif column
is_highlight, # highlight column
int(count_as_unread), # unread column
+ thread_id, # thread_id column
)
- def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None:
- # We don't use simple_insert_many here to avoid the overhead
- # of generating lists of dicts.
-
- sql = """
- INSERT INTO event_push_actions_staging
- (event_id, user_id, actions, notif, highlight, unread)
- VALUES (?, ?, ?, ?, ?, ?)
- """
-
- txn.execute_batch(
- sql,
- (
- _gen_entry(user_id, actions)
- for user_id, actions in user_id_actions.items()
- ),
- )
-
- return await self.db_pool.runInteraction(
- "add_push_actions_to_staging", _add_push_actions_to_staging_txn
+ await self.db_pool.simple_insert_many(
+ "event_push_actions_staging",
+ keys=(
+ "event_id",
+ "user_id",
+ "actions",
+ "notif",
+ "highlight",
+ "unread",
+ "thread_id",
+ ),
+ values=[
+ _gen_entry(user_id, actions)
+ for user_id, actions in user_id_actions.items()
+ ],
+ desc="add_push_actions_to_staging",
)
async def remove_push_actions_from_staging(self, event_id: str) -> None:
@@ -1087,6 +1091,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
# Replace the previous summary with the new counts.
+ #
+ # TODO(threads): Upsert per-thread instead of setting them all to main.
self.db_pool.simple_upsert_txn(
txn,
table="event_push_summary",
@@ -1096,6 +1102,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
"unread_count": unread_count,
"stream_ordering": old_rotate_stream_ordering,
"last_receipt_stream_ordering": stream_ordering,
+ "thread_id": "main",
},
)
@@ -1244,17 +1251,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
logger.info("Rotating notifications, handling %d rows", len(summaries))
+ # TODO(threads): Update on a per-thread basis.
self.db_pool.simple_upsert_many_txn(
txn,
table="event_push_summary",
key_names=("user_id", "room_id"),
key_values=[(user_id, room_id) for user_id, room_id in summaries],
- value_names=("notif_count", "unread_count", "stream_ordering"),
+ value_names=("notif_count", "unread_count", "stream_ordering", "thread_id"),
value_values=[
(
summary.notif_count,
summary.unread_count,
summary.stream_ordering,
+ "main",
)
for summary in summaries.values()
],
@@ -1361,7 +1370,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
table="event_push_actions",
columns=["highlight", "stream_ordering"],
where_clause="highlight=0",
- psql_only=True,
)
async def get_push_actions_for_user(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1c3b804da0..5932668f2f 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -2192,9 +2192,9 @@ class PersistEventsStore:
sql = """
INSERT INTO event_push_actions (
room_id, event_id, user_id, actions, stream_ordering,
- topological_ordering, notif, highlight, unread
+ topological_ordering, notif, highlight, unread, thread_id
)
- SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
+ SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread, thread_id
FROM event_push_actions_staging
WHERE event_id = ?
"""
@@ -2435,17 +2435,31 @@ class PersistEventsStore:
"DELETE FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?"
)
+ backward_extremity_tuples_to_remove = [
+ (ev.event_id, ev.room_id)
+ for ev in events
+ if not ev.internal_metadata.is_outlier()
+ # If we encountered an event with no prev_events, then we might
+ # as well remove it now because it won't ever have anything else
+ # to backfill from.
+ or len(ev.prev_event_ids()) == 0
+ ]
txn.execute_batch(
query,
- [
- (ev.event_id, ev.room_id)
- for ev in events
- if not ev.internal_metadata.is_outlier()
- # If we encountered an event with no prev_events, then we might
- # as well remove it now because it won't ever have anything else
- # to backfill from.
- or len(ev.prev_event_ids()) == 0
- ],
+ backward_extremity_tuples_to_remove,
+ )
+
+ # Clear out the failed backfill attempts after we successfully pulled
+ # the event. Since we no longer need these events as backward
+ # extremities, it also means that they won't be backfilled from again so
+ # we no longer need to store the backfill attempts around it.
+ query = """
+ DELETE FROM event_failed_pull_attempts
+ WHERE event_id = ? and room_id = ?
+ """
+ txn.execute_batch(
+ query,
+ backward_extremity_tuples_to_remove,
)
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 90e6d82058..9f6b1fcef1 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -81,6 +81,7 @@ from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import AsyncLruCache
+from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -339,6 +340,7 @@ class EventsWorkerStore(SQLBaseStore):
) -> Optional[EventBase]:
...
+ @cancellable
async def get_event(
self,
event_id: str,
@@ -433,6 +435,7 @@ class EventsWorkerStore(SQLBaseStore):
@trace
@tag_args
+ @cancellable
async def get_events_as_list(
self,
event_ids: Collection[str],
@@ -584,6 +587,7 @@ class EventsWorkerStore(SQLBaseStore):
return events
+ @cancellable
async def _get_events_from_cache_or_db(
self, event_ids: Iterable[str], allow_rejected: bool = False
) -> Dict[str, EventCacheEntry]:
@@ -1156,7 +1160,7 @@ class EventsWorkerStore(SQLBaseStore):
if format_version is None:
# This means that we stored the event before we had the concept
# of a event format version, so it must be a V1 event.
- format_version = EventFormatVersions.V1
+ format_version = EventFormatVersions.ROOM_V1_V2
room_version_id = row.room_version_id
@@ -1186,10 +1190,10 @@ class EventsWorkerStore(SQLBaseStore):
#
# So, the following approximations should be adequate.
- if format_version == EventFormatVersions.V1:
+ if format_version == EventFormatVersions.ROOM_V1_V2:
# if it's event format v1 then it must be room v1 or v2
room_version = RoomVersions.V1
- elif format_version == EventFormatVersions.V2:
+ elif format_version == EventFormatVersions.ROOM_V3:
# if it's event format v2 then it must be room v3
room_version = RoomVersions.V3
else:
@@ -2111,7 +2115,14 @@ class EventsWorkerStore(SQLBaseStore):
AND room_id = ?
/* Make sure event is not rejected */
AND rejections.event_id IS NULL
- ORDER BY origin_server_ts %s
+ /**
+ * First sort by the message timestamp. If the message timestamps are the
+ * same, we want the message that logically comes "next" (before/after
+ * the given timestamp) based on the DAG and its topological order (`depth`).
+ * Finally, we can tie-break based on when it was received on the server
+ * (`stream_ordering`).
+ */
+ ORDER BY origin_server_ts %s, depth %s, stream_ordering %s
LIMIT 1;
"""
@@ -2130,7 +2141,8 @@ class EventsWorkerStore(SQLBaseStore):
order = "ASC"
txn.execute(
- sql_template % (comparison_operator, order), (timestamp, room_id)
+ sql_template % (comparison_operator, order, order, order),
+ (timestamp, room_id),
)
row = txn.fetchone()
if row:
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index 2d7633fbd5..7270ef09da 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -129,91 +129,48 @@ class LockStore(SQLBaseStore):
now = self._clock.time_msec()
token = random_string(6)
- if self.db_pool.engine.can_native_upsert:
-
- def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool:
- # We take out the lock if either a) there is no row for the lock
- # already, b) the existing row has timed out, or c) the row is
- # for this instance (which means the process got killed and
- # restarted)
- sql = """
- INSERT INTO worker_locks (lock_name, lock_key, instance_name, token, last_renewed_ts)
- VALUES (?, ?, ?, ?, ?)
- ON CONFLICT (lock_name, lock_key)
- DO UPDATE
- SET
- token = EXCLUDED.token,
- instance_name = EXCLUDED.instance_name,
- last_renewed_ts = EXCLUDED.last_renewed_ts
- WHERE
- worker_locks.last_renewed_ts < ?
- OR worker_locks.instance_name = EXCLUDED.instance_name
- """
- txn.execute(
- sql,
- (
- lock_name,
- lock_key,
- self._instance_name,
- token,
- now,
- now - _LOCK_TIMEOUT_MS,
- ),
- )
-
- # We only acquired the lock if we inserted or updated the table.
- return bool(txn.rowcount)
-
- did_lock = await self.db_pool.runInteraction(
- "try_acquire_lock",
- _try_acquire_lock_txn,
- # We can autocommit here as we're executing a single query, this
- # will avoid serialization errors.
- db_autocommit=True,
+ def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool:
+ # We take out the lock if either a) there is no row for the lock
+ # already, b) the existing row has timed out, or c) the row is
+ # for this instance (which means the process got killed and
+ # restarted)
+ sql = """
+ INSERT INTO worker_locks (lock_name, lock_key, instance_name, token, last_renewed_ts)
+ VALUES (?, ?, ?, ?, ?)
+ ON CONFLICT (lock_name, lock_key)
+ DO UPDATE
+ SET
+ token = EXCLUDED.token,
+ instance_name = EXCLUDED.instance_name,
+ last_renewed_ts = EXCLUDED.last_renewed_ts
+ WHERE
+ worker_locks.last_renewed_ts < ?
+ OR worker_locks.instance_name = EXCLUDED.instance_name
+ """
+ txn.execute(
+ sql,
+ (
+ lock_name,
+ lock_key,
+ self._instance_name,
+ token,
+ now,
+ now - _LOCK_TIMEOUT_MS,
+ ),
)
- if not did_lock:
- return None
-
- else:
- # If we're on an old SQLite we emulate the above logic by first
- # clearing out any existing stale locks and then upserting.
-
- def _try_acquire_lock_emulated_txn(txn: LoggingTransaction) -> bool:
- sql = """
- DELETE FROM worker_locks
- WHERE
- lock_name = ?
- AND lock_key = ?
- AND (last_renewed_ts < ? OR instance_name = ?)
- """
- txn.execute(
- sql,
- (lock_name, lock_key, now - _LOCK_TIMEOUT_MS, self._instance_name),
- )
-
- inserted = self.db_pool.simple_upsert_txn_emulated(
- txn,
- table="worker_locks",
- keyvalues={
- "lock_name": lock_name,
- "lock_key": lock_key,
- },
- values={},
- insertion_values={
- "token": token,
- "last_renewed_ts": self._clock.time_msec(),
- "instance_name": self._instance_name,
- },
- )
-
- return inserted
- did_lock = await self.db_pool.runInteraction(
- "try_acquire_lock_emulated", _try_acquire_lock_emulated_txn
- )
+ # We only acquired the lock if we inserted or updated the table.
+ return bool(txn.rowcount)
- if not did_lock:
- return None
+ did_lock = await self.db_pool.runInteraction(
+ "try_acquire_lock",
+ _try_acquire_lock_txn,
+ # We can autocommit here as we're executing a single query, this
+ # will avoid serialization errors.
+ db_autocommit=True,
+ )
+ if not did_lock:
+ return None
lock = Lock(
self._reactor,
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index f6822707e4..9213ce0b5a 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -419,6 +419,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"event_forward_extremities",
"event_push_actions",
"event_search",
+ "event_failed_pull_attempts",
"partial_state_events",
"events",
"federation_inbound_events_staging",
@@ -441,6 +442,10 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"e2e_room_keys",
"event_push_summary",
"pusher_throttle",
+ "insertion_events",
+ "insertion_event_extremities",
+ "insertion_event_edges",
+ "batch_events",
"room_account_data",
"room_tags",
# "rooms" happens last, to keep the foreign keys in the other tables
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 255620f996..ed17b2e70c 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -30,9 +30,8 @@ from typing import (
from synapse.api.errors import StoreError
from synapse.config.homeserver import ExperimentalConfig
-from synapse.push.baserules import FilteredPushRules, PushRule, compile_push_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
@@ -51,6 +50,7 @@ from synapse.storage.util.id_generators import (
IdGenerator,
StreamIdGenerator,
)
+from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@@ -72,18 +72,25 @@ def _load_rules(
"""
ruleslist = [
- PushRule(
+ PushRule.from_db(
rule_id=rawrule["rule_id"],
priority_class=rawrule["priority_class"],
- conditions=db_to_json(rawrule["conditions"]),
- actions=db_to_json(rawrule["actions"]),
+ conditions=rawrule["conditions"],
+ actions=rawrule["actions"],
)
for rawrule in rawrules
]
- push_rules = compile_push_rules(ruleslist)
+ push_rules = PushRules(
+ ruleslist,
+ )
- filtered_rules = FilteredPushRules(push_rules, enabled_map, experimental_config)
+ filtered_rules = FilteredPushRules(
+ push_rules,
+ enabled_map,
+ msc3786_enabled=experimental_config.msc3786_enabled,
+ msc3772_enabled=experimental_config.msc3772_enabled,
+ )
return filtered_rules
@@ -165,7 +172,6 @@ class PushRulesWorkerStore(
return _load_rules(rows, enabled_map, self.hs.config.experimental)
- @cached(max_entries=5000)
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
results = await self.db_pool.simple_select_list(
table="push_rules_enable",
@@ -229,9 +235,6 @@ class PushRulesWorkerStore(
return results
- @cachedList(
- cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids"
- )
async def bulk_get_push_rules_enabled(
self, user_ids: Collection[str]
) -> Dict[str, Dict[str, bool]]:
@@ -246,6 +249,7 @@ class PushRulesWorkerStore(
iterable=user_ids,
retcols=("user_name", "rule_id", "enabled"),
desc="bulk_get_push_rules_enabled",
+ batch_size=1000,
)
for row in rows:
enabled = bool(row["enabled"])
@@ -792,7 +796,6 @@ class PushRuleStore(PushRulesWorkerStore):
self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values)
txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
- txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
txn.call_after(
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
@@ -849,7 +852,7 @@ class PushRuleStore(PushRulesWorkerStore):
user_push_rules = await self.get_push_rules_for_user(user_id)
# Get rules relating to the old room and copy them to the new room
- for rule, enabled in user_push_rules:
+ for rule, enabled in user_push_rules.rules():
if not enabled:
continue
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 124c70ad37..ddb8e80b69 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -113,6 +113,24 @@ class ReceiptsWorkerStore(SQLBaseStore):
prefilled_cache=receipts_stream_prefill,
)
+ self.db_pool.updates.register_background_index_update(
+ "receipts_linearized_unique_index",
+ index_name="receipts_linearized_unique_index",
+ table="receipts_linearized",
+ columns=["room_id", "receipt_type", "user_id"],
+ where_clause="thread_id IS NULL",
+ unique=True,
+ )
+
+ self.db_pool.updates.register_background_index_update(
+ "receipts_graph_unique_index",
+ index_name="receipts_graph_unique_index",
+ table="receipts_graph",
+ columns=["room_id", "receipt_type", "user_id"],
+ where_clause="thread_id IS NULL",
+ unique=True,
+ )
+
def get_max_receipt_stream_id(self) -> int:
"""Get the current max stream ID for receipts stream"""
return self._receipts_id_gen.get_current_token()
@@ -675,7 +693,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
values={
"stream_id": stream_id,
"event_id": event_id,
+ "event_stream_ordering": stream_ordering,
"data": json_encoder.encode(data),
+ "thread_id": None,
},
# receipts_linearized has a unique constraint on
# (user_id, room_id, receipt_type), so no need to lock
@@ -812,7 +832,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# FIXME: This shouldn't invalidate the whole cache
txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,))
- self.db_pool.simple_delete_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="receipts_graph",
keyvalues={
@@ -820,19 +840,87 @@ class ReceiptsWorkerStore(SQLBaseStore):
"receipt_type": receipt_type,
"user_id": user_id,
},
- )
- self.db_pool.simple_insert_txn(
- txn,
- table="receipts_graph",
values={
- "room_id": room_id,
- "receipt_type": receipt_type,
- "user_id": user_id,
"event_ids": json_encoder.encode(event_ids),
"data": json_encoder.encode(data),
+ "thread_id": None,
},
+ # receipts_graph has a unique constraint on
+ # (user_id, room_id, receipt_type), so no need to lock
+ lock=False,
)
-class ReceiptsStore(ReceiptsWorkerStore):
+class ReceiptsBackgroundUpdateStore(SQLBaseStore):
+ POPULATE_RECEIPT_EVENT_STREAM_ORDERING = "populate_event_stream_ordering"
+
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ self.db_pool.updates.register_background_update_handler(
+ self.POPULATE_RECEIPT_EVENT_STREAM_ORDERING,
+ self._populate_receipt_event_stream_ordering,
+ )
+
+ async def _populate_receipt_event_stream_ordering(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ def _populate_receipt_event_stream_ordering_txn(
+ txn: LoggingTransaction,
+ ) -> bool:
+
+ if "max_stream_id" in progress:
+ max_stream_id = progress["max_stream_id"]
+ else:
+ txn.execute("SELECT max(stream_id) FROM receipts_linearized")
+ res = txn.fetchone()
+ if res is None or res[0] is None:
+ return True
+ else:
+ max_stream_id = res[0]
+
+ start = progress.get("stream_id", 0)
+ stop = start + batch_size
+
+ sql = """
+ UPDATE receipts_linearized
+ SET event_stream_ordering = (
+ SELECT stream_ordering
+ FROM events
+ WHERE event_id = receipts_linearized.event_id
+ )
+ WHERE stream_id >= ? AND stream_id < ?
+ """
+ txn.execute(sql, (start, stop))
+
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ self.POPULATE_RECEIPT_EVENT_STREAM_ORDERING,
+ {
+ "stream_id": stop,
+ "max_stream_id": max_stream_id,
+ },
+ )
+
+ return stop > max_stream_id
+
+ finished = await self.db_pool.runInteraction(
+ "_remove_devices_from_device_inbox_txn",
+ _populate_receipt_event_stream_ordering_txn,
+ )
+
+ if finished:
+ await self.db_pool.updates._end_background_update(
+ self.POPULATE_RECEIPT_EVENT_STREAM_ORDERING
+ )
+
+ return batch_size
+
+
+class ReceiptsStore(ReceiptsWorkerStore, ReceiptsBackgroundUpdateStore):
pass
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index cb63cd9b7d..ac821878b0 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -69,9 +69,9 @@ class TokenLookupResult:
"""
user_id: str
+ token_id: int
is_guest: bool = False
shadow_banned: bool = False
- token_id: Optional[int] = None
device_id: Optional[str] = None
valid_until_ms: Optional[int] = None
token_owner: str = attr.ib()
@@ -175,6 +175,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"is_guest",
"admin",
"consent_version",
+ "consent_ts",
"consent_server_notice_sent",
"appservice_id",
"creation_ts",
@@ -2227,7 +2228,10 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
txn,
table="users",
keyvalues={"name": user_id},
- updatevalues={"consent_version": consent_version},
+ updatevalues={
+ "consent_version": consent_version,
+ "consent_ts": self._clock.time_msec(),
+ },
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index b7d4baa6bb..bef66f1992 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -641,8 +641,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
"version": room[5],
"creator": room[6],
"encryption": room[7],
- "federatable": room[8],
- "public": room[9],
+ # room_stats_state.federatable is an integer on sqlite.
+ "federatable": bool(room[8]),
+ # rooms.is_public is an integer on sqlite.
+ "public": bool(room[9]),
"join_rules": room[10],
"guest_access": room[11],
"history_visibility": room[12],
@@ -1183,8 +1185,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
)
return False
- @staticmethod
- def _clear_partial_state_room_txn(txn: LoggingTransaction, room_id: str) -> None:
+ def _clear_partial_state_room_txn(
+ self, txn: LoggingTransaction, room_id: str
+ ) -> None:
DatabasePool.simple_delete_txn(
txn,
table="partial_state_rooms_servers",
@@ -1195,7 +1198,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
table="partial_state_rooms",
keyvalues={"room_id": room_id},
)
+ self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
+ @cached()
async def is_partial_state_room(self, room_id: str) -> bool:
"""Checks if this room has partial state.
@@ -1769,9 +1774,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
servers,
)
- @staticmethod
def _store_partial_state_room_txn(
- txn: LoggingTransaction, room_id: str, servers: Collection[str]
+ self, txn: LoggingTransaction, room_id: str, servers: Collection[str]
) -> None:
DatabasePool.simple_insert_txn(
txn,
@@ -1786,6 +1790,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
keys=("room_id", "server_name"),
values=((room_id, s) for s in servers),
)
+ self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
async def maybe_store_room_on_outlier_membership(
self, room_id: str, room_version: RoomVersion
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 827c1f1efd..a8d224602a 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -31,12 +31,8 @@ from typing import (
import attr
from synapse.api.constants import EventTypes, Membership
-from synapse.events import EventBase
from synapse.metrics import LaterGauge
-from synapse.metrics.background_process_metrics import (
- run_as_background_process,
- wrap_as_background_process,
-)
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -56,6 +52,7 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
+from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -91,15 +88,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# at a time. Keyed by room_id.
self._joined_host_linearizer = Linearizer("_JoinedHostsCache")
- # Is the current_state_events.membership up to date? Or is the
- # background update still running?
- self._current_state_events_membership_up_to_date = False
-
- txn = db_conn.cursor(
- txn_name="_check_safe_current_state_events_membership_updated"
- )
- self._check_safe_current_state_events_membership_updated_txn(txn)
- txn.close()
+ self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
if (
self.hs.config.worker.run_background_tasks
@@ -157,58 +146,41 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self._known_servers_count = max([count, 1])
return self._known_servers_count
- def _check_safe_current_state_events_membership_updated_txn(
- self, txn: LoggingTransaction
- ) -> None:
- """Checks if it is safe to assume the new current_state_events
- membership column is up to date
- """
-
- pending_update = self.db_pool.simple_select_one_txn(
- txn,
- table="background_updates",
- keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME},
- retcols=["update_name"],
- allow_none=True,
- )
-
- self._current_state_events_membership_up_to_date = not pending_update
-
- # If the update is still running, reschedule to run.
- if pending_update:
- self._clock.call_later(
- 15.0,
- run_as_background_process,
- "_check_safe_current_state_events_membership_updated",
- self.db_pool.runInteraction,
- "_check_safe_current_state_events_membership_updated",
- self._check_safe_current_state_events_membership_updated_txn,
- )
-
@cached(max_entries=100000, iterable=True)
async def get_users_in_room(self, room_id: str) -> List[str]:
+ """
+ Returns a list of users in the room sorted by longest in the room first
+ (aka. with the lowest depth). This is done to match the sort in
+ `get_current_hosts_in_room()` and so we can re-use the cache but it's
+ not horrible to have here either.
+
+ Uses `m.room.member`s in the room state at the current forward extremities to
+ determine which users are in the room.
+
+ Will return inaccurate results for rooms with partial state, since the state for
+ the forward extremities of those rooms will exclude most members. We may also
+ calculate room state incorrectly for such rooms and believe that a member is or
+ is not in the room when the opposite is true.
+ """
return await self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
)
def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
- # If we can assume current_state_events.membership is up to date
- # then we can avoid a join, which is a Very Good Thing given how
- # frequently this function gets called.
- if self._current_state_events_membership_up_to_date:
- sql = """
- SELECT state_key FROM current_state_events
- WHERE type = 'm.room.member' AND room_id = ? AND membership = ?
- """
- else:
- sql = """
- SELECT state_key FROM room_memberships as m
- INNER JOIN current_state_events as c
- ON m.event_id = c.event_id
- AND m.room_id = c.room_id
- AND m.user_id = c.state_key
- WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
- """
+ """
+ Returns a list of users in the room sorted by longest in the room first
+ (aka. with the lowest depth). This is done to match the sort in
+ `get_current_hosts_in_room()` and so we can re-use the cache but it's
+ not horrible to have here either.
+ """
+ sql = """
+ SELECT c.state_key FROM current_state_events as c
+ /* Get the depth of the event from the events table */
+ INNER JOIN events AS e USING (event_id)
+ WHERE c.type = 'm.room.member' AND c.room_id = ? AND membership = ?
+ /* Sorted by lowest depth first */
+ ORDER BY e.depth ASC;
+ """
txn.execute(sql, (room_id, Membership.JOIN))
return [r[0] for r in txn]
@@ -325,28 +297,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# We do this all in one transaction to keep the cache small.
# FIXME: get rid of this when we have room_stats
- # If we can assume current_state_events.membership is up to date
- # then we can avoid a join, which is a Very Good Thing given how
- # frequently this function gets called.
- if self._current_state_events_membership_up_to_date:
- # Note, rejected events will have a null membership field, so
- # we we manually filter them out.
- sql = """
- SELECT count(*), membership FROM current_state_events
- WHERE type = 'm.room.member' AND room_id = ?
- AND membership IS NOT NULL
- GROUP BY membership
- """
- else:
- sql = """
- SELECT count(*), m.membership FROM room_memberships as m
- INNER JOIN current_state_events as c
- ON m.event_id = c.event_id
- AND m.room_id = c.room_id
- AND m.user_id = c.state_key
- WHERE c.type = 'm.room.member' AND c.room_id = ?
- GROUP BY m.membership
- """
+ # Note, rejected events will have a null membership field, so
+ # we we manually filter them out.
+ sql = """
+ SELECT count(*), membership FROM current_state_events
+ WHERE type = 'm.room.member' AND room_id = ?
+ AND membership IS NOT NULL
+ GROUP BY membership
+ """
txn.execute(sql, (room_id,))
res: Dict[str, MemberSummary] = {}
@@ -355,30 +313,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# we order by membership and then fairly arbitrarily by event_id so
# heroes are consistent
- if self._current_state_events_membership_up_to_date:
- # Note, rejected events will have a null membership field, so
- # we we manually filter them out.
- sql = """
- SELECT state_key, membership, event_id
- FROM current_state_events
- WHERE type = 'm.room.member' AND room_id = ?
- AND membership IS NOT NULL
- ORDER BY
- CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
- event_id ASC
- LIMIT ?
- """
- else:
- sql = """
- SELECT c.state_key, m.membership, c.event_id
- FROM room_memberships as m
- INNER JOIN current_state_events as c USING (room_id, event_id)
- WHERE c.type = 'm.room.member' AND c.room_id = ?
- ORDER BY
- CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
- c.event_id ASC
- LIMIT ?
- """
+ # Note, rejected events will have a null membership field, so
+ # we we manually filter them out.
+ sql = """
+ SELECT state_key, membership, event_id
+ FROM current_state_events
+ WHERE type = 'm.room.member' AND room_id = ?
+ AND membership IS NOT NULL
+ ORDER BY
+ CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
+ event_id ASC
+ LIMIT ?
+ """
# 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
@@ -534,6 +480,47 @@ class RoomMemberWorkerStore(EventsWorkerStore):
desc="get_local_users_in_room",
)
+ async def check_local_user_in_room(self, user_id: str, room_id: str) -> bool:
+ """
+ Check whether a given local user is currently joined to the given room.
+
+ Returns:
+ A boolean indicating whether the user is currently joined to the room
+
+ Raises:
+ Exeption when called with a non-local user to this homeserver
+ """
+ if not self.hs.is_mine_id(user_id):
+ raise Exception(
+ "Cannot call 'check_local_user_in_room' on "
+ "non-local user %s" % (user_id,),
+ )
+
+ (
+ membership,
+ member_event_id,
+ ) = await self.get_local_current_membership_for_user_in_room(
+ user_id=user_id,
+ room_id=room_id,
+ )
+
+ return membership == Membership.JOIN
+
+ async def is_server_notice_room(self, room_id: str) -> bool:
+ """
+ Determines whether the given room is a 'Server Notices' room, used for
+ sending server notices to a user.
+
+ This is determined by seeing whether the server notices user is present
+ in the room.
+ """
+ if self._server_notices_mxid is None:
+ return False
+ is_server_notices_room = await self.check_local_user_in_room(
+ user_id=self._server_notices_mxid, room_id=room_id
+ )
+ return is_server_notices_room
+
async def get_local_current_membership_for_user_in_room(
self, user_id: str, room_id: str
) -> Tuple[Optional[str], Optional[str]]:
@@ -595,27 +582,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# We use `current_state_events` here and not `local_current_membership`
# as a) this gets called with remote users and b) this only gets called
# for rooms the server is participating in.
- if self._current_state_events_membership_up_to_date:
- sql = """
- SELECT room_id, e.instance_name, e.stream_ordering
- FROM current_state_events AS c
- INNER JOIN events AS e USING (room_id, event_id)
- WHERE
- c.type = 'm.room.member'
- AND c.state_key = ?
- AND c.membership = ?
- """
- else:
- sql = """
- SELECT room_id, e.instance_name, e.stream_ordering
- FROM current_state_events AS c
- INNER JOIN room_memberships AS m USING (room_id, event_id)
- INNER JOIN events AS e USING (room_id, event_id)
- WHERE
- c.type = 'm.room.member'
- AND c.state_key = ?
- AND m.membership = ?
- """
+ sql = """
+ SELECT room_id, e.instance_name, e.stream_ordering
+ FROM current_state_events AS c
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ c.type = 'm.room.member'
+ AND c.state_key = ?
+ AND c.membership = ?
+ """
txn.execute(sql, (user_id, Membership.JOIN))
return frozenset(
@@ -653,27 +628,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
user_ids,
)
- if self._current_state_events_membership_up_to_date:
- sql = f"""
- SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
- FROM current_state_events AS c
- INNER JOIN events AS e USING (room_id, event_id)
- WHERE
- c.type = 'm.room.member'
- AND c.membership = ?
- AND {clause}
- """
- else:
- sql = f"""
- SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
- FROM current_state_events AS c
- INNER JOIN room_memberships AS m USING (room_id, event_id)
- INNER JOIN events AS e USING (room_id, event_id)
- WHERE
- c.type = 'm.room.member'
- AND m.membership = ?
- AND {clause}
- """
+ sql = f"""
+ SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
+ FROM current_state_events AS c
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ c.type = 'm.room.member'
+ AND c.membership = ?
+ AND {clause}
+ """
txn.execute(sql, [Membership.JOIN] + args)
@@ -724,6 +687,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
_get_users_server_still_shares_room_with_txn,
)
+ @cancellable
async def get_rooms_for_user(
self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None
) -> FrozenSet[str]:
@@ -835,144 +799,92 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return shared_room_ids or frozenset()
- async def get_joined_users_from_state(
- self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
- ) -> Dict[str, ProfileInfo]:
- state_group: Union[object, int] = state_entry.state_group
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
-
- assert state_group is not None
- with Measure(self._clock, "get_joined_users_from_state"):
- return await self._get_joined_users_from_context(
- room_id, state_group, state, context=state_entry
- )
+ async def get_joined_user_ids_from_state(
+ self, room_id: str, state: StateMap[str]
+ ) -> Set[str]:
+ """
+ For a given set of state IDs, get a set of user IDs in the room.
- @cached(num_args=2, iterable=True, max_entries=100000)
- async def _get_joined_users_from_context(
- self,
- room_id: str,
- state_group: Union[object, int],
- current_state_ids: StateMap[str],
- event: Optional[EventBase] = None,
- context: Optional["_StateCacheEntry"] = None,
- ) -> Dict[str, ProfileInfo]:
- # We don't use `state_group`, it's there so that we can cache based
- # on it. However, it's important that it's never None, since two current_states
- # with a state_group of None are likely to be different.
- assert state_group is not None
+ This method checks the local event cache, before calling
+ `_get_user_ids_from_membership_event_ids` for any uncached events.
+ """
- users_in_room = {}
- member_event_ids = [
- e_id
- for key, e_id in current_state_ids.items()
- if key[0] == EventTypes.Member
- ]
-
- if context is not None:
- # If we have a context with a delta from a previous state group,
- # check if we also have the result from the previous group in cache.
- # If we do then we can reuse that result and simply update it with
- # any membership changes in `delta_ids`
- if context.prev_group and context.delta_ids:
- prev_res = self._get_joined_users_from_context.cache.get_immediate(
- (room_id, context.prev_group), None
- )
- if prev_res and isinstance(prev_res, dict):
- users_in_room = dict(prev_res)
- member_event_ids = [
- e_id
- for key, e_id in context.delta_ids.items()
- if key[0] == EventTypes.Member
- ]
- for etype, state_key in context.delta_ids:
- if etype == EventTypes.Member:
- users_in_room.pop(state_key, None)
-
- # We check if we have any of the member event ids in the event cache
- # before we ask the DB
-
- # We don't update the event cache hit ratio as it completely throws off
- # the hit ratio counts. After all, we don't populate the cache if we
- # miss it here
- event_map = self._get_events_from_local_cache(
- member_event_ids, update_metrics=False
- )
+ with Measure(self._clock, "get_joined_user_ids_from_state"):
+ users_in_room = set()
+ member_event_ids = [
+ e_id for key, e_id in state.items() if key[0] == EventTypes.Member
+ ]
- missing_member_event_ids = []
- for event_id in member_event_ids:
- ev_entry = event_map.get(event_id)
- if ev_entry and not ev_entry.event.rejected_reason:
- if ev_entry.event.membership == Membership.JOIN:
- users_in_room[ev_entry.event.state_key] = ProfileInfo(
- display_name=ev_entry.event.content.get("displayname", None),
- avatar_url=ev_entry.event.content.get("avatar_url", None),
- )
- else:
- missing_member_event_ids.append(event_id)
+ # We check if we have any of the member event ids in the event cache
+ # before we ask the DB
- if missing_member_event_ids:
- event_to_memberships = await self._get_joined_profiles_from_event_ids(
- missing_member_event_ids
+ # We don't update the event cache hit ratio as it completely throws off
+ # the hit ratio counts. After all, we don't populate the cache if we
+ # miss it here
+ event_map = self._get_events_from_local_cache(
+ member_event_ids, update_metrics=False
)
- users_in_room.update(row for row in event_to_memberships.values() if row)
-
- if event is not None and event.type == EventTypes.Member:
- if event.membership == Membership.JOIN:
- if event.event_id in member_event_ids:
- users_in_room[event.state_key] = ProfileInfo(
- display_name=event.content.get("displayname", None),
- avatar_url=event.content.get("avatar_url", None),
+
+ missing_member_event_ids = []
+ for event_id in member_event_ids:
+ ev_entry = event_map.get(event_id)
+ if ev_entry and not ev_entry.event.rejected_reason:
+ if ev_entry.event.membership == Membership.JOIN:
+ users_in_room.add(ev_entry.event.state_key)
+ else:
+ missing_member_event_ids.append(event_id)
+
+ if missing_member_event_ids:
+ event_to_memberships = (
+ await self._get_user_ids_from_membership_event_ids(
+ missing_member_event_ids
)
+ )
+ users_in_room.update(
+ user_id for user_id in event_to_memberships.values() if user_id
+ )
- return users_in_room
+ return users_in_room
- @cached(max_entries=10000)
- def _get_joined_profile_from_event_id(
+ @cached(
+ max_entries=10000,
+ # This name matches the old function that has been replaced - the cache name
+ # is kept here to maintain backwards compatibility.
+ name="_get_joined_profile_from_event_id",
+ )
+ def _get_user_id_from_membership_event_id(
self, event_id: str
) -> Optional[Tuple[str, ProfileInfo]]:
raise NotImplementedError()
@cachedList(
- cached_method_name="_get_joined_profile_from_event_id",
+ cached_method_name="_get_user_id_from_membership_event_id",
list_name="event_ids",
)
- async def _get_joined_profiles_from_event_ids(
+ async def _get_user_ids_from_membership_event_ids(
self, event_ids: Iterable[str]
- ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
+ ) -> Dict[str, Optional[str]]:
"""For given set of member event_ids check if they point to a join
- event and if so return the associated user and profile info.
+ event.
Args:
event_ids: The member event IDs to lookup
Returns:
- Map from event ID to `user_id` and ProfileInfo (or None if not join event).
+ Map from event ID to `user_id`, or None if event is not a join.
"""
rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
- retcols=("user_id", "display_name", "avatar_url", "event_id"),
+ retcols=("user_id", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=1000,
- desc="_get_joined_profiles_from_event_ids",
+ desc="_get_user_ids_from_membership_event_ids",
)
- return {
- row["event_id"]: (
- row["user_id"],
- ProfileInfo(
- avatar_url=row["avatar_url"], display_name=row["display_name"]
- ),
- )
- for row in rows
- }
+ return {row["event_id"]: row["user_id"] for row in rows}
@cached(max_entries=10000)
async def is_host_joined(self, room_id: str, host: str) -> bool:
@@ -1018,37 +930,81 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
@cached(iterable=True, max_entries=10000)
- async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
- """Get current hosts in room based on current state."""
+ async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
+ """
+ Get current hosts in room based on current state.
+
+ The heuristic of sorting by servers who have been in the room the
+ longest is good because they're most likely to have anything we ask
+ about.
+
+ Uses `m.room.member`s in the room state at the current forward extremities to
+ determine which hosts are in the room.
+
+ Will return inaccurate results for rooms with partial state, since the state for
+ the forward extremities of those rooms will exclude most members. We may also
+ calculate room state incorrectly for such rooms and believe that a host is or
+ is not in the room when the opposite is true.
+
+ Returns:
+ Returns a list of servers sorted by longest in the room first. (aka.
+ sorted by join with the lowest depth first).
+ """
# First we check if we already have `get_users_in_room` in the cache, as
# we can just calculate result from that
users = self.get_users_in_room.cache.get_immediate(
(room_id,), None, update_metrics=False
)
- if users is not None:
- return {get_domain_from_id(u) for u in users}
-
- if isinstance(self.database_engine, Sqlite3Engine):
+ if users is None and isinstance(self.database_engine, Sqlite3Engine):
# If we're using SQLite then let's just always use
# `get_users_in_room` rather than funky SQL.
users = await self.get_users_in_room(room_id)
- return {get_domain_from_id(u) for u in users}
+
+ if users is not None:
+ # Because `users` is sorted from lowest -> highest depth, the list
+ # of domains will also be sorted that way.
+ domains: List[str] = []
+ # We use a `Set` just for fast lookups
+ domain_set: Set[str] = set()
+ for u in users:
+ if ":" not in u:
+ continue
+ domain = get_domain_from_id(u)
+ if domain not in domain_set:
+ domain_set.add(domain)
+ domains.append(domain)
+ return domains
# For PostgreSQL we can use a regex to pull out the domains from the
# joined users in `current_state_events` via regex.
- def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
+ def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> List[str]:
+ # Returns a list of servers currently joined in the room sorted by
+ # longest in the room first (aka. with the lowest depth). The
+ # heuristic of sorting by servers who have been in the room the
+ # longest is good because they're most likely to have anything we
+ # ask about.
sql = """
- SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
- FROM current_state_events
+ SELECT
+ /* Match the domain part of the MXID */
+ substring(c.state_key FROM '@[^:]*:(.*)$') as server_domain
+ FROM current_state_events c
+ /* Get the depth of the event from the events table */
+ INNER JOIN events AS e USING (event_id)
WHERE
- type = 'm.room.member'
- AND membership = 'join'
- AND room_id = ?
+ /* Find any join state events in the room */
+ c.type = 'm.room.member'
+ AND c.membership = 'join'
+ AND c.room_id = ?
+ /* Group all state events from the same domain into their own buckets (groups) */
+ GROUP BY server_domain
+ /* Sorted by lowest depth first */
+ ORDER BY min(e.depth) ASC;
"""
txn.execute(sql, (room_id,))
- return {d for d, in txn}
+ # `server_domain` will be `NULL` for malformed MXIDs with no colons.
+ return [d for d, in txn if d is not None]
return await self.db_pool.runInteraction(
"get_current_hosts_in_room", get_current_hosts_in_room_txn
@@ -1131,12 +1087,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
else:
# The cache doesn't match the state group or prev state group,
# so we calculate the result from first principles.
- joined_users = await self.get_joined_users_from_state(
- room_id, state, state_entry
+ joined_user_ids = await self.get_joined_user_ids_from_state(
+ room_id, state
)
cache.hosts_to_joined_users = {}
- for user_id in joined_users:
+ for user_id in joined_user_ids:
host = intern_string(get_domain_from_id(user_id))
cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 0b10af0e58..32095d7969 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -23,6 +23,7 @@ from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
+from synapse.logging.tracing import trace
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -36,6 +37,7 @@ from synapse.storage.state import StateFilter
from synapse.types import JsonDict, JsonMapping, StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -142,6 +144,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return room_version
+ @trace
async def get_metadata_for_events(
self, event_ids: Collection[str]
) -> Dict[str, EventMetadata]:
@@ -281,6 +284,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
)
# FIXME: how should this be cached?
+ @cancellable
async def get_partial_filtered_current_state_ids(
self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index b4c652acf3..356d4ca788 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -446,59 +446,41 @@ class StatsStore(StateDeltasStore):
absolutes: Absolute (set) fields
additive_relatives: Fields that will be added onto if existing row present.
"""
- if self.database_engine.can_native_upsert:
- absolute_updates = [
- "%(field)s = EXCLUDED.%(field)s" % {"field": field}
- for field in absolutes.keys()
- ]
-
- relative_updates = [
- "%(field)s = EXCLUDED.%(field)s + COALESCE(%(table)s.%(field)s, 0)"
- % {"table": table, "field": field}
- for field in additive_relatives.keys()
- ]
-
- insert_cols = []
- qargs = []
-
- for (key, val) in chain(
- keyvalues.items(), absolutes.items(), additive_relatives.items()
- ):
- insert_cols.append(key)
- qargs.append(val)
+ absolute_updates = [
+ "%(field)s = EXCLUDED.%(field)s" % {"field": field}
+ for field in absolutes.keys()
+ ]
+
+ relative_updates = [
+ "%(field)s = EXCLUDED.%(field)s + COALESCE(%(table)s.%(field)s, 0)"
+ % {"table": table, "field": field}
+ for field in additive_relatives.keys()
+ ]
+
+ insert_cols = []
+ qargs = []
+
+ for (key, val) in chain(
+ keyvalues.items(), absolutes.items(), additive_relatives.items()
+ ):
+ insert_cols.append(key)
+ qargs.append(val)
+
+ sql = """
+ INSERT INTO %(table)s (%(insert_cols_cs)s)
+ VALUES (%(insert_vals_qs)s)
+ ON CONFLICT (%(key_columns)s) DO UPDATE SET %(updates)s
+ """ % {
+ "table": table,
+ "insert_cols_cs": ", ".join(insert_cols),
+ "insert_vals_qs": ", ".join(
+ ["?"] * (len(keyvalues) + len(absolutes) + len(additive_relatives))
+ ),
+ "key_columns": ", ".join(keyvalues),
+ "updates": ", ".join(chain(absolute_updates, relative_updates)),
+ }
- sql = """
- INSERT INTO %(table)s (%(insert_cols_cs)s)
- VALUES (%(insert_vals_qs)s)
- ON CONFLICT (%(key_columns)s) DO UPDATE SET %(updates)s
- """ % {
- "table": table,
- "insert_cols_cs": ", ".join(insert_cols),
- "insert_vals_qs": ", ".join(
- ["?"] * (len(keyvalues) + len(absolutes) + len(additive_relatives))
- ),
- "key_columns": ", ".join(keyvalues),
- "updates": ", ".join(chain(absolute_updates, relative_updates)),
- }
-
- txn.execute(sql, qargs)
- else:
- self.database_engine.lock_table(txn, table)
- retcols = list(chain(absolutes.keys(), additive_relatives.keys()))
- current_row = self.db_pool.simple_select_one_txn(
- txn, table, keyvalues, retcols, allow_none=True
- )
- if current_row is None:
- merged_dict = {**keyvalues, **absolutes, **additive_relatives}
- self.db_pool.simple_insert_txn(txn, table, merged_dict)
- else:
- for (key, val) in additive_relatives.items():
- if current_row[key] is None:
- current_row[key] = val
- else:
- current_row[key] += val
- current_row.update(absolutes)
- self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row)
+ txn.execute(sql, qargs)
async def _calculate_and_set_initial_state_for_room(self, room_id: str) -> None:
"""Calculate and insert an entry into room_stats_current.
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index f61f290547..f0b179eea5 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -72,6 +72,7 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import PersistedEventPosition, RoomStreamToken
from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.cancellation import cancellable
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -597,6 +598,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
+ @cancellable
async def get_membership_changes_for_user(
self,
user_id: str,
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index ba79e19f7f..f8c6877ee8 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -221,25 +221,15 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
retry_interval: how long until next retry in ms
"""
- if self.database_engine.can_native_upsert:
- await self.db_pool.runInteraction(
- "set_destination_retry_timings",
- self._set_destination_retry_timings_native,
- destination,
- failure_ts,
- retry_last_ts,
- retry_interval,
- db_autocommit=True, # Safe as its a single upsert
- )
- else:
- await self.db_pool.runInteraction(
- "set_destination_retry_timings",
- self._set_destination_retry_timings_emulated,
- destination,
- failure_ts,
- retry_last_ts,
- retry_interval,
- )
+ await self.db_pool.runInteraction(
+ "set_destination_retry_timings",
+ self._set_destination_retry_timings_native,
+ destination,
+ failure_ts,
+ retry_last_ts,
+ retry_interval,
+ db_autocommit=True, # Safe as it's a single upsert
+ )
def _set_destination_retry_timings_native(
self,
@@ -249,8 +239,6 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
retry_last_ts: int,
retry_interval: int,
) -> None:
- assert self.database_engine.can_native_upsert
-
# Upsert retry time interval if retry_interval is zero (i.e. we're
# resetting it) or greater than the existing retry interval.
#
|