diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 25e9c1efe1..264e625bd7 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -561,6 +561,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox"
+ REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
@@ -581,6 +582,11 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
self._remove_deleted_devices_from_device_inbox,
)
+ self.db_pool.updates.register_background_update_handler(
+ self.REMOVE_HIDDEN_DEVICES,
+ self._remove_hidden_devices_from_device_inbox,
+ )
+
async def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
@@ -676,6 +682,89 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
return number_deleted
+ async def _remove_hidden_devices_from_device_inbox(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """A background update that deletes all device_inboxes for hidden devices.
+
+ This should only need to be run once (when users upgrade to v1.47.0)
+
+ Args:
+ progress: JsonDict used to store progress of this background update
+ batch_size: the maximum number of rows to retrieve in a single select query
+
+ Returns:
+ The number of deleted rows
+ """
+
+ def _remove_hidden_devices_from_device_inbox_txn(
+ txn: LoggingTransaction,
+ ) -> int:
+ """stream_id is not unique
+ we need to use an inclusive `stream_id >= ?` clause,
+ since we might not have deleted all hidden device messages for the stream_id
+ returned from the previous query
+
+ Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
+ to avoid problems of deleting a large number of rows all at once
+ due to a single device having lots of device messages.
+ """
+
+ last_stream_id = progress.get("stream_id", 0)
+
+ sql = """
+ SELECT device_id, user_id, stream_id
+ FROM device_inbox
+ WHERE
+ stream_id >= ?
+ AND (device_id, user_id) IN (
+ SELECT device_id, user_id FROM devices WHERE hidden = ?
+ )
+ ORDER BY stream_id
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_stream_id, True, batch_size))
+ rows = txn.fetchall()
+
+ num_deleted = 0
+ for row in rows:
+ num_deleted += self.db_pool.simple_delete_txn(
+ txn,
+ "device_inbox",
+ {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
+ )
+
+ if rows:
+ # We don't just save the `stream_id` in progress as
+ # otherwise it can happen in large deployments that
+ # no change of status is visible in the log file, as
+ # it may be that the stream_id does not change in several runs
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ self.REMOVE_HIDDEN_DEVICES,
+ {
+ "device_id": rows[-1][0],
+ "user_id": rows[-1][1],
+ "stream_id": rows[-1][2],
+ },
+ )
+
+ return num_deleted
+
+ number_deleted = await self.db_pool.runInteraction(
+ "_remove_hidden_devices_from_device_inbox",
+ _remove_hidden_devices_from_device_inbox_txn,
+ )
+
+ # The task is finished when no more lines are deleted.
+ if not number_deleted:
+ await self.db_pool.updates._end_background_update(
+ self.REMOVE_HIDDEN_DEVICES
+ )
+
+ return number_deleted
+
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
pass
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index b15cd030e0..9ccc66e589 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -427,7 +427,7 @@ class DeviceWorkerStore(SQLBaseStore):
user_ids: the users who were signed
Returns:
- THe new stream ID.
+ The new stream ID.
"""
async with self._device_list_id_gen.get_next() as stream_id:
@@ -1322,7 +1322,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
async def add_device_change_to_streams(
self, user_id: str, device_ids: Collection[str], hosts: List[str]
- ):
+ ) -> int:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 8d9086ecf0..596275c23c 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -24,6 +24,7 @@ from typing import (
Iterable,
List,
Optional,
+ Sequence,
Set,
Tuple,
)
@@ -494,7 +495,7 @@ class PersistEventsStore:
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
- event_to_auth_chain: Dict[str, List[str]],
+ event_to_auth_chain: Dict[str, Sequence[str]],
) -> None:
"""Calculate the chain cover index for the given events.
@@ -786,7 +787,7 @@ class PersistEventsStore:
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
- event_to_auth_chain: Dict[str, List[str]],
+ event_to_auth_chain: Dict[str, Sequence[str]],
events_to_calc_chain_id_for: Set[str],
chain_map: Dict[str, Tuple[int, int]],
) -> Dict[str, Tuple[int, int]]:
@@ -1794,7 +1795,7 @@ class PersistEventsStore:
)
# Insert an edge for every prev_event connection
- for prev_event_id in event.prev_events:
+ for prev_event_id in event.prev_event_ids():
self.db_pool.simple_insert_txn(
txn,
table="insertion_event_edges",
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index ae37901be9..c6bf316d5b 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -28,6 +28,7 @@ from typing import (
import attr
from constantly import NamedConstant, Names
+from prometheus_client import Gauge
from typing_extensions import Literal
from twisted.internet import defer
@@ -81,6 +82,12 @@ EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
+event_fetch_ongoing_gauge = Gauge(
+ "synapse_event_fetch_ongoing",
+ "The number of event fetchers that are running",
+)
+
+
@attr.s(slots=True, auto_attribs=True)
class _EventCacheEntry:
event: EventBase
@@ -222,6 +229,7 @@ class EventsWorkerStore(SQLBaseStore):
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_ongoing = 0
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
@@ -732,28 +740,31 @@ class EventsWorkerStore(SQLBaseStore):
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
- i = 0
- while True:
- with self._event_fetch_lock:
- event_list = self._event_fetch_list
- self._event_fetch_list = []
-
- if not event_list:
- single_threaded = self.database_engine.single_threaded
- if (
- not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
- or single_threaded
- or i > EVENT_QUEUE_ITERATIONS
- ):
- self._event_fetch_ongoing -= 1
- return
- else:
- self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
- i += 1
- continue
- i = 0
-
- self._fetch_event_list(conn, event_list)
+ try:
+ i = 0
+ while True:
+ with self._event_fetch_lock:
+ event_list = self._event_fetch_list
+ self._event_fetch_list = []
+
+ if not event_list:
+ single_threaded = self.database_engine.single_threaded
+ if (
+ not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+ or single_threaded
+ or i > EVENT_QUEUE_ITERATIONS
+ ):
+ break
+ else:
+ self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+ i += 1
+ continue
+ i = 0
+
+ self._fetch_event_list(conn, event_list)
+ finally:
+ self._event_fetch_ongoing -= 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
def _fetch_event_list(
self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
@@ -977,6 +988,7 @@ class EventsWorkerStore(SQLBaseStore):
if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
self._event_fetch_ongoing += 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
should_start = True
else:
should_start = False
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index e70d3649ff..bb621df0dd 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -13,15 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing_extensions import TypedDict
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool
+from synapse.storage.types import Connection
from synapse.types import JsonDict
from synapse.util import json_encoder
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
# The category ID for the "default" category. We don't store as null in the
# database to avoid the fun of null != null
_DEFAULT_CATEGORY_ID = ""
@@ -35,6 +40,16 @@ class _RoomInGroup(TypedDict):
class GroupServerWorkerStore(SQLBaseStore):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ database.updates.register_background_index_update(
+ update_name="local_group_updates_index",
+ index_name="local_group_updates_stream_id_index",
+ table="local_group_updates",
+ columns=("stream_id",),
+ unique=True,
+ )
+ super().__init__(database, db_conn, hs)
+
async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
table="groups",
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index 3d1dff660b..3d0df0cbd4 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -14,6 +14,7 @@
import logging
from types import TracebackType
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type
+from weakref import WeakValueDictionary
from twisted.internet.interfaces import IReactorCore
@@ -61,7 +62,7 @@ class LockStore(SQLBaseStore):
# A map from `(lock_name, lock_key)` to the token of any locks that we
# think we currently hold.
- self._live_tokens: Dict[Tuple[str, str], str] = {}
+ self._live_tokens: Dict[Tuple[str, str], Lock] = WeakValueDictionary()
# When we shut down we want to remove the locks. Technically this can
# lead to a race, as we may drop the lock while we are still processing.
@@ -80,10 +81,10 @@ class LockStore(SQLBaseStore):
# We need to take a copy of the tokens dict as dropping the locks will
# cause the dictionary to change.
- tokens = dict(self._live_tokens)
+ locks = dict(self._live_tokens)
- for (lock_name, lock_key), token in tokens.items():
- await self._drop_lock(lock_name, lock_key, token)
+ for lock in locks.values():
+ await lock.release()
logger.info("Dropped locks due to shutdown")
@@ -93,6 +94,11 @@ class LockStore(SQLBaseStore):
used (otherwise the lock will leak).
"""
+ # Check if this process has taken out a lock and if it's still valid.
+ lock = self._live_tokens.get((lock_name, lock_key))
+ if lock and await lock.is_still_valid():
+ return None
+
now = self._clock.time_msec()
token = random_string(6)
@@ -100,7 +106,9 @@ class LockStore(SQLBaseStore):
def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool:
# We take out the lock if either a) there is no row for the lock
- # already or b) the existing row has timed out.
+ # already, b) the existing row has timed out, or c) the row is
+ # for this instance (which means the process got killed and
+ # restarted)
sql = """
INSERT INTO worker_locks (lock_name, lock_key, instance_name, token, last_renewed_ts)
VALUES (?, ?, ?, ?, ?)
@@ -112,6 +120,7 @@ class LockStore(SQLBaseStore):
last_renewed_ts = EXCLUDED.last_renewed_ts
WHERE
worker_locks.last_renewed_ts < ?
+ OR worker_locks.instance_name = EXCLUDED.instance_name
"""
txn.execute(
sql,
@@ -148,11 +157,11 @@ class LockStore(SQLBaseStore):
WHERE
lock_name = ?
AND lock_key = ?
- AND last_renewed_ts < ?
+ AND (last_renewed_ts < ? OR instance_name = ?)
"""
txn.execute(
sql,
- (lock_name, lock_key, now - _LOCK_TIMEOUT_MS),
+ (lock_name, lock_key, now - _LOCK_TIMEOUT_MS, self._instance_name),
)
inserted = self.db_pool.simple_upsert_txn_emulated(
@@ -179,9 +188,7 @@ class LockStore(SQLBaseStore):
if not did_lock:
return None
- self._live_tokens[(lock_name, lock_key)] = token
-
- return Lock(
+ lock = Lock(
self._reactor,
self._clock,
self,
@@ -190,6 +197,10 @@ class LockStore(SQLBaseStore):
token=token,
)
+ self._live_tokens[(lock_name, lock_key)] = lock
+
+ return lock
+
async def _is_lock_still_valid(
self, lock_name: str, lock_key: str, token: str
) -> bool:
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 12cf6995eb..cc0eebdb46 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -92,7 +92,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
prefilled_cache=presence_cache_prefill,
)
- async def update_presence(self, presence_states):
+ async def update_presence(self, presence_states) -> Tuple[int, int]:
assert self._can_persist_presence
stream_ordering_manager = self._presence_id_gen.get_next_mult(
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index f879bbe7c7..cefc77fa0f 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -412,22 +412,33 @@ class RoomWorkerStore(SQLBaseStore):
limit: maximum amount of rooms to retrieve
order_by: the sort order of the returned list
reverse_order: whether to reverse the room list
- search_term: a string to filter room names by
+ search_term: a string to filter room names,
+ canonical alias and room ids by.
+ Room ID must match exactly. Canonical alias must match a substring of the local part.
Returns:
A list of room dicts and an integer representing the total number of
rooms that exist given this query
"""
# Filter room names by a string
where_statement = ""
+ search_pattern = []
if search_term:
- where_statement = "WHERE LOWER(state.name) LIKE ?"
+ where_statement = """
+ WHERE LOWER(state.name) LIKE ?
+ OR LOWER(state.canonical_alias) LIKE ?
+ OR state.room_id = ?
+ """
# Our postgres db driver converts ? -> %s in SQL strings as that's the
# placeholder for postgres.
# HOWEVER, if you put a % into your SQL then everything goes wibbly.
# To get around this, we're going to surround search_term with %'s
# before giving it to the database in python instead
- search_term = "%" + search_term.lower() + "%"
+ search_pattern = [
+ "%" + search_term.lower() + "%",
+ "#%" + search_term.lower() + "%:%",
+ search_term,
+ ]
# Set ordering
if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
@@ -519,12 +530,9 @@ class RoomWorkerStore(SQLBaseStore):
)
def _get_rooms_paginate_txn(txn):
- # Execute the data query
- sql_values = (limit, start)
- if search_term:
- # Add the search term into the WHERE clause
- sql_values = (search_term,) + sql_values
- txn.execute(info_sql, sql_values)
+ # Add the search term into the WHERE clause
+ # and execute the data query
+ txn.execute(info_sql, search_pattern + [limit, start])
# Refactor room query data into a structured dictionary
rooms = []
@@ -551,8 +559,7 @@ class RoomWorkerStore(SQLBaseStore):
# Execute the count query
# Add the search term into the WHERE clause if present
- sql_values = (search_term,) if search_term else ()
- txn.execute(count_sql, sql_values)
+ txn.execute(count_sql, search_pattern)
room_count = txn.fetchone()
return rooms, room_count[0]
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 4b288bb2e7..033a9831d6 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -570,7 +570,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
async def get_joined_users_from_context(
self, event: EventBase, context: EventContext
- ):
+ ) -> Dict[str, ProfileInfo]:
state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@@ -584,7 +584,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event.room_id, state_group, current_state_ids, event=event, context=context
)
- async def get_joined_users_from_state(self, room_id, state_entry):
+ async def get_joined_users_from_state(
+ self, room_id, state_entry
+ ) -> Dict[str, ProfileInfo]:
state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@@ -607,7 +609,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
cache_context,
event=None,
context=None,
- ):
+ ) -> Dict[str, ProfileInfo]:
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
# with a state_group of None are likely to be different.
|