diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 9ff2d8d8c3..f024761ba7 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -18,7 +18,7 @@ import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.config.homeserver import HomeServerConfig
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
@@ -68,7 +68,7 @@ from .session import SessionStore
from .signatures import SignatureStore
from .state import StateStore
from .stats import StatsStore
-from .stream import StreamStore
+from .stream import StreamWorkerStore
from .tags import TagsStore
from .transactions import TransactionWorkerStore
from .ui_auth import UIAuthStore
@@ -87,7 +87,7 @@ class DataStore(
RoomStore,
RoomBatchStore,
RegistrationStore,
- StreamStore,
+ StreamWorkerStore,
ProfileStore,
PresenceStore,
TransactionWorkerStore,
@@ -129,7 +129,12 @@ class DataStore(
LockStore,
SessionStore,
):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
@@ -143,11 +148,7 @@ class DataStore(
("device_lists_outbound_pokes", "stream_id"),
],
)
- self._cross_signing_id_gen = StreamIdGenerator(
- db_conn, "e2e_cross_signing_keys", "stream_id"
- )
- self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
self._group_updates_id_gen = StreamIdGenerator(
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index f8bec266ac..32a553fdd7 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,15 +14,25 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast
from synapse.api.constants import AccountDataTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
-from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage._base import db_to_json
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdGenerator,
+ AbstractStreamIdTracker,
+ MultiWriterIdGenerator,
+ StreamIdGenerator,
+)
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -34,13 +44,19 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class AccountDataWorkerStore(SQLBaseStore):
- """This is an abstract base class where subclasses must implement
- `get_max_account_data_stream_id` which can be called in the initializer.
- """
+class AccountDataWorkerStore(CacheInvalidationWorkerStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
- self._instance_name = hs.get_instance_name()
+ # `_can_write_to_account_data` indicates whether the current worker is allowed
+ # to write account data. A value of `True` implies that `_account_data_id_gen`
+ # is an `AbstractStreamIdGenerator` and not just a tracker.
+ self._account_data_id_gen: AbstractStreamIdTracker
if isinstance(database.engine, PostgresEngine):
self._can_write_to_account_data = (
@@ -61,8 +77,6 @@ class AccountDataWorkerStore(SQLBaseStore):
writers=hs.config.worker.writers.account_data,
)
else:
- self._can_write_to_account_data = True
-
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
#
@@ -70,7 +84,8 @@ class AccountDataWorkerStore(SQLBaseStore):
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
- if hs.get_instance_name() in hs.config.worker.writers.account_data:
+ if self._instance_name in hs.config.worker.writers.account_data:
+ self._can_write_to_account_data = True
self._account_data_id_gen = StreamIdGenerator(
db_conn,
"room_account_data",
@@ -90,8 +105,6 @@ class AccountDataWorkerStore(SQLBaseStore):
"AccountDataAndTagsChangeCache", account_max
)
- super().__init__(database, db_conn, hs)
-
def get_max_account_data_stream_id(self) -> int:
"""Get the current max stream ID for account data stream
@@ -113,7 +126,9 @@ class AccountDataWorkerStore(SQLBaseStore):
room_id string to per room account_data dicts.
"""
- def get_account_data_for_user_txn(txn):
+ def get_account_data_for_user_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
rows = self.db_pool.simple_select_list_txn(
txn,
"account_data",
@@ -132,7 +147,7 @@ class AccountDataWorkerStore(SQLBaseStore):
["room_id", "account_data_type", "content"],
)
- by_room = {}
+ by_room: Dict[str, Dict[str, JsonDict]] = {}
for row in rows:
room_data = by_room.setdefault(row["room_id"], {})
room_data[row["account_data_type"]] = db_to_json(row["content"])
@@ -177,7 +192,9 @@ class AccountDataWorkerStore(SQLBaseStore):
A dict of the room account_data
"""
- def get_account_data_for_room_txn(txn):
+ def get_account_data_for_room_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[str, JsonDict]:
rows = self.db_pool.simple_select_list_txn(
txn,
"room_account_data",
@@ -207,7 +224,9 @@ class AccountDataWorkerStore(SQLBaseStore):
The room account_data for that type, or None if there isn't any set.
"""
- def get_account_data_for_room_and_type_txn(txn):
+ def get_account_data_for_room_and_type_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[JsonDict]:
content_json = self.db_pool.simple_select_one_onecol_txn(
txn,
table="room_account_data",
@@ -243,14 +262,16 @@ class AccountDataWorkerStore(SQLBaseStore):
if last_id == current_id:
return []
- def get_updated_global_account_data_txn(txn):
+ def get_updated_global_account_data_txn(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str]]:
sql = (
"SELECT stream_id, user_id, account_data_type"
" FROM account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
- return txn.fetchall()
+ return cast(List[Tuple[int, str, str]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_updated_global_account_data", get_updated_global_account_data_txn
@@ -273,14 +294,16 @@ class AccountDataWorkerStore(SQLBaseStore):
if last_id == current_id:
return []
- def get_updated_room_account_data_txn(txn):
+ def get_updated_room_account_data_txn(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str]]:
sql = (
"SELECT stream_id, user_id, room_id, account_data_type"
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
- return txn.fetchall()
+ return cast(List[Tuple[int, str, str, str]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_updated_room_account_data", get_updated_room_account_data_txn
@@ -299,7 +322,9 @@ class AccountDataWorkerStore(SQLBaseStore):
mapping from room_id string to per room account_data dicts.
"""
- def get_updated_account_data_for_user_txn(txn):
+ def get_updated_account_data_for_user_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
sql = (
"SELECT account_data_type, content FROM account_data"
" WHERE user_id = ? AND stream_id > ?"
@@ -316,7 +341,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id))
- account_data_by_room = {}
+ account_data_by_room: Dict[str, Dict[str, JsonDict]] = {}
for row in txn:
room_account_data = account_data_by_room.setdefault(row[0], {})
room_account_data[row[1]] = db_to_json(row[2])
@@ -353,12 +378,15 @@ class AccountDataWorkerStore(SQLBaseStore):
)
)
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ rows: Iterable[Any],
+ ) -> None:
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
- for row in rows:
- self.get_tags_for_user.invalidate((row.user_id,))
- self._account_data_stream_cache.entity_has_changed(row.user_id, token)
elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
@@ -372,7 +400,8 @@ class AccountDataWorkerStore(SQLBaseStore):
(row.user_id, row.room_id, row.data_type)
)
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
- return super().process_replication_rows(stream_name, instance_name, token, rows)
+
+ super().process_replication_rows(stream_name, instance_name, token, rows)
async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
@@ -389,6 +418,7 @@ class AccountDataWorkerStore(SQLBaseStore):
The maximum stream ID.
"""
assert self._can_write_to_account_data
+ assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
content_json = json_encoder.encode(content)
@@ -431,6 +461,7 @@ class AccountDataWorkerStore(SQLBaseStore):
The maximum stream ID.
"""
assert self._can_write_to_account_data
+ assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
@@ -452,7 +483,7 @@ class AccountDataWorkerStore(SQLBaseStore):
def _add_account_data_for_user(
self,
- txn,
+ txn: LoggingTransaction,
next_id: int,
user_id: str,
account_data_type: str,
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 4a883dc166..92c95a41d7 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -24,9 +24,8 @@ from synapse.appservice import (
from synapse.config.appservice import load_appservices
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.events_worker import EventsWorkerStore
-from synapse.storage.types import Connection
from synapse.types import JsonDict
from synapse.util import json_encoder
@@ -58,7 +57,12 @@ def _make_exclusive_regex(
class ApplicationServiceWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
self.services_cache = load_appservices(
hs.hostname, hs.config.appservice.app_service_config_files
)
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 36e8422fc6..0024348067 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -25,7 +25,7 @@ from synapse.replication.tcp.streams.events import (
EventsStreamEventRow,
)
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter
@@ -41,7 +41,12 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 0f56e10220..fd3fc298b3 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -18,7 +18,11 @@ from typing import TYPE_CHECKING, Optional
from synapse.events.utils import prune_event_dict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.util import json_encoder
@@ -31,7 +35,12 @@ logger = logging.getLogger(__name__)
class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
if (
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index a6fd9f2636..f3881671fd 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -26,7 +26,6 @@ from synapse.storage.database import (
make_tuple_comparison_clause,
)
from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
-from synapse.storage.types import Connection
from synapse.types import JsonDict, UserID
from synapse.util.caches.lrucache import LruCache
@@ -65,7 +64,12 @@ class LastConnectionInfo(TypedDict):
class ClientIpBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@@ -394,7 +398,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.user_ips_max_age = hs.config.server.user_ips_max_age
@@ -532,7 +541,12 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
# (user_id, access_token, ip,) -> last_seen
self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index ab8766c75b..3682cb6a81 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -601,7 +601,12 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@@ -668,7 +673,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
# There's a type mismatch here between how we want to type the row and
# what fetchone says it returns, but we silence it because we know that
# res can't be None.
- res: Tuple[Optional[int]] = txn.fetchone() # type: ignore[assignment]
+ res = cast(Tuple[Optional[int]], txn.fetchone())
if res[0] is None:
# this can only happen if the `device_inbox` table is empty, in which
# case we have no work to do.
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d5a4a661cd..273adb61fd 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -38,6 +38,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
+ LoggingDatabaseConnection,
LoggingTransaction,
make_tuple_comparison_clause,
)
@@ -61,7 +62,12 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks:
@@ -101,7 +107,9 @@ class DeviceWorkerStore(SQLBaseStore):
"count_devices_by_users", count_devices_by_users_txn, user_ids
)
- async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
+ async def get_device(
+ self, user_id: str, device_id: str
+ ) -> Optional[Dict[str, Any]]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
@@ -109,15 +117,35 @@ class DeviceWorkerStore(SQLBaseStore):
user_id: The ID of the user which owns the device
device_id: The ID of the device to retrieve
Returns:
- A dict containing the device information
- Raises:
- StoreError: if the device is not found
+ A dict containing the device information, or `None` if the device does not
+ exist.
"""
return await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
+ allow_none=True,
+ )
+
+ async def get_device_opt(
+ self, user_id: str, device_id: str
+ ) -> Optional[Dict[str, Any]]:
+ """Retrieve a device. Only returns devices that are not marked as
+ hidden.
+
+ Args:
+ user_id: The ID of the user which owns the device
+ device_id: The ID of the device to retrieve
+ Returns:
+ A dict containing the device information, or None if the device does not exist.
+ """
+ return await self.db_pool.simple_select_one(
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
+ retcols=("user_id", "device_id", "display_name"),
+ desc="get_device",
+ allow_none=True,
)
async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
@@ -274,7 +302,9 @@ class DeviceWorkerStore(SQLBaseStore):
# add the updated cross-signing keys to the results list
for user_id, result in cross_signing_keys_by_user.items():
result["user_id"] = user_id
- # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ results.append(("m.signing_key_update", result))
+ # also send the unstable version
+ # FIXME: remove this when enough servers have upgraded
results.append(("org.matrix.signing_key_update", result))
return now_stream_id, results
@@ -949,7 +979,12 @@ class DeviceWorkerStore(SQLBaseStore):
class DeviceBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@@ -1081,7 +1116,12 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index a3442814d7..f76c6121e8 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -12,16 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from collections import namedtuple
from typing import Iterable, List, Optional, Tuple
+import attr
+
from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.types import RoomAlias
from synapse.util.caches.descriptors import cached
-RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RoomAliasMapping:
+ room_id: str
+ room_alias: str
+ servers: List[str]
class DirectoryWorkerStore(CacheInvalidationWorkerStore):
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index b15fb71e62..0cb48b9dd7 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -13,35 +13,71 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional
+from typing import Dict, Iterable, Mapping, Optional, Tuple, cast
+
+from typing_extensions import Literal, TypedDict
from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import LoggingTransaction
+from synapse.types import JsonDict, JsonSerializable
from synapse.util import json_encoder
+class RoomKey(TypedDict):
+ """`KeyBackupData` in the Matrix spec.
+
+ https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3room_keyskeysroomidsessionid
+ """
+
+ first_message_index: int
+ forwarded_count: int
+ is_verified: bool
+ session_data: JsonSerializable
+
+
class EndToEndRoomKeyStore(SQLBaseStore):
+ """The store for end to end room key backups.
+
+ See https://spec.matrix.org/v1.1/client-server-api/#server-side-key-backups
+
+ As per the spec, backups are identified by an opaque version string. Internally,
+ version identifiers are assigned using incrementing integers. Non-numeric version
+ strings are treated as if they do not exist, since we would have never issued them.
+ """
+
async def update_e2e_room_key(
- self, user_id, version, room_id, session_id, room_key
- ):
+ self,
+ user_id: str,
+ version: str,
+ room_id: str,
+ session_id: str,
+ room_key: RoomKey,
+ ) -> None:
"""Replaces the encrypted E2E room key for a given session in a given backup
Args:
- user_id(str): the user whose backup we're setting
- version(str): the version ID of the backup we're updating
- room_id(str): the ID of the room whose keys we're setting
- session_id(str): the session whose room_key we're setting
- room_key(dict): the room_key being set
+ user_id: the user whose backup we're setting
+ version: the version ID of the backup we're updating
+ room_id: the ID of the room whose keys we're setting
+ session_id: the session whose room_key we're setting
+ room_key: the room_key being set
Raises:
StoreError
"""
+ try:
+ version_int = int(version)
+ except ValueError:
+ # Our versions are all ints so if we can't convert it to an integer,
+ # it doesn't exist.
+ raise StoreError(404, "No backup with that version exists")
await self.db_pool.simple_update_one(
table="e2e_room_keys",
keyvalues={
"user_id": user_id,
- "version": version,
+ "version": version_int,
"room_id": room_id,
"session_id": session_id,
},
@@ -54,22 +90,29 @@ class EndToEndRoomKeyStore(SQLBaseStore):
desc="update_e2e_room_key",
)
- async def add_e2e_room_keys(self, user_id, version, room_keys):
+ async def add_e2e_room_keys(
+ self, user_id: str, version: str, room_keys: Iterable[Tuple[str, str, RoomKey]]
+ ) -> None:
"""Bulk add room keys to a given backup.
Args:
- user_id (str): the user whose backup we're adding to
- version (str): the version ID of the backup for the set of keys we're adding to
- room_keys (iterable[(str, str, dict)]): the keys to add, in the form
- (roomID, sessionID, keyData)
+ user_id: the user whose backup we're adding to
+ version: the version ID of the backup for the set of keys we're adding to
+ room_keys: the keys to add, in the form (roomID, sessionID, keyData)
"""
+ try:
+ version_int = int(version)
+ except ValueError:
+ # Our versions are all ints so if we can't convert it to an integer,
+ # it doesn't exist.
+ raise StoreError(404, "No backup with that version exists")
values = []
for (room_id, session_id, room_key) in room_keys:
values.append(
{
"user_id": user_id,
- "version": version,
+ "version": version_int,
"room_id": room_id,
"session_id": session_id,
"first_message_index": room_key["first_message_index"],
@@ -92,31 +135,39 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
@trace
- async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def get_e2e_room_keys(
+ self,
+ user_id: str,
+ version: str,
+ room_id: Optional[str] = None,
+ session_id: Optional[str] = None,
+ ) -> Dict[
+ Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
+ ]:
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
Args:
- user_id (str): the user whose backup we're querying
- version (str): the version ID of the backup for the set of keys we're querying
- room_id (str): Optional. the ID of the room whose keys we're querying, if any.
+ user_id: the user whose backup we're querying
+ version: the version ID of the backup for the set of keys we're querying
+ room_id: Optional. the ID of the room whose keys we're querying, if any.
If not specified, we return the keys for all the rooms in the backup.
- session_id (str): Optional. the session whose room_key we're querying, if any.
+ session_id: Optional. the session whose room_key we're querying, if any.
If specified, we also require the room_id to be specified.
If not specified, we return all the keys in this version of
the backup (or for the specified room)
Returns:
- A list of dicts giving the session_data and message metadata for
- these room keys.
+ A dict giving the session_data and message metadata for these room keys.
+ `{"rooms": {room_id: {"sessions": {session_id: room_key}}}}`
"""
try:
- version = int(version)
+ version_int = int(version)
except ValueError:
return {"rooms": {}}
- keyvalues = {"user_id": user_id, "version": version}
+ keyvalues = {"user_id": user_id, "version": version_int}
if room_id:
keyvalues["room_id"] = room_id
if session_id:
@@ -137,7 +188,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
desc="get_e2e_room_keys",
)
- sessions = {"rooms": {}}
+ sessions: Dict[
+ Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
+ ] = {"rooms": {}}
for row in rows:
room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
room_entry["sessions"][row["session_id"]] = {
@@ -150,7 +203,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return sessions
- async def get_e2e_room_keys_multi(self, user_id, version, room_keys):
+ async def get_e2e_room_keys_multi(
+ self,
+ user_id: str,
+ version: str,
+ room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]],
+ ) -> Dict[str, Dict[str, RoomKey]]:
"""Get multiple room keys at a time. The difference between this function and
get_e2e_room_keys is that this function can be used to retrieve
multiple specific keys at a time, whereas get_e2e_room_keys is used for
@@ -158,26 +216,36 @@ class EndToEndRoomKeyStore(SQLBaseStore):
specific key.
Args:
- user_id (str): the user whose backup we're querying
- version (str): the version ID of the backup we're querying about
- room_keys (dict[str, dict[str, iterable[str]]]): a map from
- room ID -> {"session": [session ids]} indicating the session IDs
- that we want to query
+ user_id: the user whose backup we're querying
+ version: the version ID of the backup we're querying about
+ room_keys: a map from room ID -> {"sessions": [session ids]}
+ indicating the session IDs that we want to query
Returns:
- dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key
+ A map of room IDs to session IDs to room key
"""
+ try:
+ version_int = int(version)
+ except ValueError:
+ # Our versions are all ints so if we can't convert it to an integer,
+ # it doesn't exist.
+ return {}
return await self.db_pool.runInteraction(
"get_e2e_room_keys_multi",
self._get_e2e_room_keys_multi_txn,
user_id,
- version,
+ version_int,
room_keys,
)
@staticmethod
- def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys):
+ def _get_e2e_room_keys_multi_txn(
+ txn: LoggingTransaction,
+ user_id: str,
+ version: int,
+ room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]],
+ ) -> Dict[str, Dict[str, RoomKey]]:
if not room_keys:
return {}
@@ -209,7 +277,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
txn.execute(sql, params)
- ret = {}
+ ret: Dict[str, Dict[str, RoomKey]] = {}
for row in txn:
room_id = row[0]
@@ -231,36 +299,49 @@ class EndToEndRoomKeyStore(SQLBaseStore):
user_id: the user whose backup we're querying
version: the version ID of the backup we're querying about
"""
+ try:
+ version_int = int(version)
+ except ValueError:
+ # Our versions are all ints so if we can't convert it to an integer,
+ # it doesn't exist.
+ return 0
return await self.db_pool.simple_select_one_onecol(
table="e2e_room_keys",
- keyvalues={"user_id": user_id, "version": version},
+ keyvalues={"user_id": user_id, "version": version_int},
retcol="COUNT(*)",
desc="count_e2e_room_keys",
)
@trace
async def delete_e2e_room_keys(
- self, user_id, version, room_id=None, session_id=None
- ):
+ self,
+ user_id: str,
+ version: str,
+ room_id: Optional[str] = None,
+ session_id: Optional[str] = None,
+ ) -> None:
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session.
Args:
- user_id(str): the user whose backup we're deleting from
- version(str): the version ID of the backup for the set of keys we're deleting
- room_id(str): Optional. the ID of the room whose keys we're deleting, if any.
+ user_id: the user whose backup we're deleting from
+ version: the version ID of the backup for the set of keys we're deleting
+ room_id: Optional. the ID of the room whose keys we're deleting, if any.
If not specified, we delete the keys for all the rooms in the backup.
- session_id(str): Optional. the session whose room_key we're querying, if any.
+ session_id: Optional. the session whose room_key we're querying, if any.
If specified, we also require the room_id to be specified.
If not specified, we delete all the keys in this version of
the backup (or for the specified room)
-
- Returns:
- The deletion transaction
"""
+ try:
+ version_int = int(version)
+ except ValueError:
+ # Our versions are all ints so if we can't convert it to an integer,
+ # it doesn't exist.
+ return
- keyvalues = {"user_id": user_id, "version": int(version)}
+ keyvalues = {"user_id": user_id, "version": version_int}
if room_id:
keyvalues["room_id"] = room_id
if session_id:
@@ -271,23 +352,27 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
@staticmethod
- def _get_current_version(txn, user_id):
+ def _get_current_version(txn: LoggingTransaction, user_id: str) -> int:
txn.execute(
"SELECT MAX(version) FROM e2e_room_keys_versions "
"WHERE user_id=? AND deleted=0",
(user_id,),
)
- row = txn.fetchone()
- if not row:
+ # `SELECT MAX() FROM ...` will always return 1 row. The value in that row will
+ # be `NULL` when there are no available versions.
+ row = cast(Tuple[Optional[int]], txn.fetchone())
+ if row[0] is None:
raise StoreError(404, "No current backup version")
return row[0]
- async def get_e2e_room_keys_version_info(self, user_id, version=None):
+ async def get_e2e_room_keys_version_info(
+ self, user_id: str, version: Optional[str] = None
+ ) -> JsonDict:
"""Get info metadata about a version of our room_keys backup.
Args:
- user_id(str): the user whose backup we're querying
- version(str): Optional. the version ID of the backup we're querying about
+ user_id: the user whose backup we're querying
+ version: Optional. the version ID of the backup we're querying about
If missing, we return the information about the current version.
Raises:
StoreError: with code 404 if there are no e2e_room_keys_versions present
@@ -300,7 +385,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
etag(int): tag of the keys in the backup
"""
- def _get_e2e_room_keys_version_info_txn(txn):
+ def _get_e2e_room_keys_version_info_txn(txn: LoggingTransaction) -> JsonDict:
if version is None:
this_version = self._get_current_version(txn, user_id)
else:
@@ -309,14 +394,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it isn't there.
- raise StoreError(404, "No row found")
+ raise StoreError(404, "No backup with that version exists")
result = self.db_pool.simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
retcols=("version", "algorithm", "auth_data", "etag"),
+ allow_none=False,
)
+ assert result is not None # see comment on `simple_select_one_txn`
result["auth_data"] = db_to_json(result["auth_data"])
result["version"] = str(result["version"])
if result["etag"] is None:
@@ -328,28 +415,28 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
@trace
- async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str:
+ async def create_e2e_room_keys_version(self, user_id: str, info: JsonDict) -> str:
"""Atomically creates a new version of this user's e2e_room_keys store
with the given version info.
Args:
- user_id(str): the user whose backup we're creating a version
- info(dict): the info about the backup version to be created
+ user_id: the user whose backup we're creating a version
+ info: the info about the backup version to be created
Returns:
The newly created version ID
"""
- def _create_e2e_room_keys_version_txn(txn):
+ def _create_e2e_room_keys_version_txn(txn: LoggingTransaction) -> str:
txn.execute(
"SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?",
(user_id,),
)
- current_version = txn.fetchone()[0]
+ current_version = cast(Tuple[Optional[int]], txn.fetchone())[0]
if current_version is None:
- current_version = "0"
+ current_version = 0
- new_version = str(int(current_version) + 1)
+ new_version = current_version + 1
self.db_pool.simple_insert_txn(
txn,
@@ -362,7 +449,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
},
)
- return new_version
+ return str(new_version)
return await self.db_pool.runInteraction(
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
@@ -373,7 +460,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
self,
user_id: str,
version: str,
- info: Optional[dict] = None,
+ info: Optional[JsonDict] = None,
version_etag: Optional[int] = None,
) -> None:
"""Update a given backup version
@@ -386,7 +473,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
version_etag: etag of the keys in the backup. If None, then the etag
is not updated.
"""
- updatevalues = {}
+ updatevalues: Dict[str, object] = {}
if info is not None and "auth_data" in info:
updatevalues["auth_data"] = json_encoder.encode(info["auth_data"])
@@ -394,9 +481,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
updatevalues["etag"] = version_etag
if updatevalues:
- await self.db_pool.simple_update(
+ try:
+ version_int = int(version)
+ except ValueError:
+ # Our versions are all ints so if we can't convert it to an integer,
+ # it doesn't exist.
+ raise StoreError(404, "No backup with that version exists")
+
+ await self.db_pool.simple_update_one(
table="e2e_room_keys_versions",
- keyvalues={"user_id": user_id, "version": version},
+ keyvalues={"user_id": user_id, "version": version_int},
updatevalues=updatevalues,
desc="update_e2e_room_keys_version",
)
@@ -417,13 +511,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
or if the version requested doesn't exist.
"""
- def _delete_e2e_room_keys_version_txn(txn):
+ def _delete_e2e_room_keys_version_txn(txn: LoggingTransaction) -> None:
if version is None:
this_version = self._get_current_version(txn, user_id)
- if this_version is None:
- raise StoreError(404, "No current backup version")
else:
- this_version = version
+ try:
+ this_version = int(version)
+ except ValueError:
+ # Our versions are all ints so if we can't convert it to an integer,
+ # it isn't there.
+ raise StoreError(404, "No backup with that version exists")
self.db_pool.simple_delete_txn(
txn,
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index b06c1dc45b..57b5ffbad3 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,19 +14,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ cast,
+)
import attr
from canonicaljson import encode_canonical_json
-from twisted.enterprise.adbapi import Connection
-
from synapse.api.constants import DeviceKeyAlgorithms
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+ make_in_list_sql_clause,
+)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@@ -50,7 +63,12 @@ class DeviceKeyLookupResult:
class EndToEndKeyBackgroundStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@@ -62,8 +80,13 @@ class EndToEndKeyBackgroundStore(SQLBaseStore):
)
-class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._allow_device_name_lookup_over_federation = (
@@ -124,7 +147,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
- rv = {}
+ rv: Dict[str, Dict[str, JsonDict]] = {}
for user_id, device_keys in results.items():
rv[user_id] = {}
for device_id, device_info in device_keys.items():
@@ -195,6 +218,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
# add each cross-signing signature to the correct device in the result dict.
for (user_id, key_id, device_id, signature) in cross_sigs_result:
target_device_result = result[user_id][device_id]
+ # We've only looked up cross-signatures for non-deleted devices with key
+ # data.
+ assert target_device_result is not None
+ assert target_device_result.keys is not None
target_device_signatures = target_device_result.keys.setdefault(
"signatures", {}
)
@@ -207,7 +234,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
return result
def _get_e2e_device_keys_txn(
- self, txn, query_list, include_all_devices=False, include_deleted_devices=False
+ self,
+ txn: LoggingTransaction,
+ query_list: Collection[Tuple[str, str]],
+ include_all_devices: bool = False,
+ include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Get information on devices from the database
@@ -263,7 +294,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
return result
def _get_e2e_cross_signing_signatures_for_devices_txn(
- self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
+ self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
) -> List[Tuple[str, str, str, str]]:
"""Get cross-signing signatures for a given list of devices
@@ -289,7 +320,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
)
txn.execute(signature_sql, signature_query_params)
- return txn.fetchall()
+ return cast(
+ List[
+ Tuple[
+ str,
+ str,
+ str,
+ str,
+ ]
+ ],
+ txn.fetchall(),
+ )
async def get_e2e_one_time_keys(
self, user_id: str, device_id: str, key_ids: List[str]
@@ -335,7 +376,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
"""
- def _add_e2e_one_time_keys(txn):
+ def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("new_keys", new_keys)
@@ -375,7 +416,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
A mapping from algorithm to number of keys for that algorithm.
"""
- def _count_e2e_one_time_keys(txn):
+ def _count_e2e_one_time_keys(txn: LoggingTransaction) -> Dict[str, int]:
sql = (
"SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ?"
@@ -421,7 +462,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
)
def _set_e2e_fallback_keys_txn(
- self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ fallback_keys: JsonDict,
) -> None:
# fallback_keys will usually only have one item in it, so using a for
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
@@ -483,7 +528,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def get_e2e_cross_signing_key(
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
- ) -> Optional[dict]:
+ ) -> Optional[JsonDict]:
"""Returns a user's cross-signing key.
Args:
@@ -504,7 +549,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
return user_keys.get(key_type)
@cached(num_args=1)
- def _get_bare_e2e_cross_signing_keys(self, user_id):
+ def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]:
"""Dummy function. Only used to make a cache for
_get_bare_e2e_cross_signing_keys_bulk.
"""
@@ -517,7 +562,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
)
async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: Iterable[str]
- ) -> Dict[str, Dict[str, dict]]:
+ ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.
@@ -531,32 +576,35 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
their user ID will map to None.
"""
- return await self.db_pool.runInteraction(
+ result = await self.db_pool.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids,
)
+ # The `Optional` comes from the `@cachedList` decorator.
+ return cast(Dict[str, Optional[Dict[str, JsonDict]]], result)
+
def _get_bare_e2e_cross_signing_keys_bulk_txn(
self,
- txn: Connection,
+ txn: LoggingTransaction,
user_ids: Iterable[str],
- ) -> Dict[str, Dict[str, dict]]:
+ ) -> Dict[str, Dict[str, JsonDict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.
Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- user_ids (list[str]): the users whose keys are being requested
+ txn: db connection
+ user_ids: the users whose keys are being requested
Returns:
- dict[str, dict[str, dict]]: mapping from user ID to key type to key
- data. If a user's cross-signing keys were not found, their user
- ID will not be in the dict.
+ Mapping from user ID to key type to key data.
+ If a user's cross-signing keys were not found, their user ID will not be in
+ the dict.
"""
- result = {}
+ result: Dict[str, Dict[str, JsonDict]] = {}
for user_chunk in batch_iter(user_ids, 100):
clause, params = make_in_list_sql_clause(
@@ -596,43 +644,48 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
user_id = row["user_id"]
key_type = row["keytype"]
key = db_to_json(row["keydata"])
- user_info = result.setdefault(user_id, {})
- user_info[key_type] = key
+ user_keys = result.setdefault(user_id, {})
+ user_keys[key_type] = key
return result
def _get_e2e_cross_signing_signatures_txn(
self,
- txn: Connection,
- keys: Dict[str, Dict[str, dict]],
+ txn: LoggingTransaction,
+ keys: Dict[str, Optional[Dict[str, JsonDict]]],
from_user_id: str,
- ) -> Dict[str, Dict[str, dict]]:
+ ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
"""Returns the cross-signing signatures made by a user on a set of keys.
Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- keys (dict[str, dict[str, dict]]): a map of user ID to key type to
- key data. This dict will be modified to add signatures.
- from_user_id (str): fetch the signatures made by this user
+ txn: db connection
+ keys: a map of user ID to key type to key data.
+ This dict will be modified to add signatures.
+ from_user_id: fetch the signatures made by this user
Returns:
- dict[str, dict[str, dict]]: mapping from user ID to key type to key
- data. The return value will be the same as the keys argument,
- with the modifications included.
+ Mapping from user ID to key type to key data.
+ The return value will be the same as the keys argument, with the
+ modifications included.
"""
# find out what cross-signing keys (a.k.a. devices) we need to get
# signatures for. This is a map of (user_id, device_id) to key type
# (device_id is the key's public part).
- devices = {}
+ devices: Dict[Tuple[str, str], str] = {}
- for user_id, user_info in keys.items():
- if user_info is None:
+ for user_id, user_keys in keys.items():
+ if user_keys is None:
continue
- for key_type, key in user_info.items():
+ for key_type, key in user_keys.items():
device_id = None
for k in key["keys"].values():
device_id = k
+ # `key` ought to be a `CrossSigningKey`, whose .keys property is a
+ # dictionary with a single entry:
+ # "algorithm:base64_public_key": "base64_public_key"
+ # See https://spec.matrix.org/v1.1/client-server-api/#cross-signing
+ assert isinstance(device_id, str)
devices[(user_id, device_id)] = key_type
for batch in batch_iter(devices.keys(), size=100):
@@ -656,15 +709,20 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
# and add the signatures to the appropriate keys
for row in rows:
- key_id = row["key_id"]
- target_user_id = row["target_user_id"]
- target_device_id = row["target_device_id"]
+ key_id: str = row["key_id"]
+ target_user_id: str = row["target_user_id"]
+ target_device_id: str = row["target_device_id"]
key_type = devices[(target_user_id, target_device_id)]
# We need to copy everything, because the result may have come
# from the cache. dict.copy only does a shallow copy, so we
# need to recursively copy the dicts that will be modified.
- user_info = keys[target_user_id] = keys[target_user_id].copy()
- target_user_key = user_info[key_type] = user_info[key_type].copy()
+ user_keys = keys[target_user_id]
+ # `user_keys` cannot be `None` because we only fetched signatures for
+ # users with keys
+ assert user_keys is not None
+ user_keys = keys[target_user_id] = user_keys.copy()
+
+ target_user_key = user_keys[key_type] = user_keys[key_type].copy()
if "signatures" in target_user_key:
signatures = target_user_key["signatures"] = target_user_key[
"signatures"
@@ -683,7 +741,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None
- ) -> Dict[str, Optional[Dict[str, dict]]]:
+ ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
"""Returns the cross-signing keys for a set of users.
Args:
@@ -741,7 +799,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
if last_id == current_id:
return [], current_id, False
- def _get_all_user_signature_changes_for_remotes_txn(txn):
+ def _get_all_user_signature_changes_for_remotes_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
sql = """
SELECT stream_id, from_user_id AS user_id
FROM user_signature_stream
@@ -785,7 +845,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
@trace
def _claim_e2e_one_time_key_simple(
- txn, user_id: str, device_id: str, algorithm: str
+ txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
) -> Optional[Tuple[str, str]]:
"""Claim OTK for device for DBs that don't support RETURNING.
@@ -825,7 +885,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
@trace
def _claim_e2e_one_time_key_returning(
- txn, user_id: str, device_id: str, algorithm: str
+ txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
) -> Optional[Tuple[str, str]]:
"""Claim OTK for device for DBs that support RETURNING.
@@ -860,7 +920,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
key_id, key_json = otk_row
return f"{algorithm}:{key_id}", key_json
- results = {}
+ results: Dict[str, Dict[str, Dict[str, str]]] = {}
for user_id, device_id, algorithm in query_list:
if self.database_engine.supports_returning:
# If we support RETURNING clause we can use a single query that
@@ -930,6 +990,18 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ self._cross_signing_id_gen = StreamIdGenerator(
+ db_conn, "e2e_cross_signing_keys", "stream_id"
+ )
+
async def set_e2e_device_keys(
self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
) -> bool:
@@ -937,7 +1009,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
or the keys were already in the database.
"""
- def _set_e2e_device_keys_txn(txn):
+ def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("time_now", time_now)
@@ -973,7 +1045,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
)
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
- def delete_e2e_keys_by_device_txn(txn):
+ def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
log_kv(
{
"message": "Deleting keys for device",
@@ -1012,17 +1084,24 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
- def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
+ def _set_e2e_cross_signing_key_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ key_type: str,
+ key: JsonDict,
+ stream_id: int,
+ ) -> None:
"""Set a user's cross-signing key.
Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- user_id (str): the user to set the signing key for
- key_type (str): the type of key that is being set: either 'master'
+ txn: db connection
+ user_id: the user to set the signing key for
+ key_type: the type of key that is being set: either 'master'
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
- key (dict): the key data
- stream_id (int)
+ key: the key data
+ stream_id
"""
# the 'key' dict will look something like:
# {
@@ -1075,13 +1154,15 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
)
- async def set_e2e_cross_signing_key(self, user_id, key_type, key):
+ async def set_e2e_cross_signing_key(
+ self, user_id: str, key_type: str, key: JsonDict
+ ) -> None:
"""Set a user's cross-signing key.
Args:
- user_id (str): the user to set the user-signing key for
- key_type (str): the type of cross-signing key to set
- key (dict): the key data
+ user_id: the user to set the user-signing key for
+ key_type: the type of cross-signing key to set
+ key: the key data
"""
async with self._cross_signing_id_gen.get_next() as stream_id:
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 9580a40785..270b30800b 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -24,7 +24,11 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.events import EventBase, make_event_from_dict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine
@@ -62,7 +66,12 @@ class _NoChainCoverIndex(Exception):
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks:
@@ -279,7 +288,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
new_front = set()
for chunk in batch_iter(front, 100):
# Pull the auth events either from the cache or DB.
- to_fetch = [] # Event IDs to fetch from DB # type: List[str]
+ to_fetch: List[str] = [] # Event IDs to fetch from DB
for event_id in chunk:
res = self._event_auth_cache.get(event_id)
if res is None:
@@ -606,8 +615,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# currently walking, either from cache or DB.
search, chunk = search[:-100], search[-100:]
- found = [] # Results found # type: List[Tuple[str, str, int]]
- to_fetch = [] # Event IDs to fetch from DB # type: List[str]
+ found: List[Tuple[str, str, int]] = [] # Results found
+ to_fetch: List[str] = [] # Event IDs to fetch from DB
for _, event_id in chunk:
res = self._event_auth_cache.get(event_id)
if res is None:
@@ -1384,7 +1393,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
count = await self.db_pool.simple_select_one_onecol(
table="federation_inbound_events_staging",
keyvalues={"room_id": room_id},
- retcol="COALESCE(COUNT(*), 0)",
+ retcol="COUNT(*)",
desc="prune_staged_events_in_room_count",
)
@@ -1476,9 +1485,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"""Update the prometheus metrics for the inbound federation staging area."""
def _get_stats_for_federation_staging_txn(txn):
- txn.execute(
- "SELECT coalesce(count(*), 0) FROM federation_inbound_events_staging"
- )
+ txn.execute("SELECT count(*) FROM federation_inbound_events_staging")
(count,) = txn.fetchone()
txn.execute(
@@ -1514,7 +1521,12 @@ class EventFederationStore(EventFederationWorkerStore):
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 3efdd0c920..a98e6b2593 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -13,14 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
import attr
-from typing_extensions import TypedDict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -30,29 +33,64 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}]
-DEFAULT_HIGHLIGHT_ACTION = [
+DEFAULT_NOTIF_ACTION: List[Union[dict, str]] = [
+ "notify",
+ {"set_tweak": "highlight", "value": False},
+]
+DEFAULT_HIGHLIGHT_ACTION: List[Union[dict, str]] = [
"notify",
{"set_tweak": "sound", "value": "default"},
{"set_tweak": "highlight"},
]
-class BasePushAction(TypedDict):
- event_id: str
- actions: List[Union[dict, str]]
-
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class HttpPushAction:
+ """
+ HttpPushAction instances include the information used to generate HTTP
+ requests to a push gateway.
+ """
-class HttpPushAction(BasePushAction):
+ event_id: str
room_id: str
stream_ordering: int
+ actions: List[Union[dict, str]]
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class EmailPushAction(HttpPushAction):
+ """
+ EmailPushAction instances include the information used to render an email
+ push notification.
+ """
+
received_ts: Optional[int]
-def _serialize_action(actions, is_highlight):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UserPushAction(EmailPushAction):
+ """
+ UserPushAction instances include the necessary information to respond to
+ /notifications requests.
+ """
+
+ topological_ordering: int
+ highlight: bool
+ profile_tag: str
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class NotifCounts:
+ """
+ The per-user, per-room count of notifications. Used by sync and push.
+ """
+
+ notify_count: int
+ unread_count: int
+ highlight_count: int
+
+
+def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
"""Custom serializer for actions. This allows us to "compress" common actions.
We use the fact that most users have the same actions for notifs (and for
@@ -70,7 +108,7 @@ def _serialize_action(actions, is_highlight):
return json_encoder.encode(actions)
-def _deserialize_action(actions, is_highlight):
+def _deserialize_action(actions: str, is_highlight: bool) -> List[Union[dict, str]]:
"""Custom deserializer for actions. This allows us to "compress" common actions"""
if actions:
return db_to_json(actions)
@@ -82,12 +120,17 @@ def _deserialize_action(actions, is_highlight):
class EventPushActionsWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
# These get correctly set by _find_stream_orderings_for_times_txn
- self.stream_ordering_month_ago = None
- self.stream_ordering_day_ago = None
+ self.stream_ordering_month_ago: Optional[int] = None
+ self.stream_ordering_day_ago: Optional[int] = None
cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn")
self._find_stream_orderings_for_times_txn(cur)
@@ -111,7 +154,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
room_id: str,
user_id: str,
last_read_event_id: Optional[str],
- ) -> Dict[str, int]:
+ ) -> NotifCounts:
"""Get the notification count, the highlight count and the unread message count
for a given user in a given room after the given read receipt.
@@ -140,15 +183,15 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def _get_unread_counts_by_receipt_txn(
self,
- txn,
- room_id,
- user_id,
- last_read_event_id,
- ):
+ txn: LoggingTransaction,
+ room_id: str,
+ user_id: str,
+ last_read_event_id: Optional[str],
+ ) -> NotifCounts:
stream_ordering = None
if last_read_event_id is not None:
- stream_ordering = self.get_stream_id_for_event_txn(
+ stream_ordering = self.get_stream_id_for_event_txn( # type: ignore[attr-defined]
txn,
last_read_event_id,
allow_none=True,
@@ -166,13 +209,15 @@ class EventPushActionsWorkerStore(SQLBaseStore):
retcol="event_id",
)
- stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
+ stream_ordering = self.get_stream_id_for_event_txn(txn, event_id) # type: ignore[attr-defined]
return self._get_unread_counts_by_pos_txn(
txn, room_id, user_id, stream_ordering
)
- def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
+ def _get_unread_counts_by_pos_txn(
+ self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
+ ) -> NotifCounts:
sql = (
"SELECT"
" COUNT(CASE WHEN notif = 1 THEN 1 END),"
@@ -210,16 +255,16 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# for this row.
unread_count += row[1]
- return {
- "notify_count": notif_count,
- "unread_count": unread_count,
- "highlight_count": highlight_count,
- }
+ return NotifCounts(
+ notify_count=notif_count,
+ unread_count=unread_count,
+ highlight_count=highlight_count,
+ )
async def get_push_action_users_in_range(
- self, min_stream_ordering, max_stream_ordering
- ):
- def f(txn):
+ self, min_stream_ordering: int, max_stream_ordering: int
+ ) -> List[str]:
+ def f(txn: LoggingTransaction) -> List[str]:
sql = (
"SELECT DISTINCT(user_id) FROM event_push_actions WHERE"
" stream_ordering >= ? AND stream_ordering <= ? AND notif = 1"
@@ -227,8 +272,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn]
- ret = await self.db_pool.runInteraction("get_push_action_users_in_range", f)
- return ret
+ return await self.db_pool.runInteraction("get_push_action_users_in_range", f)
async def get_unread_push_actions_for_user_in_range_for_http(
self,
@@ -254,7 +298,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"""
# find rooms that have a read receipt in them and return the next
# push actions
- def get_after_receipt(txn):
+ def get_after_receipt(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, int, str, bool]]:
# find rooms that have a read receipt in them and return the next
# push actions
sql = (
@@ -280,7 +326,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
- return txn.fetchall()
+ return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
after_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
@@ -289,7 +335,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# There are rooms with push actions in them but you don't have a read receipt in
# them e.g. rooms you've been invited to, so get push actions for rooms which do
# not have read receipts in them too.
- def get_no_receipt(txn):
+ def get_no_receipt(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, int, str, bool]]:
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" ep.highlight "
@@ -309,19 +357,19 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
- return txn.fetchall()
+ return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
no_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
)
notifs = [
- {
- "event_id": row[0],
- "room_id": row[1],
- "stream_ordering": row[2],
- "actions": _deserialize_action(row[3], row[4]),
- }
+ HttpPushAction(
+ event_id=row[0],
+ room_id=row[1],
+ stream_ordering=row[2],
+ actions=_deserialize_action(row[3], row[4]),
+ )
for row in after_read_receipt + no_read_receipt
]
@@ -329,7 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# contain results from the first query, correctly ordered, followed
# by results from the second query, but we want them all ordered
# by stream_ordering, oldest first.
- notifs.sort(key=lambda r: r["stream_ordering"])
+ notifs.sort(key=lambda r: r.stream_ordering)
# Take only up to the limit. We have to stop at the limit because
# one of the subqueries may have hit the limit.
@@ -359,7 +407,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"""
# find rooms that have a read receipt in them and return the most recent
# push actions
- def get_after_receipt(txn):
+ def get_after_receipt(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, int, str, bool, int]]:
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" ep.highlight, e.received_ts"
@@ -384,7 +434,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
- return txn.fetchall()
+ return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
after_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
@@ -393,7 +443,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# There are rooms with push actions in them but you don't have a read receipt in
# them e.g. rooms you've been invited to, so get push actions for rooms which do
# not have read receipts in them too.
- def get_no_receipt(txn):
+ def get_no_receipt(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, int, str, bool, int]]:
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" ep.highlight, e.received_ts"
@@ -413,7 +465,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
- return txn.fetchall()
+ return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
no_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
@@ -421,13 +473,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# Make a list of dicts from the two sets of results.
notifs = [
- {
- "event_id": row[0],
- "room_id": row[1],
- "stream_ordering": row[2],
- "actions": _deserialize_action(row[3], row[4]),
- "received_ts": row[5],
- }
+ EmailPushAction(
+ event_id=row[0],
+ room_id=row[1],
+ stream_ordering=row[2],
+ actions=_deserialize_action(row[3], row[4]),
+ received_ts=row[5],
+ )
for row in after_read_receipt + no_read_receipt
]
@@ -435,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# contain results from the first query, correctly ordered, followed
# by results from the second query, but we want them all ordered
# by received_ts (most recent first)
- notifs.sort(key=lambda r: -(r["received_ts"] or 0))
+ notifs.sort(key=lambda r: -(r.received_ts or 0))
# Now return the first `limit`
return notifs[:limit]
@@ -456,7 +508,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
True if there may be push to process, False if there definitely isn't.
"""
- def _get_if_maybe_push_in_range_for_user_txn(txn):
+ def _get_if_maybe_push_in_range_for_user_txn(txn: LoggingTransaction) -> bool:
sql = """
SELECT 1 FROM event_push_actions
WHERE user_id = ? AND stream_ordering > ? AND notif = 1
@@ -490,19 +542,21 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# This is a helper function for generating the necessary tuple that
# can be used to insert into the `event_push_actions_staging` table.
- def _gen_entry(user_id, actions):
+ def _gen_entry(
+ user_id: str, actions: List[Union[dict, str]]
+ ) -> Tuple[str, str, str, int, int, int]:
is_highlight = 1 if _action_has_highlight(actions) else 0
notif = 1 if "notify" in actions else 0
return (
event_id, # event_id column
user_id, # user_id column
- _serialize_action(actions, is_highlight), # actions column
+ _serialize_action(actions, bool(is_highlight)), # actions column
notif, # notif column
is_highlight, # highlight column
int(count_as_unread), # unread column
)
- def _add_push_actions_to_staging_txn(txn):
+ def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None:
# We don't use simple_insert_many here to avoid the overhead
# of generating lists of dicts.
@@ -530,12 +584,11 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"""
try:
- res = await self.db_pool.simple_delete(
+ await self.db_pool.simple_delete(
table="event_push_actions_staging",
keyvalues={"event_id": event_id},
desc="remove_push_actions_from_staging",
)
- return res
except Exception:
# this method is called from an exception handler, so propagating
# another exception here really isn't helpful - there's nothing
@@ -588,7 +641,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
@staticmethod
- def _find_first_stream_ordering_after_ts_txn(txn, ts):
+ def _find_first_stream_ordering_after_ts_txn(
+ txn: LoggingTransaction, ts: int
+ ) -> int:
"""
Find the stream_ordering of the first event that was received on or
after a given timestamp. This is relatively slow as there is no index
@@ -600,14 +655,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
stream_ordering
Args:
- txn (twisted.enterprise.adbapi.Transaction):
- ts (int): timestamp to search for
+ txn:
+ ts: timestamp to search for
Returns:
- int: stream ordering
+ The stream ordering
"""
txn.execute("SELECT MAX(stream_ordering) FROM events")
- max_stream_ordering = txn.fetchone()[0]
+ max_stream_ordering = cast(Tuple[Optional[int]], txn.fetchone())[0]
if max_stream_ordering is None:
return 0
@@ -663,8 +718,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return range_end
- async def get_time_of_last_push_action_before(self, stream_ordering):
- def f(txn):
+ async def get_time_of_last_push_action_before(
+ self, stream_ordering: int
+ ) -> Optional[int]:
+ def f(txn: LoggingTransaction) -> Optional[Tuple[int]]:
sql = (
"SELECT e.received_ts"
" FROM event_push_actions AS ep"
@@ -674,7 +731,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" LIMIT 1"
)
txn.execute(sql, (stream_ordering,))
- return txn.fetchone()
+ return cast(Optional[Tuple[int]], txn.fetchone())
result = await self.db_pool.runInteraction(
"get_time_of_last_push_action_before", f
@@ -682,7 +739,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return result[0] if result else None
@wrap_as_background_process("rotate_notifs")
- async def _rotate_notifs(self):
+ async def _rotate_notifs(self) -> None:
if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
return
self._doing_notif_rotation = True
@@ -700,7 +757,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
finally:
self._doing_notif_rotation = False
- def _rotate_notifs_txn(self, txn):
+ def _rotate_notifs_txn(self, txn: LoggingTransaction) -> bool:
"""Archives older notifications into event_push_summary. Returns whether
the archiving process has caught up or not.
"""
@@ -725,6 +782,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
stream_row = txn.fetchone()
if stream_row:
(offset_stream_ordering,) = stream_row
+ assert self.stream_ordering_day_ago is not None
rotate_to_stream_ordering = min(
self.stream_ordering_day_ago, offset_stream_ordering
)
@@ -740,7 +798,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# We have caught up iff we were limited by `stream_ordering_day_ago`
return caught_up
- def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
+ def _rotate_notifs_before_txn(
+ self, txn: LoggingTransaction, rotate_to_stream_ordering: int
+ ) -> None:
old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
@@ -861,8 +921,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
def _remove_old_push_actions_before_txn(
- self, txn, room_id, user_id, stream_ordering
- ):
+ self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
+ ) -> None:
"""
Purges old push actions for a user and room before a given
stream_ordering.
@@ -910,7 +970,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@@ -929,9 +994,15 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
async def get_push_actions_for_user(
- self, user_id, before=None, limit=50, only_highlight=False
- ):
- def f(txn):
+ self,
+ user_id: str,
+ before: Optional[str] = None,
+ limit: int = 50,
+ only_highlight: bool = False,
+ ) -> List[UserPushAction]:
+ def f(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, int, int, str, bool, str, int]]:
before_clause = ""
if before:
before_clause = "AND epa.stream_ordering < ?"
@@ -958,32 +1029,44 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
" LIMIT ?" % (before_clause,)
)
txn.execute(sql, args)
- return self.db_pool.cursor_to_dict(txn)
+ return cast(
+ List[Tuple[str, str, int, int, str, bool, str, int]], txn.fetchall()
+ )
push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
- for pa in push_actions:
- pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
- return push_actions
+ return [
+ UserPushAction(
+ event_id=row[0],
+ room_id=row[1],
+ stream_ordering=row[2],
+ actions=_deserialize_action(row[4], row[5]),
+ received_ts=row[7],
+ topological_ordering=row[3],
+ highlight=row[5],
+ profile_tag=row[6],
+ )
+ for row in push_actions
+ ]
-def _action_has_highlight(actions):
+def _action_has_highlight(actions: List[Union[dict, str]]) -> bool:
for action in actions:
- try:
- if action.get("set_tweak", None) == "highlight":
- return action.get("value", True)
- except AttributeError:
- pass
+ if not isinstance(action, dict):
+ continue
+
+ if action.get("set_tweak", None) == "highlight":
+ return action.get("value", True)
return False
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _EventPushSummary:
"""Summary of pending event push actions for a given user in a given room.
Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
"""
- unread_count = attr.ib(type=int)
- stream_ordering = attr.ib(type=int)
- old_user_id = attr.ib(type=str)
- notif_count = attr.ib(type=int)
+ unread_count: int
+ stream_ordering: int
+ old_user_id: str
+ notif_count: int
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 4e528612ea..dd255aefb9 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -19,6 +19,7 @@ from collections import OrderedDict
from typing import (
TYPE_CHECKING,
Any,
+ Collection,
Dict,
Generator,
Iterable,
@@ -40,10 +41,13 @@ from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.storage.databases.main.search import SearchEntry
-from synapse.storage.types import Connection
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
@@ -94,7 +98,7 @@ class PersistEventsStore:
hs: "HomeServer",
db: DatabasePool,
main_data_store: "DataStore",
- db_conn: Connection,
+ db_conn: LoggingDatabaseConnection,
):
self.hs = hs
self.db_pool = db
@@ -1319,14 +1323,13 @@ class PersistEventsStore:
return [ec for ec in events_and_contexts if ec[0] not in to_remove]
- def _store_event_txn(self, txn, events_and_contexts):
+ def _store_event_txn(
+ self,
+ txn: LoggingTransaction,
+ events_and_contexts: Collection[Tuple[EventBase, EventContext]],
+ ) -> None:
"""Insert new events into the event, event_json, redaction and
state_events tables.
-
- Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- events_and_contexts (list[(EventBase, EventContext)]): events
- we are persisting
"""
if not events_and_contexts:
@@ -1339,46 +1342,58 @@ class PersistEventsStore:
d.pop("redacted_because", None)
return d
- self.db_pool.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_values_txn(
txn,
table="event_json",
- values=[
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "internal_metadata": json_encoder.encode(
- event.internal_metadata.get_dict()
- ),
- "json": json_encoder.encode(event_dict(event)),
- "format_version": event.format_version,
- }
+ keys=("event_id", "room_id", "internal_metadata", "json", "format_version"),
+ values=(
+ (
+ event.event_id,
+ event.room_id,
+ json_encoder.encode(event.internal_metadata.get_dict()),
+ json_encoder.encode(event_dict(event)),
+ event.format_version,
+ )
for event, _ in events_and_contexts
- ],
+ ),
)
- self.db_pool.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_values_txn(
txn,
table="events",
- values=[
- {
- "instance_name": self._instance_name,
- "stream_ordering": event.internal_metadata.stream_ordering,
- "topological_ordering": event.depth,
- "depth": event.depth,
- "event_id": event.event_id,
- "room_id": event.room_id,
- "type": event.type,
- "processed": True,
- "outlier": event.internal_metadata.is_outlier(),
- "origin_server_ts": int(event.origin_server_ts),
- "received_ts": self._clock.time_msec(),
- "sender": event.sender,
- "contains_url": (
- "url" in event.content and isinstance(event.content["url"], str)
- ),
- }
+ keys=(
+ "instance_name",
+ "stream_ordering",
+ "topological_ordering",
+ "depth",
+ "event_id",
+ "room_id",
+ "type",
+ "processed",
+ "outlier",
+ "origin_server_ts",
+ "received_ts",
+ "sender",
+ "contains_url",
+ ),
+ values=(
+ (
+ self._instance_name,
+ event.internal_metadata.stream_ordering,
+ event.depth, # topological_ordering
+ event.depth, # depth
+ event.event_id,
+ event.room_id,
+ event.type,
+ True, # processed
+ event.internal_metadata.is_outlier(),
+ int(event.origin_server_ts),
+ self._clock.time_msec(),
+ event.sender,
+ "url" in event.content and isinstance(event.content["url"], str),
+ )
for event, _ in events_and_contexts
- ],
+ ),
)
# If we're persisting an unredacted event we go and ensure
@@ -1397,27 +1412,15 @@ class PersistEventsStore:
)
txn.execute(sql + clause, [False] + args)
- state_events_and_contexts = [
- ec for ec in events_and_contexts if ec[0].is_state()
- ]
-
- state_values = []
- for event, _ in state_events_and_contexts:
- vals = {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "type": event.type,
- "state_key": event.state_key,
- }
-
- # TODO: How does this work with backfilling?
- if hasattr(event, "replaces_state"):
- vals["prev_state"] = event.replaces_state
-
- state_values.append(vals)
-
- self.db_pool.simple_insert_many_txn(
- txn, table="state_events", values=state_values
+ self.db_pool.simple_insert_many_values_txn(
+ txn,
+ table="state_events",
+ keys=("event_id", "room_id", "type", "state_key"),
+ values=(
+ (event.event_id, event.room_id, event.type, event.state_key)
+ for event, _ in events_and_contexts
+ if event.is_state()
+ ),
)
def _store_rejected_events_txn(self, txn, events_and_contexts):
@@ -1780,10 +1783,14 @@ class PersistEventsStore:
)
if rel_type == RelationTypes.REPLACE:
- txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
+ txn.call_after(
+ self.store.get_applicable_edit.invalidate, (parent_id, event.room_id)
+ )
if rel_type == RelationTypes.THREAD:
- txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
+ txn.call_after(
+ self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
+ )
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
"""Handles keeping track of insertion events and edges/connections.
@@ -1969,14 +1976,17 @@ class PersistEventsStore:
txn, self.store.get_retention_policy_for_room, (event.room_id,)
)
- def store_event_search_txn(self, txn, event, key, value):
+ def store_event_search_txn(
+ self, txn: LoggingTransaction, event: EventBase, key: str, value: str
+ ) -> None:
"""Add event to the search table
Args:
- txn (cursor):
- event (EventBase):
- key (str):
- value (str):
+ txn: The database transaction.
+ event: The event being added to the search table.
+ key: A key describing the search value (one of "content.name",
+ "content.topic", or "content.body")
+ value: The value from the event's content.
"""
self.store.store_search_entries_txn(
txn,
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index c88fd35e7f..a68f14ba48 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast
import attr
@@ -23,6 +23,7 @@ from synapse.events import make_event_from_dict
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
+ LoggingDatabaseConnection,
LoggingTransaction,
make_tuple_comparison_clause,
)
@@ -83,7 +84,12 @@ class _CalculateChainCover:
class EventsBackgroundUpdatesStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
@@ -234,12 +240,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
################################################################################
- async def _background_reindex_fields_sender(self, progress, batch_size):
+ async def _background_reindex_fields_sender(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
- def reindex_txn(txn):
+ def reindex_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT stream_ordering, event_id, json FROM events"
" INNER JOIN event_json USING (event_id)"
@@ -301,12 +309,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return result
- async def _background_reindex_origin_server_ts(self, progress, batch_size):
+ async def _background_reindex_origin_server_ts(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
- def reindex_search_txn(txn):
+ def reindex_search_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT stream_ordering, event_id FROM events"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
@@ -375,7 +385,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return result
- async def _cleanup_extremities_bg_update(self, progress, batch_size):
+ async def _cleanup_extremities_bg_update(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""Background update to clean out extremities that should have been
deleted previously.
@@ -396,12 +408,12 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
# have any descendants, but if they do then we should delete those
# extremities.
- def _cleanup_extremities_bg_update_txn(txn):
+ def _cleanup_extremities_bg_update_txn(txn: LoggingTransaction) -> int:
# The set of extremity event IDs that we're checking this round
original_set = set()
- # A dict[str, set[str]] of event ID to their prev events.
- graph = {}
+ # A dict[str, Set[str]] of event ID to their prev events.
+ graph: Dict[str, Set[str]] = {}
# The set of descendants of the original set that are not rejected
# nor soft-failed. Ancestors of these events should be removed
@@ -530,7 +542,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
room_ids = {row["room_id"] for row in rows}
for room_id in room_ids:
txn.call_after(
- self.get_latest_event_ids_in_room.invalidate, (room_id,)
+ self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined]
)
self.db_pool.simple_delete_many_txn(
@@ -552,7 +564,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
_BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES
)
- def _drop_table_txn(txn):
+ def _drop_table_txn(txn: LoggingTransaction) -> None:
txn.execute("DROP TABLE _extremities_to_check")
await self.db_pool.runInteraction(
@@ -561,11 +573,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return num_handled
- async def _redactions_received_ts(self, progress, batch_size):
+ async def _redactions_received_ts(self, progress: JsonDict, batch_size: int) -> int:
"""Handles filling out the `received_ts` column in redactions."""
last_event_id = progress.get("last_event_id", "")
- def _redactions_received_ts_txn(txn):
+ def _redactions_received_ts_txn(txn: LoggingTransaction) -> int:
# Fetch the set of event IDs that we want to update
sql = """
SELECT event_id FROM redactions
@@ -616,10 +628,12 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return count
- async def _event_fix_redactions_bytes(self, progress, batch_size):
+ async def _event_fix_redactions_bytes(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""Undoes hex encoded censored redacted event JSON."""
- def _event_fix_redactions_bytes_txn(txn):
+ def _event_fix_redactions_bytes_txn(txn: LoggingTransaction) -> None:
# This update is quite fast due to new index.
txn.execute(
"""
@@ -644,11 +658,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return 1
- async def _event_store_labels(self, progress, batch_size):
+ async def _event_store_labels(self, progress: JsonDict, batch_size: int) -> int:
"""Background update handler which will store labels for existing events."""
last_event_id = progress.get("last_event_id", "")
- def _event_store_labels_txn(txn):
+ def _event_store_labels_txn(txn: LoggingTransaction) -> int:
txn.execute(
"""
SELECT event_id, json FROM event_json
@@ -748,7 +762,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
),
)
- return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore
+ return cast(
+ List[Tuple[str, str, JsonDict, bool, bool]],
+ [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn],
+ )
results = await self.db_pool.runInteraction(
desc="_rejected_events_metadata_get", func=get_rejected_events
@@ -906,7 +923,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
def _calculate_chain_cover_txn(
self,
- txn: Cursor,
+ txn: LoggingTransaction,
last_room_id: str,
last_depth: int,
last_stream: int,
@@ -1017,10 +1034,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
PersistEventsStore._add_chain_cover_index(
txn,
self.db_pool,
- self.event_chain_id_gen,
+ self.event_chain_id_gen, # type: ignore[attr-defined]
event_to_room_id,
event_to_types,
- event_to_auth_chain,
+ cast(Dict[str, Sequence[str]], event_to_auth_chain),
)
return _CalculateChainCover(
@@ -1040,7 +1057,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
"""
current_event_id = progress.get("current_event_id", "")
- def purged_chain_cover_txn(txn) -> int:
+ def purged_chain_cover_txn(txn: LoggingTransaction) -> int:
# The event ID from events will be null if the chain ID / sequence
# number points to a purged event.
sql = """
@@ -1175,14 +1192,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
# Iterate the parent IDs and invalidate caches.
for parent_id in {r[1] for r in relations_to_insert}:
cache_tuple = (parent_id,)
- self._invalidate_cache_and_stream(
- txn, self.get_relations_for_event, cache_tuple
+ self._invalidate_cache_and_stream( # type: ignore[attr-defined]
+ txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined]
)
- self._invalidate_cache_and_stream(
- txn, self.get_aggregation_groups_for_event, cache_tuple
+ self._invalidate_cache_and_stream( # type: ignore[attr-defined]
+ txn, self.get_aggregation_groups_for_event, cache_tuple # type: ignore[attr-defined]
)
- self._invalidate_cache_and_stream(
- txn, self.get_thread_summary, cache_tuple
+ self._invalidate_cache_and_stream( # type: ignore[attr-defined]
+ txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined]
)
if results:
@@ -1214,7 +1231,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
"""
batch_size = max(batch_size, 1)
- def process(txn: Cursor) -> int:
+ def process(txn: LoggingTransaction) -> int:
last_stream = progress.get("last_stream", -(1 << 31))
txn.execute(
"""
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c7b660ac5a..8d4287045a 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1383,10 +1383,6 @@ class EventsWorkerStore(SQLBaseStore):
return {"v1": complexity_v1}
- def get_current_events_token(self) -> int:
- """The current maximum token that events have reached"""
- return self._stream_id_gen.get_current_token()
-
async def get_all_new_forward_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index cf842803bc..cb9ee08fa8 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Union
+from typing import Optional, Tuple, Union, cast
from canonicaljson import encode_canonical_json
@@ -63,7 +63,7 @@ class FilteringStore(SQLBaseStore):
sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
txn.execute(sql, (user_localpart,))
- max_id = txn.fetchone()[0] # type: ignore[index]
+ max_id = cast(Tuple[Optional[int]], txn.fetchone())[0]
if max_id is None:
filter_id = 0
else:
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index bb621df0dd..3f6086050b 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -19,8 +19,7 @@ from typing_extensions import TypedDict
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
-from synapse.storage.types import Connection
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.types import JsonDict
from synapse.util import json_encoder
@@ -40,7 +39,12 @@ class _RoomInGroup(TypedDict):
class GroupServerWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
database.updates.register_background_index_update(
update_name="local_group_updates_index",
index_name="local_group_updates_stream_id_index",
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index a540f7fb26..bedacaf0d7 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -20,8 +20,11 @@ from twisted.internet.interfaces import IReactorCore
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingTransaction
-from synapse.storage.types import Connection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.util import Clock
from synapse.util.stringutils import random_string
@@ -54,7 +57,12 @@ class LockStore(SQLBaseStore):
`last_renewed_ts` column with the current time.
"""
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._reactor = hs.get_reactor()
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 1b076683f7..cbba356b4a 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -23,6 +23,7 @@ from typing import (
Optional,
Tuple,
Union,
+ cast,
)
from synapse.storage._base import SQLBaseStore
@@ -220,7 +221,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
WHERE user_id = ?
"""
txn.execute(sql, args)
- count = txn.fetchone()[0] # type: ignore[index]
+ count = cast(Tuple[int], txn.fetchone())[0]
sql = """
SELECT
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index d901933ae4..1480a0f048 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Dict
from synapse.metrics import GaugeBucketCollector
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
@@ -55,7 +55,12 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
stats and prometheus metrics.
"""
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
# Read the extrems every 60 minutes
@@ -100,7 +105,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
def _count_messages(txn):
sql = """
- SELECT COALESCE(COUNT(*), 0) FROM events
+ SELECT COUNT(*) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
@@ -117,7 +122,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
like_clause = "%:" + self.hs.hostname
sql = """
- SELECT COALESCE(COUNT(*), 0) FROM events
+ SELECT COUNT(*) FROM events
WHERE type = 'm.room.encrypted'
AND sender LIKE ?
AND stream_ordering > ?
@@ -134,7 +139,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
async def count_daily_active_e2ee_rooms(self):
def _count(txn):
sql = """
- SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+ SELECT COUNT(DISTINCT room_id) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
@@ -156,7 +161,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
def _count_messages(txn):
sql = """
- SELECT COALESCE(COUNT(*), 0) FROM events
+ SELECT COUNT(*) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
@@ -173,7 +178,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
like_clause = "%:" + self.hs.hostname
sql = """
- SELECT COALESCE(COUNT(*), 0) FROM events
+ SELECT COUNT(*) FROM events
WHERE type = 'm.room.message'
AND sender LIKE ?
AND stream_ordering > ?
@@ -190,7 +195,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
async def count_daily_active_rooms(self):
def _count(txn):
sql = """
- SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+ SELECT COUNT(DISTINCT room_id) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
@@ -226,7 +231,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
Returns number of users seen in the past time_from period
"""
sql = """
- SELECT COALESCE(count(*), 0) FROM (
+ SELECT COUNT(*) FROM (
SELECT user_id FROM user_ips
WHERE last_seen > ?
GROUP BY user_id
@@ -253,7 +258,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
thirty_days_ago_in_secs = now - thirty_days_in_secs
sql = """
- SELECT platform, COALESCE(count(*), 0) FROM (
+ SELECT platform, COUNT(*) FROM (
SELECT
users.name, platform, users.creation_ts * 1000,
MAX(uip.last_seen)
@@ -291,7 +296,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
results[row[0]] = row[1]
sql = """
- SELECT COALESCE(count(*), 0) FROM (
+ SELECT COUNT(*) FROM (
SELECT users.name, users.creation_ts * 1000,
MAX(uip.last_seen)
FROM users
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index b5284e4f67..8f09dd8e87 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -16,8 +16,13 @@ from typing import TYPE_CHECKING, Dict, List, Optional
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ make_in_list_sql_clause,
+)
from synapse.util.caches.descriptors import cached
+from synapse.util.threepids import canonicalise_email
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -30,7 +35,12 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
class MonthlyActiveUsersWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._clock = hs.get_clock()
self.hs = hs
@@ -49,7 +59,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
def _count_users(txn):
# Exclude app service users
sql = """
- SELECT COALESCE(count(*), 0)
+ SELECT COUNT(*)
FROM monthly_active_users
LEFT JOIN users
ON monthly_active_users.user_id=users.name
@@ -76,7 +86,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
def _count_users_by_service(txn):
sql = """
- SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0)
+ SELECT COALESCE(appservice_id, 'native'), COUNT(*)
FROM monthly_active_users
LEFT JOIN users ON monthly_active_users.user_id=users.name
GROUP BY appservice_id;
@@ -103,7 +113,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
: self.hs.config.server.max_mau_value
]:
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
- tp["medium"], tp["address"]
+ tp["medium"], canonicalise_email(tp["address"])
)
if user_id:
users.append(user_id)
@@ -212,7 +222,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._mau_stats_only = hs.config.server.mau_stats_only
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index cc0eebdb46..cbf9ec38f7 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
@@ -33,7 +33,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore):
def __init__(
self,
database: DatabasePool,
- db_conn: Connection,
+ db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
@@ -52,7 +52,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
def __init__(
self,
database: DatabasePool,
- db_conn: Connection,
+ db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
@@ -269,6 +269,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
"""
# Add user entries to the table, updating the presence_stream_id column if the user already
# exists in the table.
+ presence_stream_id = self._presence_id_gen.get_current_token()
await self.db_pool.simple_upsert_many(
table="users_to_send_full_presence_to",
key_names=("user_id",),
@@ -279,9 +280,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
# devices at different times, each device will receive full presence once - when
# the presence stream ID in their sync token is less than the one in the table
# for their user ID.
- value_values=(
- (self._presence_id_gen.get_current_token(),) for _ in user_ids
- ),
+ value_values=[(presence_stream_id,) for _ in user_ids],
desc="add_users_to_send_full_presence_to",
)
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 3b63267395..e01c94930a 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -20,7 +20,7 @@ from synapse.api.errors import NotFoundError, StoreError
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.pusher import PusherWorkerStore
@@ -81,7 +81,12 @@ class PushRulesWorkerStore(
`get_max_push_rules_stream_id` which can be called in the initializer.
"""
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index b73ce53c91..747b4f31df 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -22,7 +22,7 @@ from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
-from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -196,27 +196,6 @@ class PusherWorkerStore(SQLBaseStore):
# This only exists for the cachedList decorator
raise NotImplementedError()
- @cachedList(
- cached_method_name="get_if_user_has_pusher",
- list_name="user_ids",
- num_args=1,
- )
- async def get_if_users_have_pushers(
- self, user_ids: Iterable[str]
- ) -> Dict[str, bool]:
- rows = await self.db_pool.simple_select_many_batch(
- table="pushers",
- column="user_name",
- iterable=user_ids,
- retcols=["user_name"],
- desc="get_if_users_have_pushers",
- )
-
- result = {user_id: False for user_id in user_ids}
- result.update({r["user_name"]: True for r in rows})
-
- return result
-
async def update_pusher_last_stream_ordering(
self, app_id, pushkey, user_id, last_stream_ordering
) -> None:
@@ -515,7 +494,7 @@ class PusherStore(PusherWorkerStore):
# invalidate, since we the user might not have had a pusher before
await self.db_pool.runInteraction(
"add_pusher",
- self._invalidate_cache_and_stream, # type: ignore
+ self._invalidate_cache_and_stream, # type: ignore[attr-defined]
self.get_if_user_has_pusher,
(user_id,),
)
@@ -524,7 +503,7 @@ class PusherStore(PusherWorkerStore):
self, app_id: str, pushkey: str, user_id: str
) -> None:
def delete_pusher_txn(txn, stream_id):
- self._invalidate_cache_and_stream( # type: ignore
+ self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_if_user_has_pusher, (user_id,)
)
@@ -569,7 +548,7 @@ class PusherStore(PusherWorkerStore):
pushers = list(await self.get_pushers_by_user_id(user_id))
def delete_pushers_txn(txn, stream_ids):
- self._invalidate_cache_and_stream( # type: ignore
+ self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_if_user_has_pusher, (user_id,)
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index c99f8aebdb..bf0b903af2 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,14 +14,29 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+)
from twisted.internet import defer
+from synapse.api.constants import ReceiptTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
@@ -36,7 +51,12 @@ logger = logging.getLogger(__name__)
class ReceiptsWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
self._instance_name = hs.get_instance_name()
if isinstance(database.engine, PostgresEngine):
@@ -78,17 +98,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
)
- def get_max_receipt_stream_id(self):
- """Get the current max stream ID for receipts stream
-
- Returns:
- int
- """
+ def get_max_receipt_stream_id(self) -> int:
+ """Get the current max stream ID for receipts stream"""
return self._receipts_id_gen.get_current_token()
@cached()
- async def get_users_with_read_receipts_in_room(self, room_id):
- receipts = await self.get_receipts_for_room(room_id, "m.read")
+ async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]:
+ receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ)
return {r["user_id"] for r in receipts}
@cached(num_args=2)
@@ -119,7 +135,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
@cached(num_args=2)
- async def get_receipts_for_user(self, user_id, receipt_type):
+ async def get_receipts_for_user(
+ self, user_id: str, receipt_type: str
+ ) -> Dict[str, str]:
rows = await self.db_pool.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
@@ -129,8 +147,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
return {row["room_id"]: row["event_id"] for row in rows}
- async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
- def f(txn):
+ async def get_receipts_for_user_with_orderings(
+ self, user_id: str, receipt_type: str
+ ) -> JsonDict:
+ def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
sql = (
"SELECT rl.room_id, rl.event_id,"
" e.topological_ordering, e.stream_ordering"
@@ -209,10 +229,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(num_args=3, tree=True)
async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
- ) -> List[dict]:
+ ) -> List[JsonDict]:
"""See get_linearized_receipts_for_room"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if from_key:
sql = (
"SELECT * FROM receipts_linearized WHERE"
@@ -250,11 +270,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
list_name="room_ids",
num_args=3,
)
- async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+ async def _get_linearized_receipts_for_rooms(
+ self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
+ ) -> Dict[str, List[JsonDict]]:
if not room_ids:
return {}
- def f(txn):
+ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if from_key:
sql = """
SELECT * FROM receipts_linearized WHERE
@@ -323,7 +345,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
A dictionary of roomids to a list of receipts.
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if from_key:
sql = """
SELECT * FROM receipts_linearized WHERE
@@ -379,7 +401,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
if last_id == current_id:
return defer.succeed([])
- def _get_users_sent_receipts_between_txn(txn):
+ def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
sql = """
SELECT DISTINCT user_id FROM receipts_linearized
WHERE ? < stream_id AND stream_id <= ?
@@ -419,7 +441,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
- def get_all_updated_receipts_txn(txn):
+ def get_all_updated_receipts_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, list]], int, bool]:
sql = """
SELECT stream_id, room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized
@@ -446,8 +470,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
def _invalidate_get_users_with_receipts_in_room(
self, room_id: str, receipt_type: str, user_id: str
- ):
- if receipt_type != "m.read":
+ ) -> None:
+ if receipt_type != ReceiptTypes.READ:
return
res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
@@ -461,7 +485,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
- def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
+ def invalidate_caches_for_receipt(
+ self, room_id: str, receipt_type: str, user_id: str
+ ) -> None:
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
@@ -482,11 +508,18 @@ class ReceiptsWorkerStore(SQLBaseStore):
return super().process_replication_rows(stream_name, instance_name, token, rows)
def insert_linearized_receipt_txn(
- self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
- ):
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ receipt_type: str,
+ user_id: str,
+ event_id: str,
+ data: JsonDict,
+ stream_id: int,
+ ) -> Optional[int]:
"""Inserts a read-receipt into the database if it's newer than the current RR
- Returns: int|None
+ Returns:
None if the RR is older than the current RR
otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown)
@@ -550,7 +583,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
lock=False,
)
- if receipt_type == "m.read" and stream_ordering is not None:
+ if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
self._remove_old_push_actions_before_txn(
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
)
@@ -580,7 +613,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
else:
# we need to points in graph -> linearized form.
# TODO: Make this better.
- def graph_to_linear(txn):
+ def graph_to_linear(txn: LoggingTransaction) -> str:
clause, args = make_in_list_sql_clause(
self.database_engine, "event_id", event_ids
)
@@ -634,11 +667,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
return stream_id, max_persisted_id
async def insert_graph_receipt(
- self, room_id, receipt_type, user_id, event_ids, data
- ):
+ self,
+ room_id: str,
+ receipt_type: str,
+ user_id: str,
+ event_ids: List[str],
+ data: JsonDict,
+ ) -> None:
assert self._can_write_to_receipts
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id,
@@ -649,8 +687,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
def insert_graph_receipt_txn(
- self, txn, room_id, receipt_type, user_id, event_ids, data
- ):
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ receipt_type: str,
+ user_id: str,
+ event_ids: List[str],
+ data: JsonDict,
+ ) -> None:
assert self._can_write_to_receipts
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e1ddf06916..4175c82a25 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@
import logging
import random
import re
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import attr
@@ -794,7 +794,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
yesterday = int(self._clock.time()) - (60 * 60 * 24)
sql = """
- SELECT user_type, COALESCE(count(*), 0) AS count FROM (
+ SELECT user_type, COUNT(*) AS count FROM (
SELECT
CASE
WHEN is_guest=0 AND appservice_id IS NULL THEN 'native'
@@ -819,7 +819,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def _count_users(txn):
txn.execute(
"""
- SELECT COALESCE(COUNT(*), 0) FROM users
+ SELECT COUNT(*) FROM users
WHERE appservice_id IS NULL
"""
)
@@ -856,7 +856,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Args:
medium: threepid medium e.g. email
- address: threepid address e.g. me@example.com
+ address: threepid address e.g. me@example.com. This must already be
+ in canonical form.
Returns:
The user ID or None if no user id/threepid mapping exists
@@ -1356,12 +1357,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# Override type because the return type is only optional if
# allow_none is True, and we don't want mypy throwing errors
# about None not being indexable.
- res: Dict[str, Any] = self.db_pool.simple_select_one_txn(
- txn,
- "registration_tokens",
- keyvalues={"token": token},
- retcols=["pending", "completed"],
- ) # type: ignore
+ res = cast(
+ Dict[str, Any],
+ self.db_pool.simple_select_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["pending", "completed"],
+ ),
+ )
# Decrement pending and increment completed
self.db_pool.simple_update_one_txn(
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 0a43acda07..4ff6aed253 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union, cast
import attr
@@ -37,6 +37,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def get_relations_for_event(
self,
event_id: str,
+ room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
aggregation_key: Optional[str] = None,
@@ -49,6 +50,7 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event_id: Fetch events that relate to this event ID.
+ room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
aggregation_key: Only fetch events with this aggregation key, if given.
@@ -63,8 +65,8 @@ class RelationsWorkerStore(SQLBaseStore):
the form `{"event_id": "..."}`.
"""
- where_clause = ["relates_to_id = ?"]
- where_args: List[Union[str, int]] = [event_id]
+ where_clause = ["relates_to_id = ?", "room_id = ?"]
+ where_args: List[Union[str, int]] = [event_id, room_id]
if relation_type is not None:
where_clause.append("relation_type = ?")
@@ -199,6 +201,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def get_aggregation_groups_for_event(
self,
event_id: str,
+ room_id: str,
event_type: Optional[str] = None,
limit: int = 5,
direction: str = "b",
@@ -213,6 +216,7 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event_id: Fetch events that relate to this event ID.
+ room_id: The room the event belongs to.
event_type: Only fetch events with this event type, if given.
limit: Only fetch the `limit` groups.
direction: Whether to fetch the highest count first (`"b"`) or
@@ -225,8 +229,12 @@ class RelationsWorkerStore(SQLBaseStore):
`type`, `key` and `count` fields.
"""
- where_clause = ["relates_to_id = ?", "relation_type = ?"]
- where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]
+ where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
+ where_args: List[Union[str, int]] = [
+ event_id,
+ room_id,
+ RelationTypes.ANNOTATION,
+ ]
if event_type:
where_clause.append("type = ?")
@@ -288,7 +296,9 @@ class RelationsWorkerStore(SQLBaseStore):
)
@cached()
- async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
+ async def get_applicable_edit(
+ self, event_id: str, room_id: str
+ ) -> Optional[EventBase]:
"""Get the most recent edit (if any) that has happened for the given
event.
@@ -296,6 +306,7 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event_id: The original event ID
+ room_id: The original event's room ID
Returns:
The most recent edit, if any.
@@ -317,13 +328,14 @@ class RelationsWorkerStore(SQLBaseStore):
WHERE
relates_to_id = ?
AND relation_type = ?
+ AND edit.room_id = ?
AND edit.type = 'm.room.message'
ORDER by edit.origin_server_ts DESC, edit.event_id DESC
LIMIT 1
"""
def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
- txn.execute(sql, (event_id, RelationTypes.REPLACE))
+ txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id))
row = txn.fetchone()
if row:
return row[0]
@@ -340,13 +352,14 @@ class RelationsWorkerStore(SQLBaseStore):
@cached()
async def get_thread_summary(
- self, event_id: str
+ self, event_id: str, room_id: str
) -> Tuple[int, Optional[EventBase]]:
"""Get the number of threaded replies, the senders of those replies, and
the latest reply (if any) for the given event.
Args:
- event_id: The original event ID
+ event_id: Summarize the thread related to this event ID.
+ room_id: The room the event belongs to.
Returns:
The number of items in the thread and the most recent response, if any.
@@ -363,12 +376,13 @@ class RelationsWorkerStore(SQLBaseStore):
INNER JOIN events USING (event_id)
WHERE
relates_to_id = ?
+ AND room_id = ?
AND relation_type = ?
ORDER BY topological_ordering DESC, stream_ordering DESC
LIMIT 1
"""
- txn.execute(sql, (event_id, RelationTypes.THREAD))
+ txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
row = txn.fetchone()
if row is None:
return 0, None
@@ -376,14 +390,16 @@ class RelationsWorkerStore(SQLBaseStore):
latest_event_id = row[0]
sql = """
- SELECT COALESCE(COUNT(event_id), 0)
+ SELECT COUNT(event_id)
FROM event_relations
+ INNER JOIN events USING (event_id)
WHERE
relates_to_id = ?
+ AND room_id = ?
AND relation_type = ?
"""
- txn.execute(sql, (event_id, RelationTypes.THREAD))
- count = txn.fetchone()[0] # type: ignore[index]
+ txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
+ count = cast(Tuple[int], txn.fetchone())[0]
return count, latest_event_id
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 7d694d852d..c0e837854a 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -13,20 +13,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import collections
import logging
from abc import abstractmethod
from enum import Enum
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+ Union,
+ cast,
+)
+
+import attr
from synapse.api.constants import EventContentFields, EventTypes, JoinRules
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingTransaction
-from synapse.storage.databases.main.search import SearchStore
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import IdGenerator
from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -38,9 +54,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-RatelimitOverride = collections.namedtuple(
- "RatelimitOverride", ("messages_per_second", "burst_count")
-)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RatelimitOverride:
+ messages_per_second: int
+ burst_count: int
class RoomSortOrder(Enum):
@@ -71,8 +88,13 @@ class RoomSortOrder(Enum):
STATE_EVENTS = "state_events"
-class RoomWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+class RoomWorkerStore(CacheInvalidationWorkerStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.config = hs.config
@@ -83,7 +105,7 @@ class RoomWorkerStore(SQLBaseStore):
room_creator_user_id: str,
is_public: bool,
room_version: RoomVersion,
- ):
+ ) -> None:
"""Stores a room.
Args:
@@ -111,7 +133,7 @@ class RoomWorkerStore(SQLBaseStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
- async def get_room(self, room_id: str) -> dict:
+ async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]:
"""Retrieve a room.
Args:
@@ -136,7 +158,9 @@ class RoomWorkerStore(SQLBaseStore):
A dict containing the room information, or None if the room is unknown.
"""
- def get_room_with_stats_txn(txn, room_id):
+ def get_room_with_stats_txn(
+ txn: LoggingTransaction, room_id: str
+ ) -> Optional[Dict[str, Any]]:
sql = """
SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
@@ -185,7 +209,7 @@ class RoomWorkerStore(SQLBaseStore):
ignore_non_federatable: If true filters out non-federatable rooms
"""
- def _count_public_rooms_txn(txn):
+ def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
query_args = []
if network_tuple:
@@ -195,6 +219,7 @@ class RoomWorkerStore(SQLBaseStore):
WHERE appservice_id = ? AND network_id = ?
"""
query_args.append(network_tuple.appservice_id)
+ assert network_tuple.network_id is not None
query_args.append(network_tuple.network_id)
else:
published_sql = """
@@ -208,7 +233,7 @@ class RoomWorkerStore(SQLBaseStore):
sql = """
SELECT
- COALESCE(COUNT(*), 0)
+ COUNT(*)
FROM (
%(published_sql)s
) published
@@ -226,7 +251,7 @@ class RoomWorkerStore(SQLBaseStore):
}
txn.execute(sql, query_args)
- return txn.fetchone()[0]
+ return cast(Tuple[int], txn.fetchone())[0]
return await self.db_pool.runInteraction(
"count_public_rooms", _count_public_rooms_txn
@@ -235,11 +260,11 @@ class RoomWorkerStore(SQLBaseStore):
async def get_room_count(self) -> int:
"""Retrieve the total number of rooms."""
- def f(txn):
+ def f(txn: LoggingTransaction) -> int:
sql = "SELECT count(*) FROM rooms"
txn.execute(sql)
- row = txn.fetchone()
- return row[0] or 0
+ row = cast(Tuple[int], txn.fetchone())
+ return row[0]
return await self.db_pool.runInteraction("get_rooms", f)
@@ -251,7 +276,7 @@ class RoomWorkerStore(SQLBaseStore):
bounds: Optional[Tuple[int, str]],
forwards: bool,
ignore_non_federatable: bool = False,
- ):
+ ) -> List[Dict[str, Any]]:
"""Gets the largest public rooms (where largest is in terms of joined
members, as tracked in the statistics table).
@@ -272,7 +297,7 @@ class RoomWorkerStore(SQLBaseStore):
"""
where_clauses = []
- query_args = []
+ query_args: List[Union[str, int]] = []
if network_tuple:
if network_tuple.appservice_id:
@@ -281,6 +306,7 @@ class RoomWorkerStore(SQLBaseStore):
WHERE appservice_id = ? AND network_id = ?
"""
query_args.append(network_tuple.appservice_id)
+ assert network_tuple.network_id is not None
query_args.append(network_tuple.network_id)
else:
published_sql = """
@@ -372,7 +398,9 @@ class RoomWorkerStore(SQLBaseStore):
LIMIT ?
"""
- def _get_largest_public_rooms_txn(txn):
+ def _get_largest_public_rooms_txn(
+ txn: LoggingTransaction,
+ ) -> List[Dict[str, Any]]:
txn.execute(sql, query_args)
results = self.db_pool.cursor_to_dict(txn)
@@ -435,7 +463,7 @@ class RoomWorkerStore(SQLBaseStore):
"""
# Filter room names by a string
where_statement = ""
- search_pattern = []
+ search_pattern: List[object] = []
if search_term:
where_statement = """
WHERE LOWER(state.name) LIKE ?
@@ -543,7 +571,9 @@ class RoomWorkerStore(SQLBaseStore):
where_statement,
)
- def _get_rooms_paginate_txn(txn):
+ def _get_rooms_paginate_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Dict[str, Any]], int]:
# Add the search term into the WHERE clause
# and execute the data query
txn.execute(info_sql, search_pattern + [limit, start])
@@ -575,7 +605,7 @@ class RoomWorkerStore(SQLBaseStore):
# Add the search term into the WHERE clause if present
txn.execute(count_sql, search_pattern)
- room_count = txn.fetchone()
+ room_count = cast(Tuple[int], txn.fetchone())
return rooms, room_count[0]
return await self.db_pool.runInteraction(
@@ -620,7 +650,7 @@ class RoomWorkerStore(SQLBaseStore):
burst_count: How many actions that can be performed before being limited.
"""
- def set_ratelimit_txn(txn):
+ def set_ratelimit_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_upsert_txn(
txn,
table="ratelimit_override",
@@ -643,7 +673,7 @@ class RoomWorkerStore(SQLBaseStore):
user_id: user ID of the user
"""
- def delete_ratelimit_txn(txn):
+ def delete_ratelimit_txn(txn: LoggingTransaction) -> None:
row = self.db_pool.simple_select_one_txn(
txn,
table="ratelimit_override",
@@ -667,7 +697,7 @@ class RoomWorkerStore(SQLBaseStore):
await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn)
@cached()
- async def get_retention_policy_for_room(self, room_id):
+ async def get_retention_policy_for_room(self, room_id: str) -> Dict[str, int]:
"""Get the retention policy for a given room.
If no retention policy has been found for this room, returns a policy defined
@@ -676,13 +706,15 @@ class RoomWorkerStore(SQLBaseStore):
configuration).
Args:
- room_id (str): The ID of the room to get the retention policy of.
+ room_id: The ID of the room to get the retention policy of.
Returns:
- dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
+ A dict containing "min_lifetime" and "max_lifetime" for this room.
"""
- def get_retention_policy_for_room_txn(txn):
+ def get_retention_policy_for_room_txn(
+ txn: LoggingTransaction,
+ ) -> List[Dict[str, Optional[int]]]:
txn.execute(
"""
SELECT min_lifetime, max_lifetime FROM room_retention
@@ -707,19 +739,23 @@ class RoomWorkerStore(SQLBaseStore):
"max_lifetime": self.config.retention.retention_default_max_lifetime,
}
- row = ret[0]
+ min_lifetime = ret[0]["min_lifetime"]
+ max_lifetime = ret[0]["max_lifetime"]
# If one of the room's policy's attributes isn't defined, use the matching
# attribute from the default policy.
# The default values will be None if no default policy has been defined, or if one
# of the attributes is missing from the default policy.
- if row["min_lifetime"] is None:
- row["min_lifetime"] = self.config.retention.retention_default_min_lifetime
+ if min_lifetime is None:
+ min_lifetime = self.config.retention.retention_default_min_lifetime
- if row["max_lifetime"] is None:
- row["max_lifetime"] = self.config.retention.retention_default_max_lifetime
+ if max_lifetime is None:
+ max_lifetime = self.config.retention.retention_default_max_lifetime
- return row
+ return {
+ "min_lifetime": min_lifetime,
+ "max_lifetime": max_lifetime,
+ }
async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
"""Retrieves all the local and remote media MXC URIs in a given room
@@ -731,7 +767,9 @@ class RoomWorkerStore(SQLBaseStore):
The local and remote media as a lists of the media IDs.
"""
- def _get_media_mxcs_in_room_txn(txn):
+ def _get_media_mxcs_in_room_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[str], List[str]]:
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
local_media_mxcs = []
remote_media_mxcs = []
@@ -757,7 +795,7 @@ class RoomWorkerStore(SQLBaseStore):
logger.info("Quarantining media in room: %s", room_id)
- def _quarantine_media_in_room_txn(txn):
+ def _quarantine_media_in_room_txn(txn: LoggingTransaction) -> int:
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
return self._quarantine_media_txn(
txn, local_mxcs, remote_mxcs, quarantined_by
@@ -767,13 +805,11 @@ class RoomWorkerStore(SQLBaseStore):
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
- def _get_media_mxcs_in_room_txn(self, txn, room_id):
+ def _get_media_mxcs_in_room_txn(
+ self, txn: LoggingTransaction, room_id: str
+ ) -> Tuple[List[str], List[Tuple[str, str]]]:
"""Retrieves all the local and remote media MXC URIs in a given room
- Args:
- txn (cursor)
- room_id (str)
-
Returns:
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
@@ -841,7 +877,7 @@ class RoomWorkerStore(SQLBaseStore):
logger.info("Quarantining media: %s/%s", server_name, media_id)
is_local = server_name == self.config.server.server_name
- def _quarantine_media_by_id_txn(txn):
+ def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int:
local_mxcs = [media_id] if is_local else []
remote_mxcs = [(server_name, media_id)] if not is_local else []
@@ -863,7 +899,7 @@ class RoomWorkerStore(SQLBaseStore):
quarantined_by: The ID of the user who made the quarantine request
"""
- def _quarantine_media_by_user_txn(txn):
+ def _quarantine_media_by_user_txn(txn: LoggingTransaction) -> int:
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
@@ -871,7 +907,9 @@ class RoomWorkerStore(SQLBaseStore):
"quarantine_media_by_user", _quarantine_media_by_user_txn
)
- def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True):
+ def _get_media_ids_by_user_txn(
+ self, txn: LoggingTransaction, user_id: str, filter_quarantined: bool = True
+ ) -> List[str]:
"""Retrieves local media IDs by a given user
Args:
@@ -900,7 +938,7 @@ class RoomWorkerStore(SQLBaseStore):
def _quarantine_media_txn(
self,
- txn,
+ txn: LoggingTransaction,
local_mxcs: List[str],
remote_mxcs: List[Tuple[str, str]],
quarantined_by: Optional[str],
@@ -928,12 +966,15 @@ class RoomWorkerStore(SQLBaseStore):
# set quarantine
if quarantined_by is not None:
sql += "AND safe_from_quarantine = ?"
- rows = [(quarantined_by, media_id, False) for media_id in local_mxcs]
+ txn.executemany(
+ sql, [(quarantined_by, media_id, False) for media_id in local_mxcs]
+ )
# remove from quarantine
else:
- rows = [(quarantined_by, media_id) for media_id in local_mxcs]
+ txn.executemany(
+ sql, [(quarantined_by, media_id) for media_id in local_mxcs]
+ )
- txn.executemany(sql, rows)
# Note that a rowcount of -1 can be used to indicate no rows were affected.
total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
@@ -951,7 +992,7 @@ class RoomWorkerStore(SQLBaseStore):
async def get_rooms_for_retention_period_in_range(
self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
- ) -> Dict[str, dict]:
+ ) -> Dict[str, Dict[str, Optional[int]]]:
"""Retrieves all of the rooms within the given retention range.
Optionally includes the rooms which don't have a retention policy.
@@ -971,7 +1012,9 @@ class RoomWorkerStore(SQLBaseStore):
"min_lifetime" (int|None), and "max_lifetime" (int|None).
"""
- def get_rooms_for_retention_period_in_range_txn(txn):
+ def get_rooms_for_retention_period_in_range_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[str, Dict[str, Optional[int]]]:
range_conditions = []
args = []
@@ -1050,11 +1093,14 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
class RoomBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
- self.config = hs.config
-
self.db_pool.updates.register_background_update_handler(
"insert_room_retention",
self._background_insert_retention,
@@ -1085,7 +1131,9 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
self._background_populate_rooms_creator_column,
)
- async def _background_insert_retention(self, progress, batch_size):
+ async def _background_insert_retention(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""Retrieves a list of all rooms within a range and inserts an entry for each of
them into the room_retention table.
NULLs the property's columns if missing from the retention event in the room's
@@ -1095,7 +1143,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
last_room = progress.get("room_id", "")
- def _background_insert_retention_txn(txn):
+ def _background_insert_retention_txn(txn: LoggingTransaction) -> bool:
txn.execute(
"""
SELECT state.room_id, state.event_id, events.json
@@ -1154,15 +1202,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return batch_size
async def _background_add_rooms_room_version_column(
- self, progress: dict, batch_size: int
- ):
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""Background update to go and add room version information to `rooms`
table from `current_state_events` table.
"""
last_room_id = progress.get("room_id", "")
- def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction):
+ def _background_add_rooms_room_version_column_txn(
+ txn: LoggingTransaction,
+ ) -> bool:
sql = """
SELECT room_id, json FROM current_state_events
INNER JOIN event_json USING (room_id, event_id)
@@ -1223,7 +1273,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return batch_size
async def _remove_tombstoned_rooms_from_directory(
- self, progress, batch_size
+ self, progress: JsonDict, batch_size: int
) -> int:
"""Removes any rooms with tombstone events from the room directory
@@ -1233,7 +1283,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
last_room = progress.get("room_id", "")
- def _get_rooms(txn):
+ def _get_rooms(txn: LoggingTransaction) -> List[str]:
txn.execute(
"""
SELECT room_id
@@ -1271,7 +1321,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return len(rooms)
@abstractmethod
- def set_room_is_public(self, room_id, is_public):
+ def set_room_is_public(self, room_id: str, is_public: bool) -> Awaitable[None]:
# this will need to be implemented if a background update is performed with
# existing (tombstoned, public) rooms in the database.
#
@@ -1318,7 +1368,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
32-bit integer field.
"""
- def process(txn: Cursor) -> int:
+ def process(txn: LoggingTransaction) -> int:
last_room = progress.get("last_room", "")
txn.execute(
"""
@@ -1375,15 +1425,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return 0
async def _background_populate_rooms_creator_column(
- self, progress: dict, batch_size: int
- ):
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""Background update to go and add creator information to `rooms`
table from `current_state_events` table.
"""
last_room_id = progress.get("room_id", "")
- def _background_populate_rooms_creator_column_txn(txn: LoggingTransaction):
+ def _background_populate_rooms_creator_column_txn(
+ txn: LoggingTransaction,
+ ) -> bool:
sql = """
SELECT room_id, json FROM event_json
INNER JOIN rooms AS room USING (room_id)
@@ -1434,15 +1486,20 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return batch_size
-class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
- self.config = hs.config
+ self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
async def upsert_room_on_join(
self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase]
- ):
+ ) -> None:
"""Ensure that the room is stored in the table
Called when we join a room over federation, and overwrites any room version
@@ -1488,7 +1545,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
async def maybe_store_room_on_outlier_membership(
self, room_id: str, room_version: RoomVersion
- ):
+ ) -> None:
"""
When we receive an invite or any other event over federation that may relate to a room
we are not in, store the version of the room if we don't already know the room version.
@@ -1528,8 +1585,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
self.hs.get_notifier().on_new_replication_data()
async def set_room_is_public_appservice(
- self, room_id, appservice_id, network_id, is_public
- ):
+ self, room_id: str, appservice_id: str, network_id: str, is_public: bool
+ ) -> None:
"""Edit the appservice/network specific public room list.
Each appservice can have a number of published room lists associated
@@ -1538,11 +1595,10 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
network.
Args:
- room_id (str)
- appservice_id (str)
- network_id (str)
- is_public (bool): Whether to publish or unpublish the room from the
- list.
+ room_id
+ appservice_id
+ network_id
+ is_public: Whether to publish or unpublish the room from the list.
"""
if is_public:
@@ -1607,7 +1663,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
event_report: json list of information from event report
"""
- def _get_event_report_txn(txn, report_id):
+ def _get_event_report_txn(
+ txn: LoggingTransaction, report_id: int
+ ) -> Optional[Dict[str, Any]]:
sql = """
SELECT
@@ -1679,9 +1737,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
count: total number of event reports matching the filter criteria
"""
- def _get_event_reports_paginate_txn(txn):
+ def _get_event_reports_paginate_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Dict[str, Any]], int]:
filters = []
- args = []
+ args: List[object] = []
if user_id:
filters.append("er.user_id LIKE ?")
@@ -1705,7 +1765,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
where_clause
)
txn.execute(sql, args)
- count = txn.fetchone()[0]
+ count = cast(Tuple[int], txn.fetchone())[0]
sql = """
SELECT
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 6b2a8d06a6..cda80d6511 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -37,7 +37,7 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process,
)
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import Sqlite3Engine
from synapse.storage.roommember import (
@@ -64,7 +64,12 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
class RoomMemberWorkerStore(EventsWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
# Used by `_get_joined_hosts` to ensure only one thing mutates the cache
@@ -985,7 +990,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
class RoomMemberBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
@@ -1135,7 +1145,12 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
async def forget(self, user_id: str, room_id: str) -> None:
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 642560a70d..3cbaca21b5 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -14,13 +14,18 @@
import logging
import re
-from collections import namedtuple
from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
+import attr
+
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@@ -29,10 +34,15 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-SearchEntry = namedtuple(
- "SearchEntry",
- ["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"],
-)
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class SearchEntry:
+ key: str
+ value: str
+ event_id: str
+ room_id: str
+ stream_ordering: Optional[int]
+ origin_server_ts: int
def _clean_value_for_search(value: str) -> str:
@@ -105,7 +115,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
if not hs.config.server.enable_search:
@@ -358,7 +373,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
class SearchStore(SearchBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
async def search_msgs(self, room_ids, search_term, keys):
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index fa2c3b1feb..2fb3e65192 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -14,7 +14,6 @@
# limitations under the License.
import collections.abc
import logging
-from collections import namedtuple
from typing import TYPE_CHECKING, Iterable, Optional, Set
from synapse.api.constants import EventTypes, Membership
@@ -22,7 +21,11 @@ from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
@@ -39,24 +42,16 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
-class _GetStateGroupDelta(
- namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
-):
- """Return type of get_state_group_delta that implements __len__, which lets
- us use the itrable flag when caching
- """
-
- __slots__ = []
-
- def __len__(self):
- return len(self.delta_ids) if self.delta_ids else 0
-
-
# this inherits from EventsWorkerStore because it calls self.get_events
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers."""
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
async def get_room_version(self, room_id: str) -> RoomVersion:
@@ -182,11 +177,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
NotFoundError if the room is unknown
"""
state_ids = await self.get_current_state_ids(room_id)
+
+ if not state_ids:
+ raise NotFoundError(f"Current state for room {room_id} is empty")
+
create_id = state_ids.get((EventTypes.Create, ""))
# If we can't find the create event, assume we've hit a dead end
if not create_id:
- raise NotFoundError("Unknown room %s" % (room_id,))
+ raise NotFoundError(f"No create event in current state for room {room_id}")
# Retrieve the room's create event and return
create_event = await self.get_event(create_id)
@@ -349,7 +348,12 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@@ -536,5 +540,10 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
* `state_groups_state`: Maps state group to state events.
"""
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 7f3624b128..188afec332 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -56,7 +56,9 @@ class StateDeltasStore(SQLBaseStore):
prev_stream_id = int(prev_stream_id)
# check we're not going backwards
- assert prev_stream_id <= max_stream_id
+ assert (
+ prev_stream_id <= max_stream_id
+ ), f"New stream id {max_stream_id} is smaller than prev stream id {prev_stream_id}"
if not self._curr_state_delta_stream_cache.has_any_entity_changed(
prev_stream_id
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 5d7b59d861..427ae1f649 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -16,7 +16,7 @@
import logging
from enum import Enum
from itertools import chain
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
from typing_extensions import Counter
@@ -24,7 +24,11 @@ from twisted.internet.defer import DeferredLock
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.errors import StoreError
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
@@ -96,7 +100,12 @@ class UserSortOrder(Enum):
class StatsStore(StateDeltasStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@@ -117,7 +126,9 @@ class StatsStore(StateDeltasStore):
self.db_pool.updates.register_noop_background_update("populate_stats_cleanup")
self.db_pool.updates.register_noop_background_update("populate_stats_prepare")
- async def _populate_stats_process_users(self, progress, batch_size):
+ async def _populate_stats_process_users(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""
This is a background update which regenerates statistics for users.
"""
@@ -129,7 +140,7 @@ class StatsStore(StateDeltasStore):
last_user_id = progress.get("last_user_id", "")
- def _get_next_batch(txn):
+ def _get_next_batch(txn: LoggingTransaction) -> List[str]:
sql = """
SELECT DISTINCT name FROM users
WHERE name > ?
@@ -163,7 +174,9 @@ class StatsStore(StateDeltasStore):
return len(users_to_work_on)
- async def _populate_stats_process_rooms(self, progress, batch_size):
+ async def _populate_stats_process_rooms(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""This is a background update which regenerates statistics for rooms."""
if not self.stats_enabled:
await self.db_pool.updates._end_background_update(
@@ -173,7 +186,7 @@ class StatsStore(StateDeltasStore):
last_room_id = progress.get("last_room_id", "")
- def _get_next_batch(txn):
+ def _get_next_batch(txn: LoggingTransaction) -> List[str]:
sql = """
SELECT DISTINCT room_id FROM current_state_events
WHERE room_id > ?
@@ -302,7 +315,7 @@ class StatsStore(StateDeltasStore):
stream_id: Current position.
"""
- def _bulk_update_stats_delta_txn(txn):
+ def _bulk_update_stats_delta_txn(txn: LoggingTransaction) -> None:
for stats_type, stats_updates in updates.items():
for stats_id, fields in stats_updates.items():
logger.debug(
@@ -334,7 +347,7 @@ class StatsStore(StateDeltasStore):
stats_type: str,
stats_id: str,
fields: Dict[str, int],
- complete_with_stream_id: Optional[int],
+ complete_with_stream_id: int,
absolute_field_overrides: Optional[Dict[str, int]] = None,
) -> None:
"""
@@ -367,14 +380,14 @@ class StatsStore(StateDeltasStore):
def _update_stats_delta_txn(
self,
- txn,
- ts,
- stats_type,
- stats_id,
- fields,
- complete_with_stream_id,
- absolute_field_overrides=None,
- ):
+ txn: LoggingTransaction,
+ ts: int,
+ stats_type: str,
+ stats_id: str,
+ fields: Dict[str, int],
+ complete_with_stream_id: int,
+ absolute_field_overrides: Optional[Dict[str, int]] = None,
+ ) -> None:
if absolute_field_overrides is None:
absolute_field_overrides = {}
@@ -417,20 +430,23 @@ class StatsStore(StateDeltasStore):
)
def _upsert_with_additive_relatives_txn(
- self, txn, table, keyvalues, absolutes, additive_relatives
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ absolutes: Dict[str, Any],
+ additive_relatives: Dict[str, int],
+ ) -> None:
"""Used to update values in the stats tables.
This is basically a slightly convoluted upsert that *adds* to any
existing rows.
Args:
- txn
- table (str): Table name
- keyvalues (dict[str, any]): Row-identifying key values
- absolutes (dict[str, any]): Absolute (set) fields
- additive_relatives (dict[str, int]): Fields that will be added onto
- if existing row present.
+ table: Table name
+ keyvalues: Row-identifying key values
+ absolutes: Absolute (set) fields
+ additive_relatives: Fields that will be added onto if existing row present.
"""
if self.database_engine.can_native_upsert:
absolute_updates = [
@@ -486,20 +502,17 @@ class StatsStore(StateDeltasStore):
current_row.update(absolutes)
self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row)
- async def _calculate_and_set_initial_state_for_room(
- self, room_id: str
- ) -> Tuple[dict, dict, int]:
+ async def _calculate_and_set_initial_state_for_room(self, room_id: str) -> None:
"""Calculate and insert an entry into room_stats_current.
Args:
room_id: The room ID under calculation.
-
- Returns:
- A tuple of room state, membership counts and stream position.
"""
- def _fetch_current_state_stats(txn):
- pos = self.get_room_max_stream_ordering()
+ def _fetch_current_state_stats(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[str], Dict[str, int], int, List[str], int]:
+ pos = self.get_room_max_stream_ordering() # type: ignore[attr-defined]
rows = self.db_pool.simple_select_many_txn(
txn,
@@ -519,7 +532,7 @@ class StatsStore(StateDeltasStore):
retcols=["event_id"],
)
- event_ids = [row["event_id"] for row in rows]
+ event_ids = cast(List[str], [row["event_id"] for row in rows])
txn.execute(
"""
@@ -533,15 +546,15 @@ class StatsStore(StateDeltasStore):
txn.execute(
"""
- SELECT COALESCE(count(*), 0) FROM current_state_events
+ SELECT COUNT(*) FROM current_state_events
WHERE room_id = ?
""",
(room_id,),
)
- (current_state_events_count,) = txn.fetchone()
+ current_state_events_count = cast(Tuple[int], txn.fetchone())[0]
- users_in_room = self.get_users_in_room_txn(txn, room_id)
+ users_in_room = self.get_users_in_room_txn(txn, room_id) # type: ignore[attr-defined]
return (
event_ids,
@@ -561,7 +574,7 @@ class StatsStore(StateDeltasStore):
"get_initial_state_for_room", _fetch_current_state_stats
)
- state_event_map = await self.get_events(event_ids, get_prev_content=False)
+ state_event_map = await self.get_events(event_ids, get_prev_content=False) # type: ignore[attr-defined]
room_state = {
"join_rules": None,
@@ -617,8 +630,10 @@ class StatsStore(StateDeltasStore):
},
)
- async def _calculate_and_set_initial_state_for_user(self, user_id):
- def _calculate_and_set_initial_state_for_user_txn(txn):
+ async def _calculate_and_set_initial_state_for_user(self, user_id: str) -> None:
+ def _calculate_and_set_initial_state_for_user_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[int, int]:
pos = self._get_max_stream_id_in_current_state_deltas_txn(txn)
txn.execute(
@@ -629,7 +644,7 @@ class StatsStore(StateDeltasStore):
""",
(user_id,),
)
- (count,) = txn.fetchone()
+ count = cast(Tuple[int], txn.fetchone())[0]
return count, pos
joined_rooms, pos = await self.db_pool.runInteraction(
@@ -673,7 +688,9 @@ class StatsStore(StateDeltasStore):
users that exist given this query
"""
- def get_users_media_usage_paginate_txn(txn):
+ def get_users_media_usage_paginate_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[JsonDict], int]:
filters = []
args = [self.hs.config.server.server_name]
@@ -728,7 +745,7 @@ class StatsStore(StateDeltasStore):
sql_base=sql_base,
)
txn.execute(sql, args)
- count = txn.fetchone()[0]
+ count = cast(Tuple[int], txn.fetchone())[0]
sql = """
SELECT
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 57aab55259..319464b1fa 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -34,11 +34,11 @@ what sort order was used:
- topological tokems: "t%d-%d", where the integers map to the topological
and stream ordering columns respectively.
"""
-import abc
+
import logging
-from collections import namedtuple
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple
+import attr
from frozendict import frozendict
from twisted.internet import defer
@@ -49,6 +49,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
+ LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
)
@@ -73,9 +74,11 @@ _TOPOLOGICAL_TOKEN = "topological"
# Used as return values for pagination APIs
-_EventDictReturn = namedtuple(
- "_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering")
-)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventDictReturn:
+ event_id: str
+ topological_ordering: Optional[int]
+ stream_ordering: int
def generate_pagination_where_clause(
@@ -333,13 +336,13 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
return " AND ".join(clauses), args
-class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
- """This is an abstract base class where subclasses must implement
- `get_room_max_stream_ordering` and `get_room_min_stream_ordering`
- which can be called in the initializer.
- """
-
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
@@ -371,13 +374,22 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
self._stream_order_on_start = self.get_room_max_stream_ordering()
- @abc.abstractmethod
def get_room_max_stream_ordering(self) -> int:
- raise NotImplementedError()
+ """Get the stream_ordering of regular events that we have committed up to
+
+ Returns the maximum stream id such that all stream ids less than or
+ equal to it have been successfully persisted.
+ """
+ return self._stream_id_gen.get_current_token()
- @abc.abstractmethod
def get_room_min_stream_ordering(self) -> int:
- raise NotImplementedError()
+ """Get the stream_ordering of backfilled events that we have committed up to
+
+ Backfilled events use *negative* stream orderings, so this returns the
+ minimum negative stream id such that all stream ids greater than or
+ equal to it have been successfully persisted.
+ """
+ return self._backfill_id_gen.get_current_token()
def get_room_max_token(self) -> RoomStreamToken:
"""Get a `RoomStreamToken` that marks the current maximum persisted
@@ -819,7 +831,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
for event, row in zip(events, rows):
stream = row.stream_ordering
if topo_order and row.topological_ordering:
- topo = row.topological_ordering
+ topo: Optional[int] = row.topological_ordering
else:
topo = None
internal = event.internal_metadata
@@ -1343,11 +1355,3 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
retcol="instance_name",
desc="get_name_from_instance_id",
)
-
-
-class StreamStore(StreamWorkerStore):
- def get_room_max_stream_ordering(self) -> int:
- return self._stream_id_gen.get_current_token()
-
- def get_room_min_stream_ordering(self) -> int:
- return self._backfill_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 8f510de53d..c8e508a910 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -15,11 +15,13 @@
# limitations under the License.
import logging
-from typing import Dict, List, Tuple, cast
+from typing import Any, Dict, Iterable, List, Tuple, cast
+from synapse.replication.tcp.streams import TagAccountDataStream
from synapse.storage._base import db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
+from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -204,6 +206,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
The next account data ID.
"""
assert self._can_write_to_account_data
+ assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
content_json = json_encoder.encode(content)
@@ -230,6 +233,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
The next account data ID.
"""
assert self._can_write_to_account_data
+ assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None:
sql = (
@@ -258,6 +262,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
next_id: The the revision to advance to.
"""
assert self._can_write_to_account_data
+ assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
txn.call_after(
self._account_data_stream_cache.entity_has_changed, user_id, next_id
@@ -287,6 +292,21 @@ class TagsWorkerStore(AccountDataWorkerStore):
# than the id that the client has.
pass
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ rows: Iterable[Any],
+ ) -> None:
+ if stream_name == TagAccountDataStream.NAME:
+ self._account_data_id_gen.advance(instance_name, token)
+ for row in rows:
+ self.get_tags_for_user.invalidate((row.user_id,))
+ self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+
+ super().process_replication_rows(stream_name, instance_name, token, rows)
+
class TagsStore(TagsWorkerStore):
pass
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 1622822552..6c299cafa5 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -13,16 +13,19 @@
# limitations under the License.
import logging
-from collections import namedtuple
from enum import Enum
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast
import attr
from canonicaljson import encode_canonical_json
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import db_to_json
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
@@ -35,16 +38,6 @@ db_binary_type = memoryview
logger = logging.getLogger(__name__)
-_TransactionRow = namedtuple(
- "_TransactionRow",
- ("id", "transaction_id", "destination", "ts", "response_code", "response_json"),
-)
-
-_UpdateTransactionRow = namedtuple(
- "_TransactionRow", ("response_code", "response_json")
-)
-
-
class DestinationSortOrder(Enum):
"""Enum to define the sorting method used when returning destinations."""
@@ -71,7 +64,12 @@ class DestinationRetryTimings:
class TransactionWorkerStore(CacheInvalidationWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks:
@@ -82,7 +80,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
now = self._clock.time_msec()
month_ago = now - 30 * 24 * 60 * 60 * 1000
- def _cleanup_transactions_txn(txn):
+ def _cleanup_transactions_txn(txn: LoggingTransaction) -> None:
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
await self.db_pool.runInteraction(
@@ -112,7 +110,9 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
origin,
)
- def _get_received_txn_response(self, txn, transaction_id, origin):
+ def _get_received_txn_response(
+ self, txn: LoggingTransaction, transaction_id: str, origin: str
+ ) -> Optional[Tuple[int, JsonDict]]:
result = self.db_pool.simple_select_one_txn(
txn,
table="received_transactions",
@@ -187,7 +187,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
return result
def _get_destination_retry_timings(
- self, txn, destination: str
+ self, txn: LoggingTransaction, destination: str
) -> Optional[DestinationRetryTimings]:
result = self.db_pool.simple_select_one_txn(
txn,
@@ -222,7 +222,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
"""
if self.database_engine.can_native_upsert:
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings_native,
destination,
@@ -232,7 +232,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
db_autocommit=True, # Safe as its a single upsert
)
else:
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings_emulated,
destination,
@@ -242,8 +242,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
)
def _set_destination_retry_timings_native(
- self, txn, destination, failure_ts, retry_last_ts, retry_interval
- ):
+ self,
+ txn: LoggingTransaction,
+ destination: str,
+ failure_ts: Optional[int],
+ retry_last_ts: int,
+ retry_interval: int,
+ ) -> None:
assert self.database_engine.can_native_upsert
# Upsert retry time interval if retry_interval is zero (i.e. we're
@@ -273,8 +278,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
)
def _set_destination_retry_timings_emulated(
- self, txn, destination, failure_ts, retry_last_ts, retry_interval
- ):
+ self,
+ txn: LoggingTransaction,
+ destination: str,
+ failure_ts: Optional[int],
+ retry_last_ts: int,
+ retry_interval: int,
+ ) -> None:
self.database_engine.lock_table(txn, "destinations")
# We need to be careful here as the data may have changed from under us
@@ -384,7 +394,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
last_successful_stream_ordering: the stream_ordering of the most
recent successfully-sent PDU
"""
- return await self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
"destinations",
keyvalues={"destination": destination},
values={"last_successful_stream_ordering": last_successful_stream_ordering},
@@ -525,7 +535,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
else:
order = "ASC"
- args = []
+ args: List[object] = []
where_statement = ""
if destination:
args.extend(["%" + destination.lower() + "%"])
@@ -534,7 +544,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
sql_base = f"FROM destinations {where_statement} "
sql = f"SELECT COUNT(*) as total_destinations {sql_base}"
txn.execute(sql, args)
- count = txn.fetchone()[0]
+ count = cast(Tuple[int], txn.fetchone())[0]
sql = f"""
SELECT destination, retry_last_ts, retry_interval, failure_ts,
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 340ca9e47d..a1a1a6a14a 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union, cast
import attr
@@ -225,11 +225,14 @@ class UIAuthWorkerStore(SQLBaseStore):
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
):
# Get the current value.
- result: Dict[str, Any] = self.db_pool.simple_select_one_txn( # type: ignore
- txn,
- table="ui_auth_sessions",
- keyvalues={"session_id": session_id},
- retcols=("serverdict",),
+ result = cast(
+ Dict[str, Any],
+ self.db_pool.simple_select_one_txn(
+ txn,
+ table="ui_auth_sessions",
+ keyvalues={"session_id": session_id},
+ retcols=("serverdict",),
+ ),
)
# Update it and add it back to the database.
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index e98a45b6af..0f9b8575d3 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -32,11 +32,14 @@ if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.storage.types import Connection
from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
from synapse.util.caches.descriptors import cached
@@ -53,7 +56,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
def __init__(
self,
database: DatabasePool,
- db_conn: Connection,
+ db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
@@ -592,7 +595,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
def __init__(
self,
database: DatabasePool,
- db_conn: Connection,
+ db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
) -> None:
super().__init__(database, db_conn, hs)
|