diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 20b755056b..cfe887b7f7 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -13,33 +13,49 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Generic, List, Optional, Type, TypeVar
+from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_conn
from synapse.storage.databases.main.events import PersistEventsStore
from synapse.storage.databases.state import StateGroupDataStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
-class Databases:
+DataStoreT = TypeVar("DataStoreT", bound=SQLBaseStore, covariant=True)
+
+
+class Databases(Generic[DataStoreT]):
"""The various databases.
These are low level interfaces to physical databases.
Attributes:
- main (DataStore)
+ databases
+ main
+ state
+ persist_events
"""
- def __init__(self, main_store_class, hs):
+ databases: List[DatabasePool]
+ main: DataStoreT
+ state: StateGroupDataStore
+ persist_events: Optional[PersistEventsStore]
+
+ def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"):
# Note we pass in the main store class here as workers use a different main
# store.
self.databases = []
- main = None
- state = None
- persist_events = None
+ main: Optional[DataStoreT] = None
+ state: Optional[StateGroupDataStore] = None
+ persist_events: Optional[PersistEventsStore] = None
for database_config in hs.config.database.databases:
db_name = database_config.name
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 5c21402dea..259cae5b37 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import List, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import DatabasePool
@@ -75,6 +75,9 @@ from .ui_auth import UIAuthStore
from .user_directory import UserDirectoryStore
from .user_erasure_store import UserErasureStore
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -126,7 +129,7 @@ class DataStore(
LockStore,
SessionStore,
):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 70ca3e09f7..f8bec266ac 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Dict, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from synapse.api.constants import AccountDataTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@@ -28,6 +28,9 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -36,7 +39,7 @@ class AccountDataWorkerStore(SQLBaseStore):
`get_max_account_data_stream_id` which can be called in the initializer.
"""
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
self._instance_name = hs.get_instance_name()
if isinstance(database.engine, PostgresEngine):
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index c57ae5ef15..36e8422fc6 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -15,7 +15,7 @@
import itertools
import logging
-from typing import Any, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.replication.tcp.streams import BackfillStream, CachesStream
@@ -29,6 +29,9 @@ from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -38,7 +41,7 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index b81d9218ce..1dc7f0ebe3 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -478,6 +478,58 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
return {(d["user_id"], d["device_id"]): d for d in res}
+ async def get_user_ip_and_agents(
+ self, user: UserID, since_ts: int = 0
+ ) -> List[LastConnectionInfo]:
+ """Fetch the IPs and user agents for a user since the given timestamp.
+
+ The result might be slightly out of date as client IPs are inserted in batches.
+
+ Args:
+ user: The user for which to fetch IP addresses and user agents.
+ since_ts: The timestamp after which to fetch IP addresses and user agents,
+ in milliseconds.
+
+ Returns:
+ A list of dictionaries, each containing:
+ * `access_token`: The access token used.
+ * `ip`: The IP address used.
+ * `user_agent`: The last user agent seen for this access token and IP
+ address combination.
+ * `last_seen`: The timestamp at which this access token and IP address
+ combination was last seen, in milliseconds.
+
+ Only the latest user agent for each access token and IP address combination
+ is available.
+ """
+ user_id = user.to_string()
+
+ def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
+ txn.execute(
+ """
+ SELECT access_token, ip, user_agent, last_seen FROM user_ips
+ WHERE last_seen >= ? AND user_id = ?
+ ORDER BY last_seen
+ DESC
+ """,
+ (since_ts, user_id),
+ )
+ return cast(List[Tuple[str, str, str, int]], txn.fetchall())
+
+ rows = await self.db_pool.runInteraction(
+ desc="get_user_ip_and_agents", func=get_recent
+ )
+
+ return [
+ {
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "last_seen": last_seen,
+ }
+ for access_token, ip, user_agent, last_seen in rows
+ ]
+
class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
@@ -622,49 +674,43 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
async def get_user_ip_and_agents(
self, user: UserID, since_ts: int = 0
) -> List[LastConnectionInfo]:
+ """Fetch the IPs and user agents for a user since the given timestamp.
+
+ Args:
+ user: The user for which to fetch IP addresses and user agents.
+ since_ts: The timestamp after which to fetch IP addresses and user agents,
+ in milliseconds.
+
+ Returns:
+ A list of dictionaries, each containing:
+ * `access_token`: The access token used.
+ * `ip`: The IP address used.
+ * `user_agent`: The last user agent seen for this access token and IP
+ address combination.
+ * `last_seen`: The timestamp at which this access token and IP address
+ combination was last seen, in milliseconds.
+
+ Only the latest user agent for each access token and IP address combination
+ is available.
"""
- Fetch IP/User Agent connection since a given timestamp.
- """
- user_id = user.to_string()
- results: Dict[Tuple[str, str], Tuple[str, int]] = {}
+ results: Dict[Tuple[str, str], LastConnectionInfo] = {
+ (connection["access_token"], connection["ip"]): connection
+ for connection in await super().get_user_ip_and_agents(user, since_ts)
+ }
+ # Overlay data that is pending insertion on top of the results from the
+ # database.
+ user_id = user.to_string()
for key in self._batch_row_update:
- (
- uid,
- access_token,
- ip,
- ) = key
+ uid, access_token, ip = key
if uid == user_id:
user_agent, _, last_seen = self._batch_row_update[key]
if last_seen >= since_ts:
- results[(access_token, ip)] = (user_agent, last_seen)
-
- def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
- txn.execute(
- """
- SELECT access_token, ip, user_agent, last_seen FROM user_ips
- WHERE last_seen >= ? AND user_id = ?
- ORDER BY last_seen
- DESC
- """,
- (since_ts, user_id),
- )
- return cast(List[Tuple[str, str, str, int]], txn.fetchall())
-
- rows = await self.db_pool.runInteraction(
- desc="get_user_ip_and_agents", func=get_recent
- )
+ results[(access_token, ip)] = {
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "last_seen": last_seen,
+ }
- results.update(
- ((access_token, ip), (user_agent, last_seen))
- for access_token, ip, user_agent, last_seen in rows
- )
- return [
- {
- "access_token": access_token,
- "ip": ip,
- "user_agent": user_agent,
- "last_seen": last_seen,
- }
- for (access_token, ip), (user_agent, last_seen) in results.items()
- ]
+ return list(results.values())
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 3154906d45..264e625bd7 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -13,24 +13,28 @@
# limitations under the License.
import logging
-from typing import List, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class DeviceInboxWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
@@ -485,10 +489,12 @@ class DeviceInboxWorkerStore(SQLBaseStore):
devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
+ # We exclude hidden devices (such as cross-signing keys) here as they are
+ # not expected to receive to-device messages.
devices = self.db_pool.simple_select_onecol_txn(
txn,
table="devices",
- keyvalues={"user_id": user_id},
+ keyvalues={"user_id": user_id, "hidden": False},
retcol="device_id",
)
@@ -501,10 +507,12 @@ class DeviceInboxWorkerStore(SQLBaseStore):
if not devices:
continue
+ # We exclude hidden devices (such as cross-signing keys) here as they are
+ # not expected to receive to-device messages.
rows = self.db_pool.simple_select_many_txn(
txn,
table="devices",
- keyvalues={"user_id": user_id},
+ keyvalues={"user_id": user_id, "hidden": False},
column="device_id",
iterable=devices,
retcols=("device_id",),
@@ -552,8 +560,10 @@ class DeviceInboxWorkerStore(SQLBaseStore):
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
+ REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox"
+ REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@@ -567,6 +577,16 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
+ self.db_pool.updates.register_background_update_handler(
+ self.REMOVE_DELETED_DEVICES,
+ self._remove_deleted_devices_from_device_inbox,
+ )
+
+ self.db_pool.updates.register_background_update_handler(
+ self.REMOVE_HIDDEN_DEVICES,
+ self._remove_hidden_devices_from_device_inbox,
+ )
+
async def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
@@ -579,6 +599,172 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
return 1
+ async def _remove_deleted_devices_from_device_inbox(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """A background update that deletes all device_inboxes for deleted devices.
+
+ This should only need to be run once (when users upgrade to v1.47.0)
+
+ Args:
+ progress: JsonDict used to store progress of this background update
+ batch_size: the maximum number of rows to retrieve in a single select query
+
+ Returns:
+ The number of deleted rows
+ """
+
+ def _remove_deleted_devices_from_device_inbox_txn(
+ txn: LoggingTransaction,
+ ) -> int:
+ """stream_id is not unique
+ we need to use an inclusive `stream_id >= ?` clause,
+ since we might not have deleted all dead device messages for the stream_id
+ returned from the previous query
+
+ Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
+ to avoid problems of deleting a large number of rows all at once
+ due to a single device having lots of device messages.
+ """
+
+ last_stream_id = progress.get("stream_id", 0)
+
+ sql = """
+ SELECT device_id, user_id, stream_id
+ FROM device_inbox
+ WHERE
+ stream_id >= ?
+ AND (device_id, user_id) NOT IN (
+ SELECT device_id, user_id FROM devices
+ )
+ ORDER BY stream_id
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_stream_id, batch_size))
+ rows = txn.fetchall()
+
+ num_deleted = 0
+ for row in rows:
+ num_deleted += self.db_pool.simple_delete_txn(
+ txn,
+ "device_inbox",
+ {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
+ )
+
+ if rows:
+ # send more than stream_id to progress
+ # otherwise it can happen in large deployments that
+ # no change of status is visible in the log file
+ # it may be that the stream_id does not change in several runs
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ self.REMOVE_DELETED_DEVICES,
+ {
+ "device_id": rows[-1][0],
+ "user_id": rows[-1][1],
+ "stream_id": rows[-1][2],
+ },
+ )
+
+ return num_deleted
+
+ number_deleted = await self.db_pool.runInteraction(
+ "_remove_deleted_devices_from_device_inbox",
+ _remove_deleted_devices_from_device_inbox_txn,
+ )
+
+ # The task is finished when no more lines are deleted.
+ if not number_deleted:
+ await self.db_pool.updates._end_background_update(
+ self.REMOVE_DELETED_DEVICES
+ )
+
+ return number_deleted
+
+ async def _remove_hidden_devices_from_device_inbox(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """A background update that deletes all device_inboxes for hidden devices.
+
+ This should only need to be run once (when users upgrade to v1.47.0)
+
+ Args:
+ progress: JsonDict used to store progress of this background update
+ batch_size: the maximum number of rows to retrieve in a single select query
+
+ Returns:
+ The number of deleted rows
+ """
+
+ def _remove_hidden_devices_from_device_inbox_txn(
+ txn: LoggingTransaction,
+ ) -> int:
+ """stream_id is not unique
+ we need to use an inclusive `stream_id >= ?` clause,
+ since we might not have deleted all hidden device messages for the stream_id
+ returned from the previous query
+
+ Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
+ to avoid problems of deleting a large number of rows all at once
+ due to a single device having lots of device messages.
+ """
+
+ last_stream_id = progress.get("stream_id", 0)
+
+ sql = """
+ SELECT device_id, user_id, stream_id
+ FROM device_inbox
+ WHERE
+ stream_id >= ?
+ AND (device_id, user_id) IN (
+ SELECT device_id, user_id FROM devices WHERE hidden = ?
+ )
+ ORDER BY stream_id
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_stream_id, True, batch_size))
+ rows = txn.fetchall()
+
+ num_deleted = 0
+ for row in rows:
+ num_deleted += self.db_pool.simple_delete_txn(
+ txn,
+ "device_inbox",
+ {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
+ )
+
+ if rows:
+ # We don't just save the `stream_id` in progress as
+ # otherwise it can happen in large deployments that
+ # no change of status is visible in the log file, as
+ # it may be that the stream_id does not change in several runs
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ self.REMOVE_HIDDEN_DEVICES,
+ {
+ "device_id": rows[-1][0],
+ "user_id": rows[-1][1],
+ "stream_id": rows[-1][2],
+ },
+ )
+
+ return num_deleted
+
+ number_deleted = await self.db_pool.runInteraction(
+ "_remove_hidden_devices_from_device_inbox",
+ _remove_hidden_devices_from_device_inbox_txn,
+ )
+
+ # The task is finished when no more lines are deleted.
+ if not number_deleted:
+ await self.db_pool.updates._end_background_update(
+ self.REMOVE_HIDDEN_DEVICES
+ )
+
+ return number_deleted
+
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
pass
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 6464520386..9ccc66e589 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -15,7 +15,17 @@
# limitations under the License.
import abc
import logging
-from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+)
from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
@@ -38,6 +48,9 @@ from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
@@ -48,7 +61,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks:
@@ -414,7 +427,7 @@ class DeviceWorkerStore(SQLBaseStore):
user_ids: the users who were signed
Returns:
- THe new stream ID.
+ The new stream ID.
"""
async with self._device_list_id_gen.get_next() as stream_id:
@@ -915,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore):
class DeviceBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@@ -1047,7 +1060,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
@@ -1121,19 +1134,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
raise StoreError(500, "Problem storing device.")
async def delete_device(self, user_id: str, device_id: str) -> None:
- """Delete a device.
+ """Delete a device and its device_inbox.
Args:
user_id: The ID of the user which owns the device
device_id: The ID of the device to delete
"""
- await self.db_pool.simple_delete_one(
- table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
- desc="delete_device",
- )
- self.device_id_exists_cache.invalidate((user_id, device_id))
+ await self.delete_devices(user_id, [device_id])
async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
"""Deletes several devices.
@@ -1142,13 +1150,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: The ID of the user which owns the devices
device_ids: The IDs of the devices to delete
"""
- await self.db_pool.simple_delete_many(
- table="devices",
- column="device_id",
- iterable=device_ids,
- keyvalues={"user_id": user_id, "hidden": False},
- desc="delete_devices",
- )
+
+ def _delete_devices_txn(txn: LoggingTransaction) -> None:
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="devices",
+ column="device_id",
+ values=device_ids,
+ keyvalues={"user_id": user_id, "hidden": False},
+ )
+
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="device_inbox",
+ column="device_id",
+ values=device_ids,
+ keyvalues={"user_id": user_id},
+ )
+
+ await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))
@@ -1302,7 +1322,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
async def add_device_change_to_streams(
self, user_id: str, device_ids: Collection[str], hosts: List[str]
- ):
+ ) -> int:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index d7bddfc154..c58f7dd009 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -14,7 +14,17 @@
import itertools
import logging
from queue import Empty, PriorityQueue
-from typing import Collection, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ NamedTuple,
+ Optional,
+ Set,
+ Tuple,
+)
from prometheus_client import Counter, Gauge
@@ -34,6 +44,9 @@ from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
oldest_pdu_in_federation_staging = Gauge(
"synapse_federation_server_oldest_inbound_pdu_in_staging",
"The age in seconds since we received the oldest pdu in the federation staging area",
@@ -67,7 +80,7 @@ class _NoChainCoverIndex(Exception):
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.hs = hs
@@ -1621,7 +1634,7 @@ class EventFederationStore(EventFederationWorkerStore):
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 97b3e92d3f..d957e770dc 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import attr
@@ -23,6 +23,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -64,7 +67,7 @@ def _deserialize_action(actions, is_highlight):
class EventPushActionsWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
# These get correctly set by _find_stream_orderings_for_times_txn
@@ -892,7 +895,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index e4f5c86244..3790a52a89 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -24,6 +24,7 @@ from typing import (
Iterable,
List,
Optional,
+ Sequence,
Set,
Tuple,
)
@@ -494,7 +495,7 @@ class PersistEventsStore:
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
- event_to_auth_chain: Dict[str, List[str]],
+ event_to_auth_chain: Dict[str, Sequence[str]],
) -> None:
"""Calculate the chain cover index for the given events.
@@ -786,7 +787,7 @@ class PersistEventsStore:
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
- event_to_auth_chain: Dict[str, List[str]],
+ event_to_auth_chain: Dict[str, Sequence[str]],
events_to_calc_chain_id_for: Set[str],
chain_map: Dict[str, Tuple[int, int]],
) -> Dict[str, Tuple[int, int]]:
@@ -1710,6 +1711,7 @@ class PersistEventsStore:
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
RelationTypes.REPLACE,
+ RelationTypes.THREAD,
):
# Unknown relation type
return
@@ -1740,6 +1742,9 @@ class PersistEventsStore:
if rel_type == RelationTypes.REPLACE:
txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
+ if rel_type == RelationTypes.THREAD:
+ txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
+
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
"""Handles keeping track of insertion events and edges/connections.
Part of MSC2716.
@@ -1790,7 +1795,7 @@ class PersistEventsStore:
)
# Insert an edge for every prev_event connection
- for prev_event_id in event.prev_events:
+ for prev_event_id in event.prev_event_ids():
self.db_pool.simple_insert_txn(
txn,
table="insertion_event_edges",
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 1afc59fafb..ae3a8a63e4 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -13,19 +13,26 @@
# limitations under the License.
import logging
-from typing import Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import attr
-from synapse.api.constants import EventContentFields
+from synapse.api.constants import EventContentFields, RelationTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingTransaction,
+ make_tuple_comparison_clause,
+)
from synapse.storage.databases.main.events import PersistEventsStore
from synapse.storage.types import Cursor
from synapse.types import JsonDict
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -76,7 +83,7 @@ class _CalculateChainCover:
class EventsBackgroundUpdatesStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
@@ -164,6 +171,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self._purged_chain_cover_index,
)
+ self.db_pool.updates.register_background_update_handler(
+ "event_thread_relation", self._event_thread_relation
+ )
+
################################################################################
# bg updates for replacing stream_ordering with a BIGINT
@@ -1088,6 +1099,79 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return result
+ async def _event_thread_relation(self, progress: JsonDict, batch_size: int) -> int:
+ """Background update handler which will store thread relations for existing events."""
+ last_event_id = progress.get("last_event_id", "")
+
+ def _event_thread_relation_txn(txn: LoggingTransaction) -> int:
+ txn.execute(
+ """
+ SELECT event_id, json FROM event_json
+ LEFT JOIN event_relations USING (event_id)
+ WHERE event_id > ? AND event_relations.event_id IS NULL
+ ORDER BY event_id LIMIT ?
+ """,
+ (last_event_id, batch_size),
+ )
+
+ results = list(txn)
+ missing_thread_relations = []
+ for (event_id, event_json_raw) in results:
+ try:
+ event_json = db_to_json(event_json_raw)
+ except Exception as e:
+ logger.warning(
+ "Unable to load event %s (no relations will be updated): %s",
+ event_id,
+ e,
+ )
+ continue
+
+ # If there's no relation (or it is not a thread), skip!
+ relates_to = event_json["content"].get("m.relates_to")
+ if not relates_to or not isinstance(relates_to, dict):
+ continue
+ if relates_to.get("rel_type") != RelationTypes.THREAD:
+ continue
+
+ # Get the parent ID.
+ parent_id = relates_to.get("event_id")
+ if not isinstance(parent_id, str):
+ continue
+
+ missing_thread_relations.append((event_id, parent_id))
+
+ # Insert the missing data.
+ self.db_pool.simple_insert_many_txn(
+ txn=txn,
+ table="event_relations",
+ values=[
+ {
+ "event_id": event_id,
+ "relates_to_Id": parent_id,
+ "relation_type": RelationTypes.THREAD,
+ }
+ for event_id, parent_id in missing_thread_relations
+ ],
+ )
+
+ if results:
+ latest_event_id = results[-1][0]
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "event_thread_relation", {"last_event_id": latest_event_id}
+ )
+
+ return len(results)
+
+ num_rows = await self.db_pool.runInteraction(
+ desc="event_thread_relation", func=_event_thread_relation_txn
+ )
+
+ if not num_rows:
+ await self.db_pool.updates._end_background_update("event_thread_relation")
+
+ return num_rows
+
async def _background_populate_stream_ordering2(
self, progress: JsonDict, batch_size: int
) -> int:
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index e70d3649ff..bb621df0dd 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -13,15 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing_extensions import TypedDict
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool
+from synapse.storage.types import Connection
from synapse.types import JsonDict
from synapse.util import json_encoder
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
# The category ID for the "default" category. We don't store as null in the
# database to avoid the fun of null != null
_DEFAULT_CATEGORY_ID = ""
@@ -35,6 +40,16 @@ class _RoomInGroup(TypedDict):
class GroupServerWorkerStore(SQLBaseStore):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ database.updates.register_background_index_update(
+ update_name="local_group_updates_index",
+ index_name="local_group_updates_stream_id_index",
+ table="local_group_updates",
+ columns=("stream_id",),
+ unique=True,
+ )
+ super().__init__(database, db_conn, hs)
+
async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
table="groups",
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 2fa945d171..717487be28 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -13,11 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
-from typing import Any, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
"media_repository_drop_index_wo_method"
)
@@ -43,7 +46,7 @@ class MediaSortOrder(Enum):
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@@ -123,7 +126,7 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars"""
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index dac3d14da8..d901933ae4 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -14,7 +14,7 @@
import calendar
import logging
import time
-from typing import Dict
+from typing import TYPE_CHECKING, Dict
from synapse.metrics import GaugeBucketCollector
from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -24,6 +24,9 @@ from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
# Collect metrics on the number of forward extremities that exist.
@@ -52,7 +55,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
stats and prometheus metrics.
"""
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
# Read the extrems every 60 minutes
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index a14ac03d4b..b5284e4f67 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -12,13 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Dict, List, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.util.caches.descriptors import cached
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
# Number of msec of granularity to store the monthly_active_user timestamp
@@ -27,7 +30,7 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
class MonthlyActiveUsersWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self._clock = hs.get_clock()
self.hs = hs
@@ -209,7 +212,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self._mau_stats_only = hs.config.server.mau_stats_only
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 12cf6995eb..cc0eebdb46 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -92,7 +92,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
prefilled_cache=presence_cache_prefill,
)
- async def update_presence(self, presence_states):
+ async def update_presence(self, presence_states) -> Tuple[int, int]:
assert self._can_persist_presence
stream_ordering_manager = self._presence_id_gen.get_next_mult(
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index ba7075caa5..dd8e27e226 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -91,7 +91,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
async def update_remote_profile_cache(
- self, user_id: str, displayname: str, avatar_url: str
+ self, user_id: str, displayname: Optional[str], avatar_url: Optional[str]
) -> int:
return await self.db_pool.simple_update(
table="remote_profile_cache",
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index fc720f5947..fa782023d4 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -14,7 +14,7 @@
# limitations under the License.
import abc
import logging
-from typing import Dict, List, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from synapse.api.errors import NotFoundError, StoreError
from synapse.push.baserules import list_with_base_rules
@@ -33,6 +33,9 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -75,7 +78,7 @@ class PushRulesWorkerStore(
`get_max_push_rules_stream_id` which can be called in the initializer.
"""
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 01a4281301..c99f8aebdb 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Any, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
from twisted.internet import defer
@@ -29,11 +29,14 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class ReceiptsWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
self._instance_name = hs.get_instance_name()
if isinstance(database.engine, PostgresEngine):
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 37d47aa823..6c7d6ba508 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -499,6 +499,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
+ async def set_user_type(self, user: UserID, user_type: Optional[UserTypes]) -> None:
+ """Sets the user type.
+
+ Args:
+ user: user ID of the user.
+ user_type: type of the user or None for a user without a type.
+ """
+
+ def set_user_type_txn(txn):
+ self.db_pool.simple_update_one_txn(
+ txn, "users", {"name": user.to_string()}, {"user_type": user_type}
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_id, (user.to_string(),)
+ )
+
+ await self.db_pool.runInteraction("set_user_type", set_user_type_txn)
+
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
sql = """
SELECT users.name as user_id,
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 2bbf6d6a95..53576ad52f 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,13 +13,14 @@
# limitations under the License.
import logging
-from typing import Optional
+from typing import List, Optional, Tuple, Union
import attr
from synapse.api.constants import RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.relations import (
AggregationPaginationToken,
@@ -63,7 +64,7 @@ class RelationsWorkerStore(SQLBaseStore):
"""
where_clause = ["relates_to_id = ?"]
- where_args = [event_id]
+ where_args: List[Union[str, int]] = [event_id]
if relation_type is not None:
where_clause.append("relation_type = ?")
@@ -80,8 +81,8 @@ class RelationsWorkerStore(SQLBaseStore):
pagination_clause = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
- from_token=attr.astuple(from_token) if from_token else None,
- to_token=attr.astuple(to_token) if to_token else None,
+ from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
+ to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
engine=self.database_engine,
)
@@ -106,7 +107,9 @@ class RelationsWorkerStore(SQLBaseStore):
order,
)
- def _get_recent_references_for_event_txn(txn):
+ def _get_recent_references_for_event_txn(
+ txn: LoggingTransaction,
+ ) -> PaginationChunk:
txn.execute(sql, where_args + [limit + 1])
last_topo_id = None
@@ -160,7 +163,7 @@ class RelationsWorkerStore(SQLBaseStore):
"""
where_clause = ["relates_to_id = ?", "relation_type = ?"]
- where_args = [event_id, RelationTypes.ANNOTATION]
+ where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]
if event_type:
where_clause.append("type = ?")
@@ -169,8 +172,8 @@ class RelationsWorkerStore(SQLBaseStore):
having_clause = generate_pagination_where_clause(
direction=direction,
column_names=("COUNT(*)", "MAX(stream_ordering)"),
- from_token=attr.astuple(from_token) if from_token else None,
- to_token=attr.astuple(to_token) if to_token else None,
+ from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
+ to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
engine=self.database_engine,
)
@@ -199,7 +202,9 @@ class RelationsWorkerStore(SQLBaseStore):
having_clause=having_clause,
)
- def _get_aggregation_groups_for_event_txn(txn):
+ def _get_aggregation_groups_for_event_txn(
+ txn: LoggingTransaction,
+ ) -> PaginationChunk:
txn.execute(sql, where_args + [limit + 1])
next_batch = None
@@ -254,11 +259,12 @@ class RelationsWorkerStore(SQLBaseStore):
LIMIT 1
"""
- def _get_applicable_edit_txn(txn):
+ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
txn.execute(sql, (event_id, RelationTypes.REPLACE))
row = txn.fetchone()
if row:
return row[0]
+ return None
edit_id = await self.db_pool.runInteraction(
"get_applicable_edit", _get_applicable_edit_txn
@@ -267,7 +273,66 @@ class RelationsWorkerStore(SQLBaseStore):
if not edit_id:
return None
- return await self.get_event(edit_id, allow_none=True)
+ return await self.get_event(edit_id, allow_none=True) # type: ignore[attr-defined]
+
+ @cached()
+ async def get_thread_summary(
+ self, event_id: str
+ ) -> Tuple[int, Optional[EventBase]]:
+ """Get the number of threaded replies, the senders of those replies, and
+ the latest reply (if any) for the given event.
+
+ Args:
+ event_id: The original event ID
+
+ Returns:
+ The number of items in the thread and the most recent response, if any.
+ """
+
+ def _get_thread_summary_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[int, Optional[str]]:
+ # Fetch the count of threaded events and the latest event ID.
+ # TODO Should this only allow m.room.message events.
+ sql = """
+ SELECT event_id
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ relates_to_id = ?
+ AND relation_type = ?
+ ORDER BY topological_ordering DESC, stream_ordering DESC
+ LIMIT 1
+ """
+
+ txn.execute(sql, (event_id, RelationTypes.THREAD))
+ row = txn.fetchone()
+ if row is None:
+ return 0, None
+
+ latest_event_id = row[0]
+
+ sql = """
+ SELECT COALESCE(COUNT(event_id), 0)
+ FROM event_relations
+ WHERE
+ relates_to_id = ?
+ AND relation_type = ?
+ """
+ txn.execute(sql, (event_id, RelationTypes.THREAD))
+ count = txn.fetchone()[0] # type: ignore[index]
+
+ return count, latest_event_id
+
+ count, latest_event_id = await self.db_pool.runInteraction(
+ "get_thread_summary", _get_thread_summary_txn
+ )
+
+ latest_event = None
+ if latest_event_id:
+ latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined]
+
+ return count, latest_event
async def has_user_annotated_event(
self, parent_id: str, event_type: str, aggregation_key: str, sender: str
@@ -297,7 +362,7 @@ class RelationsWorkerStore(SQLBaseStore):
LIMIT 1;
"""
- def _get_if_user_has_annotated_event(txn):
+ def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool:
txn.execute(
sql,
(
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 835d7889cb..cefc77fa0f 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -17,7 +17,7 @@ import collections
import logging
from abc import abstractmethod
from enum import Enum
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from synapse.api.constants import EventContentFields, EventTypes, JoinRules
from synapse.api.errors import StoreError
@@ -32,6 +32,9 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import MXC_REGEX
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -69,7 +72,7 @@ class RoomSortOrder(Enum):
class RoomWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.config = hs.config
@@ -409,22 +412,33 @@ class RoomWorkerStore(SQLBaseStore):
limit: maximum amount of rooms to retrieve
order_by: the sort order of the returned list
reverse_order: whether to reverse the room list
- search_term: a string to filter room names by
+ search_term: a string to filter room names,
+ canonical alias and room ids by.
+ Room ID must match exactly. Canonical alias must match a substring of the local part.
Returns:
A list of room dicts and an integer representing the total number of
rooms that exist given this query
"""
# Filter room names by a string
where_statement = ""
+ search_pattern = []
if search_term:
- where_statement = "WHERE LOWER(state.name) LIKE ?"
+ where_statement = """
+ WHERE LOWER(state.name) LIKE ?
+ OR LOWER(state.canonical_alias) LIKE ?
+ OR state.room_id = ?
+ """
# Our postgres db driver converts ? -> %s in SQL strings as that's the
# placeholder for postgres.
# HOWEVER, if you put a % into your SQL then everything goes wibbly.
# To get around this, we're going to surround search_term with %'s
# before giving it to the database in python instead
- search_term = "%" + search_term.lower() + "%"
+ search_pattern = [
+ "%" + search_term.lower() + "%",
+ "#%" + search_term.lower() + "%:%",
+ search_term,
+ ]
# Set ordering
if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
@@ -516,12 +530,9 @@ class RoomWorkerStore(SQLBaseStore):
)
def _get_rooms_paginate_txn(txn):
- # Execute the data query
- sql_values = (limit, start)
- if search_term:
- # Add the search term into the WHERE clause
- sql_values = (search_term,) + sql_values
- txn.execute(info_sql, sql_values)
+ # Add the search term into the WHERE clause
+ # and execute the data query
+ txn.execute(info_sql, search_pattern + [limit, start])
# Refactor room query data into a structured dictionary
rooms = []
@@ -548,8 +559,7 @@ class RoomWorkerStore(SQLBaseStore):
# Execute the count query
# Add the search term into the WHERE clause if present
- sql_values = (search_term,) if search_term else ()
- txn.execute(count_sql, sql_values)
+ txn.execute(count_sql, search_pattern)
room_count = txn.fetchone()
return rooms, room_count[0]
@@ -1026,7 +1036,7 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
class RoomBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.config = hs.config
@@ -1411,7 +1421,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.config = hs.config
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index ddb162a4fc..033a9831d6 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -53,6 +53,7 @@ from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
from synapse.util.metrics import Measure
if TYPE_CHECKING:
+ from synapse.server import HomeServer
from synapse.state import _StateCacheEntry
logger = logging.getLogger(__name__)
@@ -63,7 +64,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
class RoomMemberWorkerStore(EventsWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
# Used by `_get_joined_hosts` to ensure only one thing mutates the cache
@@ -569,7 +570,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
async def get_joined_users_from_context(
self, event: EventBase, context: EventContext
- ):
+ ) -> Dict[str, ProfileInfo]:
state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@@ -583,7 +584,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event.room_id, state_group, current_state_ids, event=event, context=context
)
- async def get_joined_users_from_state(self, room_id, state_entry):
+ async def get_joined_users_from_state(
+ self, room_id, state_entry
+ ) -> Dict[str, ProfileInfo]:
state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@@ -606,7 +609,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
cache_context,
event=None,
context=None,
- ):
+ ) -> Dict[str, ProfileInfo]:
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
# with a state_group of None are likely to be different.
@@ -982,7 +985,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
class RoomMemberBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
@@ -1132,7 +1135,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
async def forget(self, user_id: str, room_id: str) -> None:
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index c85383c975..7fe233767f 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -15,7 +15,7 @@
import logging
import re
from collections import namedtuple
-from typing import Collection, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
from synapse.api.errors import SynapseError
from synapse.events import EventBase
@@ -24,6 +24,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
SearchEntry = namedtuple(
@@ -102,7 +105,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
if not hs.config.server.enable_search:
@@ -355,7 +358,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
class SearchStore(SearchBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
async def search_msgs(self, room_ids, search_term, keys):
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index a8e8dd4577..fa2c3b1feb 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -15,7 +15,7 @@
import collections.abc
import logging
from collections import namedtuple
-from typing import Iterable, Optional, Set
+from typing import TYPE_CHECKING, Iterable, Optional, Set
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@@ -30,6 +30,9 @@ from synapse.types import StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -53,7 +56,7 @@ class _GetStateGroupDelta(
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers."""
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
async def get_room_version(self, room_id: str) -> RoomVersion:
@@ -346,7 +349,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@@ -533,5 +536,5 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
* `state_groups_state`: Maps state group to state events.
"""
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index e20033bb28..5d7b59d861 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -16,7 +16,7 @@
import logging
from enum import Enum
from itertools import chain
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing_extensions import Counter
@@ -29,6 +29,9 @@ from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
# these fields track absolutes (e.g. total number of rooms on the server)
@@ -93,7 +96,7 @@ class UserSortOrder(Enum):
class StatsStore(StateDeltasStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 860146cd1b..d7dc1f73ac 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -14,7 +14,7 @@
import logging
from collections import namedtuple
-from typing import Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
@@ -26,6 +26,9 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
db_binary_type = memoryview
logger = logging.getLogger(__name__)
@@ -57,7 +60,7 @@ class DestinationRetryTimings:
class TransactionWorkerStore(CacheInvalidationWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks:
|