diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 20b755056b..cfe887b7f7 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -13,33 +13,49 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Generic, List, Optional, Type, TypeVar
+from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_conn
from synapse.storage.databases.main.events import PersistEventsStore
from synapse.storage.databases.state import StateGroupDataStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
-class Databases:
+DataStoreT = TypeVar("DataStoreT", bound=SQLBaseStore, covariant=True)
+
+
+class Databases(Generic[DataStoreT]):
"""The various databases.
These are low level interfaces to physical databases.
Attributes:
- main (DataStore)
+ databases
+ main
+ state
+ persist_events
"""
- def __init__(self, main_store_class, hs):
+ databases: List[DatabasePool]
+ main: DataStoreT
+ state: StateGroupDataStore
+ persist_events: Optional[PersistEventsStore]
+
+ def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"):
# Note we pass in the main store class here as workers use a different main
# store.
self.databases = []
- main = None
- state = None
- persist_events = None
+ main: Optional[DataStoreT] = None
+ state: Optional[StateGroupDataStore] = None
+ persist_events: Optional[PersistEventsStore] = None
for database_config in hs.config.database.databases:
db_name = database_config.name
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 5c21402dea..259cae5b37 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import List, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import DatabasePool
@@ -75,6 +75,9 @@ from .ui_auth import UIAuthStore
from .user_directory import UserDirectoryStore
from .user_erasure_store import UserErasureStore
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -126,7 +129,7 @@ class DataStore(
LockStore,
SessionStore,
):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 70ca3e09f7..f8bec266ac 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Dict, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from synapse.api.constants import AccountDataTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@@ -28,6 +28,9 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -36,7 +39,7 @@ class AccountDataWorkerStore(SQLBaseStore):
`get_max_account_data_stream_id` which can be called in the initializer.
"""
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
self._instance_name = hs.get_instance_name()
if isinstance(database.engine, PostgresEngine):
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index c57ae5ef15..36e8422fc6 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -15,7 +15,7 @@
import itertools
import logging
-from typing import Any, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.replication.tcp.streams import BackfillStream, CachesStream
@@ -29,6 +29,9 @@ from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -38,7 +41,7 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 72af4c1fc3..a6fd9f2636 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -13,14 +13,26 @@
# limitations under the License.
import logging
-from typing import Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast
+
+from typing_extensions import TypedDict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
-from synapse.types import UserID
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+ make_tuple_comparison_clause,
+)
+from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
+from synapse.storage.types import Connection
+from synapse.types import JsonDict, UserID
from synapse.util.caches.lrucache import LruCache
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
@@ -29,8 +41,31 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 10 * 60 * 1000
+class DeviceLastConnectionInfo(TypedDict):
+ """Metadata for the last connection seen for a user and device combination"""
+
+ # These types must match the columns in the `devices` table
+ user_id: str
+ device_id: str
+
+ ip: Optional[str]
+ user_agent: Optional[str]
+ last_seen: Optional[int]
+
+
+class LastConnectionInfo(TypedDict):
+ """Metadata for the last connection seen for an access token and IP combination"""
+
+ # These types must match the columns in the `user_ips` table
+ access_token: str
+ ip: str
+
+ user_agent: str
+ last_seen: int
+
+
class ClientIpBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@@ -81,8 +116,10 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
"devices_last_seen", self._devices_last_seen_update
)
- async def _remove_user_ip_nonunique(self, progress, batch_size):
- def f(conn):
+ async def _remove_user_ip_nonunique(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ def f(conn: LoggingDatabaseConnection) -> None:
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
txn.close()
@@ -93,14 +130,14 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
)
return 1
- async def _analyze_user_ip(self, progress, batch_size):
+ async def _analyze_user_ip(self, progress: JsonDict, batch_size: int) -> int:
# Background update to analyze user_ips table before we run the
# deduplication background update. The table may not have been analyzed
# for ages due to the table locks.
#
# This will lock out the naive upserts to user_ips while it happens, but
# the analyze should be quick (28GB table takes ~10s)
- def user_ips_analyze(txn):
+ def user_ips_analyze(txn: LoggingTransaction) -> None:
txn.execute("ANALYZE user_ips")
await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)
@@ -109,16 +146,16 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return 1
- async def _remove_user_ip_dupes(self, progress, batch_size):
+ async def _remove_user_ip_dupes(self, progress: JsonDict, batch_size: int) -> int:
# This works function works by scanning the user_ips table in batches
# based on `last_seen`. For each row in a batch it searches the rest of
# the table to see if there are any duplicates, if there are then they
# are removed and replaced with a suitable row.
# Fetch the start of the batch
- begin_last_seen = progress.get("last_seen", 0)
+ begin_last_seen: int = progress.get("last_seen", 0)
- def get_last_seen(txn):
+ def get_last_seen(txn: LoggingTransaction) -> Optional[int]:
txn.execute(
"""
SELECT last_seen FROM user_ips
@@ -129,7 +166,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
""",
(begin_last_seen, batch_size),
)
- row = txn.fetchone()
+ row = cast(Optional[Tuple[int]], txn.fetchone())
if row:
return row[0]
else:
@@ -149,7 +186,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
end_last_seen,
)
- def remove(txn):
+ def remove(txn: LoggingTransaction) -> None:
# This works by looking at all entries in the given time span, and
# then for each (user_id, access_token, ip) tuple in that range
# checking for any duplicates in the rest of the table (via a join).
@@ -161,10 +198,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
# Define the search space, which requires handling the last batch in
# a different way
+ args: Tuple[int, ...]
if last:
clause = "? <= last_seen"
args = (begin_last_seen,)
else:
+ assert end_last_seen is not None
clause = "? <= last_seen AND last_seen < ?"
args = (begin_last_seen, end_last_seen)
@@ -189,7 +228,9 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
),
args,
)
- res = txn.fetchall()
+ res = cast(
+ List[Tuple[str, str, str, Optional[str], str, int, int]], txn.fetchall()
+ )
# We've got some duplicates
for i in res:
@@ -278,13 +319,15 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return batch_size
- async def _devices_last_seen_update(self, progress, batch_size):
+ async def _devices_last_seen_update(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""Background update to insert last seen info into devices table"""
- last_user_id = progress.get("last_user_id", "")
- last_device_id = progress.get("last_device_id", "")
+ last_user_id: str = progress.get("last_user_id", "")
+ last_device_id: str = progress.get("last_device_id", "")
- def _devices_last_seen_update_txn(txn):
+ def _devices_last_seen_update_txn(txn: LoggingTransaction) -> int:
# This consists of two queries:
#
# 1. The sub-query searches for the next N devices and joins
@@ -296,6 +339,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
# we'll just end up updating the same device row multiple
# times, which is fine.
+ where_args: List[Union[str, int]]
where_clause, where_args = make_tuple_comparison_clause(
[("user_id", last_user_id), ("device_id", last_device_id)],
)
@@ -319,7 +363,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
}
txn.execute(sql, where_args + [batch_size])
- rows = txn.fetchall()
+ rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
if not rows:
return 0
@@ -350,7 +394,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.user_ips_max_age = hs.config.server.user_ips_max_age
@@ -359,7 +403,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
@wrap_as_background_process("prune_old_user_ips")
- async def _prune_old_user_ips(self):
+ async def _prune_old_user_ips(self) -> None:
"""Removes entries in user IPs older than the configured period."""
if self.user_ips_max_age is None:
@@ -394,9 +438,9 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
)
"""
- timestamp = self.clock.time_msec() - self.user_ips_max_age
+ timestamp = self._clock.time_msec() - self.user_ips_max_age
- def _prune_old_user_ips_txn(txn):
+ def _prune_old_user_ips_txn(txn: LoggingTransaction) -> None:
txn.execute(sql, (timestamp,))
await self.db_pool.runInteraction(
@@ -405,7 +449,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
async def get_last_client_ip_by_device(
self, user_id: str, device_id: Optional[str]
- ) -> Dict[Tuple[str, str], dict]:
+ ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
"""For each device_id listed, give the user_ip it was last seen on.
The result might be slightly out of date as client IPs are inserted in batches.
@@ -423,26 +467,84 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
if device_id is not None:
keyvalues["device_id"] = device_id
- res = await self.db_pool.simple_select_list(
- table="devices",
- keyvalues=keyvalues,
- retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+ res = cast(
+ List[DeviceLastConnectionInfo],
+ await self.db_pool.simple_select_list(
+ table="devices",
+ keyvalues=keyvalues,
+ retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+ ),
)
return {(d["user_id"], d["device_id"]): d for d in res}
+ async def get_user_ip_and_agents(
+ self, user: UserID, since_ts: int = 0
+ ) -> List[LastConnectionInfo]:
+ """Fetch the IPs and user agents for a user since the given timestamp.
+
+ The result might be slightly out of date as client IPs are inserted in batches.
+
+ Args:
+ user: The user for which to fetch IP addresses and user agents.
+ since_ts: The timestamp after which to fetch IP addresses and user agents,
+ in milliseconds.
+
+ Returns:
+ A list of dictionaries, each containing:
+ * `access_token`: The access token used.
+ * `ip`: The IP address used.
+ * `user_agent`: The last user agent seen for this access token and IP
+ address combination.
+ * `last_seen`: The timestamp at which this access token and IP address
+ combination was last seen, in milliseconds.
+
+ Only the latest user agent for each access token and IP address combination
+ is available.
+ """
+ user_id = user.to_string()
+
+ def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
+ txn.execute(
+ """
+ SELECT access_token, ip, user_agent, last_seen FROM user_ips
+ WHERE last_seen >= ? AND user_id = ?
+ ORDER BY last_seen
+ DESC
+ """,
+ (since_ts, user_id),
+ )
+ return cast(List[Tuple[str, str, str, int]], txn.fetchall())
+
+ rows = await self.db_pool.runInteraction(
+ desc="get_user_ip_and_agents", func=get_recent
+ )
+
+ return [
+ {
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "last_seen": last_seen,
+ }
+ for access_token, ip, user_agent, last_seen in rows
+ ]
+
-class ClientIpStore(ClientIpWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
- self.client_ip_last_seen = LruCache(
+ # (user_id, access_token, ip,) -> last_seen
+ self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
cache_name="client_ip_last_seen", max_size=50000
)
super().__init__(database, db_conn, hs)
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
- self._batch_row_update = {}
+ self._batch_row_update: Dict[
+ Tuple[str, str, str], Tuple[str, Optional[str], int]
+ ] = {}
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
@@ -452,8 +554,14 @@ class ClientIpStore(ClientIpWorkerStore):
)
async def insert_client_ip(
- self, user_id, access_token, ip, user_agent, device_id, now=None
- ):
+ self,
+ user_id: str,
+ access_token: str,
+ ip: str,
+ user_agent: str,
+ device_id: Optional[str],
+ now: Optional[int] = None,
+ ) -> None:
if not now:
now = int(self._clock.time_msec())
key = (user_id, access_token, ip)
@@ -485,7 +593,11 @@ class ClientIpStore(ClientIpWorkerStore):
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
- def _update_client_ips_batch_txn(self, txn, to_update):
+ def _update_client_ips_batch_txn(
+ self,
+ txn: LoggingTransaction,
+ to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]],
+ ) -> None:
if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
not self.database_engine.can_native_upsert
):
@@ -525,7 +637,7 @@ class ClientIpStore(ClientIpWorkerStore):
async def get_last_client_ip_by_device(
self, user_id: str, device_id: Optional[str]
- ) -> Dict[Tuple[str, str], dict]:
+ ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
"""For each device_id listed, give the user_ip it was last seen on
Args:
@@ -561,50 +673,44 @@ class ClientIpStore(ClientIpWorkerStore):
async def get_user_ip_and_agents(
self, user: UserID, since_ts: int = 0
- ) -> List[Dict[str, Union[str, int]]]:
- """
- Fetch IP/User Agent connection since a given timestamp.
+ ) -> List[LastConnectionInfo]:
+ """Fetch the IPs and user agents for a user since the given timestamp.
+
+ Args:
+ user: The user for which to fetch IP addresses and user agents.
+ since_ts: The timestamp after which to fetch IP addresses and user agents,
+ in milliseconds.
+
+ Returns:
+ A list of dictionaries, each containing:
+ * `access_token`: The access token used.
+ * `ip`: The IP address used.
+ * `user_agent`: The last user agent seen for this access token and IP
+ address combination.
+ * `last_seen`: The timestamp at which this access token and IP address
+ combination was last seen, in milliseconds.
+
+ Only the latest user agent for each access token and IP address combination
+ is available.
"""
- user_id = user.to_string()
- results = {}
+ results: Dict[Tuple[str, str], LastConnectionInfo] = {
+ (connection["access_token"], connection["ip"]): connection
+ for connection in await super().get_user_ip_and_agents(user, since_ts)
+ }
+ # Overlay data that is pending insertion on top of the results from the
+ # database.
+ user_id = user.to_string()
for key in self._batch_row_update:
- (
- uid,
- access_token,
- ip,
- ) = key
+ uid, access_token, ip = key
if uid == user_id:
user_agent, _, last_seen = self._batch_row_update[key]
if last_seen >= since_ts:
- results[(access_token, ip)] = (user_agent, last_seen)
-
- def get_recent(txn):
- txn.execute(
- """
- SELECT access_token, ip, user_agent, last_seen FROM user_ips
- WHERE last_seen >= ? AND user_id = ?
- ORDER BY last_seen
- DESC
- """,
- (since_ts, user_id),
- )
- return txn.fetchall()
-
- rows = await self.db_pool.runInteraction(
- desc="get_user_ip_and_agents", func=get_recent
- )
+ results[(access_token, ip)] = {
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "last_seen": last_seen,
+ }
- results.update(
- ((access_token, ip), (user_agent, last_seen))
- for access_token, ip, user_agent, last_seen in rows
- )
- return [
- {
- "access_token": access_token,
- "ip": ip,
- "user_agent": user_agent,
- "last_seen": last_seen,
- }
- for (access_token, ip), (user_agent, last_seen) in results.items()
- ]
+ return list(results.values())
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 3154906d45..8143168107 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import List, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -26,11 +26,14 @@ from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class DeviceInboxWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
@@ -553,7 +556,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 6464520386..a01bf2c5b7 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -15,7 +15,17 @@
# limitations under the License.
import abc
import logging
-from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+)
from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
@@ -38,6 +48,9 @@ from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
@@ -48,7 +61,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks:
@@ -915,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore):
class DeviceBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@@ -1047,7 +1060,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 10184d6ae7..ef5d1ef01e 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -14,7 +14,7 @@
import itertools
import logging
from queue import Empty, PriorityQueue
-from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
from prometheus_client import Counter, Gauge
@@ -34,6 +34,9 @@ from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
oldest_pdu_in_federation_staging = Gauge(
"synapse_federation_server_oldest_inbound_pdu_in_staging",
"The age in seconds since we received the oldest pdu in the federation staging area",
@@ -59,7 +62,7 @@ class _NoChainCoverIndex(Exception):
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks:
@@ -906,7 +909,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
desc="get_latest_event_ids_in_room",
)
- async def get_min_depth(self, room_id: str) -> int:
+ async def get_min_depth(self, room_id: str) -> Optional[int]:
"""For the given room, get the minimum depth we have seen for it."""
return await self.db_pool.runInteraction(
"get_min_depth", self._get_min_depth_interaction, room_id
@@ -1511,7 +1514,7 @@ class EventFederationStore(EventFederationWorkerStore):
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 97b3e92d3f..d957e770dc 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import attr
@@ -23,6 +23,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -64,7 +67,7 @@ def _deserialize_action(actions, is_highlight):
class EventPushActionsWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
# These get correctly set by _find_stream_orderings_for_times_txn
@@ -892,7 +895,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 19f55c19c5..8d9086ecf0 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1710,6 +1710,7 @@ class PersistEventsStore:
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
RelationTypes.REPLACE,
+ RelationTypes.THREAD,
):
# Unknown relation type
return
@@ -1740,6 +1741,9 @@ class PersistEventsStore:
if rel_type == RelationTypes.REPLACE:
txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
+ if rel_type == RelationTypes.THREAD:
+ txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
+
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
"""Handles keeping track of insertion events and edges/connections.
Part of MSC2716.
@@ -2069,12 +2073,14 @@ class PersistEventsStore:
state_groups[event.event_id] = context.state_group
- self.db_pool.simple_insert_many_txn(
+ self.db_pool.simple_upsert_many_txn(
txn,
table="event_to_state_groups",
- values=[
- {"state_group": state_group_id, "event_id": event_id}
- for event_id, state_group_id in state_groups.items()
+ key_names=["event_id"],
+ key_values=[[event_id] for event_id, _ in state_groups.items()],
+ value_names=["state_group"],
+ value_values=[
+ [state_group_id] for _, state_group_id in state_groups.items()
],
)
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 1afc59fafb..f92d824876 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -13,19 +13,26 @@
# limitations under the License.
import logging
-from typing import Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import attr
-from synapse.api.constants import EventContentFields
+from synapse.api.constants import EventContentFields, RelationTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingTransaction,
+ make_tuple_comparison_clause,
+)
from synapse.storage.databases.main.events import PersistEventsStore
from synapse.storage.types import Cursor
from synapse.types import JsonDict
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -76,7 +83,7 @@ class _CalculateChainCover:
class EventsBackgroundUpdatesStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
@@ -164,6 +171,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self._purged_chain_cover_index,
)
+ self.db_pool.updates.register_background_update_handler(
+ "event_thread_relation", self._event_thread_relation
+ )
+
################################################################################
# bg updates for replacing stream_ordering with a BIGINT
@@ -1088,6 +1099,79 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return result
+ async def _event_thread_relation(self, progress: JsonDict, batch_size: int) -> int:
+ """Background update handler which will store thread relations for existing events."""
+ last_event_id = progress.get("last_event_id", "")
+
+ def _event_thread_relation_txn(txn: LoggingTransaction) -> int:
+ txn.execute(
+ """
+ SELECT event_id, json FROM event_json
+ LEFT JOIN event_relations USING (event_id)
+ WHERE event_id > ? AND relates_to_id IS NULL
+ ORDER BY event_id LIMIT ?
+ """,
+ (last_event_id, batch_size),
+ )
+
+ results = list(txn)
+ missing_thread_relations = []
+ for (event_id, event_json_raw) in results:
+ try:
+ event_json = db_to_json(event_json_raw)
+ except Exception as e:
+ logger.warning(
+ "Unable to load event %s (no relations will be updated): %s",
+ event_id,
+ e,
+ )
+ continue
+
+ # If there's no relation (or it is not a thread), skip!
+ relates_to = event_json["content"].get("m.relates_to")
+ if not relates_to or not isinstance(relates_to, dict):
+ continue
+ if relates_to.get("rel_type") != RelationTypes.THREAD:
+ continue
+
+ # Get the parent ID.
+ parent_id = relates_to.get("event_id")
+ if not isinstance(parent_id, str):
+ continue
+
+ missing_thread_relations.append((event_id, parent_id))
+
+ # Insert the missing data.
+ self.db_pool.simple_insert_many_txn(
+ txn=txn,
+ table="event_relations",
+ values=[
+ {
+ "event_id": event_id,
+ "relates_to_Id": parent_id,
+ "relation_type": RelationTypes.THREAD,
+ }
+ for event_id, parent_id in missing_thread_relations
+ ],
+ )
+
+ if results:
+ latest_event_id = results[-1][0]
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "event_thread_relation", {"last_event_id": latest_event_id}
+ )
+
+ return len(results)
+
+ num_rows = await self.db_pool.runInteraction(
+ desc="event_thread_relation", func=_event_thread_relation_txn
+ )
+
+ if not num_rows:
+ await self.db_pool.updates._end_background_update("event_thread_relation")
+
+ return num_rows
+
async def _background_populate_stream_ordering2(
self, progress: JsonDict, batch_size: int
) -> int:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 4a1a2f4a6a..ae37901be9 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -55,8 +55,9 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import PostgresEngine
+from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
@@ -86,6 +87,47 @@ class _EventCacheEntry:
redacted_event: Optional[EventBase]
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventRow:
+ """
+ An event, as pulled from the database.
+
+ Properties:
+ event_id: The event ID of the event.
+
+ stream_ordering: stream ordering for this event
+
+ json: json-encoded event structure
+
+ internal_metadata: json-encoded internal metadata dict
+
+ format_version: The format of the event. Hopefully one of EventFormatVersions.
+ 'None' means the event predates EventFormatVersions (so the event is format V1).
+
+ room_version_id: The version of the room which contains the event. Hopefully
+ one of RoomVersions.
+
+ Due to historical reasons, there may be a few events in the database which
+ do not have an associated room; in this case None will be returned here.
+
+ rejected_reason: if the event was rejected, the reason why.
+
+ redactions: a list of event-ids which (claim to) redact this event.
+
+ outlier: True if this event is an outlier.
+ """
+
+ event_id: str
+ stream_ordering: int
+ json: str
+ internal_metadata: str
+ format_version: Optional[int]
+ room_version_id: Optional[int]
+ rejected_reason: Optional[str]
+ redactions: List[str]
+ outlier: bool
+
+
class EventRedactBehaviour(Names):
"""
What to do when retrieving a redacted event from the database.
@@ -686,7 +728,7 @@ class EventsWorkerStore(SQLBaseStore):
for e in state_to_include.values()
]
- def _do_fetch(self, conn):
+ def _do_fetch(self, conn: Connection) -> None:
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
@@ -713,13 +755,15 @@ class EventsWorkerStore(SQLBaseStore):
self._fetch_event_list(conn, event_list)
- def _fetch_event_list(self, conn, event_list):
+ def _fetch_event_list(
+ self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
+ ) -> None:
"""Handle a load of requests from the _event_fetch_list queue
Args:
- conn (twisted.enterprise.adbapi.Connection): database connection
+ conn: database connection
- event_list (list[Tuple[list[str], Deferred]]):
+ event_list:
The fetch requests. Each entry consists of a list of event
ids to be fetched, and a deferred to be completed once the
events have been fetched.
@@ -788,7 +832,7 @@ class EventsWorkerStore(SQLBaseStore):
row = row_map.get(event_id)
fetched_events[event_id] = row
if row:
- redaction_ids.update(row["redactions"])
+ redaction_ids.update(row.redactions)
events_to_fetch = redaction_ids.difference(fetched_events.keys())
if events_to_fetch:
@@ -799,32 +843,32 @@ class EventsWorkerStore(SQLBaseStore):
for event_id, row in fetched_events.items():
if not row:
continue
- assert row["event_id"] == event_id
+ assert row.event_id == event_id
- rejected_reason = row["rejected_reason"]
+ rejected_reason = row.rejected_reason
# If the event or metadata cannot be parsed, log the error and act
# as if the event is unknown.
try:
- d = db_to_json(row["json"])
+ d = db_to_json(row.json)
except ValueError:
logger.error("Unable to parse json from event: %s", event_id)
continue
try:
- internal_metadata = db_to_json(row["internal_metadata"])
+ internal_metadata = db_to_json(row.internal_metadata)
except ValueError:
logger.error(
"Unable to parse internal_metadata from event: %s", event_id
)
continue
- format_version = row["format_version"]
+ format_version = row.format_version
if format_version is None:
# This means that we stored the event before we had the concept
# of a event format version, so it must be a V1 event.
format_version = EventFormatVersions.V1
- room_version_id = row["room_version_id"]
+ room_version_id = row.room_version_id
if not room_version_id:
# this should only happen for out-of-band membership events which
@@ -889,8 +933,8 @@ class EventsWorkerStore(SQLBaseStore):
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
- original_ev.internal_metadata.stream_ordering = row["stream_ordering"]
- original_ev.internal_metadata.outlier = row["outlier"]
+ original_ev.internal_metadata.stream_ordering = row.stream_ordering
+ original_ev.internal_metadata.outlier = row.outlier
event_map[event_id] = original_ev
@@ -898,7 +942,7 @@ class EventsWorkerStore(SQLBaseStore):
# the cache entries.
result_map = {}
for event_id, original_ev in event_map.items():
- redactions = fetched_events[event_id]["redactions"]
+ redactions = fetched_events[event_id].redactions
redacted_event = self._maybe_redact_event_row(
original_ev, redactions, event_map
)
@@ -912,17 +956,17 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
- async def _enqueue_events(self, events):
+ async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
Args:
- events (Iterable[str]): events to be fetched.
+ events: events to be fetched.
Returns:
- Dict[str, Dict]: map from event id to row data from the database.
- May contain events that weren't requested.
+ A map from event id to row data from the database. May contain events
+ that weren't requested.
"""
events_d = defer.Deferred()
@@ -949,43 +993,19 @@ class EventsWorkerStore(SQLBaseStore):
return row_map
- def _fetch_event_rows(self, txn, event_ids):
+ def _fetch_event_rows(
+ self, txn: LoggingTransaction, event_ids: Iterable[str]
+ ) -> Dict[str, _EventRow]:
"""Fetch event rows from the database
Events which are not found are omitted from the result.
- The returned per-event dicts contain the following keys:
-
- * event_id (str)
-
- * stream_ordering (int): stream ordering for this event
-
- * json (str): json-encoded event structure
-
- * internal_metadata (str): json-encoded internal metadata dict
-
- * format_version (int|None): The format of the event. Hopefully one
- of EventFormatVersions. 'None' means the event predates
- EventFormatVersions (so the event is format V1).
-
- * room_version_id (str|None): The version of the room which contains the event.
- Hopefully one of RoomVersions.
-
- Due to historical reasons, there may be a few events in the database which
- do not have an associated room; in this case None will be returned here.
-
- * rejected_reason (str|None): if the event was rejected, the reason
- why.
-
- * redactions (List[str]): a list of event-ids which (claim to) redact
- this event.
-
Args:
- txn (twisted.enterprise.adbapi.Connection):
- event_ids (Iterable[str]): event IDs to fetch
+ txn: The database transaction.
+ event_ids: event IDs to fetch
Returns:
- Dict[str, Dict]: a map from event id to event info.
+ A map from event id to event info.
"""
event_dict = {}
for evs in batch_iter(event_ids, 200):
@@ -1013,17 +1033,17 @@ class EventsWorkerStore(SQLBaseStore):
for row in txn:
event_id = row[0]
- event_dict[event_id] = {
- "event_id": event_id,
- "stream_ordering": row[1],
- "internal_metadata": row[2],
- "json": row[3],
- "format_version": row[4],
- "room_version_id": row[5],
- "rejected_reason": row[6],
- "redactions": [],
- "outlier": row[7],
- }
+ event_dict[event_id] = _EventRow(
+ event_id=event_id,
+ stream_ordering=row[1],
+ internal_metadata=row[2],
+ json=row[3],
+ format_version=row[4],
+ room_version_id=row[5],
+ rejected_reason=row[6],
+ redactions=[],
+ outlier=row[7],
+ )
# check for redactions
redactions_sql = "SELECT event_id, redacts FROM redactions WHERE "
@@ -1035,7 +1055,7 @@ class EventsWorkerStore(SQLBaseStore):
for (redacter, redacted) in txn:
d = event_dict.get(redacted)
if d:
- d["redactions"].append(redacter)
+ d.redactions.append(redacter)
return event_dict
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 2fa945d171..717487be28 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -13,11 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
-from typing import Any, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
"media_repository_drop_index_wo_method"
)
@@ -43,7 +46,7 @@ class MediaSortOrder(Enum):
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@@ -123,7 +126,7 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars"""
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index dac3d14da8..d901933ae4 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -14,7 +14,7 @@
import calendar
import logging
import time
-from typing import Dict
+from typing import TYPE_CHECKING, Dict
from synapse.metrics import GaugeBucketCollector
from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -24,6 +24,9 @@ from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
# Collect metrics on the number of forward extremities that exist.
@@ -52,7 +55,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
stats and prometheus metrics.
"""
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
# Read the extrems every 60 minutes
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index a14ac03d4b..b5284e4f67 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -12,13 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Dict, List, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.util.caches.descriptors import cached
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
# Number of msec of granularity to store the monthly_active_user timestamp
@@ -27,7 +30,7 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
class MonthlyActiveUsersWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self._clock = hs.get_clock()
self.hs = hs
@@ -209,7 +212,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self._mau_stats_only = hs.config.server.mau_stats_only
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index fc720f5947..fa782023d4 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -14,7 +14,7 @@
# limitations under the License.
import abc
import logging
-from typing import Dict, List, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from synapse.api.errors import NotFoundError, StoreError
from synapse.push.baserules import list_with_base_rules
@@ -33,6 +33,9 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -75,7 +78,7 @@ class PushRulesWorkerStore(
`get_max_push_rules_stream_id` which can be called in the initializer.
"""
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 01a4281301..c99f8aebdb 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Any, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
from twisted.internet import defer
@@ -29,11 +29,14 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class ReceiptsWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
self._instance_name = hs.get_instance_name()
if isinstance(database.engine, PostgresEngine):
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 181841ee06..6c7d6ba508 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -23,7 +23,11 @@ import attr
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Cursor
@@ -40,6 +44,13 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
logger = logging.getLogger(__name__)
+class ExternalIDReuseException(Exception):
+ """Exception if writing an external id for a user fails,
+ because this external id is given to an other user."""
+
+ pass
+
+
@attr.s(frozen=True, slots=True)
class TokenLookupResult:
"""Result of looking up an access token.
@@ -488,6 +499,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
+ async def set_user_type(self, user: UserID, user_type: Optional[UserTypes]) -> None:
+ """Sets the user type.
+
+ Args:
+ user: user ID of the user.
+ user_type: type of the user or None for a user without a type.
+ """
+
+ def set_user_type_txn(txn):
+ self.db_pool.simple_update_one_txn(
+ txn, "users", {"name": user.to_string()}, {"user_type": user_type}
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_id, (user.to_string(),)
+ )
+
+ await self.db_pool.runInteraction("set_user_type", set_user_type_txn)
+
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
sql = """
SELECT users.name as user_id,
@@ -588,24 +617,44 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
auth_provider: identifier for the remote auth provider
external_id: id on that system
user_id: complete mxid that it is mapped to
+ Raises:
+ ExternalIDReuseException if the new external_id could not be mapped.
"""
- await self.db_pool.simple_insert(
+
+ try:
+ await self.db_pool.runInteraction(
+ "record_user_external_id",
+ self._record_user_external_id_txn,
+ auth_provider,
+ external_id,
+ user_id,
+ )
+ except self.database_engine.module.IntegrityError:
+ raise ExternalIDReuseException()
+
+ def _record_user_external_id_txn(
+ self,
+ txn: LoggingTransaction,
+ auth_provider: str,
+ external_id: str,
+ user_id: str,
+ ) -> None:
+
+ self.db_pool.simple_insert_txn(
+ txn,
table="user_external_ids",
values={
"auth_provider": auth_provider,
"external_id": external_id,
"user_id": user_id,
},
- desc="record_user_external_id",
)
async def remove_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
) -> None:
"""Remove a mapping from an external user id to a mxid
-
If the mapping is not found, this method does nothing.
-
Args:
auth_provider: identifier for the remote auth provider
external_id: id on that system
@@ -621,6 +670,60 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="remove_user_external_id",
)
+ async def replace_user_external_id(
+ self,
+ record_external_ids: List[Tuple[str, str]],
+ user_id: str,
+ ) -> None:
+ """Replace mappings from external user ids to a mxid in a single transaction.
+ All mappings are deleted and the new ones are created.
+
+ Args:
+ record_external_ids:
+ List with tuple of auth_provider and external_id to record
+ user_id: complete mxid that it is mapped to
+ Raises:
+ ExternalIDReuseException if the new external_id could not be mapped.
+ """
+
+ def _remove_user_external_ids_txn(
+ txn: LoggingTransaction,
+ user_id: str,
+ ) -> None:
+ """Remove all mappings from external user ids to a mxid
+ If these mappings are not found, this method does nothing.
+
+ Args:
+ user_id: complete mxid that it is mapped to
+ """
+
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="user_external_ids",
+ keyvalues={"user_id": user_id},
+ )
+
+ def _replace_user_external_id_txn(
+ txn: LoggingTransaction,
+ ):
+ _remove_user_external_ids_txn(txn, user_id)
+
+ for auth_provider, external_id in record_external_ids:
+ self._record_user_external_id_txn(
+ txn,
+ auth_provider,
+ external_id,
+ user_id,
+ )
+
+ try:
+ await self.db_pool.runInteraction(
+ "replace_user_external_id",
+ _replace_user_external_id_txn,
+ )
+ except self.database_engine.module.IntegrityError:
+ raise ExternalIDReuseException()
+
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
) -> Optional[str]:
@@ -2237,7 +2340,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
# accident.
row = {"client_secret": None, "validated_at": None}
else:
- raise ThreepidValidationError(400, "Unknown session_id")
+ raise ThreepidValidationError("Unknown session_id")
retrieved_client_secret = row["client_secret"]
validated_at = row["validated_at"]
@@ -2252,14 +2355,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
if not row:
raise ThreepidValidationError(
- 400, "Validation token not found or has expired"
+ "Validation token not found or has expired"
)
expires = row["expires"]
next_link = row["next_link"]
if retrieved_client_secret != client_secret:
raise ThreepidValidationError(
- 400, "This client_secret does not match the provided session_id"
+ "This client_secret does not match the provided session_id"
)
# If the session is already validated, no need to revalidate
@@ -2268,7 +2371,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
if expires <= current_ts:
raise ThreepidValidationError(
- 400, "This token has expired. Please request a new one"
+ "This token has expired. Please request a new one"
)
# Looks good. Validate the session
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 2bbf6d6a95..40760fbd1b 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import Optional
+from typing import Optional, Tuple
import attr
@@ -269,6 +269,63 @@ class RelationsWorkerStore(SQLBaseStore):
return await self.get_event(edit_id, allow_none=True)
+ @cached()
+ async def get_thread_summary(
+ self, event_id: str
+ ) -> Tuple[int, Optional[EventBase]]:
+ """Get the number of threaded replies, the senders of those replies, and
+ the latest reply (if any) for the given event.
+
+ Args:
+ event_id: The original event ID
+
+ Returns:
+ The number of items in the thread and the most recent response, if any.
+ """
+
+ def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
+ # Fetch the count of threaded events and the latest event ID.
+ # TODO Should this only allow m.room.message events.
+ sql = """
+ SELECT event_id
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ relates_to_id = ?
+ AND relation_type = ?
+ ORDER BY topological_ordering DESC, stream_ordering DESC
+ LIMIT 1
+ """
+
+ txn.execute(sql, (event_id, RelationTypes.THREAD))
+ row = txn.fetchone()
+ if row is None:
+ return 0, None
+
+ latest_event_id = row[0]
+
+ sql = """
+ SELECT COALESCE(COUNT(event_id), 0)
+ FROM event_relations
+ WHERE
+ relates_to_id = ?
+ AND relation_type = ?
+ """
+ txn.execute(sql, (event_id, RelationTypes.THREAD))
+ count = txn.fetchone()[0]
+
+ return count, latest_event_id
+
+ count, latest_event_id = await self.db_pool.runInteraction(
+ "get_thread_summary", _get_thread_summary_txn
+ )
+
+ latest_event = None
+ if latest_event_id:
+ latest_event = await self.get_event(latest_event_id, allow_none=True)
+
+ return count, latest_event
+
async def has_user_annotated_event(
self, parent_id: str, event_type: str, aggregation_key: str, sender: str
) -> bool:
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index d69eaf80ce..f879bbe7c7 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -17,7 +17,7 @@ import collections
import logging
from abc import abstractmethod
from enum import Enum
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from synapse.api.constants import EventContentFields, EventTypes, JoinRules
from synapse.api.errors import StoreError
@@ -32,6 +32,9 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import MXC_REGEX
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -69,7 +72,7 @@ class RoomSortOrder(Enum):
class RoomWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.config = hs.config
@@ -679,8 +682,8 @@ class RoomWorkerStore(SQLBaseStore):
# policy.
if not ret:
return {
- "min_lifetime": self.config.server.retention_default_min_lifetime,
- "max_lifetime": self.config.server.retention_default_max_lifetime,
+ "min_lifetime": self.config.retention.retention_default_min_lifetime,
+ "max_lifetime": self.config.retention.retention_default_max_lifetime,
}
row = ret[0]
@@ -690,10 +693,10 @@ class RoomWorkerStore(SQLBaseStore):
# The default values will be None if no default policy has been defined, or if one
# of the attributes is missing from the default policy.
if row["min_lifetime"] is None:
- row["min_lifetime"] = self.config.server.retention_default_min_lifetime
+ row["min_lifetime"] = self.config.retention.retention_default_min_lifetime
if row["max_lifetime"] is None:
- row["max_lifetime"] = self.config.server.retention_default_max_lifetime
+ row["max_lifetime"] = self.config.retention.retention_default_max_lifetime
return row
@@ -1026,7 +1029,7 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
class RoomBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.config = hs.config
@@ -1411,7 +1414,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.config = hs.config
diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py
index 300a563c9e..dcbce8fdcf 100644
--- a/synapse/storage/databases/main/room_batch.py
+++ b/synapse/storage/databases/main/room_batch.py
@@ -36,3 +36,16 @@ class RoomBatchStore(SQLBaseStore):
retcol="event_id",
allow_none=True,
)
+
+ async def store_state_group_id_for_event_id(
+ self, event_id: str, state_group_id: int
+ ) -> Optional[str]:
+ {
+ await self.db_pool.simple_upsert(
+ table="event_to_state_groups",
+ keyvalues={"event_id": event_id},
+ values={"state_group": state_group_id, "event_id": event_id},
+ # Unique constraint on event_id so we don't have to lock
+ lock=False,
+ )
+ }
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index ddb162a4fc..4b288bb2e7 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -53,6 +53,7 @@ from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
from synapse.util.metrics import Measure
if TYPE_CHECKING:
+ from synapse.server import HomeServer
from synapse.state import _StateCacheEntry
logger = logging.getLogger(__name__)
@@ -63,7 +64,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
class RoomMemberWorkerStore(EventsWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
# Used by `_get_joined_hosts` to ensure only one thing mutates the cache
@@ -982,7 +983,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
class RoomMemberBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
@@ -1132,7 +1133,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
async def forget(self, user_id: str, room_id: str) -> None:
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 25df8758bd..642560a70d 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -15,7 +15,7 @@
import logging
import re
from collections import namedtuple
-from typing import Collection, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
from synapse.api.errors import SynapseError
from synapse.events import EventBase
@@ -24,6 +24,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
SearchEntry = namedtuple(
@@ -102,7 +105,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
if not hs.config.server.enable_search:
@@ -355,7 +358,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
class SearchStore(SearchBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
async def search_msgs(self, room_ids, search_term, keys):
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index a8e8dd4577..fa2c3b1feb 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -15,7 +15,7 @@
import collections.abc
import logging
from collections import namedtuple
-from typing import Iterable, Optional, Set
+from typing import TYPE_CHECKING, Iterable, Optional, Set
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@@ -30,6 +30,9 @@ from synapse.types import StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -53,7 +56,7 @@ class _GetStateGroupDelta(
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers."""
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
async def get_room_version(self, room_id: str) -> RoomVersion:
@@ -346,7 +349,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@@ -533,5 +536,5 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
* `state_groups_state`: Maps state group to state events.
"""
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index e20033bb28..5d7b59d861 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -16,7 +16,7 @@
import logging
from enum import Enum
from itertools import chain
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing_extensions import Counter
@@ -29,6 +29,9 @@ from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
# these fields track absolutes (e.g. total number of rooms on the server)
@@ -93,7 +96,7 @@ class UserSortOrder(Enum):
class StatsStore(StateDeltasStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 860146cd1b..d7dc1f73ac 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -14,7 +14,7 @@
import logging
from collections import namedtuple
-from typing import Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
@@ -26,6 +26,9 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
db_binary_type = memoryview
logger = logging.getLogger(__name__)
@@ -57,7 +60,7 @@ class DestinationRetryTimings:
class TransactionWorkerStore(CacheInvalidationWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks:
|