diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 6305414e3d..eee07227ef 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -36,7 +36,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
if (
hs.config.worker.run_background_tasks
- and self.hs.config.redaction_retention_period is not None
+ and self.hs.config.server.redaction_retention_period is not None
):
hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000)
@@ -48,7 +48,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
By censor we mean update the event_json table with the redacted event.
"""
- if self.hs.config.redaction_retention_period is None:
+ if self.hs.config.server.redaction_retention_period is None:
return
if not (
@@ -60,7 +60,9 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
# created.
return
- before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
+ before_ts = (
+ self._clock.time_msec() - self.hs.config.server.redaction_retention_period
+ )
# We fetch all redactions that:
# 1. point to an event we have,
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 5f611d7b09..400831282a 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -353,7 +353,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- self.user_ips_max_age = hs.config.user_ips_max_age
+ self.user_ips_max_age = hs.config.server.user_ips_max_age
if hs.config.worker.run_background_tasks and self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 584f818ff3..19f55c19c5 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -104,7 +104,7 @@ class PersistEventsStore:
self._clock = hs.get_clock()
self._instance_name = hs.get_instance_name()
- self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+ self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
# Ideally we'd move these ID gens here, unfortunately some other ID
@@ -1276,13 +1276,6 @@ class PersistEventsStore:
logger.exception("")
raise
- # update the stored internal_metadata to update the "outlier" flag.
- # TODO: This is unused as of Synapse 1.31. Remove it once we are happy
- # to drop backwards-compatibility with 1.30.
- metadata_json = json_encoder.encode(event.internal_metadata.get_dict())
- sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
- txn.execute(sql, (metadata_json, event.event_id))
-
# Add an entry to the ex_outlier_stream table to replicate the
# change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering
@@ -1327,19 +1320,6 @@ class PersistEventsStore:
d.pop("redacted_because", None)
return d
- def get_internal_metadata(event):
- im = event.internal_metadata.get_dict()
-
- # temporary hack for database compatibility with Synapse 1.30 and earlier:
- # store the `outlier` flag inside the internal_metadata json as well as in
- # the `events` table, so that if anyone rolls back to an older Synapse,
- # things keep working. This can be removed once we are happy to drop support
- # for that
- if event.internal_metadata.is_outlier():
- im["outlier"] = True
-
- return im
-
self.db_pool.simple_insert_many_txn(
txn,
table="event_json",
@@ -1348,7 +1328,7 @@ class PersistEventsStore:
"event_id": event.event_id,
"room_id": event.room_id,
"internal_metadata": json_encoder.encode(
- get_internal_metadata(event)
+ event.internal_metadata.get_dict()
),
"json": json_encoder.encode(event_dict(event)),
"format_version": event.format_version,
@@ -1783,9 +1763,8 @@ class PersistEventsStore:
retcol="creator",
allow_none=True,
)
- if (
- not room_version.msc2716_historical
- or not self.hs.config.experimental.msc2716_enabled
+ if not room_version.msc2716_historical and (
+ not self.hs.config.experimental.msc2716_enabled
or event.sender != room_creator
):
return
@@ -1845,9 +1824,8 @@ class PersistEventsStore:
retcol="creator",
allow_none=True,
)
- if (
- not room_version.msc2716_historical
- or not self.hs.config.experimental.msc2716_enabled
+ if not room_version.msc2716_historical and (
+ not self.hs.config.experimental.msc2716_enabled
or event.sender != room_creator
):
return
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index bb244a03c0..434986fa64 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Union
+
from canonicaljson import encode_canonical_json
from synapse.api.errors import Codes, SynapseError
@@ -22,7 +24,9 @@ from synapse.util.caches.descriptors import cached
class FilteringStore(SQLBaseStore):
@cached(num_args=2)
- async def get_user_filter(self, user_localpart, filter_id):
+ async def get_user_filter(
+ self, user_localpart: str, filter_id: Union[int, str]
+ ) -> JsonDict:
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
# with a coherent error message rather than 500 M_UNKNOWN.
try:
@@ -40,7 +44,7 @@ class FilteringStore(SQLBaseStore):
return db_to_json(def_json)
- async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
+ async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
def_json = encode_canonical_json(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index b76ee51a9b..ec4d47a560 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -32,8 +32,8 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
self._clock = hs.get_clock()
self.hs = hs
- self._limit_usage_by_mau = hs.config.limit_usage_by_mau
- self._max_mau_value = hs.config.max_mau_value
+ self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau
+ self._max_mau_value = hs.config.server.max_mau_value
@cached(num_args=0)
async def get_monthly_active_count(self) -> int:
@@ -96,8 +96,8 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"""
users = []
- for tp in self.hs.config.mau_limits_reserved_threepids[
- : self.hs.config.max_mau_value
+ for tp in self.hs.config.server.mau_limits_reserved_threepids[
+ : self.hs.config.server.max_mau_value
]:
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
tp["medium"], tp["address"]
@@ -212,7 +212,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- self._mau_stats_only = hs.config.mau_stats_only
+ self._mau_stats_only = hs.config.server.mau_stats_only
# Do not add more reserved users than the total allowable number
self.db_pool.new_transaction(
@@ -221,7 +221,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
[],
[],
self._initialise_reserved_users,
- hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
+ hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
)
def _initialise_reserved_users(self, txn, threepids):
@@ -354,3 +354,27 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
await self.upsert_monthly_active_user(user_id)
elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
await self.upsert_monthly_active_user(user_id)
+
+ async def remove_deactivated_user_from_mau_table(self, user_id: str) -> None:
+ """
+ Removes a deactivated user from the monthly active user
+ table and resets affected caches.
+
+ Args:
+ user_id(str): the user_id to remove
+ """
+
+ rows_deleted = await self.db_pool.simple_delete(
+ table="monthly_active_users",
+ keyvalues={"user_id": user_id},
+ desc="simple_delete",
+ )
+
+ if rows_deleted != 0:
+ await self.invalidate_cache_and_stream(
+ "user_last_seen_monthly_active", (user_id,)
+ )
+ await self.invalidate_cache_and_stream("get_monthly_active_count", ())
+ await self.invalidate_cache_and_stream(
+ "get_monthly_active_count_by_service", ()
+ )
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index a7fb8cd848..fc720f5947 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -14,7 +14,7 @@
# limitations under the License.
import abc
import logging
-from typing import List, Tuple, Union
+from typing import Dict, List, Tuple, Union
from synapse.api.errors import NotFoundError, StoreError
from synapse.push.baserules import list_with_base_rules
@@ -101,7 +101,9 @@ class PushRulesWorkerStore(
prefilled_cache=push_rules_prefill,
)
- self._users_new_default_push_rules = hs.config.users_new_default_push_rules
+ self._users_new_default_push_rules = (
+ hs.config.server.users_new_default_push_rules
+ )
@abc.abstractmethod
def get_max_push_rules_stream_id(self):
@@ -137,7 +139,7 @@ class PushRulesWorkerStore(
return _load_rules(rows, enabled_map, use_new_defaults)
@cached(max_entries=5000)
- async def get_push_rules_enabled_for_user(self, user_id):
+ async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]:
results = await self.db_pool.simple_select_list(
table="push_rules_enable",
keyvalues={"user_name": user_id},
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index a93caae8d0..b73ce53c91 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -18,8 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional,
from synapse.push import PusherConfig, ThrottleParams
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
-from synapse.storage.types import Connection
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
@@ -32,7 +31,12 @@ logger = logging.getLogger(__name__)
class PusherWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._pushers_id_gen = StreamIdGenerator(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index c83089ee63..181841ee06 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -26,7 +26,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore
-from synapse.storage.types import Connection, Cursor
+from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID, UserInfo
@@ -207,7 +207,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return False
now = self._clock.time_msec()
- trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
+ trial_duration_ms = self.config.server.mau_trial_days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
return is_trial
@@ -1710,7 +1710,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
We do this by grandfathering in existing user threepids assuming that
they used one of the server configured trusted identity servers.
"""
- id_servers = set(self.config.trusted_third_party_id_servers)
+ id_servers = set(self.config.registration.trusted_third_party_id_servers)
def _bg_user_threepids_grandfather_txn(txn):
sql = """
@@ -1775,10 +1775,17 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
- self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
+ self._ignore_unknown_session_error = (
+ hs.config.server.request_token_inhibit_3pid_errors
+ )
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 118b390e93..d69eaf80ce 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -679,8 +679,8 @@ class RoomWorkerStore(SQLBaseStore):
# policy.
if not ret:
return {
- "min_lifetime": self.config.retention_default_min_lifetime,
- "max_lifetime": self.config.retention_default_max_lifetime,
+ "min_lifetime": self.config.server.retention_default_min_lifetime,
+ "max_lifetime": self.config.server.retention_default_max_lifetime,
}
row = ret[0]
@@ -690,10 +690,10 @@ class RoomWorkerStore(SQLBaseStore):
# The default values will be None if no default policy has been defined, or if one
# of the attributes is missing from the default policy.
if row["min_lifetime"] is None:
- row["min_lifetime"] = self.config.retention_default_min_lifetime
+ row["min_lifetime"] = self.config.server.retention_default_min_lifetime
if row["max_lifetime"] is None:
- row["max_lifetime"] = self.config.retention_default_max_lifetime
+ row["max_lifetime"] = self.config.server.retention_default_max_lifetime
return row
diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py
index a383388757..300a563c9e 100644
--- a/synapse/storage/databases/main/room_batch.py
+++ b/synapse/storage/databases/main/room_batch.py
@@ -18,7 +18,9 @@ from synapse.storage._base import SQLBaseStore
class RoomBatchStore(SQLBaseStore):
- async def get_insertion_event_by_batch_id(self, batch_id: str) -> Optional[str]:
+ async def get_insertion_event_by_batch_id(
+ self, room_id: str, batch_id: str
+ ) -> Optional[str]:
"""Retrieve a insertion event ID.
Args:
@@ -30,7 +32,7 @@ class RoomBatchStore(SQLBaseStore):
"""
return await self.db_pool.simple_select_one_onecol(
table="insertion_events",
- keyvalues={"next_batch_id": batch_id},
+ keyvalues={"room_id": room_id, "next_batch_id": batch_id},
retcol="event_id",
allow_none=True,
)
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 9eb74a81a0..25df8758bd 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -51,7 +51,7 @@ class SearchWorkerStore(SQLBaseStore):
txn:
entries: entries to be added to the table
"""
- if not self.hs.config.enable_search:
+ if not self.hs.config.server.enable_search:
return
if isinstance(self.database_engine, PostgresEngine):
sql = (
@@ -105,7 +105,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- if not hs.config.enable_search:
+ if not hs.config.server.enable_search:
return
self.db_pool.updates.register_background_update_handler(
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 90d65edc42..5c713a732e 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -40,12 +40,10 @@ from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
-
TEMP_TABLE = "_temp_populate_user_directory"
class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
-
# How many records do we calculate before sending it to
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
@@ -230,38 +228,49 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
is_in_room = await self.is_host_joined(room_id, self.server_name)
if is_in_room:
- is_public = await self.is_room_world_readable_or_publicly_joinable(
- room_id
- )
-
users_with_profile = await self.get_users_in_room_with_profiles(room_id)
+ # Throw away users excluded from the directory.
+ users_with_profile = {
+ user_id: profile
+ for user_id, profile in users_with_profile.items()
+ if not self.hs.is_mine_id(user_id)
+ or await self.should_include_local_user_in_dir(user_id)
+ }
- # Update each user in the user directory.
+ # Upsert a user_directory record for each remote user we see.
for user_id, profile in users_with_profile.items():
+ # Local users are processed separately in
+ # `_populate_user_directory_users`; there we can read from
+ # the `profiles` table to ensure we don't leak their per-room
+ # profiles. It also means we write local users to this table
+ # exactly once, rather than once for every room they're in.
+ if self.hs.is_mine_id(user_id):
+ continue
+ # TODO `users_with_profile` above reads from the `user_directory`
+ # table, meaning that `profile` is bespoke to this room.
+ # and this leaks remote users' per-room profiles to the user directory.
await self.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
- to_insert = set()
-
+ # Now update the room sharing tables to include this room.
+ is_public = await self.is_room_world_readable_or_publicly_joinable(
+ room_id
+ )
if is_public:
- for user_id in users_with_profile:
- if self.get_if_app_services_interested_in_user(user_id):
- continue
-
- to_insert.add(user_id)
-
- if to_insert:
- await self.add_users_in_public_rooms(room_id, to_insert)
- to_insert.clear()
+ if users_with_profile:
+ await self.add_users_in_public_rooms(
+ room_id, users_with_profile.keys()
+ )
else:
+ to_insert = set()
for user_id in users_with_profile:
+ # We want the set of pairs (L, M) where L and M are
+ # in `users_with_profile` and L is local.
+ # Do so by looking for the local user L first.
if not self.hs.is_mine_id(user_id):
continue
- if self.get_if_app_services_interested_in_user(user_id):
- continue
-
for other_user_id in users_with_profile:
if user_id == other_user_id:
continue
@@ -349,10 +358,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
for user_id in users_to_work_on:
- profile = await self.get_profileinfo(get_localpart_from_id(user_id))
- await self.update_profile_in_user_dir(
- user_id, profile.display_name, profile.avatar_url
- )
+ if await self.should_include_local_user_in_dir(user_id):
+ profile = await self.get_profileinfo(get_localpart_from_id(user_id))
+ await self.update_profile_in_user_dir(
+ user_id, profile.display_name, profile.avatar_url
+ )
# We've finished processing a user. Delete it from the table.
await self.db_pool.simple_delete_one(
@@ -369,6 +379,24 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return len(users_to_work_on)
+ async def should_include_local_user_in_dir(self, user: str) -> bool:
+ """Certain classes of local user are omitted from the user directory.
+ Is this user one of them?
+ """
+ # App service users aren't usually contactable, so exclude them.
+ if self.get_if_app_services_interested_in_user(user):
+ # TODO we might want to make this configurable for each app service
+ return False
+
+ # Support users are for diagnostics and should not appear in the user directory.
+ if await self.is_support_user(user):
+ return False
+
+ # Deactivated users aren't contactable, so should not appear in the user directory.
+ if await self.get_user_deactivated_status(user):
+ return False
+ return True
+
async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
"""Check if the room is either world_readable or publically joinable"""
@@ -527,7 +555,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
desc="get_user_in_directory",
)
- async def update_user_directory_stream_pos(self, stream_id: int) -> None:
+ async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None:
await self.db_pool.simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
@@ -537,7 +565,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
-
# How many records do we calculate before sending it to
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index f31880b8ec..11ca47ea28 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -366,7 +366,7 @@ def _upgrade_existing_database(
+ "new for the server to understand"
)
- # some of the deltas assume that config.server_name is set correctly, so now
+ # some of the deltas assume that server_name is set correctly, so now
# is a good time to run the sanity check.
if not is_empty and "main" in databases:
from synapse.storage.databases.main import check_database_before_upgrade
@@ -487,6 +487,10 @@ def _upgrade_existing_database(
spec = importlib.util.spec_from_file_location(
module_name, absolute_path
)
+ if spec is None:
+ raise RuntimeError(
+ f"Could not build a module spec for {module_name} at {absolute_path}"
+ )
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 573e05a482..1aee741a8b 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# When updating these values, please leave a short summary of the changes below.
-
-SCHEMA_VERSION = 64
+SCHEMA_VERSION = 64 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@@ -46,7 +44,7 @@ Changes in SCHEMA_VERSION = 64:
"""
-SCHEMA_COMPAT_VERSION = 59
+SCHEMA_COMPAT_VERSION = 60 # 60: "outlier" not in internal_metadata.
"""Limit on how far the synapse codebase can be rolled back without breaking db compat
This value is stored in the database, and checked on startup. If the value in the
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 6f7cbe40f4..852bd79fee 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -16,42 +16,62 @@ import logging
import threading
from collections import OrderedDict
from contextlib import contextmanager
-from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
+from types import TracebackType
+from typing import (
+ AsyncContextManager,
+ ContextManager,
+ Dict,
+ Generator,
+ Generic,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+)
import attr
from sortedcontainers import SortedSet
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import PostgresSequenceGenerator
logger = logging.getLogger(__name__)
+T = TypeVar("T")
+
+
class IdGenerator:
- def __init__(self, db_conn, table, column):
+ def __init__(
+ self,
+ db_conn: LoggingDatabaseConnection,
+ table: str,
+ column: str,
+ ):
self._lock = threading.Lock()
self._next_id = _load_current_id(db_conn, table, column)
- def get_next(self):
+ def get_next(self) -> int:
with self._lock:
self._next_id += 1
return self._next_id
-def _load_current_id(db_conn, table, column, step=1):
- """
-
- Args:
- db_conn (object):
- table (str):
- column (str):
- step (int):
-
- Returns:
- int
- """
+def _load_current_id(
+ db_conn: LoggingDatabaseConnection, table: str, column: str, step: int = 1
+) -> int:
# debug logging for https://github.com/matrix-org/synapse/issues/7968
logger.info("initialising stream generator for %s(%s)", table, column)
cur = db_conn.cursor(txn_name="_load_current_id")
@@ -59,7 +79,9 @@ def _load_current_id(db_conn, table, column, step=1):
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
else:
cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
- (val,) = cur.fetchone()
+ result = cur.fetchone()
+ assert result is not None
+ (val,) = result
cur.close()
current_id = int(val) if val else step
return (max if step > 0 else min)(current_id, step)
@@ -93,16 +115,16 @@ class StreamIdGenerator:
def __init__(
self,
- db_conn,
- table,
- column,
+ db_conn: LoggingDatabaseConnection,
+ table: str,
+ column: str,
extra_tables: Iterable[Tuple[str, str]] = (),
- step=1,
- ):
+ step: int = 1,
+ ) -> None:
assert step != 0
self._lock = threading.Lock()
- self._step = step
- self._current = _load_current_id(db_conn, table, column, step)
+ self._step: int = step
+ self._current: int = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables:
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
@@ -115,7 +137,7 @@ class StreamIdGenerator:
# The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
- def get_next(self):
+ def get_next(self) -> AsyncContextManager[int]:
"""
Usage:
async with stream_id_gen.get_next() as stream_id:
@@ -128,7 +150,7 @@ class StreamIdGenerator:
self._unfinished_ids[next_id] = next_id
@contextmanager
- def manager():
+ def manager() -> Generator[int, None, None]:
try:
yield next_id
finally:
@@ -137,7 +159,7 @@ class StreamIdGenerator:
return _AsyncCtxManagerWrapper(manager())
- def get_next_mult(self, n):
+ def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
"""
Usage:
async with stream_id_gen.get_next(n) as stream_ids:
@@ -155,7 +177,7 @@ class StreamIdGenerator:
self._unfinished_ids[next_id] = next_id
@contextmanager
- def manager():
+ def manager() -> Generator[Sequence[int], None, None]:
try:
yield next_ids
finally:
@@ -215,7 +237,7 @@ class MultiWriterIdGenerator:
def __init__(
self,
- db_conn,
+ db_conn: LoggingDatabaseConnection,
db: DatabasePool,
stream_name: str,
instance_name: str,
@@ -223,7 +245,7 @@ class MultiWriterIdGenerator:
sequence_name: str,
writers: List[str],
positive: bool = True,
- ):
+ ) -> None:
self._db = db
self._stream_name = stream_name
self._instance_name = instance_name
@@ -285,9 +307,9 @@ class MultiWriterIdGenerator:
def _load_current_ids(
self,
- db_conn,
+ db_conn: LoggingDatabaseConnection,
tables: List[Tuple[str, str, str]],
- ):
+ ) -> None:
cur = db_conn.cursor(txn_name="_load_current_ids")
# Load the current positions of all writers for the stream.
@@ -335,7 +357,9 @@ class MultiWriterIdGenerator:
"agg": "MAX" if self._positive else "-MIN",
}
cur.execute(sql)
- (stream_id,) = cur.fetchone()
+ result = cur.fetchone()
+ assert result is not None
+ (stream_id,) = result
max_stream_id = max(max_stream_id, stream_id)
@@ -354,7 +378,7 @@ class MultiWriterIdGenerator:
self._persisted_upto_position = min_stream_id
- rows = []
+ rows: List[Tuple[str, int]] = []
for table, instance_column, id_column in tables:
sql = """
SELECT %(instance)s, %(id)s FROM %(table)s
@@ -367,7 +391,8 @@ class MultiWriterIdGenerator:
}
cur.execute(sql, (min_stream_id * self._return_factor,))
- rows.extend(cur)
+ # Cast safety: this corresponds to the types returned by the query above.
+ rows.extend(cast(Iterable[Tuple[str, int]], cur))
# Sort so that we handle rows in order for each instance.
rows.sort()
@@ -385,13 +410,13 @@ class MultiWriterIdGenerator:
cur.close()
- def _load_next_id_txn(self, txn) -> int:
+ def _load_next_id_txn(self, txn: Cursor) -> int:
return self._sequence_gen.get_next_id_txn(txn)
- def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
+ def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]:
return self._sequence_gen.get_next_mult_txn(txn, n)
- def get_next(self):
+ def get_next(self) -> AsyncContextManager[int]:
"""
Usage:
async with stream_id_gen.get_next() as stream_id:
@@ -403,9 +428,12 @@ class MultiWriterIdGenerator:
if self._writers and self._instance_name not in self._writers:
raise Exception("Tried to allocate stream ID on non-writer")
- return _MultiWriterCtxManager(self)
+ # Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids,
+ # controls the return type. If `None` or omitted, the context manager yields
+ # a single integer stream_id; otherwise it yields a list of stream_ids.
+ return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
- def get_next_mult(self, n: int):
+ def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
"""
Usage:
async with stream_id_gen.get_next_mult(5) as stream_ids:
@@ -417,9 +445,10 @@ class MultiWriterIdGenerator:
if self._writers and self._instance_name not in self._writers:
raise Exception("Tried to allocate stream ID on non-writer")
- return _MultiWriterCtxManager(self, n)
+ # Cast safety: see get_next.
+ return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n))
- def get_next_txn(self, txn: LoggingTransaction):
+ def get_next_txn(self, txn: LoggingTransaction) -> int:
"""
Usage:
@@ -457,7 +486,7 @@ class MultiWriterIdGenerator:
return self._return_factor * next_id
- def _mark_id_as_finished(self, next_id: int):
+ def _mark_id_as_finished(self, next_id: int) -> None:
"""The ID has finished being processed so we should advance the
current position if possible.
"""
@@ -534,7 +563,7 @@ class MultiWriterIdGenerator:
for name, i in self._current_positions.items()
}
- def advance(self, instance_name: str, new_id: int):
+ def advance(self, instance_name: str, new_id: int) -> None:
"""Advance the position of the named writer to the given ID, if greater
than existing entry.
"""
@@ -560,7 +589,7 @@ class MultiWriterIdGenerator:
with self._lock:
return self._return_factor * self._persisted_upto_position
- def _add_persisted_position(self, new_id: int):
+ def _add_persisted_position(self, new_id: int) -> None:
"""Record that we have persisted a position.
This is used to keep the `_current_positions` up to date.
@@ -606,7 +635,7 @@ class MultiWriterIdGenerator:
# do.
break
- def _update_stream_positions_table_txn(self, txn: Cursor):
+ def _update_stream_positions_table_txn(self, txn: Cursor) -> None:
"""Update the `stream_positions` table with newly persisted position."""
if not self._writers:
@@ -628,20 +657,25 @@ class MultiWriterIdGenerator:
txn.execute(sql, (self._stream_name, self._instance_name, pos))
-@attr.s(slots=True)
-class _AsyncCtxManagerWrapper:
+@attr.s(frozen=True, auto_attribs=True)
+class _AsyncCtxManagerWrapper(Generic[T]):
"""Helper class to convert a plain context manager to an async one.
This is mainly useful if you have a plain context manager but the interface
requires an async one.
"""
- inner = attr.ib()
+ inner: ContextManager[T]
- async def __aenter__(self):
+ async def __aenter__(self) -> T:
return self.inner.__enter__()
- async def __aexit__(self, exc_type, exc, tb):
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> Optional[bool]:
return self.inner.__exit__(exc_type, exc, tb)
@@ -671,7 +705,12 @@ class _MultiWriterCtxManager:
else:
return [i * self.id_gen._return_factor for i in self.stream_ids]
- async def __aexit__(self, exc_type, exc, tb):
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> bool:
for i in self.stream_ids:
self.id_gen._mark_id_as_finished(i)
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index bb33e04fb1..75268cbe15 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -81,7 +81,7 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
id_column: str,
stream_name: Optional[str] = None,
positive: bool = True,
- ):
+ ) -> None:
"""Should be called during start up to test that the current value of
the sequence is greater than or equal to the maximum ID in the table.
@@ -122,7 +122,7 @@ class PostgresSequenceGenerator(SequenceGenerator):
id_column: str,
stream_name: Optional[str] = None,
positive: bool = True,
- ):
+ ) -> None:
"""See SequenceGenerator.check_consistency for docstring."""
txn = db_conn.cursor(txn_name="sequence.check_consistency")
@@ -244,7 +244,7 @@ class LocalSequenceGenerator(SequenceGenerator):
id_column: str,
stream_name: Optional[str] = None,
positive: bool = True,
- ):
+ ) -> None:
# There is nothing to do for in memory sequences
pass
|