diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 259cae5b37..9ff2d8d8c3 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -123,9 +123,9 @@ class DataStore(
RelationsStore,
CensorEventsStore,
UIAuthStore,
+ EventForwardExtremitiesStore,
CacheInvalidationWorkerStore,
ServerMetricsStore,
- EventForwardExtremitiesStore,
LockStore,
SessionStore,
):
@@ -154,6 +154,7 @@ class DataStore(
db_conn, "local_group_updates", "stream_id"
)
+ self._cache_id_gen: Optional[MultiWriterIdGenerator]
if isinstance(self.database_engine, PostgresEngine):
# We set the `writers` to an empty list here as we don't care about
# missing updates over restarts, as we'll not have anything in our
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index baec35ee27..4a883dc166 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -143,7 +143,7 @@ class ApplicationServiceTransactionWorkerStore(
A list of ApplicationServices, which may be empty.
"""
results = await self.db_pool.simple_select_list(
- "application_services_state", {"state": state}, ["as_id"]
+ "application_services_state", {"state": state.value}, ["as_id"]
)
# NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services()
@@ -173,7 +173,7 @@ class ApplicationServiceTransactionWorkerStore(
desc="get_appservice_state",
)
if result:
- return result.get("state")
+ return ApplicationServiceState(result.get("state"))
return None
async def set_appservice_state(
@@ -186,7 +186,7 @@ class ApplicationServiceTransactionWorkerStore(
state: The connectivity state to apply.
"""
await self.db_pool.simple_upsert(
- "application_services_state", {"as_id": service.id}, {"state": state}
+ "application_services_state", {"as_id": service.id}, {"state": state.value}
)
async def create_appservice_txn(
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index eee07227ef..0f56e10220 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -13,12 +13,12 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
from synapse.events.utils import prune_event_dict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.util import json_encoder
@@ -41,7 +41,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000)
@wrap_as_background_process("_censor_redactions")
- async def _censor_redactions(self):
+ async def _censor_redactions(self) -> None:
"""Censors all redactions older than the configured period that haven't
been censored yet.
@@ -105,7 +105,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
and original_event.internal_metadata.is_redacted()
):
# Redaction was allowed
- pruned_json = json_encoder.encode(
+ pruned_json: Optional[str] = json_encoder.encode(
prune_event_dict(
original_event.room_version, original_event.get_dict()
)
@@ -116,7 +116,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
updates.append((redaction_id, event_id, pruned_json))
- def _update_censor_txn(txn):
+ def _update_censor_txn(txn: LoggingTransaction) -> None:
for redaction_id, event_id, pruned_json in updates:
if pruned_json:
self._censor_event_txn(txn, event_id, pruned_json)
@@ -130,14 +130,16 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
await self.db_pool.runInteraction("_update_censor_txn", _update_censor_txn)
- def _censor_event_txn(self, txn, event_id, pruned_json):
+ def _censor_event_txn(
+ self, txn: LoggingTransaction, event_id: str, pruned_json: str
+ ) -> None:
"""Censor an event by replacing its JSON in the event_json table with the
provided pruned JSON.
Args:
- txn (LoggingTransaction): The database transaction.
- event_id (str): The ID of the event to censor.
- pruned_json (str): The pruned JSON
+ txn: The database transaction.
+ event_id: The ID of the event to censor.
+ pruned_json: The pruned JSON
"""
self.db_pool.simple_update_one_txn(
txn,
@@ -157,7 +159,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
# Try to retrieve the event's content from the database or the event cache.
event = await self.get_event(event_id)
- def delete_expired_event_txn(txn):
+ def delete_expired_event_txn(txn: LoggingTransaction) -> None:
# Delete the expiry timestamp associated with this event from the database.
self._delete_event_expiry_txn(txn, event_id)
@@ -194,14 +196,14 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
"delete_expired_event", delete_expired_event_txn
)
- def _delete_event_expiry_txn(self, txn, event_id):
+ def _delete_event_expiry_txn(self, txn: LoggingTransaction, event_id: str) -> None:
"""Delete the expiry timestamp associated with an event ID without deleting the
actual event.
Args:
- txn (LoggingTransaction): The transaction to use to perform the deletion.
- event_id (str): The event ID to delete the associated expiry timestamp of.
+ txn: The transaction to use to perform the deletion.
+ event_id: The event ID to delete the associated expiry timestamp of.
"""
- return self.db_pool.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
)
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index ae3afdd5d2..ab8766c75b 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -1,4 +1,5 @@
# Copyright 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,9 +20,17 @@ from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdGenerator,
+ MultiWriterIdGenerator,
+ StreamIdGenerator,
+)
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
@@ -34,14 +43,21 @@ logger = logging.getLogger(__name__)
class DeviceInboxWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
- self._last_device_delete_cache = ExpiringCache(
+ self._last_device_delete_cache: ExpiringCache[
+ Tuple[str, Optional[str]], int
+ ] = ExpiringCache(
cache_name="last_device_delete_cache",
clock=self._clock,
max_len=10000,
@@ -53,14 +69,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
self._instance_name in hs.config.worker.writers.to_device
)
- self._device_inbox_id_gen = MultiWriterIdGenerator(
- db_conn=db_conn,
- db=database,
- stream_name="to_device",
- instance_name=self._instance_name,
- tables=[("device_inbox", "instance_name", "stream_id")],
- sequence_name="device_inbox_sequence",
- writers=hs.config.worker.writers.to_device,
+ self._device_inbox_id_gen: AbstractStreamIdGenerator = (
+ MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ stream_name="to_device",
+ instance_name=self._instance_name,
+ tables=[("device_inbox", "instance_name", "stream_id")],
+ sequence_name="device_inbox_sequence",
+ writers=hs.config.worker.writers.to_device,
+ )
)
else:
self._can_write_to_device = True
@@ -101,6 +119,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ToDeviceStream.NAME:
+ # If replication is happening than postgres must be being used.
+ assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
self._device_inbox_id_gen.advance(instance_name, token)
for row in rows:
if row.entity.startswith("@"):
@@ -220,11 +240,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
log_kv({"message": f"deleted {count} messages for device", "count": count})
# Update the cache, ensuring that we only ever increase the value
- last_deleted_stream_id = self._last_device_delete_cache.get(
+ updated_last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), 0
)
self._last_device_delete_cache[(user_id, device_id)] = max(
- last_deleted_stream_id, up_to_stream_id
+ updated_last_deleted_stream_id, up_to_stream_id
)
return count
@@ -432,7 +452,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
async with self._device_inbox_id_gen.get_next() as stream_id:
- now_ms = self.clock.time_msec()
+ now_ms = self._clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
)
@@ -483,7 +503,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
async with self._device_inbox_id_gen.get_next() as stream_id:
- now_ms = self.clock.time_msec()
+ now_ms = self._clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",
add_messages_txn,
@@ -579,6 +599,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox"
REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
+ REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
@@ -594,14 +615,18 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
- self.db_pool.updates.register_background_update_handler(
- self.REMOVE_DELETED_DEVICES,
- self._remove_deleted_devices_from_device_inbox,
+ # Used to be a background update that deletes all device_inboxes for deleted
+ # devices.
+ self.db_pool.updates.register_noop_background_update(
+ self.REMOVE_DELETED_DEVICES
)
+ # Used to be a background update that deletes all device_inboxes for hidden
+ # devices.
+ self.db_pool.updates.register_noop_background_update(self.REMOVE_HIDDEN_DEVICES)
self.db_pool.updates.register_background_update_handler(
- self.REMOVE_HIDDEN_DEVICES,
- self._remove_hidden_devices_from_device_inbox,
+ self.REMOVE_DEAD_DEVICES_FROM_INBOX,
+ self._remove_dead_devices_from_device_inbox,
)
async def _background_drop_index_device_inbox(self, progress, batch_size):
@@ -616,171 +641,83 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
return 1
- async def _remove_deleted_devices_from_device_inbox(
- self, progress: JsonDict, batch_size: int
+ async def _remove_dead_devices_from_device_inbox(
+ self,
+ progress: JsonDict,
+ batch_size: int,
) -> int:
- """A background update that deletes all device_inboxes for deleted devices.
-
- This should only need to be run once (when users upgrade to v1.47.0)
+ """A background update to remove devices that were either deleted or hidden from
+ the device_inbox table.
Args:
- progress: JsonDict used to store progress of this background update
- batch_size: the maximum number of rows to retrieve in a single select query
+ progress: The update's progress dict.
+ batch_size: The batch size for this update.
Returns:
- The number of deleted rows
+ The number of rows deleted.
"""
- def _remove_deleted_devices_from_device_inbox_txn(
+ def _remove_dead_devices_from_device_inbox_txn(
txn: LoggingTransaction,
- ) -> int:
- """stream_id is not unique
- we need to use an inclusive `stream_id >= ?` clause,
- since we might not have deleted all dead device messages for the stream_id
- returned from the previous query
+ ) -> Tuple[int, bool]:
- Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
- to avoid problems of deleting a large number of rows all at once
- due to a single device having lots of device messages.
- """
+ if "max_stream_id" in progress:
+ max_stream_id = progress["max_stream_id"]
+ else:
+ txn.execute("SELECT max(stream_id) FROM device_inbox")
+ # There's a type mismatch here between how we want to type the row and
+ # what fetchone says it returns, but we silence it because we know that
+ # res can't be None.
+ res: Tuple[Optional[int]] = txn.fetchone() # type: ignore[assignment]
+ if res[0] is None:
+ # this can only happen if the `device_inbox` table is empty, in which
+ # case we have no work to do.
+ return 0, True
+ else:
+ max_stream_id = res[0]
- last_stream_id = progress.get("stream_id", 0)
+ start = progress.get("stream_id", 0)
+ stop = start + batch_size
+ # delete rows in `device_inbox` which do *not* correspond to a known,
+ # unhidden device.
sql = """
- SELECT device_id, user_id, stream_id
- FROM device_inbox
+ DELETE FROM device_inbox
WHERE
- stream_id >= ?
- AND (device_id, user_id) NOT IN (
- SELECT device_id, user_id FROM devices
+ stream_id >= ? AND stream_id < ?
+ AND NOT EXISTS (
+ SELECT * FROM devices d
+ WHERE
+ d.device_id=device_inbox.device_id
+ AND d.user_id=device_inbox.user_id
+ AND NOT hidden
)
- ORDER BY stream_id
- LIMIT ?
- """
-
- txn.execute(sql, (last_stream_id, batch_size))
- rows = txn.fetchall()
-
- num_deleted = 0
- for row in rows:
- num_deleted += self.db_pool.simple_delete_txn(
- txn,
- "device_inbox",
- {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
- )
+ """
- if rows:
- # send more than stream_id to progress
- # otherwise it can happen in large deployments that
- # no change of status is visible in the log file
- # it may be that the stream_id does not change in several runs
- self.db_pool.updates._background_update_progress_txn(
- txn,
- self.REMOVE_DELETED_DEVICES,
- {
- "device_id": rows[-1][0],
- "user_id": rows[-1][1],
- "stream_id": rows[-1][2],
- },
- )
-
- return num_deleted
+ txn.execute(sql, (start, stop))
- number_deleted = await self.db_pool.runInteraction(
- "_remove_deleted_devices_from_device_inbox",
- _remove_deleted_devices_from_device_inbox_txn,
- )
-
- # The task is finished when no more lines are deleted.
- if not number_deleted:
- await self.db_pool.updates._end_background_update(
- self.REMOVE_DELETED_DEVICES
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ self.REMOVE_DEAD_DEVICES_FROM_INBOX,
+ {
+ "stream_id": stop,
+ "max_stream_id": max_stream_id,
+ },
)
- return number_deleted
-
- async def _remove_hidden_devices_from_device_inbox(
- self, progress: JsonDict, batch_size: int
- ) -> int:
- """A background update that deletes all device_inboxes for hidden devices.
-
- This should only need to be run once (when users upgrade to v1.47.0)
-
- Args:
- progress: JsonDict used to store progress of this background update
- batch_size: the maximum number of rows to retrieve in a single select query
-
- Returns:
- The number of deleted rows
- """
-
- def _remove_hidden_devices_from_device_inbox_txn(
- txn: LoggingTransaction,
- ) -> int:
- """stream_id is not unique
- we need to use an inclusive `stream_id >= ?` clause,
- since we might not have deleted all hidden device messages for the stream_id
- returned from the previous query
-
- Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
- to avoid problems of deleting a large number of rows all at once
- due to a single device having lots of device messages.
- """
-
- last_stream_id = progress.get("stream_id", 0)
-
- sql = """
- SELECT device_id, user_id, stream_id
- FROM device_inbox
- WHERE
- stream_id >= ?
- AND (device_id, user_id) IN (
- SELECT device_id, user_id FROM devices WHERE hidden = ?
- )
- ORDER BY stream_id
- LIMIT ?
- """
-
- txn.execute(sql, (last_stream_id, True, batch_size))
- rows = txn.fetchall()
-
- num_deleted = 0
- for row in rows:
- num_deleted += self.db_pool.simple_delete_txn(
- txn,
- "device_inbox",
- {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
- )
-
- if rows:
- # We don't just save the `stream_id` in progress as
- # otherwise it can happen in large deployments that
- # no change of status is visible in the log file, as
- # it may be that the stream_id does not change in several runs
- self.db_pool.updates._background_update_progress_txn(
- txn,
- self.REMOVE_HIDDEN_DEVICES,
- {
- "device_id": rows[-1][0],
- "user_id": rows[-1][1],
- "stream_id": rows[-1][2],
- },
- )
-
- return num_deleted
+ return stop > max_stream_id
- number_deleted = await self.db_pool.runInteraction(
- "_remove_hidden_devices_from_device_inbox",
- _remove_hidden_devices_from_device_inbox_txn,
+ finished = await self.db_pool.runInteraction(
+ "_remove_devices_from_device_inbox_txn",
+ _remove_dead_devices_from_device_inbox_txn,
)
- # The task is finished when no more lines are deleted.
- if not number_deleted:
+ if finished:
await self.db_pool.updates._end_background_update(
- self.REMOVE_HIDDEN_DEVICES
+ self.REMOVE_DEAD_DEVICES_FROM_INBOX,
)
- return number_deleted
+ return batch_size
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9ccc66e589..838a2a6a3d 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -139,6 +139,27 @@ class DeviceWorkerStore(SQLBaseStore):
return {d["device_id"]: d for d in devices}
+ async def get_devices_by_auth_provider_session_id(
+ self, auth_provider_id: str, auth_provider_session_id: str
+ ) -> List[Dict[str, Any]]:
+ """Retrieve the list of devices associated with a SSO IdP session ID.
+
+ Args:
+ auth_provider_id: The SSO IdP ID as defined in the server config
+ auth_provider_session_id: The session ID within the IdP
+ Returns:
+ A list of dicts containing the device_id and the user_id of each device
+ """
+ return await self.db_pool.simple_select_list(
+ table="device_auth_providers",
+ keyvalues={
+ "auth_provider_id": auth_provider_id,
+ "auth_provider_session_id": auth_provider_session_id,
+ },
+ retcols=("user_id", "device_id"),
+ desc="get_devices_by_auth_provider_session_id",
+ )
+
@trace
async def get_device_updates_by_remote(
self, destination: str, from_stream_id: int, limit: int
@@ -253,7 +274,9 @@ class DeviceWorkerStore(SQLBaseStore):
# add the updated cross-signing keys to the results list
for user_id, result in cross_signing_keys_by_user.items():
result["user_id"] = user_id
- # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ results.append(("m.signing_key_update", result))
+ # also send the unstable version
+ # FIXME: remove this when enough servers have upgraded
results.append(("org.matrix.signing_key_update", result))
return now_stream_id, results
@@ -1070,7 +1093,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def store_device(
- self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
+ self,
+ user_id: str,
+ device_id: str,
+ initial_device_display_name: Optional[str],
+ auth_provider_id: Optional[str] = None,
+ auth_provider_session_id: Optional[str] = None,
) -> bool:
"""Ensure the given device is known; add it to the store if not
@@ -1079,6 +1107,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id: id of device
initial_device_display_name: initial displayname of the device.
Ignored if device exists.
+ auth_provider_id: The SSO IdP the user used, if any.
+ auth_provider_session_id: The session ID (sid) got from a OIDC login.
Returns:
Whether the device was inserted or an existing device existed with that ID.
@@ -1115,6 +1145,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if hidden:
raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
+ if auth_provider_id and auth_provider_session_id:
+ await self.db_pool.simple_insert(
+ "device_auth_providers",
+ values={
+ "user_id": user_id,
+ "device_id": device_id,
+ "auth_provider_id": auth_provider_id,
+ "auth_provider_session_id": auth_provider_session_id,
+ },
+ desc="store_device_auth_provider",
+ )
+
self.device_id_exists_cache.set(key, True)
return inserted
except StoreError:
@@ -1168,6 +1210,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
keyvalues={"user_id": user_id},
)
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="device_auth_providers",
+ column="device_id",
+ values=device_ids,
+ keyvalues={"user_id": user_id},
+ )
+
await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 6daf8b8ffb..a3442814d7 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -13,17 +13,18 @@
# limitations under the License.
from collections import namedtuple
-from typing import Iterable, List, Optional
+from typing import Iterable, List, Optional, Tuple
from synapse.api.errors import SynapseError
-from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import LoggingTransaction
+from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.types import RoomAlias
from synapse.util.caches.descriptors import cached
RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
-class DirectoryWorkerStore(SQLBaseStore):
+class DirectoryWorkerStore(CacheInvalidationWorkerStore):
async def get_association_from_room_alias(
self, room_alias: RoomAlias
) -> Optional[RoomAliasMapping]:
@@ -91,7 +92,7 @@ class DirectoryWorkerStore(SQLBaseStore):
creator: Optional user_id of creator.
"""
- def alias_txn(txn):
+ def alias_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_insert_txn(
txn,
"room_aliases",
@@ -126,14 +127,16 @@ class DirectoryWorkerStore(SQLBaseStore):
class DirectoryStore(DirectoryWorkerStore):
- async def delete_room_alias(self, room_alias: RoomAlias) -> str:
+ async def delete_room_alias(self, room_alias: RoomAlias) -> Optional[str]:
room_id = await self.db_pool.runInteraction(
"delete_room_alias", self._delete_room_alias_txn, room_alias
)
return room_id
- def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str:
+ def _delete_room_alias_txn(
+ self, txn: LoggingTransaction, room_alias: RoomAlias
+ ) -> Optional[str]:
txn.execute(
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
(room_alias.to_string(),),
@@ -173,9 +176,9 @@ class DirectoryStore(DirectoryWorkerStore):
If None, the creator will be left unchanged.
"""
- def _update_aliases_for_room_txn(txn):
+ def _update_aliases_for_room_txn(txn: LoggingTransaction) -> None:
update_creator_sql = ""
- sql_params = (new_room_id, old_room_id)
+ sql_params: Tuple[str, ...] = (new_room_id, old_room_id)
if creator:
update_creator_sql = ", creator = ?"
sql_params = (new_room_id, creator, old_room_id)
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index a95ac34f09..b06c1dc45b 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -408,29 +408,58 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
fallback_keys: the keys to set. This is a map from key ID (which is
of the form "algorithm:id") to key data.
"""
+ await self.db_pool.runInteraction(
+ "set_e2e_fallback_keys_txn",
+ self._set_e2e_fallback_keys_txn,
+ user_id,
+ device_id,
+ fallback_keys,
+ )
+
+ await self.invalidate_cache_and_stream(
+ "get_e2e_unused_fallback_key_types", (user_id, device_id)
+ )
+
+ def _set_e2e_fallback_keys_txn(
+ self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict
+ ) -> None:
# fallback_keys will usually only have one item in it, so using a for
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
# FIXME: make sure that only one key per algorithm is uploaded
for key_id, fallback_key in fallback_keys.items():
algorithm, key_id = key_id.split(":", 1)
- await self.db_pool.simple_upsert(
- "e2e_fallback_keys_json",
+ old_key_json = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="e2e_fallback_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
},
- values={
- "key_id": key_id,
- "key_json": json_encoder.encode(fallback_key),
- "used": False,
- },
- desc="set_e2e_fallback_key",
+ retcol="key_json",
+ allow_none=True,
)
- await self.invalidate_cache_and_stream(
- "get_e2e_unused_fallback_key_types", (user_id, device_id)
- )
+ new_key_json = encode_canonical_json(fallback_key).decode("utf-8")
+
+ # If the uploaded key is the same as the current fallback key,
+ # don't do anything. This prevents marking the key as unused if it
+ # was already used.
+ if old_key_json != new_key_json:
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="e2e_fallback_keys_json",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ },
+ values={
+ "key_id": key_id,
+ "key_json": json_encoder.encode(fallback_key),
+ "used": False,
+ },
+ )
@cached(max_entries=10000)
async def get_e2e_unused_fallback_key_types(
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index c58f7dd009..df70524fef 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1672,9 +1672,9 @@ class EventFederationStore(EventFederationWorkerStore):
DELETE FROM event_auth
WHERE event_id IN (
SELECT event_id FROM events
- LEFT JOIN state_events USING (room_id, event_id)
+ LEFT JOIN state_events AS se USING (room_id, event_id)
WHERE ? <= stream_ordering AND stream_ordering < ?
- AND state_key IS null
+ AND se.state_key IS null
)
"""
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index d957e770dc..3efdd0c920 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import attr
+from typing_extensions import TypedDict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -37,6 +38,20 @@ DEFAULT_HIGHLIGHT_ACTION = [
]
+class BasePushAction(TypedDict):
+ event_id: str
+ actions: List[Union[dict, str]]
+
+
+class HttpPushAction(BasePushAction):
+ room_id: str
+ stream_ordering: int
+
+
+class EmailPushAction(HttpPushAction):
+ received_ts: Optional[int]
+
+
def _serialize_action(actions, is_highlight):
"""Custom serializer for actions. This allows us to "compress" common actions.
@@ -221,7 +236,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
min_stream_ordering: int,
max_stream_ordering: int,
limit: int = 20,
- ) -> List[dict]:
+ ) -> List[HttpPushAction]:
"""Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the httppusher.
@@ -326,7 +341,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
min_stream_ordering: int,
max_stream_ordering: int,
limit: int = 20,
- ) -> List[dict]:
+ ) -> List[EmailPushAction]:
"""Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the emailpusher
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 3790a52a89..28d9d99b46 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1,6 +1,6 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018-2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
# limitations under the License.
import itertools
import logging
-from collections import OrderedDict, namedtuple
+from collections import OrderedDict
from typing import (
TYPE_CHECKING,
Any,
@@ -41,9 +41,10 @@ from synapse.events.snapshot import EventContext # noqa: F401
from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
@@ -64,9 +65,6 @@ event_counter = Counter(
)
-_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
-
-
@attr.s(slots=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
@@ -108,23 +106,30 @@ class PersistEventsStore:
self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
- # Ideally we'd move these ID gens here, unfortunately some other ID
- # generators are chained off them so doing so is a bit of a PITA.
- self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
- self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
-
# This should only exist on instances that are configured to write
assert (
hs.get_instance_name() in hs.config.worker.writers.events
), "Can only instantiate EventsStore on master"
+ # Since we have been configured to write, we ought to have id generators,
+ # rather than id trackers.
+ assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator)
+ assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator)
+
+ # Ideally we'd move these ID gens here, unfortunately some other ID
+ # generators are chained off them so doing so is a bit of a PITA.
+ self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
+ self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
+
async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
+ *,
current_state_for_room: Dict[str, StateMap[str]],
state_delta_for_room: Dict[str, DeltaState],
new_forward_extremeties: Dict[str, List[str]],
- backfilled: bool = False,
+ use_negative_stream_ordering: bool = False,
+ inhibit_local_membership_updates: bool = False,
) -> None:
"""Persist a set of events alongside updates to the current state and
forward extremities tables.
@@ -137,7 +142,14 @@ class PersistEventsStore:
room state
new_forward_extremities: Map from room_id to list of event IDs
that are the new forward extremities of the room.
- backfilled
+ use_negative_stream_ordering: Whether to start stream_ordering on
+ the negative side and decrement. This should be set as True
+ for backfilled events because backfilled events get a negative
+ stream ordering so they don't come down incremental `/sync`.
+ inhibit_local_membership_updates: Stop the local_current_membership
+ from being updated by these events. This should be set to True
+ for backfilled events because backfilled events in the past do
+ not affect the current local state.
Returns:
Resolves when the events have been persisted
@@ -159,7 +171,7 @@ class PersistEventsStore:
#
# Note: Multiple instances of this function cannot be in flight at
# the same time for the same room.
- if backfilled:
+ if use_negative_stream_ordering:
stream_ordering_manager = self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
@@ -176,13 +188,13 @@ class PersistEventsStore:
"persist_events",
self._persist_events_txn,
events_and_contexts=events_and_contexts,
- backfilled=backfilled,
+ inhibit_local_membership_updates=inhibit_local_membership_updates,
state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties,
)
persist_event_counter.inc(len(events_and_contexts))
- if not backfilled:
+ if stream < 0:
# backfilled events have negative stream orderings, so we don't
# want to set the event_persisted_position to that.
synapse.metrics.event_persisted_position.set(
@@ -316,8 +328,9 @@ class PersistEventsStore:
def _persist_events_txn(
self,
txn: LoggingTransaction,
+ *,
events_and_contexts: List[Tuple[EventBase, EventContext]],
- backfilled: bool,
+ inhibit_local_membership_updates: bool = False,
state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
):
@@ -330,7 +343,10 @@ class PersistEventsStore:
Args:
txn
events_and_contexts: events to persist
- backfilled: True if the events were backfilled
+ inhibit_local_membership_updates: Stop the local_current_membership
+ from being updated by these events. This should be set to True
+ for backfilled events because backfilled events in the past do
+ not affect the current local state.
delete_existing True to purge existing table rows for the events
from the database. This is useful when retrying due to
IntegrityError.
@@ -363,9 +379,7 @@ class PersistEventsStore:
events_and_contexts
)
- self._update_room_depths_txn(
- txn, events_and_contexts=events_and_contexts, backfilled=backfilled
- )
+ self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts)
# _update_outliers_txn filters out any events which have already been
# persisted, and returns the filtered list.
@@ -398,7 +412,7 @@ class PersistEventsStore:
txn,
events_and_contexts=events_and_contexts,
all_events_and_contexts=all_events_and_contexts,
- backfilled=backfilled,
+ inhibit_local_membership_updates=inhibit_local_membership_updates,
)
# We call this last as it assumes we've inserted the events into
@@ -561,9 +575,9 @@ class PersistEventsStore:
# fetch their auth event info.
while missing_auth_chains:
sql = """
- SELECT event_id, events.type, state_key, chain_id, sequence_number
+ SELECT event_id, events.type, se.state_key, chain_id, sequence_number
FROM events
- INNER JOIN state_events USING (event_id)
+ INNER JOIN state_events AS se USING (event_id)
LEFT JOIN event_auth_chains USING (event_id)
WHERE
"""
@@ -1200,7 +1214,6 @@ class PersistEventsStore:
self,
txn,
events_and_contexts: List[Tuple[EventBase, EventContext]],
- backfilled: bool,
):
"""Update min_depth for each room
@@ -1208,13 +1221,18 @@ class PersistEventsStore:
txn (twisted.enterprise.adbapi.Connection): db connection
events_and_contexts (list[(EventBase, EventContext)]): events
we are persisting
- backfilled (bool): True if the events were backfilled
"""
depth_updates: Dict[str, int] = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
- if not backfilled:
+ # Then update the `stream_ordering` position to mark the latest
+ # event as the front of the room. This should not be done for
+ # backfilled events because backfilled events have negative
+ # stream_ordering and happened in the past so we know that we don't
+ # need to update the stream_ordering tip/front for the room.
+ assert event.internal_metadata.stream_ordering is not None
+ if event.internal_metadata.stream_ordering >= 0:
txn.call_after(
self.store._events_stream_cache.entity_has_changed,
event.room_id,
@@ -1427,7 +1445,12 @@ class PersistEventsStore:
return [ec for ec in events_and_contexts if ec[0] not in to_remove]
def _update_metadata_tables_txn(
- self, txn, events_and_contexts, all_events_and_contexts, backfilled
+ self,
+ txn,
+ *,
+ events_and_contexts,
+ all_events_and_contexts,
+ inhibit_local_membership_updates: bool = False,
):
"""Update all the miscellaneous tables for new events
@@ -1439,7 +1462,10 @@ class PersistEventsStore:
events that we were going to persist. This includes events
we've already persisted, etc, that wouldn't appear in
events_and_context.
- backfilled (bool): True if the events were backfilled
+ inhibit_local_membership_updates: Stop the local_current_membership
+ from being updated by these events. This should be set to True
+ for backfilled events because backfilled events in the past do
+ not affect the current local state.
"""
# Insert all the push actions into the event_push_actions table.
@@ -1513,7 +1539,7 @@ class PersistEventsStore:
for event, _ in events_and_contexts
if event.type == EventTypes.Member
],
- backfilled=backfilled,
+ inhibit_local_membership_updates=inhibit_local_membership_updates,
)
# Insert event_reference_hashes table.
@@ -1553,11 +1579,13 @@ class PersistEventsStore:
for row in rows:
event = ev_map[row["event_id"]]
if not row["rejects"] and not row["redacts"]:
- to_prefill.append(_EventCacheEntry(event=event, redacted_event=None))
+ to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
def prefill():
for cache_entry in to_prefill:
- self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
+ self.store._get_event_cache.set(
+ (cache_entry.event.event_id,), cache_entry
+ )
txn.call_after(prefill)
@@ -1638,11 +1666,22 @@ class PersistEventsStore:
txn, table="event_reference_hashes", values=vals
)
- def _store_room_members_txn(self, txn, events, backfilled):
- """Store a room member in the database."""
+ def _store_room_members_txn(
+ self, txn, events, *, inhibit_local_membership_updates: bool = False
+ ):
+ """
+ Store a room member in the database.
+ Args:
+ txn: The transaction to use.
+ events: List of events to store.
+ inhibit_local_membership_updates: Stop the local_current_membership
+ from being updated by these events. This should be set to True
+ for backfilled events because backfilled events in the past do
+ not affect the current local state.
+ """
- def str_or_none(val: Any) -> Optional[str]:
- return val if isinstance(val, str) else None
+ def non_null_str_or_none(val: Any) -> Optional[str]:
+ return val if isinstance(val, str) and "\u0000" not in val else None
self.db_pool.simple_insert_many_txn(
txn,
@@ -1654,8 +1693,10 @@ class PersistEventsStore:
"sender": event.user_id,
"room_id": event.room_id,
"membership": event.membership,
- "display_name": str_or_none(event.content.get("displayname")),
- "avatar_url": str_or_none(event.content.get("avatar_url")),
+ "display_name": non_null_str_or_none(
+ event.content.get("displayname")
+ ),
+ "avatar_url": non_null_str_or_none(event.content.get("avatar_url")),
}
for event in events
],
@@ -1680,7 +1721,7 @@ class PersistEventsStore:
# band membership", like a remote invite or a rejection of a remote invite.
if (
self.is_mine_id(event.state_key)
- and not backfilled
+ and not inhibit_local_membership_updates
and event.internal_metadata.is_outlier()
and event.internal_metadata.is_out_of_band_membership()
):
@@ -1694,34 +1735,33 @@ class PersistEventsStore:
},
)
- def _handle_event_relations(self, txn, event):
- """Handles inserting relation data during peristence of events
+ def _handle_event_relations(
+ self, txn: LoggingTransaction, event: EventBase
+ ) -> None:
+ """Handles inserting relation data during persistence of events
Args:
- txn
- event (EventBase)
+ txn: The current database transaction.
+ event: The event which might have relations.
"""
relation = event.content.get("m.relates_to")
if not relation:
# No relations
return
+ # Relations must have a type and parent event ID.
rel_type = relation.get("rel_type")
- if rel_type not in (
- RelationTypes.ANNOTATION,
- RelationTypes.REFERENCE,
- RelationTypes.REPLACE,
- RelationTypes.THREAD,
- ):
- # Unknown relation type
+ if not isinstance(rel_type, str):
return
parent_id = relation.get("event_id")
- if not parent_id:
- # Invalid relation
+ if not isinstance(parent_id, str):
return
- aggregation_key = relation.get("key")
+ # Annotations have a key field.
+ aggregation_key = None
+ if rel_type == RelationTypes.ANNOTATION:
+ aggregation_key = relation.get("key")
self.db_pool.simple_insert_txn(
txn,
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index ae3a8a63e4..c88fd35e7f 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -1,4 +1,4 @@
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -171,8 +171,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self._purged_chain_cover_index,
)
+ # The event_thread_relation background update was replaced with the
+ # event_arbitrary_relations one, which handles any relation to avoid
+ # needed to potentially crawl the entire events table in the future.
+ self.db_pool.updates.register_noop_background_update("event_thread_relation")
+
self.db_pool.updates.register_background_update_handler(
- "event_thread_relation", self._event_thread_relation
+ "event_arbitrary_relations",
+ self._event_arbitrary_relations,
)
################################################################################
@@ -1099,23 +1105,27 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return result
- async def _event_thread_relation(self, progress: JsonDict, batch_size: int) -> int:
- """Background update handler which will store thread relations for existing events."""
+ async def _event_arbitrary_relations(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """Background update handler which will store previously unknown relations for existing events."""
last_event_id = progress.get("last_event_id", "")
- def _event_thread_relation_txn(txn: LoggingTransaction) -> int:
+ def _event_arbitrary_relations_txn(txn: LoggingTransaction) -> int:
+ # Fetch events and then filter based on whether the event has a
+ # relation or not.
txn.execute(
"""
SELECT event_id, json FROM event_json
- LEFT JOIN event_relations USING (event_id)
- WHERE event_id > ? AND event_relations.event_id IS NULL
+ WHERE event_id > ?
ORDER BY event_id LIMIT ?
""",
(last_event_id, batch_size),
)
results = list(txn)
- missing_thread_relations = []
+ # (event_id, parent_id, rel_type) for each relation
+ relations_to_insert: List[Tuple[str, str, str]] = []
for (event_id, event_json_raw) in results:
try:
event_json = db_to_json(event_json_raw)
@@ -1127,48 +1137,70 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
continue
- # If there's no relation (or it is not a thread), skip!
+ # If there's no relation, skip!
relates_to = event_json["content"].get("m.relates_to")
if not relates_to or not isinstance(relates_to, dict):
continue
- if relates_to.get("rel_type") != RelationTypes.THREAD:
+
+ # If the relation type or parent event ID is not a string, skip it.
+ #
+ # Do not consider relation types that have existed for a long time,
+ # since they will already be listed in the `event_relations` table.
+ rel_type = relates_to.get("rel_type")
+ if not isinstance(rel_type, str) or rel_type in (
+ RelationTypes.ANNOTATION,
+ RelationTypes.REFERENCE,
+ RelationTypes.REPLACE,
+ ):
continue
- # Get the parent ID.
parent_id = relates_to.get("event_id")
if not isinstance(parent_id, str):
continue
- missing_thread_relations.append((event_id, parent_id))
+ relations_to_insert.append((event_id, parent_id, rel_type))
+
+ # Insert the missing data, note that we upsert here in case the event
+ # has already been processed.
+ if relations_to_insert:
+ self.db_pool.simple_upsert_many_txn(
+ txn=txn,
+ table="event_relations",
+ key_names=("event_id",),
+ key_values=[(r[0],) for r in relations_to_insert],
+ value_names=("relates_to_id", "relation_type"),
+ value_values=[r[1:] for r in relations_to_insert],
+ )
- # Insert the missing data.
- self.db_pool.simple_insert_many_txn(
- txn=txn,
- table="event_relations",
- values=[
- {
- "event_id": event_id,
- "relates_to_Id": parent_id,
- "relation_type": RelationTypes.THREAD,
- }
- for event_id, parent_id in missing_thread_relations
- ],
- )
+ # Iterate the parent IDs and invalidate caches.
+ for parent_id in {r[1] for r in relations_to_insert}:
+ cache_tuple = (parent_id,)
+ self._invalidate_cache_and_stream(
+ txn, self.get_relations_for_event, cache_tuple
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_aggregation_groups_for_event, cache_tuple
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_thread_summary, cache_tuple
+ )
if results:
latest_event_id = results[-1][0]
self.db_pool.updates._background_update_progress_txn(
- txn, "event_thread_relation", {"last_event_id": latest_event_id}
+ txn, "event_arbitrary_relations", {"last_event_id": latest_event_id}
)
return len(results)
num_rows = await self.db_pool.runInteraction(
- desc="event_thread_relation", func=_event_thread_relation_txn
+ desc="event_arbitrary_relations", func=_event_arbitrary_relations_txn
)
if not num_rows:
- await self.db_pool.updates._end_background_update("event_thread_relation")
+ await self.db_pool.updates._end_background_update(
+ "event_arbitrary_relations"
+ )
return num_rows
diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py
index 6d2688d711..68901b4335 100644
--- a/synapse/storage/databases/main/events_forward_extremities.py
+++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -13,15 +13,20 @@
# limitations under the License.
import logging
-from typing import Dict, List
+from typing import Any, Dict, List
from synapse.api.errors import SynapseError
-from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import LoggingTransaction
+from synapse.storage.databases.main import CacheInvalidationWorkerStore
+from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
logger = logging.getLogger(__name__)
-class EventForwardExtremitiesStore(SQLBaseStore):
+class EventForwardExtremitiesStore(
+ EventFederationWorkerStore,
+ CacheInvalidationWorkerStore,
+):
async def delete_forward_extremities_for_room(self, room_id: str) -> int:
"""Delete any extra forward extremities for a room.
@@ -31,7 +36,7 @@ class EventForwardExtremitiesStore(SQLBaseStore):
Returns count deleted.
"""
- def delete_forward_extremities_for_room_txn(txn):
+ def delete_forward_extremities_for_room_txn(txn: LoggingTransaction) -> int:
# First we need to get the event_id to not delete
sql = """
SELECT event_id FROM event_forward_extremities
@@ -82,10 +87,14 @@ class EventForwardExtremitiesStore(SQLBaseStore):
delete_forward_extremities_for_room_txn,
)
- async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]:
+ async def get_forward_extremities_for_room(
+ self, room_id: str
+ ) -> List[Dict[str, Any]]:
"""Get list of forward extremities for a room."""
- def get_forward_extremities_for_room_txn(txn):
+ def get_forward_extremities_for_room_txn(
+ txn: LoggingTransaction,
+ ) -> List[Dict[str, Any]]:
sql = """
SELECT event_id, state_group, depth, received_ts
FROM event_forward_extremities
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c6bf316d5b..c7b660ac5a 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -15,14 +15,18 @@
import logging
import threading
from typing import (
+ TYPE_CHECKING,
+ Any,
Collection,
Container,
Dict,
Iterable,
List,
+ NoReturn,
Optional,
Set,
Tuple,
+ cast,
overload,
)
@@ -38,6 +42,7 @@ from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
+ RoomVersion,
RoomVersions,
)
from synapse.events import EventBase, make_event_from_dict
@@ -56,10 +61,18 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdTracker,
+ MultiWriterIdGenerator,
+ StreamIdGenerator,
+)
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
@@ -69,10 +82,13 @@ from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# These values are used in the `enqueue_event` and `_fetch_loop` methods to
# control how we batch/bulk fetch events from the database.
# The values are plucked out of thing air to make initial sync run faster
# on jki.re
@@ -89,7 +105,7 @@ event_fetch_ongoing_gauge = Gauge(
@attr.s(slots=True, auto_attribs=True)
-class _EventCacheEntry:
+class EventCacheEntry:
event: EventBase
redacted_event: Optional[EventBase]
@@ -129,7 +145,7 @@ class _EventRow:
json: str
internal_metadata: str
format_version: Optional[int]
- room_version_id: Optional[int]
+ room_version_id: Optional[str]
rejected_reason: Optional[str]
redactions: List[str]
outlier: bool
@@ -153,9 +169,16 @@ class EventsWorkerStore(SQLBaseStore):
# options controlling this.
USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
+ self._stream_id_gen: AbstractStreamIdTracker
+ self._backfill_id_gen: AbstractStreamIdTracker
if isinstance(database.engine, PostgresEngine):
# If we're using Postgres than we can use `MultiWriterIdGenerator`
# regardless of whether this process writes to the streams or not.
@@ -214,7 +237,7 @@ class EventsWorkerStore(SQLBaseStore):
5 * 60 * 1000,
)
- self._get_event_cache = LruCache(
+ self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
cache_name="*getEvent*",
max_size=hs.config.caches.event_cache_size,
)
@@ -223,19 +246,21 @@ class EventsWorkerStore(SQLBaseStore):
# ID to cache entry. Note that the returned dict may not have the
# requested event in it if the event isn't in the DB.
self._current_event_fetches: Dict[
- str, ObservableDeferred[Dict[str, _EventCacheEntry]]
+ str, ObservableDeferred[Dict[str, EventCacheEntry]]
] = {}
self._event_fetch_lock = threading.Condition()
- self._event_fetch_list = []
+ self._event_fetch_list: List[
+ Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
+ ] = []
self._event_fetch_ongoing = 0
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
- def get_chain_id_txn(txn):
+ def get_chain_id_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
- return txn.fetchone()[0]
+ return cast(Tuple[int], txn.fetchone())[0]
self.event_chain_id_gen = build_sequence_generator(
db_conn,
@@ -246,7 +271,13 @@ class EventsWorkerStore(SQLBaseStore):
id_column="chain_id",
)
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ rows: Iterable[Any],
+ ) -> None:
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
@@ -280,10 +311,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
- get_prev_content: bool = False,
- allow_rejected: bool = False,
- allow_none: Literal[False] = False,
- check_room_id: Optional[str] = None,
+ get_prev_content: bool = ...,
+ allow_rejected: bool = ...,
+ allow_none: Literal[False] = ...,
+ check_room_id: Optional[str] = ...,
) -> EventBase:
...
@@ -292,10 +323,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
- get_prev_content: bool = False,
- allow_rejected: bool = False,
- allow_none: Literal[True] = False,
- check_room_id: Optional[str] = None,
+ get_prev_content: bool = ...,
+ allow_rejected: bool = ...,
+ allow_none: Literal[True] = ...,
+ check_room_id: Optional[str] = ...,
) -> Optional[EventBase]:
...
@@ -357,7 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_events(
self,
- event_ids: Iterable[str],
+ event_ids: Collection[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
@@ -544,7 +575,7 @@ class EventsWorkerStore(SQLBaseStore):
async def _get_events_from_cache_or_db(
self, event_ids: Iterable[str], allow_rejected: bool = False
- ) -> Dict[str, _EventCacheEntry]:
+ ) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@@ -578,7 +609,7 @@ class EventsWorkerStore(SQLBaseStore):
# same dict into itself N times).
already_fetching_ids: Set[str] = set()
already_fetching_deferreds: Set[
- ObservableDeferred[Dict[str, _EventCacheEntry]]
+ ObservableDeferred[Dict[str, EventCacheEntry]]
] = set()
for event_id in missing_events_ids:
@@ -601,8 +632,8 @@ class EventsWorkerStore(SQLBaseStore):
# function returning more events than requested, but that can happen
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
- Dict[str, _EventCacheEntry]
- ] = ObservableDeferred(defer.Deferred())
+ Dict[str, EventCacheEntry]
+ ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
for event_id in missing_events_ids:
self._current_event_fetches[event_id] = fetching_deferred
@@ -658,12 +689,12 @@ class EventsWorkerStore(SQLBaseStore):
return event_entry_map
- def _invalidate_get_event_cache(self, event_id):
+ def _invalidate_get_event_cache(self, event_id: str) -> None:
self._get_event_cache.invalidate((event_id,))
def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True
- ) -> Dict[str, _EventCacheEntry]:
+ ) -> Dict[str, EventCacheEntry]:
"""Fetch events from the caches.
May return rejected events.
@@ -736,38 +767,123 @@ class EventsWorkerStore(SQLBaseStore):
for e in state_to_include.values()
]
- def _do_fetch(self, conn: Connection) -> None:
+ def _maybe_start_fetch_thread(self) -> None:
+ """Starts an event fetch thread if we are not yet at the maximum number."""
+ with self._event_fetch_lock:
+ if (
+ self._event_fetch_list
+ and self._event_fetch_ongoing < EVENT_QUEUE_THREADS
+ ):
+ self._event_fetch_ongoing += 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+ # `_event_fetch_ongoing` is decremented in `_fetch_thread`.
+ should_start = True
+ else:
+ should_start = False
+
+ if should_start:
+ run_as_background_process("fetch_events", self._fetch_thread)
+
+ async def _fetch_thread(self) -> None:
+ """Services requests for events from `_event_fetch_list`."""
+ exc = None
+ try:
+ await self.db_pool.runWithConnection(self._fetch_loop)
+ except BaseException as e:
+ exc = e
+ raise
+ finally:
+ should_restart = False
+ event_fetches_to_fail = []
+ with self._event_fetch_lock:
+ self._event_fetch_ongoing -= 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+
+ # There may still be work remaining in `_event_fetch_list` if we
+ # failed, or it was added in between us deciding to exit and
+ # decrementing `_event_fetch_ongoing`.
+ if self._event_fetch_list:
+ if exc is None:
+ # We decided to exit, but then some more work was added
+ # before `_event_fetch_ongoing` was decremented.
+ # If a new event fetch thread was not started, we should
+ # restart ourselves since the remaining event fetch threads
+ # may take a while to get around to the new work.
+ #
+ # Unfortunately it is not possible to tell whether a new
+ # event fetch thread was started, so we restart
+ # unconditionally. If we are unlucky, we will end up with
+ # an idle fetch thread, but it will time out after
+ # `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds
+ # in any case.
+ #
+ # Note that multiple fetch threads may run down this path at
+ # the same time.
+ should_restart = True
+ elif isinstance(exc, Exception):
+ if self._event_fetch_ongoing == 0:
+ # We were the last remaining fetcher and failed.
+ # Fail any outstanding fetches since no one else will
+ # handle them.
+ event_fetches_to_fail = self._event_fetch_list
+ self._event_fetch_list = []
+ else:
+ # We weren't the last remaining fetcher, so another
+ # fetcher will pick up the work. This will either happen
+ # after their existing work, however long that takes,
+ # or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if
+ # they are idle.
+ pass
+ else:
+ # The exception is a `SystemExit`, `KeyboardInterrupt` or
+ # `GeneratorExit`. Don't try to do anything clever here.
+ pass
+
+ if should_restart:
+ # We exited cleanly but noticed more work.
+ self._maybe_start_fetch_thread()
+
+ if event_fetches_to_fail:
+ # We were the last remaining fetcher and failed.
+ # Fail any outstanding fetches since no one else will handle them.
+ assert exc is not None
+ with PreserveLoggingContext():
+ for _, deferred in event_fetches_to_fail:
+ deferred.errback(exc)
+
+ def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None:
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
- try:
- i = 0
- while True:
- with self._event_fetch_lock:
- event_list = self._event_fetch_list
- self._event_fetch_list = []
-
- if not event_list:
- single_threaded = self.database_engine.single_threaded
- if (
- not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
- or single_threaded
- or i > EVENT_QUEUE_ITERATIONS
- ):
- break
- else:
- self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
- i += 1
- continue
- i = 0
+ i = 0
+ while True:
+ with self._event_fetch_lock:
+ event_list = self._event_fetch_list
+ self._event_fetch_list = []
+
+ if not event_list:
+ # There are no requests waiting. If we haven't yet reached the
+ # maximum iteration limit, wait for some more requests to turn up.
+ # Otherwise, bail out.
+ single_threaded = self.database_engine.single_threaded
+ if (
+ not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+ or single_threaded
+ or i > EVENT_QUEUE_ITERATIONS
+ ):
+ return
+
+ self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+ i += 1
+ continue
+ i = 0
- self._fetch_event_list(conn, event_list)
- finally:
- self._event_fetch_ongoing -= 1
- event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+ self._fetch_event_list(conn, event_list)
def _fetch_event_list(
- self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
+ self,
+ conn: LoggingDatabaseConnection,
+ event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]],
) -> None:
"""Handle a load of requests from the _event_fetch_list queue
@@ -794,7 +910,7 @@ class EventsWorkerStore(SQLBaseStore):
)
# We only want to resolve deferreds from the main thread
- def fire():
+ def fire() -> None:
for _, d in event_list:
d.callback(row_dict)
@@ -804,18 +920,16 @@ class EventsWorkerStore(SQLBaseStore):
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
- def fire(evs, exc):
- for _, d in evs:
- if not d.called:
- with PreserveLoggingContext():
- d.errback(exc)
+ def fire_errback(exc: Exception) -> None:
+ for _, d in event_list:
+ d.errback(exc)
with PreserveLoggingContext():
- self.hs.get_reactor().callFromThread(fire, event_list, e)
+ self.hs.get_reactor().callFromThread(fire_errback, e)
async def _get_events_from_db(
- self, event_ids: Iterable[str]
- ) -> Dict[str, _EventCacheEntry]:
+ self, event_ids: Collection[str]
+ ) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the database.
May return rejected events.
@@ -831,29 +945,29 @@ class EventsWorkerStore(SQLBaseStore):
map from event id to result. May return extra events which
weren't asked for.
"""
- fetched_events = {}
+ fetched_event_ids: Set[str] = set()
+ fetched_events: Dict[str, _EventRow] = {}
events_to_fetch = event_ids
while events_to_fetch:
row_map = await self._enqueue_events(events_to_fetch)
# we need to recursively fetch any redactions of those events
- redaction_ids = set()
+ redaction_ids: Set[str] = set()
for event_id in events_to_fetch:
row = row_map.get(event_id)
- fetched_events[event_id] = row
+ fetched_event_ids.add(event_id)
if row:
+ fetched_events[event_id] = row
redaction_ids.update(row.redactions)
- events_to_fetch = redaction_ids.difference(fetched_events.keys())
+ events_to_fetch = redaction_ids.difference(fetched_event_ids)
if events_to_fetch:
logger.debug("Also fetching redaction events %s", events_to_fetch)
# build a map from event_id to EventBase
- event_map = {}
+ event_map: Dict[str, EventBase] = {}
for event_id, row in fetched_events.items():
- if not row:
- continue
assert row.event_id == event_id
rejected_reason = row.rejected_reason
@@ -881,6 +995,7 @@ class EventsWorkerStore(SQLBaseStore):
room_version_id = row.room_version_id
+ room_version: Optional[RoomVersion]
if not room_version_id:
# this should only happen for out-of-band membership events which
# arrived before #6983 landed. For all other events, we should have
@@ -951,14 +1066,14 @@ class EventsWorkerStore(SQLBaseStore):
# finally, we can decide whether each one needs redacting, and build
# the cache entries.
- result_map = {}
+ result_map: Dict[str, EventCacheEntry] = {}
for event_id, original_ev in event_map.items():
redactions = fetched_events[event_id].redactions
redacted_event = self._maybe_redact_event_row(
original_ev, redactions, event_map
)
- cache_entry = _EventCacheEntry(
+ cache_entry = EventCacheEntry(
event=original_ev, redacted_event=redacted_event
)
@@ -967,7 +1082,7 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
- async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
+ async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@@ -980,23 +1095,12 @@ class EventsWorkerStore(SQLBaseStore):
that weren't requested.
"""
- events_d = defer.Deferred()
+ events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append((events, events_d))
-
self._event_fetch_lock.notify()
- if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
- self._event_fetch_ongoing += 1
- event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
- should_start = True
- else:
- should_start = False
-
- if should_start:
- run_as_background_process(
- "fetch_events", self.db_pool.runWithConnection, self._do_fetch
- )
+ self._maybe_start_fetch_thread()
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
@@ -1146,7 +1250,7 @@ class EventsWorkerStore(SQLBaseStore):
# no valid redaction found for this event
return None
- async def have_events_in_timeline(self, event_ids):
+ async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]:
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
@@ -1175,7 +1279,7 @@ class EventsWorkerStore(SQLBaseStore):
event_ids: events we are looking for
Returns:
- set[str]: The events we have already seen.
+ The set of events we have already seen.
"""
res = await self._have_seen_events_dict(
(room_id, event_id) for event_id in event_ids
@@ -1198,7 +1302,9 @@ class EventsWorkerStore(SQLBaseStore):
}
results = {x: True for x in cache_results}
- def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
+ def have_seen_events_txn(
+ txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
+ ) -> None:
# we deliberately do *not* query the database for room_id, to make the
# query an index-only lookup on `events_event_id_key`.
#
@@ -1224,12 +1330,14 @@ class EventsWorkerStore(SQLBaseStore):
return results
@cached(max_entries=100000, tree=True)
- async def have_seen_event(self, room_id: str, event_id: str):
+ async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
# this only exists for the benefit of the @cachedList descriptor on
# _have_seen_events_dict
raise NotImplementedError()
- def _get_current_state_event_counts_txn(self, txn, room_id):
+ def _get_current_state_event_counts_txn(
+ self, txn: LoggingTransaction, room_id: str
+ ) -> int:
"""
See get_current_state_event_counts.
"""
@@ -1254,7 +1362,7 @@ class EventsWorkerStore(SQLBaseStore):
room_id,
)
- async def get_room_complexity(self, room_id):
+ async def get_room_complexity(self, room_id: str) -> Dict[str, float]:
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
@@ -1262,10 +1370,10 @@ class EventsWorkerStore(SQLBaseStore):
more resources.
Args:
- room_id (str)
+ room_id: The room ID to query.
Returns:
- dict[str:int] of complexity version to complexity.
+ dict[str:float] of complexity version to complexity.
"""
state_events = await self.get_current_state_event_counts(room_id)
@@ -1275,13 +1383,13 @@ class EventsWorkerStore(SQLBaseStore):
return {"v1": complexity_v1}
- def get_current_events_token(self):
+ def get_current_events_token(self) -> int:
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
async def get_all_new_forward_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> List[Tuple]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns new events, for the Events replication stream
Args:
@@ -1295,13 +1403,15 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
- def get_all_new_forward_event_rows(txn):
+ def get_all_new_forward_event_rows(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+ " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" LEFT JOIN room_memberships USING (event_id)"
" LEFT JOIN rejections USING (event_id)"
@@ -1311,7 +1421,9 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (last_id, current_id, instance_name, limit))
- return txn.fetchall()
+ return cast(
+ List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+ )
return await self.db_pool.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
@@ -1319,7 +1431,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_ex_outlier_stream_rows(
self, instance_name: str, last_id: int, current_id: int
- ) -> List[Tuple]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns de-outliered events, for the Events replication stream
Args:
@@ -1332,14 +1444,16 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
- def get_ex_outlier_stream_rows_txn(txn):
+ def get_ex_outlier_stream_rows_txn(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+ " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" LEFT JOIN room_memberships USING (event_id)"
" LEFT JOIN rejections USING (event_id)"
@@ -1350,7 +1464,9 @@ class EventsWorkerStore(SQLBaseStore):
)
txn.execute(sql, (last_id, current_id, instance_name))
- return txn.fetchall()
+ return cast(
+ List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+ )
return await self.db_pool.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
@@ -1358,7 +1474,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_new_backfill_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> Tuple[List[Tuple[int, list]], int, bool]:
+ ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
"""Get updates for backfill replication stream, including all new
backfilled events and events that have gone from being outliers to not.
@@ -1386,13 +1502,15 @@ class EventsWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
- def get_all_new_backfill_event_rows(txn):
+ def get_all_new_backfill_event_rows(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
+ " se.state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" AND instance_name = ?"
@@ -1400,7 +1518,15 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (-last_id, -current_id, instance_name, limit))
- new_event_updates = [(row[0], row[1:]) for row in txn]
+ new_event_updates: List[
+ Tuple[int, Tuple[str, str, str, str, str, str]]
+ ] = []
+ row: Tuple[int, str, str, str, str, str, str]
+ # Type safety: iterating over `txn` yields `Tuple`, i.e.
+ # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+ # variadic tuple to a fixed length tuple and flags it up as an error.
+ for row in txn: # type: ignore[assignment]
+ new_event_updates.append((row[0], row[1:]))
limited = False
if len(new_event_updates) == limit:
@@ -1411,11 +1537,11 @@ class EventsWorkerStore(SQLBaseStore):
sql = (
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
+ " se.state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
@@ -1423,7 +1549,11 @@ class EventsWorkerStore(SQLBaseStore):
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_id, -upper_bound, instance_name))
- new_event_updates.extend((row[0], row[1:]) for row in txn)
+ # Type safety: iterating over `txn` yields `Tuple`, i.e.
+ # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+ # variadic tuple to a fixed length tuple and flags it up as an error.
+ for row in txn: # type: ignore[assignment]
+ new_event_updates.append((row[0], row[1:]))
if len(new_event_updates) >= limit:
upper_bound = new_event_updates[-1][0]
@@ -1437,7 +1567,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_updated_current_state_deltas(
self, instance_name: str, from_token: int, to_token: int, target_row_count: int
- ) -> Tuple[List[Tuple], int, bool]:
+ ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]:
"""Fetch updates from current_state_delta_stream
Args:
@@ -1457,7 +1587,9 @@ class EventsWorkerStore(SQLBaseStore):
* `limited` is whether there are more updates to fetch.
"""
- def get_all_updated_current_state_deltas_txn(txn):
+ def get_all_updated_current_state_deltas_txn(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str, str]]:
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
@@ -1466,21 +1598,23 @@ class EventsWorkerStore(SQLBaseStore):
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
- return txn.fetchall()
+ return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
- def get_deltas_for_stream_id_txn(txn, stream_id):
+ def get_deltas_for_stream_id_txn(
+ txn: LoggingTransaction, stream_id: int
+ ) -> List[Tuple[int, str, str, str, str]]:
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE stream_id = ?
"""
txn.execute(sql, [stream_id])
- return txn.fetchall()
+ return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
# we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id.
- rows: List[Tuple] = await self.db_pool.runInteraction(
+ rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)
@@ -1509,14 +1643,14 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
- async def is_event_after(self, event_id1, event_id2):
+ async def is_event_after(self, event_id1: str, event_id2: str) -> bool:
"""Returns True if event_id1 is after event_id2 in the stream"""
to_1, so_1 = await self.get_event_ordering(event_id1)
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@cached(max_entries=5000)
- async def get_event_ordering(self, event_id):
+ async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
res = await self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
@@ -1539,7 +1673,9 @@ class EventsWorkerStore(SQLBaseStore):
None otherwise.
"""
- def get_next_event_to_expire_txn(txn):
+ def get_next_event_to_expire_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Tuple[str, int]]:
txn.execute(
"""
SELECT event_id, expiry_ts FROM event_expiry
@@ -1547,7 +1683,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
)
- return txn.fetchone()
+ return cast(Optional[Tuple[str, int]], txn.fetchone())
return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
@@ -1611,10 +1747,10 @@ class EventsWorkerStore(SQLBaseStore):
return mapping
@wrap_as_background_process("_cleanup_old_transaction_ids")
- async def _cleanup_old_transaction_ids(self):
+ async def _cleanup_old_transaction_ids(self) -> None:
"""Cleans out transaction id mappings older than 24hrs."""
- def _cleanup_old_transaction_ids_txn(txn):
+ def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
sql = """
DELETE FROM event_txn_id
WHERE inserted_ts < ?
@@ -1626,3 +1762,198 @@ class EventsWorkerStore(SQLBaseStore):
"_cleanup_old_transaction_ids",
_cleanup_old_transaction_ids_txn,
)
+
+ async def is_event_next_to_backward_gap(self, event: EventBase) -> bool:
+ """Check if the given event is next to a backward gap of missing events.
+ <latest messages> A(False)--->B(False)--->C(True)---> <gap, unknown events> <oldest messages>
+
+ Args:
+ room_id: room where the event lives
+ event_id: event to check
+
+ Returns:
+ Boolean indicating whether it's an extremity
+ """
+
+ def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool:
+ # If the event in question has any of its prev_events listed as a
+ # backward extremity, it's next to a gap.
+ #
+ # We can't just check the backward edges in `event_edges` because
+ # when we persist events, we will also record the prev_events as
+ # edges to the event in question regardless of whether we have those
+ # prev_events yet. We need to check whether those prev_events are
+ # backward extremities, also known as gaps, that need to be
+ # backfilled.
+ backward_extremity_query = """
+ SELECT 1 FROM event_backward_extremities
+ WHERE
+ room_id = ?
+ AND %s
+ LIMIT 1
+ """
+
+ # If the event in question is a backward extremity or has any of its
+ # prev_events listed as a backward extremity, it's next to a
+ # backward gap.
+ clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "event_id",
+ [event.event_id] + list(event.prev_event_ids()),
+ )
+
+ txn.execute(backward_extremity_query % (clause,), [event.room_id] + args)
+ backward_extremities = txn.fetchall()
+
+ # We consider any backward extremity as a backward gap
+ if len(backward_extremities):
+ return True
+
+ return False
+
+ return await self.db_pool.runInteraction(
+ "is_event_next_to_backward_gap_txn",
+ is_event_next_to_backward_gap_txn,
+ )
+
+ async def is_event_next_to_forward_gap(self, event: EventBase) -> bool:
+ """Check if the given event is next to a forward gap of missing events.
+ The gap in front of the latest events is not considered a gap.
+ <latest messages> A(False)--->B(False)--->C(False)---> <gap, unknown events> <oldest messages>
+ <latest messages> A(False)--->B(False)---> <gap, unknown events> --->D(True)--->E(False) <oldest messages>
+
+ Args:
+ room_id: room where the event lives
+ event_id: event to check
+
+ Returns:
+ Boolean indicating whether it's an extremity
+ """
+
+ def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool:
+ # If the event in question is a forward extremity, we will just
+ # consider any potential forward gap as not a gap since it's one of
+ # the latest events in the room.
+ #
+ # `event_forward_extremities` does not include backfilled or outlier
+ # events so we can't rely on it to find forward gaps. We can only
+ # use it to determine whether a message is the latest in the room.
+ #
+ # We can't combine this query with the `forward_edge_query` below
+ # because if the event in question has no forward edges (isn't
+ # referenced by any other event's prev_events) but is in
+ # `event_forward_extremities`, we don't want to return 0 rows and
+ # say it's next to a gap.
+ forward_extremity_query = """
+ SELECT 1 FROM event_forward_extremities
+ WHERE
+ room_id = ?
+ AND event_id = ?
+ LIMIT 1
+ """
+
+ # Check to see whether the event in question is already referenced
+ # by another event. If we don't see any edges, we're next to a
+ # forward gap.
+ forward_edge_query = """
+ SELECT 1 FROM event_edges
+ /* Check to make sure the event referencing our event in question is not rejected */
+ LEFT JOIN rejections ON event_edges.event_id == rejections.event_id
+ WHERE
+ event_edges.room_id = ?
+ AND event_edges.prev_event_id = ?
+ /* It's not a valid edge if the event referencing our event in
+ * question is rejected.
+ */
+ AND rejections.event_id IS NULL
+ LIMIT 1
+ """
+
+ # We consider any forward extremity as the latest in the room and
+ # not a forward gap.
+ #
+ # To expand, even though there is technically a gap at the front of
+ # the room where the forward extremities are, we consider those the
+ # latest messages in the room so asking other homeservers for more
+ # is useless. The new latest messages will just be federated as
+ # usual.
+ txn.execute(forward_extremity_query, (event.room_id, event.event_id))
+ forward_extremities = txn.fetchall()
+ if len(forward_extremities):
+ return False
+
+ # If there are no forward edges to the event in question (another
+ # event hasn't referenced this event in their prev_events), then we
+ # assume there is a forward gap in the history.
+ txn.execute(forward_edge_query, (event.room_id, event.event_id))
+ forward_edges = txn.fetchall()
+ if not len(forward_edges):
+ return True
+
+ return False
+
+ return await self.db_pool.runInteraction(
+ "is_event_next_to_gap_txn",
+ is_event_next_to_gap_txn,
+ )
+
+ async def get_event_id_for_timestamp(
+ self, room_id: str, timestamp: int, direction: str
+ ) -> Optional[str]:
+ """Find the closest event to the given timestamp in the given direction.
+
+ Args:
+ room_id: Room to fetch the event from
+ timestamp: The point in time (inclusive) we should navigate from in
+ the given direction to find the closest event.
+ direction: ["f"|"b"] to indicate whether we should navigate forward
+ or backward from the given timestamp to find the closest event.
+
+ Returns:
+ The closest event_id otherwise None if we can't find any event in
+ the given direction.
+ """
+
+ sql_template = """
+ SELECT event_id FROM events
+ LEFT JOIN rejections USING (event_id)
+ WHERE
+ origin_server_ts %s ?
+ AND room_id = ?
+ /* Make sure event is not rejected */
+ AND rejections.event_id IS NULL
+ ORDER BY origin_server_ts %s
+ LIMIT 1;
+ """
+
+ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]:
+ if direction == "b":
+ # Find closest event *before* a given timestamp. We use descending
+ # (which gives values largest to smallest) because we want the
+ # largest possible timestamp *before* the given timestamp.
+ comparison_operator = "<="
+ order = "DESC"
+ else:
+ # Find closest event *after* a given timestamp. We use ascending
+ # (which gives values smallest to largest) because we want the
+ # closest possible timestamp *after* the given timestamp.
+ comparison_operator = ">="
+ order = "ASC"
+
+ txn.execute(
+ sql_template % (comparison_operator, order), (timestamp, room_id)
+ )
+ row = txn.fetchone()
+ if row:
+ (event_id,) = row
+ return event_id
+
+ return None
+
+ if direction not in ("f", "b"):
+ raise ValueError("Unknown direction: %s" % (direction,))
+
+ return await self.db_pool.runInteraction(
+ "get_event_id_for_timestamp_txn",
+ get_event_id_for_timestamp_txn,
+ )
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 434986fa64..cf842803bc 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,6 +19,7 @@ from canonicaljson import encode_canonical_json
from synapse.api.errors import Codes, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
@@ -49,7 +51,7 @@ class FilteringStore(SQLBaseStore):
# Need an atomic transaction to SELECT the maximal ID so far then
# INSERT a new one
- def _do_txn(txn):
+ def _do_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT filter_id FROM user_filters "
"WHERE user_id = ? AND filter_json = ?"
@@ -61,7 +63,7 @@ class FilteringStore(SQLBaseStore):
sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
txn.execute(sql, (user_localpart,))
- max_id = txn.fetchone()[0]
+ max_id = txn.fetchone()[0] # type: ignore[index]
if max_id is None:
filter_id = 0
else:
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index 3d0df0cbd4..a540f7fb26 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
from types import TracebackType
-from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Optional, Tuple, Type
from weakref import WeakValueDictionary
from twisted.internet.interfaces import IReactorCore
@@ -62,7 +62,9 @@ class LockStore(SQLBaseStore):
# A map from `(lock_name, lock_key)` to the token of any locks that we
# think we currently hold.
- self._live_tokens: Dict[Tuple[str, str], Lock] = WeakValueDictionary()
+ self._live_tokens: WeakValueDictionary[
+ Tuple[str, str], Lock
+ ] = WeakValueDictionary()
# When we shut down we want to remove the locks. Technically this can
# lead to a race, as we may drop the lock while we are still processing.
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 717487be28..1b076683f7 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -13,10 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
+from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -46,7 +61,12 @@ class MediaSortOrder(Enum):
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@@ -102,13 +122,15 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
self._drop_media_index_without_method,
)
- async def _drop_media_index_without_method(self, progress, batch_size):
+ async def _drop_media_index_without_method(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""background update handler which removes the old constraints.
Note that this is only run on postgres.
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> None:
txn.execute(
"ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
)
@@ -126,7 +148,12 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars"""
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@@ -174,7 +201,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
plus the total count of all the user's media
"""
- def get_local_media_by_user_paginate_txn(txn):
+ def get_local_media_by_user_paginate_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Dict[str, Any]], int]:
# Set ordering
order_by_column = MediaSortOrder(order_by).value
@@ -184,14 +213,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
else:
order = "ASC"
- args = [user_id]
+ args: List[Union[str, int]] = [user_id]
sql = """
SELECT COUNT(*) as total_media
FROM local_media_repository
WHERE user_id = ?
"""
txn.execute(sql, args)
- count = txn.fetchone()[0]
+ count = txn.fetchone()[0] # type: ignore[index]
sql = """
SELECT
@@ -268,7 +297,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
sql += sql_keep
- def _get_local_media_before_txn(txn):
+ def _get_local_media_before_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (before_ts, before_ts, size_gt))
return [row[0] for row in txn]
@@ -278,13 +307,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def store_local_media(
self,
- media_id,
- media_type,
- time_now_ms,
- upload_name,
- media_length,
- user_id,
- url_cache=None,
+ media_id: str,
+ media_type: str,
+ time_now_ms: int,
+ upload_name: Optional[str],
+ media_length: int,
+ user_id: UserID,
+ url_cache: Optional[str] = None,
) -> None:
await self.db_pool.simple_insert(
"local_media_repository",
@@ -315,7 +344,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
None if the URL isn't cached.
"""
- def get_url_cache_txn(txn):
+ def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
# get the most recently cached result (relative to the given ts)
sql = (
"SELECT response_code, etag, expires_ts, og, media_id, download_ts"
@@ -359,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
- ):
+ ) -> None:
await self.db_pool.simple_insert(
"local_media_repository_url_cache",
{
@@ -390,13 +419,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def store_local_thumbnail(
self,
- media_id,
- thumbnail_width,
- thumbnail_height,
- thumbnail_type,
- thumbnail_method,
- thumbnail_length,
- ):
+ media_id: str,
+ thumbnail_width: int,
+ thumbnail_height: int,
+ thumbnail_type: str,
+ thumbnail_method: str,
+ thumbnail_length: int,
+ ) -> None:
await self.db_pool.simple_upsert(
table="local_media_repository_thumbnails",
keyvalues={
@@ -430,14 +459,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def store_cached_remote_media(
self,
- origin,
- media_id,
- media_type,
- media_length,
- time_now_ms,
- upload_name,
- filesystem_id,
- ):
+ origin: str,
+ media_id: str,
+ media_type: str,
+ media_length: int,
+ time_now_ms: int,
+ upload_name: Optional[str],
+ filesystem_id: str,
+ ) -> None:
await self.db_pool.simple_insert(
"remote_media_cache",
{
@@ -458,7 +487,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
local_media: Iterable[str],
remote_media: Iterable[Tuple[str, str]],
time_ms: int,
- ):
+ ) -> None:
"""Updates the last access time of the given media
Args:
@@ -467,7 +496,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
time_ms: Current time in milliseconds
"""
- def update_cache_txn(txn):
+ def update_cache_txn(txn: LoggingTransaction) -> None:
sql = (
"UPDATE remote_media_cache SET last_access_ts = ?"
" WHERE media_origin = ? AND media_id = ?"
@@ -488,7 +517,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn
)
@@ -542,15 +571,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def store_remote_media_thumbnail(
self,
- origin,
- media_id,
- filesystem_id,
- thumbnail_width,
- thumbnail_height,
- thumbnail_type,
- thumbnail_method,
- thumbnail_length,
- ):
+ origin: str,
+ media_id: str,
+ filesystem_id: str,
+ thumbnail_width: int,
+ thumbnail_height: int,
+ thumbnail_type: str,
+ thumbnail_method: str,
+ thumbnail_length: int,
+ ) -> None:
await self.db_pool.simple_upsert(
table="remote_media_cache_thumbnails",
keyvalues={
@@ -566,7 +595,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_remote_media_thumbnail",
)
- async def get_remote_media_before(self, before_ts):
+ async def get_remote_media_before(self, before_ts: int) -> List[Dict[str, str]]:
sql = (
"SELECT media_origin, media_id, filesystem_id"
" FROM remote_media_cache"
@@ -602,7 +631,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" LIMIT 500"
)
- def _get_expired_url_cache_txn(txn):
+ def _get_expired_url_cache_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (now_ts,))
return [row[0] for row in txn]
@@ -610,18 +639,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_expired_url_cache", _get_expired_url_cache_txn
)
- async def delete_url_cache(self, media_ids):
+ async def delete_url_cache(self, media_ids: Collection[str]) -> None:
if len(media_ids) == 0:
return
sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
- def _delete_url_cache_txn(txn):
+ def _delete_url_cache_txn(txn: LoggingTransaction) -> None:
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
- return await self.db_pool.runInteraction(
- "delete_url_cache", _delete_url_cache_txn
- )
+ await self.db_pool.runInteraction("delete_url_cache", _delete_url_cache_txn)
async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
sql = (
@@ -631,7 +658,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" LIMIT 500"
)
- def _get_url_cache_media_before_txn(txn):
+ def _get_url_cache_media_before_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (before_ts,))
return [row[0] for row in txn]
@@ -639,11 +666,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_url_cache_media_before", _get_url_cache_media_before_txn
)
- async def delete_url_cache_media(self, media_ids):
+ async def delete_url_cache_media(self, media_ids: Collection[str]) -> None:
if len(media_ids) == 0:
return
- def _delete_url_cache_media_txn(txn):
+ def _delete_url_cache_media_txn(txn: LoggingTransaction) -> None:
sql = "DELETE FROM local_media_repository WHERE media_id = ?"
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
@@ -652,6 +679,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn
)
diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py
index 2aac64901b..a46685219f 100644
--- a/synapse/storage/databases/main/openid.py
+++ b/synapse/storage/databases/main/openid.py
@@ -1,6 +1,21 @@
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import Optional
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import LoggingTransaction
class OpenIdStore(SQLBaseStore):
@@ -20,7 +35,7 @@ class OpenIdStore(SQLBaseStore):
async def get_user_id_for_open_id_token(
self, token: str, ts_now_ms: int
) -> Optional[str]:
- def get_user_id_for_token_txn(txn):
+ def get_user_id_for_token_txn(txn: LoggingTransaction) -> Optional[str]:
sql = (
"SELECT user_id FROM open_id_tokens"
" WHERE token = ? AND ? <= ts_valid_until_ms"
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index dd8e27e226..e197b7203e 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.roommember import ProfileInfo
@@ -104,7 +105,7 @@ class ProfileWorkerStore(SQLBaseStore):
desc="update_remote_profile_cache",
)
- async def maybe_delete_remote_profile_cache(self, user_id):
+ async def maybe_delete_remote_profile_cache(self, user_id: str) -> None:
"""Check if we still care about the remote user's profile, and if we
don't then remove their profile from the cache
"""
@@ -116,9 +117,9 @@ class ProfileWorkerStore(SQLBaseStore):
desc="delete_remote_profile_cache",
)
- async def is_subscribed_remote_profile_for_user(self, user_id):
+ async def is_subscribed_remote_profile_for_user(self, user_id: str) -> bool:
"""Check whether we are interested in a remote user's profile."""
- res = await self.db_pool.simple_select_one_onecol(
+ res: Optional[str] = await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"user_id": user_id},
retcol="user_id",
@@ -139,13 +140,16 @@ class ProfileWorkerStore(SQLBaseStore):
if res:
return True
+ return False
async def get_remote_profile_cache_entries_that_expire(
self, last_checked: int
) -> List[Dict[str, str]]:
"""Get all users who haven't been checked since `last_checked`"""
- def _get_remote_profile_cache_entries_that_expire_txn(txn):
+ def _get_remote_profile_cache_entries_that_expire_txn(
+ txn: LoggingTransaction,
+ ) -> List[Dict[str, str]]:
sql = """
SELECT user_id, displayname, avatar_url
FROM remote_profile_cache
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 3eb30944bf..91b0576b85 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -118,7 +118,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
logger.info("[purge] looking for events to delete")
- should_delete_expr = "state_key IS NULL"
+ should_delete_expr = "state_events.state_key IS NULL"
should_delete_params: Tuple[Any, ...] = ()
if not delete_local_events:
should_delete_expr += " AND event_id NOT LIKE ?"
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index fa782023d4..3b63267395 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -28,7 +28,10 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdTracker,
+ StreamIdGenerator,
+)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -82,9 +85,9 @@ class PushRulesWorkerStore(
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
- self._push_rules_stream_id_gen: Union[
- StreamIdGenerator, SlavedIdTracker
- ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
+ self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+ db_conn, "push_rules_stream", "stream_id"
+ )
else:
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index c99f8aebdb..9c5625c8bb 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,14 +14,25 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+)
from twisted.internet import defer
+from synapse.api.constants import ReceiptTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
@@ -78,17 +89,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
)
- def get_max_receipt_stream_id(self):
- """Get the current max stream ID for receipts stream
-
- Returns:
- int
- """
+ def get_max_receipt_stream_id(self) -> int:
+ """Get the current max stream ID for receipts stream"""
return self._receipts_id_gen.get_current_token()
@cached()
- async def get_users_with_read_receipts_in_room(self, room_id):
- receipts = await self.get_receipts_for_room(room_id, "m.read")
+ async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]:
+ receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ)
return {r["user_id"] for r in receipts}
@cached(num_args=2)
@@ -119,7 +126,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
@cached(num_args=2)
- async def get_receipts_for_user(self, user_id, receipt_type):
+ async def get_receipts_for_user(
+ self, user_id: str, receipt_type: str
+ ) -> Dict[str, str]:
rows = await self.db_pool.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
@@ -129,8 +138,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
return {row["room_id"]: row["event_id"] for row in rows}
- async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
- def f(txn):
+ async def get_receipts_for_user_with_orderings(
+ self, user_id: str, receipt_type: str
+ ) -> JsonDict:
+ def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
sql = (
"SELECT rl.room_id, rl.event_id,"
" e.topological_ordering, e.stream_ordering"
@@ -209,10 +220,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(num_args=3, tree=True)
async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
- ) -> List[dict]:
+ ) -> List[JsonDict]:
"""See get_linearized_receipts_for_room"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if from_key:
sql = (
"SELECT * FROM receipts_linearized WHERE"
@@ -250,11 +261,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
list_name="room_ids",
num_args=3,
)
- async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+ async def _get_linearized_receipts_for_rooms(
+ self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
+ ) -> Dict[str, List[JsonDict]]:
if not room_ids:
return {}
- def f(txn):
+ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if from_key:
sql = """
SELECT * FROM receipts_linearized WHERE
@@ -323,7 +336,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
A dictionary of roomids to a list of receipts.
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if from_key:
sql = """
SELECT * FROM receipts_linearized WHERE
@@ -379,7 +392,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
if last_id == current_id:
return defer.succeed([])
- def _get_users_sent_receipts_between_txn(txn):
+ def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
sql = """
SELECT DISTINCT user_id FROM receipts_linearized
WHERE ? < stream_id AND stream_id <= ?
@@ -419,7 +432,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
- def get_all_updated_receipts_txn(txn):
+ def get_all_updated_receipts_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, list]], int, bool]:
sql = """
SELECT stream_id, room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized
@@ -446,8 +461,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
def _invalidate_get_users_with_receipts_in_room(
self, room_id: str, receipt_type: str, user_id: str
- ):
- if receipt_type != "m.read":
+ ) -> None:
+ if receipt_type != ReceiptTypes.READ:
return
res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
@@ -461,7 +476,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
- def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
+ def invalidate_caches_for_receipt(
+ self, room_id: str, receipt_type: str, user_id: str
+ ) -> None:
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
@@ -482,11 +499,18 @@ class ReceiptsWorkerStore(SQLBaseStore):
return super().process_replication_rows(stream_name, instance_name, token, rows)
def insert_linearized_receipt_txn(
- self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
- ):
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ receipt_type: str,
+ user_id: str,
+ event_id: str,
+ data: JsonDict,
+ stream_id: int,
+ ) -> Optional[int]:
"""Inserts a read-receipt into the database if it's newer than the current RR
- Returns: int|None
+ Returns:
None if the RR is older than the current RR
otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown)
@@ -550,7 +574,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
lock=False,
)
- if receipt_type == "m.read" and stream_ordering is not None:
+ if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
self._remove_old_push_actions_before_txn(
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
)
@@ -580,7 +604,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
else:
# we need to points in graph -> linearized form.
# TODO: Make this better.
- def graph_to_linear(txn):
+ def graph_to_linear(txn: LoggingTransaction) -> str:
clause, args = make_in_list_sql_clause(
self.database_engine, "event_id", event_ids
)
@@ -634,11 +658,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
return stream_id, max_persisted_id
async def insert_graph_receipt(
- self, room_id, receipt_type, user_id, event_ids, data
- ):
+ self,
+ room_id: str,
+ receipt_type: str,
+ user_id: str,
+ event_ids: List[str],
+ data: JsonDict,
+ ) -> None:
assert self._can_write_to_receipts
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id,
@@ -649,8 +678,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
def insert_graph_receipt_txn(
- self, txn, room_id, receipt_type, user_id, event_ids, data
- ):
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ receipt_type: str,
+ user_id: str,
+ event_ids: List[str],
+ data: JsonDict,
+ ) -> None:
assert self._can_write_to_receipts
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 6c7d6ba508..e1ddf06916 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -84,28 +84,37 @@ class TokenLookupResult:
return self.user_id
-@attr.s(frozen=True, slots=True)
+@attr.s(auto_attribs=True, frozen=True, slots=True)
class RefreshTokenLookupResult:
"""Result of looking up a refresh token."""
- user_id = attr.ib(type=str)
+ user_id: str
"""The user this token belongs to."""
- device_id = attr.ib(type=str)
+ device_id: str
"""The device associated with this refresh token."""
- token_id = attr.ib(type=int)
+ token_id: int
"""The ID of this refresh token."""
- next_token_id = attr.ib(type=Optional[int])
+ next_token_id: Optional[int]
"""The ID of the refresh token which replaced this one."""
- has_next_refresh_token_been_refreshed = attr.ib(type=bool)
+ has_next_refresh_token_been_refreshed: bool
"""True if the next refresh token was used for another refresh."""
- has_next_access_token_been_used = attr.ib(type=bool)
+ has_next_access_token_been_used: bool
"""True if the next access token was already used at least once."""
+ expiry_ts: Optional[int]
+ """The time at which the refresh token expires and can not be used.
+ If None, the refresh token doesn't expire."""
+
+ ultimate_session_expiry_ts: Optional[int]
+ """The time at which the session comes to an end and can no longer be
+ refreshed.
+ If None, the session can be refreshed indefinitely."""
+
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
@@ -476,7 +485,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
shadow_banned: true iff the user is to be shadow-banned, false otherwise.
"""
- def set_shadow_banned_txn(txn):
+ def set_shadow_banned_txn(txn: LoggingTransaction) -> None:
user_id = user.to_string()
self.db_pool.simple_update_one_txn(
txn,
@@ -1198,8 +1207,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
expiration_ts = now_ms + self._account_validity_period
if use_delta:
+ assert self._account_validity_startup_job_max_delta is not None
expiration_ts = random.randrange(
- expiration_ts - self._account_validity_startup_job_max_delta,
+ int(expiration_ts - self._account_validity_startup_job_max_delta),
expiration_ts,
)
@@ -1625,8 +1635,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
rt.user_id,
rt.device_id,
rt.next_token_id,
- (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed,
- at.used has_next_access_token_been_used
+ (nrt.next_token_id IS NOT NULL) AS has_next_refresh_token_been_refreshed,
+ at.used AS has_next_access_token_been_used,
+ rt.expiry_ts,
+ rt.ultimate_session_expiry_ts
FROM refresh_tokens rt
LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
@@ -1647,6 +1659,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
has_next_refresh_token_been_refreshed=row[4],
# This column is nullable, ensure it's a boolean
has_next_access_token_been_used=(row[5] or False),
+ expiry_ts=row[6],
+ ultimate_session_expiry_ts=row[7],
)
return await self.db_pool.runInteraction(
@@ -1728,11 +1742,11 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
)
self.db_pool.updates.register_background_update_handler(
- "user_threepids_grandfather", self._bg_user_threepids_grandfather
+ "users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
- self.db_pool.updates.register_background_update_handler(
- "users_set_deactivated_flag", self._background_update_set_deactivated_flag
+ self.db_pool.updates.register_noop_background_update(
+ "user_threepids_grandfather"
)
self.db_pool.updates.register_background_index_update(
@@ -1805,35 +1819,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
return nb_processed
- async def _bg_user_threepids_grandfather(self, progress, batch_size):
- """We now track which identity servers a user binds their 3PID to, so
- we need to handle the case of existing bindings where we didn't track
- this.
-
- We do this by grandfathering in existing user threepids assuming that
- they used one of the server configured trusted identity servers.
- """
- id_servers = set(self.config.registration.trusted_third_party_id_servers)
-
- def _bg_user_threepids_grandfather_txn(txn):
- sql = """
- INSERT INTO user_threepid_id_server
- (user_id, medium, address, id_server)
- SELECT user_id, medium, address, ?
- FROM user_threepids
- """
-
- txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
-
- if id_servers:
- await self.db_pool.runInteraction(
- "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
- )
-
- await self.db_pool.updates._end_background_update("user_threepids_grandfather")
-
- return 1
-
async def set_user_deactivated_status(
self, user_id: str, deactivated: bool
) -> None:
@@ -1943,6 +1928,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
user_id: str,
token: str,
device_id: Optional[str],
+ expiry_ts: Optional[int],
+ ultimate_session_expiry_ts: Optional[int],
) -> int:
"""Adds a refresh token for the given user.
@@ -1950,6 +1937,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
user_id: The user ID.
token: The new access token to add.
device_id: ID of the device to associate with the refresh token.
+ expiry_ts (milliseconds since the epoch): Time after which the
+ refresh token cannot be used.
+ If None, the refresh token never expires until it has been used.
+ ultimate_session_expiry_ts (milliseconds since the epoch):
+ Time at which the session will end and can not be extended any
+ further.
+ If None, the session can be refreshed indefinitely.
Raises:
StoreError if there was a problem adding this.
Returns:
@@ -1965,6 +1959,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"device_id": device_id,
"token": token,
"next_token_id": None,
+ "expiry_ts": expiry_ts,
+ "ultimate_session_expiry_ts": ultimate_session_expiry_ts,
},
desc="add_refresh_token_to_user",
)
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 907af10995..0a43acda07 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -132,6 +132,69 @@ class RelationsWorkerStore(SQLBaseStore):
"get_recent_references_for_event", _get_recent_references_for_event_txn
)
+ async def event_includes_relation(self, event_id: str) -> bool:
+ """Check if the given event relates to another event.
+
+ An event has a relation if it has a valid m.relates_to with a rel_type
+ and event_id in the content:
+
+ {
+ "content": {
+ "m.relates_to": {
+ "rel_type": "m.replace",
+ "event_id": "$other_event_id"
+ }
+ }
+ }
+
+ Args:
+ event_id: The event to check.
+
+ Returns:
+ True if the event includes a valid relation.
+ """
+
+ result = await self.db_pool.simple_select_one_onecol(
+ table="event_relations",
+ keyvalues={"event_id": event_id},
+ retcol="event_id",
+ allow_none=True,
+ desc="event_includes_relation",
+ )
+ return result is not None
+
+ async def event_is_target_of_relation(self, parent_id: str) -> bool:
+ """Check if the given event is the target of another event's relation.
+
+ An event is the target of an event relation if it has a valid
+ m.relates_to with a rel_type and event_id pointing to parent_id in the
+ content:
+
+ {
+ "content": {
+ "m.relates_to": {
+ "rel_type": "m.replace",
+ "event_id": "$parent_id"
+ }
+ }
+ }
+
+ Args:
+ parent_id: The event to check.
+
+ Returns:
+ True if the event is the target of another event's relation.
+ """
+
+ result = await self.db_pool.simple_select_one_onecol(
+ table="event_relations",
+ keyvalues={"relates_to_id": parent_id},
+ retcol="event_id",
+ allow_none=True,
+ desc="event_is_target_of_relation",
+ )
+ return result is not None
+
@cached(tree=True)
async def get_aggregation_groups_for_event(
self,
@@ -362,7 +425,7 @@ class RelationsWorkerStore(SQLBaseStore):
%s;
"""
- def _get_if_event_has_relations(txn) -> List[str]:
+ def _get_if_events_have_relations(txn) -> List[str]:
clauses: List[str] = []
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", parent_ids
@@ -387,7 +450,7 @@ class RelationsWorkerStore(SQLBaseStore):
return [row[0] for row in txn]
return await self.db_pool.runInteraction(
- "get_if_event_has_relations", _get_if_event_has_relations
+ "get_if_events_have_relations", _get_if_events_have_relations
)
async def has_user_annotated_event(
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 17b398bb69..7d694d852d 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -397,6 +397,20 @@ class RoomWorkerStore(SQLBaseStore):
desc="is_room_blocked",
)
+ async def room_is_blocked_by(self, room_id: str) -> Optional[str]:
+ """
+ Function to retrieve user who has blocked the room.
+ user_id is non-nullable
+ It returns None if the room is not blocked.
+ """
+ return await self.db_pool.simple_select_one_onecol(
+ table="blocked_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ allow_none=True,
+ desc="room_is_blocked_by",
+ )
+
async def get_rooms_paginate(
self,
start: int,
@@ -1775,3 +1789,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
self.is_room_blocked,
(room_id,),
)
+
+ async def unblock_room(self, room_id: str) -> None:
+ """Remove the room from blocking list.
+
+ Args:
+ room_id: Room to unblock
+ """
+ await self.db_pool.simple_delete(
+ table="blocked_rooms",
+ keyvalues={"room_id": room_id},
+ desc="unblock_room",
+ )
+ await self.db_pool.runInteraction(
+ "block_room_invalidation",
+ self._invalidate_cache_and_stream,
+ self.is_room_blocked,
+ (room_id,),
+ )
diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py
index 97b2618437..39e80f6f5b 100644
--- a/synapse/storage/databases/main/room_batch.py
+++ b/synapse/storage/databases/main/room_batch.py
@@ -39,13 +39,11 @@ class RoomBatchStore(SQLBaseStore):
async def store_state_group_id_for_event_id(
self, event_id: str, state_group_id: int
- ) -> Optional[str]:
- {
- await self.db_pool.simple_upsert(
- table="event_to_state_groups",
- keyvalues={"event_id": event_id},
- values={"state_group": state_group_id, "event_id": event_id},
- # Unique constraint on event_id so we don't have to lock
- lock=False,
- )
- }
+ ) -> None:
+ await self.db_pool.simple_upsert(
+ table="event_to_state_groups",
+ keyvalues={"event_id": event_id},
+ values={"state_group": state_group_id, "event_id": event_id},
+ # Unique constraint on event_id so we don't have to lock
+ lock=False,
+ )
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 033a9831d6..6b2a8d06a6 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -476,7 +476,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
- AND state_key = ?
+ AND c.state_key = ?
AND c.membership = ?
"""
else:
@@ -487,7 +487,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
- AND state_key = ?
+ AND c.state_key = ?
AND m.membership = ?
"""
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index ab2159c2d3..3201623fe4 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -63,12 +63,12 @@ class SignatureWorkerStore(SQLBaseStore):
A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash.
"""
hashes = await self.get_event_reference_hashes(event_ids)
- hashes = {
+ encoded_hashes = {
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
for e_id, h in hashes.items()
}
- return list(hashes.items())
+ return list(encoded_hashes.items())
def _get_event_reference_hashes_txn(
self, txn: Cursor, event_id: str
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index a89747d741..188afec332 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -16,11 +16,17 @@ import logging
from typing import Any, Dict, List, Tuple
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import LoggingTransaction
+from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
class StateDeltasStore(SQLBaseStore):
+ # This class must be mixed in with a child class which provides the following
+ # attribute. TODO: can we get static analysis to enforce this?
+ _curr_state_delta_stream_cache: StreamChangeCache
+
async def get_current_state_deltas(
self, prev_stream_id: int, max_stream_id: int
) -> Tuple[int, List[Dict[str, Any]]]:
@@ -50,7 +56,9 @@ class StateDeltasStore(SQLBaseStore):
prev_stream_id = int(prev_stream_id)
# check we're not going backwards
- assert prev_stream_id <= max_stream_id
+ assert (
+ prev_stream_id <= max_stream_id
+ ), f"New stream id {max_stream_id} is smaller than prev stream id {prev_stream_id}"
if not self._curr_state_delta_stream_cache.has_any_entity_changed(
prev_stream_id
@@ -60,7 +68,9 @@ class StateDeltasStore(SQLBaseStore):
# max_stream_id.
return max_stream_id, []
- def get_current_state_deltas_txn(txn):
+ def get_current_state_deltas_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[int, List[Dict[str, Any]]]:
# First we calculate the max stream id that will give us less than
# N results.
# We arbitrarily limit to 100 stream_id entries to ensure we don't
@@ -106,7 +116,9 @@ class StateDeltasStore(SQLBaseStore):
"get_current_state_deltas", get_current_state_deltas_txn
)
- def _get_max_stream_id_in_current_state_deltas_txn(self, txn):
+ def _get_max_stream_id_in_current_state_deltas_txn(
+ self, txn: LoggingTransaction
+ ) -> int:
return self.db_pool.simple_select_one_onecol_txn(
txn,
table="current_state_delta_stream",
@@ -114,7 +126,7 @@ class StateDeltasStore(SQLBaseStore):
retcol="COALESCE(MAX(stream_id), -1)",
)
- async def get_max_stream_id_in_current_state_deltas(self):
+ async def get_max_stream_id_in_current_state_deltas(self) -> int:
return await self.db_pool.runInteraction(
"get_max_stream_id_in_current_state_deltas",
self._get_max_stream_id_in_current_state_deltas_txn,
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 42dc807d17..57aab55259 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -497,7 +497,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
oldest `limit` events.
Returns:
- The list of events (in ascending order) and the token from the start
+ The list of events (in ascending stream order) and the token from the start
of the chunk of events returned.
"""
if from_key == to_key:
@@ -510,7 +510,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
if not has_changed:
return [], from_key
- def f(txn):
+ def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
# To handle tokens with a non-empty instance_map we fetch more
# results than necessary and then filter down
min_from_id = from_key.stream
@@ -565,6 +565,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
async def get_membership_changes_for_user(
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]:
+ """Fetch membership events for a given user.
+
+ All such events whose stream ordering `s` lies in the range
+ `from_key < s <= to_key` are returned. Events are ordered by ascending stream
+ order.
+ """
+ # Start by ruling out cases where a DB query is not necessary.
if from_key == to_key:
return []
@@ -575,7 +582,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
if not has_changed:
return []
- def f(txn):
+ def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
# To handle tokens with a non-empty instance_map we fetch more
# results than necessary and then filter down
min_from_id = from_key.stream
@@ -634,7 +641,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
Returns:
A list of events and a token pointing to the start of the returned
- events. The events returned are in ascending order.
+ events. The events returned are in ascending topological order.
"""
rows, token = await self.get_recent_event_ids_for_room(
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index f93ff0a545..8f510de53d 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -1,5 +1,6 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,9 +15,10 @@
# limitations under the License.
import logging
-from typing import Dict, List, Tuple
+from typing import Dict, List, Tuple, cast
from synapse.storage._base import db_to_json
+from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.types import JsonDict
from synapse.util import json_encoder
@@ -50,7 +52,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
async def get_all_updated_tags(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ ) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]:
"""Get updates for tags replication stream.
Args:
@@ -75,7 +77,9 @@ class TagsWorkerStore(AccountDataWorkerStore):
if last_id == current_id:
return [], current_id, False
- def get_all_updated_tags_txn(txn):
+ def get_all_updated_tags_txn(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str]]:
sql = (
"SELECT stream_id, user_id, room_id"
" FROM room_tags_revisions as r"
@@ -83,13 +87,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
- return txn.fetchall()
+ # mypy doesn't understand what the query is selecting.
+ return cast(List[Tuple[int, str, str]], txn.fetchall())
tag_ids = await self.db_pool.runInteraction(
"get_all_updated_tags", get_all_updated_tags_txn
)
- def get_tag_content(txn, tag_ids):
+ def get_tag_content(
+ txn: LoggingTransaction, tag_ids
+ ) -> List[Tuple[int, Tuple[str, str, str]]]:
sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
results = []
for stream_id, user_id, room_id in tag_ids:
@@ -127,15 +134,15 @@ class TagsWorkerStore(AccountDataWorkerStore):
given version
Args:
- user_id(str): The user to get the tags for.
- stream_id(int): The earliest update to get for the user.
+ user_id: The user to get the tags for.
+ stream_id: The earliest update to get for the user.
Returns:
A mapping from room_id strings to lists of tag strings for all the
rooms that changed since the stream_id token.
"""
- def get_updated_tags_txn(txn):
+ def get_updated_tags_txn(txn: LoggingTransaction) -> List[str]:
sql = (
"SELECT room_id from room_tags_revisions"
" WHERE user_id = ? AND stream_id > ?"
@@ -200,7 +207,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
content_json = json_encoder.encode(content)
- def add_tag_txn(txn, next_id):
+ def add_tag_txn(txn: LoggingTransaction, next_id: int) -> None:
self.db_pool.simple_upsert_txn(
txn,
table="room_tags",
@@ -224,7 +231,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
"""
assert self._can_write_to_account_data
- def remove_tag_txn(txn, next_id):
+ def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None:
sql = (
"DELETE FROM room_tags "
" WHERE user_id = ? AND room_id = ? AND tag = ?"
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index d7dc1f73ac..1622822552 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -14,6 +14,7 @@
import logging
from collections import namedtuple
+from enum import Enum
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
import attr
@@ -44,6 +45,16 @@ _UpdateTransactionRow = namedtuple(
)
+class DestinationSortOrder(Enum):
+ """Enum to define the sorting method used when returning destinations."""
+
+ DESTINATION = "destination"
+ RETRY_LAST_TS = "retry_last_ts"
+ RETTRY_INTERVAL = "retry_interval"
+ FAILURE_TS = "failure_ts"
+ LAST_SUCCESSFUL_STREAM_ORDERING = "last_successful_stream_ordering"
+
+
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DestinationRetryTimings:
"""The current destination retry timing info for a remote server."""
@@ -480,3 +491,62 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
destinations = [row[0] for row in txn]
return destinations
+
+ async def get_destinations_paginate(
+ self,
+ start: int,
+ limit: int,
+ destination: Optional[str] = None,
+ order_by: str = DestinationSortOrder.DESTINATION.value,
+ direction: str = "f",
+ ) -> Tuple[List[JsonDict], int]:
+ """Function to retrieve a paginated list of destinations.
+ This will return a json list of destinations and the
+ total number of destinations matching the filter criteria.
+
+ Args:
+ start: start number to begin the query from
+ limit: number of rows to retrieve
+ destination: search string in destination
+ order_by: the sort order of the returned list
+ direction: sort ascending or descending
+ Returns:
+ A tuple of a list of mappings from destination to information
+ and a count of total destinations.
+ """
+
+ def get_destinations_paginate_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[JsonDict], int]:
+ order_by_column = DestinationSortOrder(order_by).value
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ args = []
+ where_statement = ""
+ if destination:
+ args.extend(["%" + destination.lower() + "%"])
+ where_statement = "WHERE LOWER(destination) LIKE ?"
+
+ sql_base = f"FROM destinations {where_statement} "
+ sql = f"SELECT COUNT(*) as total_destinations {sql_base}"
+ txn.execute(sql, args)
+ count = txn.fetchone()[0]
+
+ sql = f"""
+ SELECT destination, retry_last_ts, retry_interval, failure_ts,
+ last_successful_stream_ordering
+ {sql_base}
+ ORDER BY {order_by_column} {order}, destination ASC
+ LIMIT ? OFFSET ?
+ """
+ txn.execute(sql, args + [limit, start])
+ destinations = self.db_pool.cursor_to_dict(txn)
+ return destinations, count
+
+ return await self.db_pool.runInteraction(
+ "get_destinations_paginate_txn", get_destinations_paginate_txn
+ )
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index 1ecdd40c38..f79006533f 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -14,11 +14,12 @@
from typing import Dict, Iterable
-from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import LoggingTransaction
+from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.util.caches.descriptors import cached, cachedList
-class UserErasureWorkerStore(SQLBaseStore):
+class UserErasureWorkerStore(CacheInvalidationWorkerStore):
@cached()
async def is_user_erased(self, user_id: str) -> bool:
"""
@@ -69,7 +70,7 @@ class UserErasureStore(UserErasureWorkerStore):
user_id: full user_id to be erased
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> None:
# first check if they are already in the list
txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,))
if txn.fetchone():
@@ -89,7 +90,7 @@ class UserErasureStore(UserErasureWorkerStore):
user_id: full user_id to be un-erased
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> None:
# first check if they are already in the list
txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,))
if not txn.fetchone():
|