diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 8143168107..264e625bd7 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -19,9 +19,10 @@ from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -488,10 +489,12 @@ class DeviceInboxWorkerStore(SQLBaseStore):
devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
+ # We exclude hidden devices (such as cross-signing keys) here as they are
+ # not expected to receive to-device messages.
devices = self.db_pool.simple_select_onecol_txn(
txn,
table="devices",
- keyvalues={"user_id": user_id},
+ keyvalues={"user_id": user_id, "hidden": False},
retcol="device_id",
)
@@ -504,10 +507,12 @@ class DeviceInboxWorkerStore(SQLBaseStore):
if not devices:
continue
+ # We exclude hidden devices (such as cross-signing keys) here as they are
+ # not expected to receive to-device messages.
rows = self.db_pool.simple_select_many_txn(
txn,
table="devices",
- keyvalues={"user_id": user_id},
+ keyvalues={"user_id": user_id, "hidden": False},
column="device_id",
iterable=devices,
retcols=("device_id",),
@@ -555,6 +560,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
+ REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox"
+ REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
@@ -570,6 +577,16 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
+ self.db_pool.updates.register_background_update_handler(
+ self.REMOVE_DELETED_DEVICES,
+ self._remove_deleted_devices_from_device_inbox,
+ )
+
+ self.db_pool.updates.register_background_update_handler(
+ self.REMOVE_HIDDEN_DEVICES,
+ self._remove_hidden_devices_from_device_inbox,
+ )
+
async def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
@@ -582,6 +599,172 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
return 1
+ async def _remove_deleted_devices_from_device_inbox(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """A background update that deletes all device_inboxes for deleted devices.
+
+ This should only need to be run once (when users upgrade to v1.47.0)
+
+ Args:
+ progress: JsonDict used to store progress of this background update
+ batch_size: the maximum number of rows to retrieve in a single select query
+
+ Returns:
+ The number of deleted rows
+ """
+
+ def _remove_deleted_devices_from_device_inbox_txn(
+ txn: LoggingTransaction,
+ ) -> int:
+ """stream_id is not unique
+ we need to use an inclusive `stream_id >= ?` clause,
+ since we might not have deleted all dead device messages for the stream_id
+ returned from the previous query
+
+ Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
+ to avoid problems of deleting a large number of rows all at once
+ due to a single device having lots of device messages.
+ """
+
+ last_stream_id = progress.get("stream_id", 0)
+
+ sql = """
+ SELECT device_id, user_id, stream_id
+ FROM device_inbox
+ WHERE
+ stream_id >= ?
+ AND (device_id, user_id) NOT IN (
+ SELECT device_id, user_id FROM devices
+ )
+ ORDER BY stream_id
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_stream_id, batch_size))
+ rows = txn.fetchall()
+
+ num_deleted = 0
+ for row in rows:
+ num_deleted += self.db_pool.simple_delete_txn(
+ txn,
+ "device_inbox",
+ {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
+ )
+
+ if rows:
+ # send more than stream_id to progress
+ # otherwise it can happen in large deployments that
+ # no change of status is visible in the log file
+ # it may be that the stream_id does not change in several runs
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ self.REMOVE_DELETED_DEVICES,
+ {
+ "device_id": rows[-1][0],
+ "user_id": rows[-1][1],
+ "stream_id": rows[-1][2],
+ },
+ )
+
+ return num_deleted
+
+ number_deleted = await self.db_pool.runInteraction(
+ "_remove_deleted_devices_from_device_inbox",
+ _remove_deleted_devices_from_device_inbox_txn,
+ )
+
+ # The task is finished when no more lines are deleted.
+ if not number_deleted:
+ await self.db_pool.updates._end_background_update(
+ self.REMOVE_DELETED_DEVICES
+ )
+
+ return number_deleted
+
+ async def _remove_hidden_devices_from_device_inbox(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """A background update that deletes all device_inboxes for hidden devices.
+
+ This should only need to be run once (when users upgrade to v1.47.0)
+
+ Args:
+ progress: JsonDict used to store progress of this background update
+ batch_size: the maximum number of rows to retrieve in a single select query
+
+ Returns:
+ The number of deleted rows
+ """
+
+ def _remove_hidden_devices_from_device_inbox_txn(
+ txn: LoggingTransaction,
+ ) -> int:
+ """stream_id is not unique
+ we need to use an inclusive `stream_id >= ?` clause,
+ since we might not have deleted all hidden device messages for the stream_id
+ returned from the previous query
+
+ Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
+ to avoid problems of deleting a large number of rows all at once
+ due to a single device having lots of device messages.
+ """
+
+ last_stream_id = progress.get("stream_id", 0)
+
+ sql = """
+ SELECT device_id, user_id, stream_id
+ FROM device_inbox
+ WHERE
+ stream_id >= ?
+ AND (device_id, user_id) IN (
+ SELECT device_id, user_id FROM devices WHERE hidden = ?
+ )
+ ORDER BY stream_id
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_stream_id, True, batch_size))
+ rows = txn.fetchall()
+
+ num_deleted = 0
+ for row in rows:
+ num_deleted += self.db_pool.simple_delete_txn(
+ txn,
+ "device_inbox",
+ {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
+ )
+
+ if rows:
+ # We don't just save the `stream_id` in progress as
+ # otherwise it can happen in large deployments that
+ # no change of status is visible in the log file, as
+ # it may be that the stream_id does not change in several runs
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ self.REMOVE_HIDDEN_DEVICES,
+ {
+ "device_id": rows[-1][0],
+ "user_id": rows[-1][1],
+ "stream_id": rows[-1][2],
+ },
+ )
+
+ return num_deleted
+
+ number_deleted = await self.db_pool.runInteraction(
+ "_remove_hidden_devices_from_device_inbox",
+ _remove_hidden_devices_from_device_inbox_txn,
+ )
+
+ # The task is finished when no more lines are deleted.
+ if not number_deleted:
+ await self.db_pool.updates._end_background_update(
+ self.REMOVE_HIDDEN_DEVICES
+ )
+
+ return number_deleted
+
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
pass
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index a01bf2c5b7..9ccc66e589 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -427,7 +427,7 @@ class DeviceWorkerStore(SQLBaseStore):
user_ids: the users who were signed
Returns:
- THe new stream ID.
+ The new stream ID.
"""
async with self._device_list_id_gen.get_next() as stream_id:
@@ -1134,19 +1134,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
raise StoreError(500, "Problem storing device.")
async def delete_device(self, user_id: str, device_id: str) -> None:
- """Delete a device.
+ """Delete a device and its device_inbox.
Args:
user_id: The ID of the user which owns the device
device_id: The ID of the device to delete
"""
- await self.db_pool.simple_delete_one(
- table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
- desc="delete_device",
- )
- self.device_id_exists_cache.invalidate((user_id, device_id))
+ await self.delete_devices(user_id, [device_id])
async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
"""Deletes several devices.
@@ -1155,13 +1150,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: The ID of the user which owns the devices
device_ids: The IDs of the devices to delete
"""
- await self.db_pool.simple_delete_many(
- table="devices",
- column="device_id",
- iterable=device_ids,
- keyvalues={"user_id": user_id, "hidden": False},
- desc="delete_devices",
- )
+
+ def _delete_devices_txn(txn: LoggingTransaction) -> None:
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="devices",
+ column="device_id",
+ values=device_ids,
+ keyvalues={"user_id": user_id, "hidden": False},
+ )
+
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="device_inbox",
+ column="device_id",
+ values=device_ids,
+ keyvalues={"user_id": user_id},
+ )
+
+ await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))
@@ -1315,7 +1322,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
async def add_device_change_to_streams(
self, user_id: str, device_ids: Collection[str], hosts: List[str]
- ):
+ ) -> int:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 8d9086ecf0..596275c23c 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -24,6 +24,7 @@ from typing import (
Iterable,
List,
Optional,
+ Sequence,
Set,
Tuple,
)
@@ -494,7 +495,7 @@ class PersistEventsStore:
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
- event_to_auth_chain: Dict[str, List[str]],
+ event_to_auth_chain: Dict[str, Sequence[str]],
) -> None:
"""Calculate the chain cover index for the given events.
@@ -786,7 +787,7 @@ class PersistEventsStore:
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
- event_to_auth_chain: Dict[str, List[str]],
+ event_to_auth_chain: Dict[str, Sequence[str]],
events_to_calc_chain_id_for: Set[str],
chain_map: Dict[str, Tuple[int, int]],
) -> Dict[str, Tuple[int, int]]:
@@ -1794,7 +1795,7 @@ class PersistEventsStore:
)
# Insert an edge for every prev_event connection
- for prev_event_id in event.prev_events:
+ for prev_event_id in event.prev_event_ids():
self.db_pool.simple_insert_txn(
txn,
table="insertion_event_edges",
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 12cf6995eb..cc0eebdb46 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -92,7 +92,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
prefilled_cache=presence_cache_prefill,
)
- async def update_presence(self, presence_states):
+ async def update_presence(self, presence_states) -> Tuple[int, int]:
assert self._can_persist_presence
stream_ordering_manager = self._presence_id_gen.get_next_mult(
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index ba7075caa5..dd8e27e226 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -91,7 +91,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
async def update_remote_profile_cache(
- self, user_id: str, displayname: str, avatar_url: str
+ self, user_id: str, displayname: Optional[str], avatar_url: Optional[str]
) -> int:
return await self.db_pool.simple_update(
table="remote_profile_cache",
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 40760fbd1b..53576ad52f 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,13 +13,14 @@
# limitations under the License.
import logging
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple, Union
import attr
from synapse.api.constants import RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.relations import (
AggregationPaginationToken,
@@ -63,7 +64,7 @@ class RelationsWorkerStore(SQLBaseStore):
"""
where_clause = ["relates_to_id = ?"]
- where_args = [event_id]
+ where_args: List[Union[str, int]] = [event_id]
if relation_type is not None:
where_clause.append("relation_type = ?")
@@ -80,8 +81,8 @@ class RelationsWorkerStore(SQLBaseStore):
pagination_clause = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
- from_token=attr.astuple(from_token) if from_token else None,
- to_token=attr.astuple(to_token) if to_token else None,
+ from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
+ to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
engine=self.database_engine,
)
@@ -106,7 +107,9 @@ class RelationsWorkerStore(SQLBaseStore):
order,
)
- def _get_recent_references_for_event_txn(txn):
+ def _get_recent_references_for_event_txn(
+ txn: LoggingTransaction,
+ ) -> PaginationChunk:
txn.execute(sql, where_args + [limit + 1])
last_topo_id = None
@@ -160,7 +163,7 @@ class RelationsWorkerStore(SQLBaseStore):
"""
where_clause = ["relates_to_id = ?", "relation_type = ?"]
- where_args = [event_id, RelationTypes.ANNOTATION]
+ where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]
if event_type:
where_clause.append("type = ?")
@@ -169,8 +172,8 @@ class RelationsWorkerStore(SQLBaseStore):
having_clause = generate_pagination_where_clause(
direction=direction,
column_names=("COUNT(*)", "MAX(stream_ordering)"),
- from_token=attr.astuple(from_token) if from_token else None,
- to_token=attr.astuple(to_token) if to_token else None,
+ from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
+ to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
engine=self.database_engine,
)
@@ -199,7 +202,9 @@ class RelationsWorkerStore(SQLBaseStore):
having_clause=having_clause,
)
- def _get_aggregation_groups_for_event_txn(txn):
+ def _get_aggregation_groups_for_event_txn(
+ txn: LoggingTransaction,
+ ) -> PaginationChunk:
txn.execute(sql, where_args + [limit + 1])
next_batch = None
@@ -254,11 +259,12 @@ class RelationsWorkerStore(SQLBaseStore):
LIMIT 1
"""
- def _get_applicable_edit_txn(txn):
+ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
txn.execute(sql, (event_id, RelationTypes.REPLACE))
row = txn.fetchone()
if row:
return row[0]
+ return None
edit_id = await self.db_pool.runInteraction(
"get_applicable_edit", _get_applicable_edit_txn
@@ -267,7 +273,7 @@ class RelationsWorkerStore(SQLBaseStore):
if not edit_id:
return None
- return await self.get_event(edit_id, allow_none=True)
+ return await self.get_event(edit_id, allow_none=True) # type: ignore[attr-defined]
@cached()
async def get_thread_summary(
@@ -283,7 +289,9 @@ class RelationsWorkerStore(SQLBaseStore):
The number of items in the thread and the most recent response, if any.
"""
- def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
+ def _get_thread_summary_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[int, Optional[str]]:
# Fetch the count of threaded events and the latest event ID.
# TODO Should this only allow m.room.message events.
sql = """
@@ -312,7 +320,7 @@ class RelationsWorkerStore(SQLBaseStore):
AND relation_type = ?
"""
txn.execute(sql, (event_id, RelationTypes.THREAD))
- count = txn.fetchone()[0]
+ count = txn.fetchone()[0] # type: ignore[index]
return count, latest_event_id
@@ -322,7 +330,7 @@ class RelationsWorkerStore(SQLBaseStore):
latest_event = None
if latest_event_id:
- latest_event = await self.get_event(latest_event_id, allow_none=True)
+ latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined]
return count, latest_event
@@ -354,7 +362,7 @@ class RelationsWorkerStore(SQLBaseStore):
LIMIT 1;
"""
- def _get_if_user_has_annotated_event(txn):
+ def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool:
txn.execute(
sql,
(
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index f879bbe7c7..cefc77fa0f 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -412,22 +412,33 @@ class RoomWorkerStore(SQLBaseStore):
limit: maximum amount of rooms to retrieve
order_by: the sort order of the returned list
reverse_order: whether to reverse the room list
- search_term: a string to filter room names by
+ search_term: a string to filter room names,
+ canonical alias and room ids by.
+ Room ID must match exactly. Canonical alias must match a substring of the local part.
Returns:
A list of room dicts and an integer representing the total number of
rooms that exist given this query
"""
# Filter room names by a string
where_statement = ""
+ search_pattern = []
if search_term:
- where_statement = "WHERE LOWER(state.name) LIKE ?"
+ where_statement = """
+ WHERE LOWER(state.name) LIKE ?
+ OR LOWER(state.canonical_alias) LIKE ?
+ OR state.room_id = ?
+ """
# Our postgres db driver converts ? -> %s in SQL strings as that's the
# placeholder for postgres.
# HOWEVER, if you put a % into your SQL then everything goes wibbly.
# To get around this, we're going to surround search_term with %'s
# before giving it to the database in python instead
- search_term = "%" + search_term.lower() + "%"
+ search_pattern = [
+ "%" + search_term.lower() + "%",
+ "#%" + search_term.lower() + "%:%",
+ search_term,
+ ]
# Set ordering
if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
@@ -519,12 +530,9 @@ class RoomWorkerStore(SQLBaseStore):
)
def _get_rooms_paginate_txn(txn):
- # Execute the data query
- sql_values = (limit, start)
- if search_term:
- # Add the search term into the WHERE clause
- sql_values = (search_term,) + sql_values
- txn.execute(info_sql, sql_values)
+ # Add the search term into the WHERE clause
+ # and execute the data query
+ txn.execute(info_sql, search_pattern + [limit, start])
# Refactor room query data into a structured dictionary
rooms = []
@@ -551,8 +559,7 @@ class RoomWorkerStore(SQLBaseStore):
# Execute the count query
# Add the search term into the WHERE clause if present
- sql_values = (search_term,) if search_term else ()
- txn.execute(count_sql, sql_values)
+ txn.execute(count_sql, search_pattern)
room_count = txn.fetchone()
return rooms, room_count[0]
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 4b288bb2e7..033a9831d6 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -570,7 +570,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
async def get_joined_users_from_context(
self, event: EventBase, context: EventContext
- ):
+ ) -> Dict[str, ProfileInfo]:
state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@@ -584,7 +584,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event.room_id, state_group, current_state_ids, event=event, context=context
)
- async def get_joined_users_from_state(self, room_id, state_entry):
+ async def get_joined_users_from_state(
+ self, room_id, state_entry
+ ) -> Dict[str, ProfileInfo]:
state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@@ -607,7 +609,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
cache_context,
event=None,
context=None,
- ):
+ ) -> Dict[str, ProfileInfo]:
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
# with a state_group of None are likely to be different.
|