From c66a06ac6b69b0a03f5c6284ded980399e9df94e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 21 Oct 2019 12:56:42 +0100 Subject: Move storage classes into a main "data store". This is in preparation for having multiple data stores that offer different functionality, e.g. splitting out state or event storage. --- tests/storage/test_appservice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests/storage/test_appservice.py') diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 622b16a071..dfeea24599 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -24,7 +24,7 @@ from twisted.internet import defer from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.config._base import ConfigError -from synapse.storage.appservice import ( +from synapse.storage.data_stores.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) -- cgit 1.5.1 From 9a4fb457cf5918c85068ea249cd2d58b3e2e3cfc Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 6 Dec 2019 13:08:40 +0000 Subject: Change DataStores to accept 'database' param. --- synapse/app/federation_sender.py | 5 +++-- synapse/app/user_dir.py | 5 +++-- synapse/replication/slave/storage/_base.py | 5 +++-- synapse/replication/slave/storage/account_data.py | 5 +++-- synapse/replication/slave/storage/client_ips.py | 5 +++-- synapse/replication/slave/storage/deviceinbox.py | 5 +++-- synapse/replication/slave/storage/devices.py | 5 +++-- synapse/replication/slave/storage/events.py | 5 +++-- synapse/replication/slave/storage/filtering.py | 5 +++-- synapse/replication/slave/storage/groups.py | 5 +++-- synapse/replication/slave/storage/presence.py | 5 +++-- synapse/replication/slave/storage/push_rule.py | 5 +++-- synapse/replication/slave/storage/pushers.py | 5 +++-- synapse/replication/slave/storage/receipts.py | 5 +++-- synapse/replication/slave/storage/room.py | 5 +++-- synapse/storage/_base.py | 2 +- synapse/storage/data_stores/main/__init__.py | 5 +++-- synapse/storage/data_stores/main/account_data.py | 9 +++++---- synapse/storage/data_stores/main/appservice.py | 5 +++-- synapse/storage/data_stores/main/client_ips.py | 9 +++++---- synapse/storage/data_stores/main/deviceinbox.py | 9 +++++---- synapse/storage/data_stores/main/devices.py | 9 +++++---- synapse/storage/data_stores/main/event_federation.py | 5 +++-- synapse/storage/data_stores/main/event_push_actions.py | 9 +++++---- synapse/storage/data_stores/main/events.py | 5 +++-- synapse/storage/data_stores/main/events_bg_updates.py | 5 +++-- synapse/storage/data_stores/main/events_worker.py | 5 +++-- synapse/storage/data_stores/main/media_repository.py | 11 +++++++---- synapse/storage/data_stores/main/monthly_active_users.py | 7 ++++--- synapse/storage/data_stores/main/push_rule.py | 5 +++-- synapse/storage/data_stores/main/receipts.py | 9 +++++---- synapse/storage/data_stores/main/registration.py | 13 +++++++------ synapse/storage/data_stores/main/room.py | 9 +++++---- synapse/storage/data_stores/main/roommember.py | 13 +++++++------ synapse/storage/data_stores/main/search.py | 9 +++++---- synapse/storage/data_stores/main/state.py | 13 +++++++------ synapse/storage/data_stores/main/stats.py | 5 +++-- synapse/storage/data_stores/main/stream.py | 5 +++-- synapse/storage/data_stores/main/transactions.py | 5 +++-- synapse/storage/data_stores/main/user_directory.py | 9 +++++---- tests/storage/test_appservice.py | 5 +++-- 41 files changed, 156 insertions(+), 114 deletions(-) (limited to 'tests/storage/test_appservice.py') diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 448e45e00f..f24920a7d6 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -40,6 +40,7 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.streams._base import ReceiptsStream from synapse.server import HomeServer +from synapse.storage.database import Database from synapse.storage.engines import create_engine from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer @@ -59,8 +60,8 @@ class FederationSenderSlaveStore( SlavedDeviceStore, SlavedPresenceStore, ): - def __init__(self, db_conn, hs): - super(FederationSenderSlaveStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(FederationSenderSlaveStore, self).__init__(database, db_conn, hs) # We pull out the current federation stream position now so that we # always have a known value for the federation position in memory so diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index b6d4481725..c01fb34a9b 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -43,6 +43,7 @@ from synapse.replication.tcp.streams.events import ( from synapse.rest.client.v2_alpha import user_directory from synapse.server import HomeServer from synapse.storage.data_stores.main.user_directory import UserDirectoryStore +from synapse.storage.database import Database from synapse.storage.engines import create_engine from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.httpresourcetree import create_resource_tree @@ -60,8 +61,8 @@ class UserDirectorySlaveStore( UserDirectoryStore, BaseSlavedStore, ): - def __init__(self, db_conn, hs): - super(UserDirectorySlaveStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(UserDirectorySlaveStore, self).__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 6ece1d6745..b91a528245 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -20,6 +20,7 @@ import six from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from ._slaved_id_tracker import SlavedIdTracker @@ -35,8 +36,8 @@ def __func__(inp): class BaseSlavedStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(BaseSlavedStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(BaseSlavedStore, self).__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = SlavedIdTracker( db_conn, "cache_invalidation_stream", "stream_id" diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index bc2f6a12ae..ebe94909cb 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -18,15 +18,16 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore from synapse.storage.data_stores.main.tags import TagsWorkerStore +from synapse.storage.database import Database class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._account_data_id_gen = SlavedIdTracker( db_conn, "account_data_max_stream_id", "stream_id" ) - super(SlavedAccountDataStore, self).__init__(db_conn, hs) + super(SlavedAccountDataStore, self).__init__(database, db_conn, hs) def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index b4f58cea19..fbf996e33a 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -14,6 +14,7 @@ # limitations under the License. from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY +from synapse.storage.database import Database from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache @@ -21,8 +22,8 @@ from ._base import BaseSlavedStore class SlavedClientIpStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedClientIpStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedClientIpStore, self).__init__(database, db_conn, hs) self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 9fb6c5c6ff..0c237c6e0f 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -16,13 +16,14 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore +from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( db_conn, "device_max_stream_id", "stream_id" ) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index de50748c30..dc625e0d7a 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -18,12 +18,13 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage.data_stores.main.devices import DeviceWorkerStore from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedDeviceStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedDeviceStore, self).__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index d0a0eaf75b..29f35b9915 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -31,6 +31,7 @@ from synapse.storage.data_stores.main.signatures import SignatureWorkerStore from synapse.storage.data_stores.main.state import StateGroupWorkerStore from synapse.storage.data_stores.main.stream import StreamWorkerStore from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker @@ -59,13 +60,13 @@ class SlavedEventStore( RelationsWorkerStore, BaseSlavedStore, ): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") self._backfill_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", step=-1 ) - super(SlavedEventStore, self).__init__(db_conn, hs) + super(SlavedEventStore, self).__init__(database, db_conn, hs) # Cached functions can't be accessed through a class instance so we need # to reach inside the __dict__ to extract them. diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index 5c84ebd125..bcb0688954 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -14,13 +14,14 @@ # limitations under the License. from synapse.storage.data_stores.main.filtering import FilteringStore +from synapse.storage.database import Database from ._base import BaseSlavedStore class SlavedFilteringStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedFilteringStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedFilteringStore, self).__init__(database, db_conn, hs) # Filters are immutable so this cache doesn't need to be expired get_user_filter = FilteringStore.__dict__["get_user_filter"] diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 28a46edd28..69a4ae42f9 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -14,6 +14,7 @@ # limitations under the License. from synapse.storage import DataStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore, __func__ @@ -21,8 +22,8 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedGroupServerStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedGroupServerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedGroupServerStore, self).__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 747ced0c84..f552e7c972 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -15,6 +15,7 @@ from synapse.storage import DataStore from synapse.storage.data_stores.main.presence import PresenceStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore, __func__ @@ -22,8 +23,8 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedPresenceStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedPresenceStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedPresenceStore, self).__init__(database, db_conn, hs) self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") self._presence_on_startup = self._get_active_presence(db_conn) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 3655f05e54..eebd5a1fb6 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -15,17 +15,18 @@ # limitations under the License. from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore +from synapse.storage.database import Database from ._slaved_id_tracker import SlavedIdTracker from .events import SlavedEventStore class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id" ) - super(SlavedPushRuleStore, self).__init__(db_conn, hs) + super(SlavedPushRuleStore, self).__init__(database, db_conn, hs) def get_push_rules_stream_token(self): return ( diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index b4331d0799..f22c2d44a3 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -15,14 +15,15 @@ # limitations under the License. from synapse.storage.data_stores.main.pusher import PusherWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedPusherStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedPusherStore, self).__init__(database, db_conn, hs) self._pushers_id_gen = SlavedIdTracker( db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 43d823c601..d40dc6e1f5 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -15,6 +15,7 @@ # limitations under the License. from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker @@ -29,14 +30,14 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = SlavedIdTracker( db_conn, "receipts_linearized", "stream_id" ) - super(SlavedReceiptsStore, self).__init__(db_conn, hs) + super(SlavedReceiptsStore, self).__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index d9ad386b28..3a20f45316 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -14,14 +14,15 @@ # limitations under the License. from synapse.storage.data_stores.main.room import RoomWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class RoomStore(RoomWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(RoomStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomStore, self).__init__(database, db_conn, hs) self._public_room_id_gen = SlavedIdTracker( db_conn, "public_room_list_stream", "stream_id" ) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b7e27d4e97..f9e7f9a71e 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -37,7 +37,7 @@ class SQLBaseStore(object): per data store (and not one per physical database). """ - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self.database_engine = hs.database_engine diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index 6adb8adb04..7f5fd81bcf 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -20,6 +20,7 @@ import logging import time from synapse.api.constants import PresenceState +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( ChainedIdGenerator, @@ -111,7 +112,7 @@ class DataStore( RelationsStore, CacheInvalidationStore, ): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self.database_engine = hs.database_engine @@ -169,7 +170,7 @@ class DataStore( else: self._cache_id_gen = None - super(DataStore, self).__init__(db_conn, hs) + super(DataStore, self).__init__(database, db_conn, hs) self._presence_on_startup = self._get_active_presence(db_conn) diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index a96fe9485c..44d20c19bf 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -22,6 +22,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -38,13 +39,13 @@ class AccountDataWorkerStore(SQLBaseStore): # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max ) - super(AccountDataWorkerStore, self).__init__(db_conn, hs) + super(AccountDataWorkerStore, self).__init__(database, db_conn, hs) @abc.abstractmethod def get_max_account_data_stream_id(self): @@ -270,12 +271,12 @@ class AccountDataWorkerStore(SQLBaseStore): class AccountDataStore(AccountDataWorkerStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._account_data_id_gen = StreamIdGenerator( db_conn, "account_data_max_stream_id", "stream_id" ) - super(AccountDataStore, self).__init__(db_conn, hs) + super(AccountDataStore, self).__init__(database, db_conn, hs) def get_max_account_data_stream_id(self): """Get the current max stream id for the private user data stream diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py index 6b2e12719c..b2f39649fd 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -24,6 +24,7 @@ from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database logger = logging.getLogger(__name__) @@ -48,13 +49,13 @@ def _make_exclusive_regex(services_cache): class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.services_cache = load_appservices( hs.hostname, hs.config.app_service_config_files ) self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) - super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs) + super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs) def get_app_services(self): return self.services_cache diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 7b470a58f1..320c5b0f07 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -21,6 +21,7 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache @@ -33,8 +34,8 @@ LAST_SEEN_GRANULARITY = 120 * 1000 class ClientIpBackgroundUpdateStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_index_update( "user_ips_device_index", @@ -363,13 +364,13 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): class ClientIpStore(ClientIpBackgroundUpdateStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR ) - super(ClientIpStore, self).__init__(db_conn, hs) + super(ClientIpStore, self).__init__(database, db_conn, hs) self.user_ips_max_age = hs.config.user_ips_max_age diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index 3c9f09301a..85cfa16850 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -21,6 +21,7 @@ from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache logger = logging.getLogger(__name__) @@ -210,8 +211,8 @@ class DeviceInboxWorkerStore(SQLBaseStore): class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, db_conn, hs): - super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_index_update( "device_inbox_stream_index", @@ -241,8 +242,8 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, db_conn, hs): - super(DeviceInboxStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceInboxStore, self).__init__(database, db_conn, hs) # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 91ddaf137e..9a828231c4 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -31,6 +31,7 @@ from synapse.logging.opentracing import ( ) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.types import get_verify_key_from_cross_signing_key from synapse.util import batch_iter from synapse.util.caches.descriptors import ( @@ -642,8 +643,8 @@ class DeviceWorkerStore(SQLBaseStore): class DeviceBackgroundUpdateStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_index_update( "device_lists_stream_idx", @@ -692,8 +693,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(DeviceStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceStore, self).__init__(database, db_conn, hs) # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 31d2e8eb28..1f517e8fad 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -28,6 +28,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.signatures import SignatureWorkerStore +from synapse.storage.database import Database from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -491,8 +492,8 @@ class EventFederationStore(EventFederationWorkerStore): EVENT_AUTH_STATE_ONLY = "event_auth_state_only" - def __init__(self, db_conn, hs): - super(EventFederationStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventFederationStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_update_handler( self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index eec054cd48..9988a6d3fc 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -24,6 +24,7 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore +from synapse.storage.database import Database from synapse.util.caches.descriptors import cachedInlineCallbacks logger = logging.getLogger(__name__) @@ -68,8 +69,8 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(EventPushActionsWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs) # These get correctly set by _find_stream_orderings_for_times_txn self.stream_ordering_month_ago = None @@ -611,8 +612,8 @@ class EventPushActionsWorkerStore(SQLBaseStore): class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - def __init__(self, db_conn, hs): - super(EventPushActionsStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventPushActionsStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_index_update( self.EPA_HIGHLIGHT_INDEX, diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index d644c82784..da1529f6ea 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -41,6 +41,7 @@ from synapse.storage._base import make_in_list_sql_clause from synapse.storage.data_stores.main.event_federation import EventFederationStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.state import StateGroupWorkerStore +from synapse.storage.database import Database from synapse.types import RoomStreamToken, get_domain_from_id from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -95,8 +96,8 @@ def _retry_on_integrity_error(func): class EventsStore( StateGroupWorkerStore, EventFederationStore, EventsWorkerStore, ): - def __init__(self, db_conn, hs): - super(EventsStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventsStore, self).__init__(database, db_conn, hs) # Collect metrics on the number of forward extremities that exist. # Counter of number of extremities to count diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index cb1fc30c31..efee17b929 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -23,6 +23,7 @@ from twisted.internet import defer from synapse.api.constants import EventContentFields from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database logger = logging.getLogger(__name__) @@ -33,8 +34,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" - def __init__(self, db_conn, hs): - super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index e041fc5eac..9ee117ce0f 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -33,6 +33,7 @@ from synapse.events.utils import prune_event from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.types import get_domain_from_id from synapse.util import batch_iter from synapse.util.caches.descriptors import Cache @@ -55,8 +56,8 @@ _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) class EventsWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(EventsWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventsWorkerStore, self).__init__(database, db_conn, hs) self._get_event_cache = Cache( "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py index 03c9c6f8ae..80ca36dedf 100644 --- a/synapse/storage/data_stores/main/media_repository.py +++ b/synapse/storage/data_stores/main/media_repository.py @@ -13,11 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(MediaRepositoryBackgroundUpdateStore, self).__init__( + database, db_conn, hs + ) self.db.updates.register_background_index_update( update_name="local_media_repository_url_idx", @@ -31,8 +34,8 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): """Persistence for attachments and avatars""" - def __init__(self, db_conn, hs): - super(MediaRepositoryStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(MediaRepositoryStore, self).__init__(database, db_conn, hs) def get_local_media(self, media_id): """Get the metadata for a local piece of media diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index 34bf3a1880..27158534cb 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -17,6 +17,7 @@ import logging from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -27,13 +28,13 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000 class MonthlyActiveUsersStore(SQLBaseStore): - def __init__(self, dbconn, hs): - super(MonthlyActiveUsersStore, self).__init__(None, hs) + def __init__(self, database: Database, db_conn, hs): + super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) self._clock = hs.get_clock() self.hs = hs # Do not add more reserved users than the total allowable number self.db.new_transaction( - dbconn, + db_conn, "initialise_mau_threepids", [], [], diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index de682cc63a..5ba13aa973 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -27,6 +27,7 @@ from synapse.storage.data_stores.main.appservice import ApplicationServiceWorker from synapse.storage.data_stores.main.pusher import PusherWorkerStore from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore +from synapse.storage.database import Database from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -72,8 +73,8 @@ class PushRulesWorkerStore( # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): - super(PushRulesWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) push_rules_prefill, push_rules_id = self.db.get_cache_dict( db_conn, diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py index ac2d45bd5c..96e54d145e 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py @@ -22,6 +22,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -38,8 +39,8 @@ class ReceiptsWorkerStore(SQLBaseStore): # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): - super(ReceiptsWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() @@ -315,14 +316,14 @@ class ReceiptsWorkerStore(SQLBaseStore): class ReceiptsStore(ReceiptsWorkerStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" ) - super(ReceiptsStore, self).__init__(db_conn, hs) + super(ReceiptsStore, self).__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 1ef143c6d8..5e8ecac0ea 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -27,6 +27,7 @@ from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.types import UserID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -36,8 +37,8 @@ logger = logging.getLogger(__name__) class RegistrationWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(RegistrationWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RegistrationWorkerStore, self).__init__(database, db_conn, hs) self.config = hs.config self.clock = hs.get_clock() @@ -794,8 +795,8 @@ class RegistrationWorkerStore(SQLBaseStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): - def __init__(self, db_conn, hs): - super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.clock = hs.get_clock() self.config = hs.config @@ -920,8 +921,8 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): class RegistrationStore(RegistrationBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(RegistrationStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RegistrationStore, self).__init__(database, db_conn, hs) self._account_validity = hs.config.account_validity diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index da42dae243..0148be20d3 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -29,6 +29,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.search import SearchStore +from synapse.storage.database import Database from synapse.types import ThirdPartyInstanceID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -361,8 +362,8 @@ class RoomWorkerStore(SQLBaseStore): class RoomBackgroundUpdateStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(RoomBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.config = hs.config @@ -440,8 +441,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore): class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): - def __init__(self, db_conn, hs): - super(RoomStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomStore, self).__init__(database, db_conn, hs) self.config = hs.config diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 929f6b0d39..92e3b9c512 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -32,6 +32,7 @@ from synapse.storage._base import ( make_in_list_sql_clause, ) from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database from synapse.storage.engines import Sqlite3Engine from synapse.storage.roommember import ( GetRoomsForUserWithStreamOrdering, @@ -54,8 +55,8 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" class RoomMemberWorkerStore(EventsWorkerStore): - def __init__(self, db_conn, hs): - super(RoomMemberWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs) # Is the current_state_events.membership up to date? Or is the # background update still running? @@ -835,8 +836,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): class RoomMemberBackgroundUpdateStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile ) @@ -991,8 +992,8 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(RoomMemberStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomMemberStore, self).__init__(database, db_conn, hs) def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index ffa1817e64..4eec2fae5e 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -25,6 +25,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine logger = logging.getLogger(__name__) @@ -42,8 +43,8 @@ class SearchBackgroundUpdateStore(SQLBaseStore): EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" - def __init__(self, db_conn, hs): - super(SearchBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs) if not hs.config.enable_search: return @@ -342,8 +343,8 @@ class SearchBackgroundUpdateStore(SQLBaseStore): class SearchStore(SearchBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(SearchStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SearchStore, self).__init__(database, db_conn, hs) def store_event_search_txn(self, txn, event, key, value): """Add event to the search table diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 7d5a9f8128..9ef7b48c74 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -28,6 +28,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter from synapse.util.caches import get_cache_factor_for, intern_string @@ -213,8 +214,8 @@ class StateGroupWorkerStore( STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - def __init__(self, db_conn, hs): - super(StateGroupWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) # Originally the state store used a single DictionaryCache to cache the # event IDs for the state types in a given state group to avoid hammering @@ -1029,8 +1030,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" - def __init__(self, db_conn, hs): - super(StateBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_update_handler( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self._background_deduplicate_state, @@ -1245,8 +1246,8 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): * `state_groups_state`: Maps state group to state events. """ - def __init__(self, db_conn, hs): - super(StateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StateStore, self).__init__(database, db_conn, hs) def _store_event_state_mappings_txn( self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]] diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 40579bf965..7bc186e9a1 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -22,6 +22,7 @@ from twisted.internet.defer import DeferredLock from synapse.api.constants import EventTypes, Membership from synapse.storage.data_stores.main.state_deltas import StateDeltasStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.util.caches.descriptors import cached @@ -58,8 +59,8 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")} class StatsStore(StateDeltasStore): - def __init__(self, db_conn, hs): - super(StatsStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StatsStore, self).__init__(database, db_conn, hs) self.server_name = hs.hostname self.clock = self.hs.get_clock() diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 2ff8c57109..140da8dad6 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -47,6 +47,7 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -251,8 +252,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): - super(StreamWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StreamWorkerStore, self).__init__(database, db_conn, hs) events_max = self.get_room_max_stream_ordering() event_cache_prefill, min_event_val = self.db.get_cache_dict( diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py index c0d155a43c..5b07c2fbc0 100644 --- a/synapse/storage/data_stores/main/transactions.py +++ b/synapse/storage/data_stores/main/transactions.py @@ -24,6 +24,7 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache # py2 sqlite has buffer hardcoded as only binary type, so we must use it, @@ -52,8 +53,8 @@ class TransactionStore(SQLBaseStore): """A collection of queries for handling PDUs. """ - def __init__(self, db_conn, hs): - super(TransactionStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(TransactionStore, self).__init__(database, db_conn, hs) self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py index 62ffb34b29..90c180ec6d 100644 --- a/synapse/storage/data_stores/main/user_directory.py +++ b/synapse/storage/data_stores/main/user_directory.py @@ -21,6 +21,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules from synapse.storage.data_stores.main.state import StateFilter from synapse.storage.data_stores.main.state_deltas import StateDeltasStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import get_domain_from_id, get_localpart_from_id from synapse.util.caches.descriptors import cached @@ -37,8 +38,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 - def __init__(self, db_conn, hs): - super(UserDirectoryBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -549,8 +550,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 - def __init__(self, db_conn, hs): - super(UserDirectoryStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(UserDirectoryStore, self).__init__(database, db_conn, hs) def remove_from_user_dir(self, user_id): def _remove_from_user_dir_txn(txn): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index dfeea24599..1679112d82 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -28,6 +28,7 @@ from synapse.storage.data_stores.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) +from synapse.storage.database import Database from tests import unittest from tests.utils import setup_test_homeserver @@ -382,8 +383,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): - def __init__(self, db_conn, hs): - super(TestTransactionStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(TestTransactionStore, self).__init__(database, db_conn, hs) class ApplicationServiceStoreConfigTestCase(unittest.TestCase): -- cgit 1.5.1 From 852f80d8a697c9d556b1b2641a2e4c1797cbbb46 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 6 Dec 2019 16:02:50 +0000 Subject: Fixup tests --- tests/replication/slave/storage/_base.py | 5 ++++- tests/storage/test_appservice.py | 12 +++++++----- tests/storage/test_base.py | 3 ++- tests/storage/test_profile.py | 3 +-- tests/storage/test_user_directory.py | 4 +--- tests/test_federation.py | 16 +++++----------- 6 files changed, 20 insertions(+), 23 deletions(-) (limited to 'tests/storage/test_appservice.py') diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index e7472e3a93..3dae83c543 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -20,6 +20,7 @@ from synapse.replication.tcp.client import ( ReplicationClientHandler, ) from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.storage.database import Database from tests import unittest from tests.server import FakeTransport @@ -42,7 +43,9 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): self.master_store = self.hs.get_datastore() self.storage = hs.get_storage() - self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) + self.slaved_store = self.STORE_TYPE( + Database(hs), self.hs.get_db_conn(), self.hs + ) self.event_id = 0 server_factory = ReplicationStreamProtocolFactory(self.hs) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 1679112d82..2e521e9ab7 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -55,7 +55,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - self.store = ApplicationServiceStore(hs.get_db_conn(), hs) + database = Database(hs) + self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs) def tearDown(self): # TODO: suboptimal that we need to create files for tests! @@ -124,7 +125,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.as_yaml_files = [] - self.store = TestTransactionStore(hs.get_db_conn(), hs) + database = Database(hs) + self.store = TestTransactionStore(database, hs.get_db_conn(), hs) def _add_service(self, url, as_token, id): as_yaml = dict( @@ -417,7 +419,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.event_cache_size = 1 hs.config.password_providers = [] - ApplicationServiceStore(hs.get_db_conn(), hs) + ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) @defer.inlineCallbacks def test_duplicate_ids(self): @@ -433,7 +435,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(hs.get_db_conn(), hs) + ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) e = cm.exception self.assertIn(f1, str(e)) @@ -454,7 +456,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(hs.get_db_conn(), hs) + ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) e = cm.exception self.assertIn(f1, str(e)) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 7915d48a9e..537cfe9f64 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -21,6 +21,7 @@ from mock import Mock from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.storage.engines import create_engine from tests import unittest @@ -59,7 +60,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): "test", db_pool=self.db_pool, config=config, database_engine=fake_engine ) - self.datastore = SQLBaseStore(None, hs) + self.datastore = SQLBaseStore(Database(hs), None, hs) @defer.inlineCallbacks def test_insert_1col(self): diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 24c7fe16c3..9b6f7211ae 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -16,7 +16,6 @@ from twisted.internet import defer -from synapse.storage.data_stores.main.profile import ProfileStore from synapse.types import UserID from tests import unittest @@ -28,7 +27,7 @@ class ProfileStoreTestCase(unittest.TestCase): def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.store = ProfileStore(hs.get_db_conn(), hs) + self.store = hs.get_datastore() self.u_frank = UserID.from_string("@frank:test") diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 7eea57c0e2..6a545d2eb0 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -15,8 +15,6 @@ from twisted.internet import defer -from synapse.storage.data_stores.main.user_directory import UserDirectoryStore - from tests import unittest from tests.utils import setup_test_homeserver @@ -29,7 +27,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield setup_test_homeserver(self.addCleanup) - self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs) + self.store = self.hs.get_datastore() # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice. diff --git a/tests/test_federation.py b/tests/test_federation.py index 7d82b58466..ad165d7295 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -33,6 +33,8 @@ class MessageAcceptTests(unittest.TestCase): self.reactor.advance(0.1) self.room_id = self.successResultOf(room)["room_id"] + self.store = self.homeserver.get_datastore() + # Figure out what the most recent event is most_recent = self.successResultOf( maybeDeferred( @@ -77,10 +79,7 @@ class MessageAcceptTests(unittest.TestCase): # Make sure we actually joined the room self.assertEqual( self.successResultOf( - maybeDeferred( - self.homeserver.get_datastore().get_latest_event_ids_in_room, - self.room_id, - ) + maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) )[0], "$join:test.serv", ) @@ -100,10 +99,7 @@ class MessageAcceptTests(unittest.TestCase): # Figure out what the most recent event is most_recent = self.successResultOf( - maybeDeferred( - self.homeserver.get_datastore().get_latest_event_ids_in_room, - self.room_id, - ) + maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) )[0] # Now lie about an event @@ -141,7 +137,5 @@ class MessageAcceptTests(unittest.TestCase): ) # Make sure the invalid event isn't there - extrem = maybeDeferred( - self.homeserver.get_datastore().get_latest_event_ids_in_room, self.room_id - ) + extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") -- cgit 1.5.1 From 2284eb3a533a2df04784df08da28e67d6588a5ea Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 18 Dec 2019 10:45:12 +0000 Subject: Add database config class (#6513) This encapsulates config for a given database and is the way to get new connections. --- changelog.d/6513.misc | 1 + scripts-dev/update_database | 9 +-- scripts/synapse_port_db | 58 ++++++++----------- synapse/config/database.py | 78 ++++++++++++++++++++------ synapse/handlers/presence.py | 2 +- synapse/server.py | 41 ++------------ synapse/storage/_base.py | 2 +- synapse/storage/data_stores/__init__.py | 40 ++++++++++--- synapse/storage/data_stores/main/client_ips.py | 2 +- synapse/storage/database.py | 45 ++++++++++++++- synapse/storage/engines/sqlite.py | 16 +++++- synapse/storage/prepare_database.py | 7 +-- tests/handlers/test_typing.py | 39 ++++++------- tests/replication/slave/storage/_base.py | 6 +- tests/server.py | 55 +++++++++--------- tests/storage/test_appservice.py | 37 ++++++++---- tests/storage/test_base.py | 14 +++-- tests/storage/test_registration.py | 1 - tests/utils.py | 43 +++++--------- 19 files changed, 287 insertions(+), 209 deletions(-) create mode 100644 changelog.d/6513.misc (limited to 'tests/storage/test_appservice.py') diff --git a/changelog.d/6513.misc b/changelog.d/6513.misc new file mode 100644 index 0000000000..36700f5657 --- /dev/null +++ b/changelog.d/6513.misc @@ -0,0 +1 @@ +Remove all assumptions of there being a single phyiscal DB apart from the `synapse.config`. diff --git a/scripts-dev/update_database b/scripts-dev/update_database index 23017c21f8..1d62f0403a 100755 --- a/scripts-dev/update_database +++ b/scripts-dev/update_database @@ -26,7 +26,6 @@ from synapse.config.homeserver import HomeServerConfig from synapse.metrics.background_process_metrics import run_as_background_process from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.storage.prepare_database import prepare_database logger = logging.getLogger("update_database") @@ -77,12 +76,8 @@ if __name__ == "__main__": # Instantiate and initialise the homeserver object. hs = MockHomeserver(config) - db_conn = hs.get_db_conn() - # Update the database to the latest schema. - prepare_database(db_conn, hs.database_engine, config=config) - db_conn.commit() - - # setup instantiates the store within the homeserver object. + # Setup instantiates the store within the homeserver object and updates the + # DB. hs.setup() store = hs.get_datastore() diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index e393a9b2f7..5b5368988c 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -30,6 +30,7 @@ import yaml from twisted.enterprise import adbapi from twisted.internet import defer, reactor +from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.logging.context import PreserveLoggingContext from synapse.storage._base import LoggingTransaction @@ -55,7 +56,7 @@ from synapse.storage.data_stores.main.stats import StatsStore from synapse.storage.data_stores.main.user_directory import ( UserDirectoryBackgroundUpdateStore, ) -from synapse.storage.database import Database +from synapse.storage.database import Database, make_conn from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database from synapse.util import Clock @@ -165,23 +166,17 @@ class Store( class MockHomeserver: - def __init__(self, config, database_engine, db_conn, db_pool): - self.database_engine = database_engine - self.db_conn = db_conn - self.db_pool = db_pool + def __init__(self, config): self.clock = Clock(reactor) self.config = config self.hostname = config.server_name - def get_db_conn(self): - return self.db_conn - - def get_db_pool(self): - return self.db_pool - def get_clock(self): return self.clock + def get_reactor(self): + return reactor + class Porter(object): def __init__(self, **kwargs): @@ -445,45 +440,36 @@ class Porter(object): else: return - def setup_db(self, db_config, database_engine): - db_conn = database_engine.module.connect( - **{ - k: v - for k, v in db_config.get("args", {}).items() - if not k.startswith("cp_") - } - ) - - prepare_database(db_conn, database_engine, config=None) + def setup_db(self, db_config: DatabaseConnectionConfig, engine): + db_conn = make_conn(db_config, engine) + prepare_database(db_conn, engine, config=None) db_conn.commit() return db_conn @defer.inlineCallbacks - def build_db_store(self, config): + def build_db_store(self, db_config: DatabaseConnectionConfig): """Builds and returns a database store using the provided configuration. Args: - config: The database configuration, i.e. a dict following the structure of - the "database" section of Synapse's configuration file. + config: The database configuration Returns: The built Store object. """ - engine = create_engine(config) - - self.progress.set_state("Preparing %s" % config["name"]) - conn = self.setup_db(config, engine) + self.progress.set_state("Preparing %s" % db_config.config["name"]) - db_pool = adbapi.ConnectionPool(config["name"], **config["args"]) + engine = create_engine(db_config.config) + conn = self.setup_db(db_config, engine) - hs = MockHomeserver(self.hs_config, engine, conn, db_pool) + hs = MockHomeserver(self.hs_config) - store = Store(Database(hs), conn, hs) + store = Store(Database(hs, db_config, engine), conn, hs) yield store.db.runInteraction( - "%s_engine.check_database" % config["name"], engine.check_database, + "%s_engine.check_database" % db_config.config["name"], + engine.check_database, ) return store @@ -509,7 +495,11 @@ class Porter(object): @defer.inlineCallbacks def run(self): try: - self.sqlite_store = yield self.build_db_store(self.sqlite_config) + self.sqlite_store = yield self.build_db_store( + DatabaseConnectionConfig( + "master", self.sqlite_config, data_stores=["main"] + ) + ) # Check if all background updates are done, abort if not. updates_complete = ( @@ -524,7 +514,7 @@ class Porter(object): defer.returnValue(None) self.postgres_store = yield self.build_db_store( - self.hs_config.database_config + self.hs_config.get_single_database() ) yield self.run_background_updates_on_postgres() diff --git a/synapse/config/database.py b/synapse/config/database.py index 0e2509f0b1..5f2f3c7cfd 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -12,12 +12,43 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from textwrap import indent +from typing import List import yaml -from ._base import Config +from synapse.config._base import Config, ConfigError + +logger = logging.getLogger(__name__) + + +class DatabaseConnectionConfig: + """Contains the connection config for a particular database. + + Args: + name: A label for the database, used for logging. + db_config: The config for a particular database, as per `database` + section of main config. Has two fields: `name` for database + module name, and `args` for the args to give to the database + connector. + data_stores: The list of data stores that should be provisioned on the + database. + """ + + def __init__(self, name: str, db_config: dict, data_stores: List[str]): + if db_config["name"] not in ("sqlite3", "psycopg2"): + raise ConfigError("Unsupported database type %r" % (db_config["name"],)) + + if db_config["name"] == "sqlite3": + db_config.setdefault("args", {}).update( + {"cp_min": 1, "cp_max": 1, "check_same_thread": False} + ) + + self.name = name + self.config = db_config + self.data_stores = data_stores class DatabaseConfig(Config): @@ -26,20 +57,14 @@ class DatabaseConfig(Config): def read_config(self, config, **kwargs): self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) - self.database_config = config.get("database") + database_config = config.get("database") - if self.database_config is None: - self.database_config = {"name": "sqlite3", "args": {}} + if database_config is None: + database_config = {"name": "sqlite3", "args": {}} - name = self.database_config.get("name", None) - if name == "psycopg2": - pass - elif name == "sqlite3": - self.database_config.setdefault("args", {}).update( - {"cp_min": 1, "cp_max": 1, "check_same_thread": False} - ) - else: - raise RuntimeError("Unsupported database type '%s'" % (name,)) + self.databases = [ + DatabaseConnectionConfig("master", database_config, data_stores=["main"]) + ] self.set_databasepath(config.get("database_path")) @@ -76,11 +101,24 @@ class DatabaseConfig(Config): self.set_databasepath(args.database_path) def set_databasepath(self, database_path): + if database_path is None: + return + if database_path != ":memory:": database_path = self.abspath(database_path) - if self.database_config.get("name", None) == "sqlite3": - if database_path is not None: - self.database_config["args"]["database"] = database_path + + # We only support setting a database path if we have a single sqlite3 + # database. + if len(self.databases) != 1: + raise ConfigError("Cannot specify 'database_path' with multiple databases") + + database = self.get_single_database() + if database.config["name"] != "sqlite3": + # We don't raise here as we haven't done so before for this case. + logger.warn("Ignoring 'database_path' for non-sqlite3 database") + return + + database.config["args"]["database"] = database_path @staticmethod def add_arguments(parser): @@ -91,3 +129,11 @@ class DatabaseConfig(Config): metavar="SQLITE_DATABASE_PATH", help="The path to a sqlite database to use.", ) + + def get_single_database(self) -> DatabaseConnectionConfig: + """Returns the database if there is only one, useful for e.g. tests + """ + if len(self.databases) != 1: + raise Exception("More than one database exists") + + return self.databases[0] diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index eda15bc623..240c4add12 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -230,7 +230,7 @@ class PresenceHandler(object): is some spurious presence changes that will self-correct. """ # If the DB pool has already terminated, don't try updating - if not self.hs.get_db_pool().running: + if not self.store.database.is_running(): return logger.info( diff --git a/synapse/server.py b/synapse/server.py index 5021068ce0..7926867b77 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -25,7 +25,6 @@ import abc import logging import os -from twisted.enterprise import adbapi from twisted.mail.smtp import sendmail from twisted.web.client import BrowserLikePolicyForHTTPS @@ -98,7 +97,6 @@ from synapse.server_notices.worker_server_notices_sender import ( ) from synapse.state import StateHandler, StateResolutionHandler from synapse.storage import DataStores, Storage -from synapse.storage.engines import create_engine from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor @@ -134,7 +132,6 @@ class HomeServer(object): DEPENDENCIES = [ "http_client", - "db_pool", "federation_client", "federation_server", "handlers", @@ -233,12 +230,6 @@ class HomeServer(object): self.admin_redaction_ratelimiter = Ratelimiter() self.registration_ratelimiter = Ratelimiter() - self.database_engine = create_engine(config.database_config) - config.database_config.setdefault("args", {})[ - "cp_openfun" - ] = self.database_engine.on_new_connection - self.db_config = config.database_config - self.datastores = None # Other kwargs are explicit dependencies @@ -247,10 +238,8 @@ class HomeServer(object): def setup(self): logger.info("Setting up.") - with self.get_db_conn() as conn: - self.datastores = DataStores(self.DATASTORE_CLASS, conn, self) - conn.commit() self.start_time = int(self.get_clock().time()) + self.datastores = DataStores(self.DATASTORE_CLASS, self) logger.info("Finished setting up.") def setup_master(self): @@ -284,6 +273,9 @@ class HomeServer(object): def get_datastore(self): return self.datastores.main + def get_datastores(self): + return self.datastores + def get_config(self): return self.config @@ -433,31 +425,6 @@ class HomeServer(object): ) return MatrixFederationHttpClient(self, tls_client_options_factory) - def build_db_pool(self): - name = self.db_config["name"] - - return adbapi.ConnectionPool( - name, cp_reactor=self.get_reactor(), **self.db_config.get("args", {}) - ) - - def get_db_conn(self, run_new_connection=True): - """Makes a new connection to the database, skipping the db pool - - Returns: - Connection: a connection object implementing the PEP-249 spec - """ - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v - for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - def build_media_repository_resource(self): # build the media repo resource. This indirects through the HomeServer # to ensure that we only have a single instance of diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b7637b5dc0..88546ad614 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -40,7 +40,7 @@ class SQLBaseStore(object): def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() - self.database_engine = hs.database_engine + self.database_engine = database.engine self.db = database self.rand = random.SystemRandom() diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index cafedd5c0d..0983e059c0 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -13,24 +13,50 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.database import Database +import logging + +from synapse.storage.database import Database, make_conn +from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database +logger = logging.getLogger(__name__) + class DataStores(object): """The various data stores. These are low level interfaces to physical databases. + + Attributes: + main (DataStore) """ - def __init__(self, main_store_class, db_conn, hs): + def __init__(self, main_store_class, hs): # Note we pass in the main store class here as workers use a different main # store. - database = Database(hs) - # Check that db is correctly configured. - database.engine.check_database(db_conn.cursor()) + self.databases = [] + + for database_config in hs.config.database.databases: + db_name = database_config.name + engine = create_engine(database_config.config) + + with make_conn(database_config, engine) as db_conn: + logger.info("Preparing database %r...", db_name) + + engine.check_database(db_conn.cursor()) + prepare_database( + db_conn, engine, hs.config, data_stores=database_config.data_stores, + ) + + database = Database(hs, database_config, engine) + + if "main" in database_config.data_stores: + logger.info("Starting 'main' data store") + self.main = main_store_class(database, db_conn, hs) + + db_conn.commit() - prepare_database(db_conn, database.engine, config=hs.config) + self.databases.append(database) - self.main = main_store_class(database, db_conn, hs) + logger.info("Database %r prepared", db_name) diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index add3037b69..13f4c9c72e 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -412,7 +412,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): def _update_client_ips_batch(self): # If the DB pool has already terminated, don't try updating - if not self.hs.get_db_pool().running: + if not self.db.is_running(): return to_update = self._batch_row_update diff --git a/synapse/storage/database.py b/synapse/storage/database.py index ec19ae1d9d..1003dd84a5 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -24,9 +24,11 @@ from six.moves import intern, range from prometheus_client import Histogram +from twisted.enterprise import adbapi from twisted.internet import defer from synapse.api.errors import StoreError +from synapse.config.database import DatabaseConnectionConfig from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater @@ -74,6 +76,37 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { } +def make_pool( + reactor, db_config: DatabaseConnectionConfig, engine +) -> adbapi.ConnectionPool: + """Get the connection pool for the database. + """ + + return adbapi.ConnectionPool( + db_config.config["name"], + cp_reactor=reactor, + cp_openfun=engine.on_new_connection, + **db_config.config.get("args", {}) + ) + + +def make_conn(db_config: DatabaseConnectionConfig, engine): + """Make a new connection to the database and return it. + + Returns: + Connection + """ + + db_params = { + k: v + for k, v in db_config.config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = engine.module.connect(**db_params) + engine.on_new_connection(db_conn) + return db_conn + + class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() @@ -218,10 +251,11 @@ class Database(object): _TXN_ID = 0 - def __init__(self, hs): + def __init__(self, hs, database_config: DatabaseConnectionConfig, engine): self.hs = hs self._clock = hs.get_clock() - self._db_pool = hs.get_db_pool() + self._database_config = database_config + self._db_pool = make_pool(hs.get_reactor(), database_config, engine) self.updates = BackgroundUpdater(hs, self) @@ -234,7 +268,7 @@ class Database(object): # to watch it self._txn_perf_counters = PerformanceCounters() - self.engine = hs.database_engine + self.engine = engine # A set of tables that are not safe to use native upserts in. self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) @@ -255,6 +289,11 @@ class Database(object): self._check_safe_to_upsert, ) + def is_running(self): + """Is the database pool currently running + """ + return self._db_pool.running + @defer.inlineCallbacks def _check_safe_to_upsert(self): """ diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index ddad17dc5a..df039a072d 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -16,8 +16,6 @@ import struct import threading -from synapse.storage.prepare_database import prepare_database - class Sqlite3Engine(object): single_threaded = True @@ -25,6 +23,9 @@ class Sqlite3Engine(object): def __init__(self, database_module, database_config): self.module = database_module + database = database_config.get("args", {}).get("database") + self._is_in_memory = database in (None, ":memory:",) + # The current max state_group, or None if we haven't looked # in the DB yet. self._current_state_group_id = None @@ -59,7 +60,16 @@ class Sqlite3Engine(object): return sql def on_new_connection(self, db_conn): - prepare_database(db_conn, self, config=None) + + # We need to import here to avoid an import loop. + from synapse.storage.prepare_database import prepare_database + + if self._is_in_memory: + # In memory databases need to be rebuilt each time. Ideally we'd + # reuse the same connection as we do when starting up, but that + # would involve using adbapi before we have started the reactor. + prepare_database(db_conn, self, config=None) + db_conn.create_function("rank", 1, _rank) def is_deadlock(self, error): diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 731e1c9d9c..b4194b44ee 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -41,7 +41,7 @@ class UpgradeDatabaseException(PrepareDatabaseException): pass -def prepare_database(db_conn, database_engine, config): +def prepare_database(db_conn, database_engine, config, data_stores=["main"]): """Prepares a database for usage. Will either create all necessary tables or upgrade from an older schema version. @@ -54,11 +54,10 @@ def prepare_database(db_conn, database_engine, config): config (synapse.config.homeserver.HomeServerConfig|None): application config, or None if we are connecting to an existing database which we expect to be configured already + data_stores (list[str]): The name of the data stores that will be used + with this database. Defaults to all data stores. """ - # For now we only have the one datastore. - data_stores = ["main"] - try: cur = db_conn.cursor() version_info = _get_or_create_schema_state(cur, database_engine) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 92b8726093..596ddc6970 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -64,28 +64,29 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): mock_federation_client = Mock(spec=["put_json"]) mock_federation_client.put_json.return_value = defer.succeed((200, "OK")) + datastores = Mock() + datastores.main = Mock( + spec=[ + # Bits that Federation needs + "prep_send_transaction", + "delivered_txn", + "get_received_txn_response", + "set_received_txn_response", + "get_destination_retry_timings", + "get_devices_by_remote", + # Bits that user_directory needs + "get_user_directory_stream_pos", + "get_current_state_deltas", + "get_device_updates_by_remote", + ] + ) + hs = self.setup_test_homeserver( - datastore=( - Mock( - spec=[ - # Bits that Federation needs - "prep_send_transaction", - "delivered_txn", - "get_received_txn_response", - "set_received_txn_response", - "get_destination_retry_timings", - "get_device_updates_by_remote", - # Bits that user_directory needs - "get_user_directory_stream_pos", - "get_current_state_deltas", - ] - ) - ), - notifier=Mock(), - http_client=mock_federation_client, - keyring=mock_keyring, + notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring ) + hs.datastores = datastores + return hs def prepare(self, reactor, clock, hs): diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 3dae83c543..2a1e7c7166 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -20,7 +20,7 @@ from synapse.replication.tcp.client import ( ReplicationClientHandler, ) from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory -from synapse.storage.database import Database +from synapse.storage.database import make_conn from tests import unittest from tests.server import FakeTransport @@ -41,10 +41,12 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): + db_config = hs.config.database.get_single_database() self.master_store = self.hs.get_datastore() self.storage = hs.get_storage() + database = hs.get_datastores().databases[0] self.slaved_store = self.STORE_TYPE( - Database(hs), self.hs.get_db_conn(), self.hs + database, make_conn(db_config, database.engine), self.hs ) self.event_id = 0 diff --git a/tests/server.py b/tests/server.py index 2b7cf4242e..a554dfdd57 100644 --- a/tests/server.py +++ b/tests/server.py @@ -302,41 +302,42 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): Set up a synchronous test server, driven by the reactor used by the homeserver. """ - d = _sth(cleanup_func, *args, **kwargs).result + server = _sth(cleanup_func, *args, **kwargs) - if isinstance(d, Failure): - d.raiseException() + database = server.config.database.get_single_database() # Make the thread pool synchronous. - clock = d.get_clock() - pool = d.get_db_pool() - - def runWithConnection(func, *args, **kwargs): - return threads.deferToThreadPool( - pool._reactor, - pool.threadpool, - pool._runWithConnection, - func, - *args, - **kwargs - ) - - def runInteraction(interaction, *args, **kwargs): - return threads.deferToThreadPool( - pool._reactor, - pool.threadpool, - pool._runInteraction, - interaction, - *args, - **kwargs - ) + clock = server.get_clock() + + for database in server.get_datastores().databases: + pool = database._db_pool + + def runWithConnection(func, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runWithConnection, + func, + *args, + **kwargs + ) + + def runInteraction(interaction, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runInteraction, + interaction, + *args, + **kwargs + ) - if pool: pool.runWithConnection = runWithConnection pool.runInteraction = runInteraction pool.threadpool = ThreadPool(clock._reactor) pool.running = True - return d + + return server def get_clock(): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 2e521e9ab7..fd52512696 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -28,7 +28,7 @@ from synapse.storage.data_stores.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) -from synapse.storage.database import Database +from synapse.storage.database import Database, make_conn from tests import unittest from tests.utils import setup_test_homeserver @@ -55,8 +55,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - database = Database(hs) - self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + self.store = ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) def tearDown(self): # TODO: suboptimal that we need to create files for tests! @@ -111,9 +113,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): hs.config.event_cache_size = 1 hs.config.password_providers = [] - self.db_pool = hs.get_db_pool() - self.engine = hs.database_engine - self.as_list = [ {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, {"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"}, @@ -125,8 +124,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.as_yaml_files = [] - database = Database(hs) - self.store = TestTransactionStore(database, hs.get_db_conn(), hs) + # We assume there is only one database in these tests + database = hs.get_datastores().databases[0] + self.db_pool = database._db_pool + self.engine = database.engine + + db_config = hs.config.get_single_database() + self.store = TestTransactionStore( + database, make_conn(db_config, self.engine), hs + ) def _add_service(self, url, as_token, id): as_yaml = dict( @@ -419,7 +425,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.event_cache_size = 1 hs.config.password_providers = [] - ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) @defer.inlineCallbacks def test_duplicate_ids(self): @@ -435,7 +444,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) e = cm.exception self.assertIn(f1, str(e)) @@ -456,7 +468,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) e = cm.exception self.assertIn(f1, str(e)) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 537cfe9f64..cdee0a9e60 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -52,15 +52,17 @@ class SQLBaseStoreTestCase(unittest.TestCase): config = Mock() config._disable_native_upserts = True config.event_cache_size = 1 - config.database_config = {"name": "sqlite3"} - engine = create_engine(config.database_config) + hs = TestHomeServer("test", config=config) + + sqlite_config = {"name": "sqlite3"} + engine = create_engine(sqlite_config) fake_engine = Mock(wraps=engine) fake_engine.can_native_upsert = False - hs = TestHomeServer( - "test", db_pool=self.db_pool, config=config, database_engine=fake_engine - ) - self.datastore = SQLBaseStore(Database(hs), None, hs) + db = Database(Mock(), Mock(config=sqlite_config), fake_engine) + db._db_pool = self.db_pool + + self.datastore = SQLBaseStore(db, None, hs) @defer.inlineCallbacks def test_insert_1col(self): diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 4578cc3b60..ed5786865a 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -26,7 +26,6 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.db_pool = hs.get_db_pool() self.store = hs.get_datastore() diff --git a/tests/utils.py b/tests/utils.py index 585f305b9a..9f5bf40b4b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -30,6 +30,7 @@ from twisted.internet import defer, reactor from synapse.api.constants import EventTypes from synapse.api.errors import CodeMessageException, cs_error from synapse.api.room_versions import RoomVersions +from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.federation.transport import server as federation_server @@ -177,7 +178,6 @@ class TestHomeServer(HomeServer): DATASTORE_CLASS = DataStore -@defer.inlineCallbacks def setup_test_homeserver( cleanup_func, name="test", @@ -214,7 +214,7 @@ def setup_test_homeserver( if USE_POSTGRES_FOR_TESTS: test_db = "synapse_test_%s" % uuid.uuid4().hex - config.database_config = { + database_config = { "name": "psycopg2", "args": { "database": test_db, @@ -226,12 +226,15 @@ def setup_test_homeserver( }, } else: - config.database_config = { + database_config = { "name": "sqlite3", "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, } - db_engine = create_engine(config.database_config) + database = DatabaseConnectionConfig("master", database_config, ["main"]) + config.database.databases = [database] + + db_engine = create_engine(database.config) # Create the database before we actually try and connect to it, based off # the template database we generate in setupdb() @@ -251,11 +254,6 @@ def setup_test_homeserver( cur.close() db_conn.close() - # we need to configure the connection pool to run the on_new_connection - # function, so that we can test code that uses custom sqlite functions - # (like rank). - config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection - if datastore is None: hs = homeserverToUse( name, @@ -267,21 +265,19 @@ def setup_test_homeserver( **kargs ) - # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to - # date db - if not isinstance(db_engine, PostgresEngine): - db_conn = hs.get_db_conn() - yield prepare_database(db_conn, db_engine, config) - db_conn.commit() - db_conn.close() + hs.setup() + if homeserverToUse.__name__ == "TestHomeServer": + hs.setup_master() + + if isinstance(db_engine, PostgresEngine): + database = hs.get_datastores().databases[0] - else: # We need to do cleanup on PostgreSQL def cleanup(): import psycopg2 # Close all the db pools - hs.get_db_pool().close() + database._db_pool.close() dropped = False @@ -320,23 +316,12 @@ def setup_test_homeserver( # Register the cleanup hook cleanup_func(cleanup) - hs.setup() - if homeserverToUse.__name__ == "TestHomeServer": - hs.setup_master() else: - # If we have been given an explicit datastore we probably want to mock - # out the DataStores somehow too. This all feels a bit wrong, but then - # mocking the stores feels wrong too. - datastores = Mock(datastore=datastore) - hs = homeserverToUse( name, - db_pool=None, datastore=datastore, - datastores=datastores, config=config, version_string="Synapse/tests", - database_engine=db_engine, tls_server_context_factory=Mock(), tls_client_options_factory=Mock(), reactor=reactor, -- cgit 1.5.1 From 509e381afa8c656e72f5fef3d651a9819794174a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 21 Feb 2020 07:15:07 -0500 Subject: Clarify list/set/dict/tuple comprehensions and enforce via flake8 (#6957) Ensure good comprehension hygiene using flake8-comprehensions. --- CONTRIBUTING.md | 2 +- changelog.d/6957.misc | 1 + docs/code_style.md | 2 +- scripts-dev/convert_server_keys.py | 2 +- synapse/app/_base.py | 2 +- synapse/app/federation_sender.py | 4 +-- synapse/app/pusher.py | 2 +- synapse/config/server.py | 4 +-- synapse/config/tls.py | 2 +- synapse/crypto/keyring.py | 6 ++-- synapse/federation/send_queue.py | 4 +-- synapse/groups/groups_server.py | 2 +- synapse/handlers/device.py | 2 +- synapse/handlers/directory.py | 4 +-- synapse/handlers/federation.py | 18 +++++----- synapse/handlers/presence.py | 6 ++-- synapse/handlers/receipts.py | 2 +- synapse/handlers/room.py | 2 +- synapse/handlers/search.py | 8 ++--- synapse/handlers/sync.py | 22 ++++++------ synapse/handlers/typing.py | 4 +-- synapse/logging/utils.py | 2 +- synapse/metrics/__init__.py | 2 +- synapse/metrics/background_process_metrics.py | 4 +-- synapse/push/bulk_push_rule_evaluator.py | 8 ++--- synapse/push/emailpusher.py | 2 +- synapse/push/mailer.py | 20 +++++------ synapse/push/pusherpool.py | 2 +- synapse/rest/admin/_base.py | 4 +-- synapse/rest/client/v1/push_rule.py | 6 ++-- synapse/rest/client/v1/pusher.py | 4 +-- synapse/rest/client/v2_alpha/sync.py | 2 +- synapse/rest/key/v2/remote_key_resource.py | 2 +- synapse/rest/media/v1/_base.py | 40 ++++++++++------------ synapse/state/v1.py | 10 +++--- synapse/state/v2.py | 8 ++--- synapse/storage/_base.py | 2 +- synapse/storage/background_updates.py | 2 +- synapse/storage/data_stores/main/appservice.py | 14 ++++---- synapse/storage/data_stores/main/client_ips.py | 4 +-- synapse/storage/data_stores/main/devices.py | 13 ++++--- .../storage/data_stores/main/event_federation.py | 2 +- synapse/storage/data_stores/main/events.py | 8 ++--- .../storage/data_stores/main/events_bg_updates.py | 2 +- synapse/storage/data_stores/main/events_worker.py | 6 ++-- synapse/storage/data_stores/main/push_rule.py | 8 ++--- synapse/storage/data_stores/main/receipts.py | 4 +-- synapse/storage/data_stores/main/roommember.py | 4 +-- synapse/storage/data_stores/main/state.py | 8 ++--- synapse/storage/data_stores/main/stream.py | 8 ++--- .../storage/data_stores/main/user_erasure_store.py | 4 +-- synapse/storage/data_stores/state/store.py | 4 +-- synapse/storage/database.py | 4 +-- synapse/storage/persist_events.py | 8 ++--- synapse/storage/prepare_database.py | 6 ++-- synapse/util/frozenutils.py | 2 +- synapse/visibility.py | 4 +-- tests/config/test_generate.py | 2 +- tests/federation/test_federation_server.py | 2 +- tests/handlers/test_presence.py | 4 +-- tests/handlers/test_typing.py | 6 ++-- tests/handlers/test_user_directory.py | 12 +++---- tests/push/test_email.py | 6 ++-- tests/push/test_http.py | 8 ++--- tests/rest/client/v2_alpha/test_sync.py | 28 ++++++++------- tests/storage/test__base.py | 4 +-- tests/storage/test_appservice.py | 36 +++++++++---------- tests/storage/test_cleanup_extrems.py | 10 +++--- tests/storage/test_event_metrics.py | 36 +++++++++---------- tests/storage/test_state.py | 2 +- tests/test_state.py | 18 +++------- tests/util/test_stream_change_cache.py | 18 +++------- tox.ini | 1 + 73 files changed, 251 insertions(+), 276 deletions(-) create mode 100644 changelog.d/6957.misc (limited to 'tests/storage/test_appservice.py') diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4b01b6ac8c..253a0ca648 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -60,7 +60,7 @@ python 3.6 and to install each tool: ``` # Install the dependencies -pip install -U black flake8 isort +pip install -U black flake8 flake8-comprehensions isort # Run the linter script ./scripts-dev/lint.sh diff --git a/changelog.d/6957.misc b/changelog.d/6957.misc new file mode 100644 index 0000000000..4f98030110 --- /dev/null +++ b/changelog.d/6957.misc @@ -0,0 +1 @@ +Use flake8-comprehensions to enforce good hygiene of list/set/dict comprehensions. diff --git a/docs/code_style.md b/docs/code_style.md index 71aecd41f7..6ef6f80290 100644 --- a/docs/code_style.md +++ b/docs/code_style.md @@ -30,7 +30,7 @@ The necessary tools are detailed below. Install `flake8` with: - pip install --upgrade flake8 + pip install --upgrade flake8 flake8-comprehensions Check all application and test code with: diff --git a/scripts-dev/convert_server_keys.py b/scripts-dev/convert_server_keys.py index 179be61c30..06b4c1e2ff 100644 --- a/scripts-dev/convert_server_keys.py +++ b/scripts-dev/convert_server_keys.py @@ -103,7 +103,7 @@ def main(): yaml.safe_dump(result, sys.stdout, default_flow_style=False) - rows = list(row for server, json in result.items() for row in rows_v2(server, json)) + rows = [row for server, json in result.items() for row in rows_v2(server, json)] cursor = connection.cursor() cursor.executemany( diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 109b1e2fb5..9ffd23c6df 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -141,7 +141,7 @@ def start_reactor( def quit_with_error(error_string): message_lines = error_string.split("\n") - line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2 + line_length = max(len(l) for l in message_lines if len(l) < 80) + 2 sys.stderr.write("*" * line_length + "\n") for line in message_lines: sys.stderr.write(" %s\n" % (line.rstrip(),)) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 63a91f1177..b7fcf80ddc 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -262,7 +262,7 @@ class FederationSenderHandler(object): # ... as well as device updates and messages elif stream_name == DeviceListsStream.NAME: - hosts = set(row.destination for row in rows) + hosts = {row.destination for row in rows} for host in hosts: self.federation_sender.send_device_messages(host) @@ -270,7 +270,7 @@ class FederationSenderHandler(object): # The to_device stream includes stuff to be pushed to both local # clients and remote servers, so we ignore entities that start with # '@' (since they'll be local users rather than destinations). - hosts = set(row.entity for row in rows if not row.entity.startswith("@")) + hosts = {row.entity for row in rows if not row.entity.startswith("@")} for host in hosts: self.federation_sender.send_device_messages(host) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index e46b6ac598..84e9f8d5e2 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -158,7 +158,7 @@ class PusherReplicationHandler(ReplicationClientHandler): yield self.pusher_pool.on_new_notifications(token, token) elif stream_name == "receipts": yield self.pusher_pool.on_new_receipts( - token, token, set(row.room_id for row in rows) + token, token, {row.room_id for row in rows} ) except Exception: logger.exception("Error poking pushers") diff --git a/synapse/config/server.py b/synapse/config/server.py index 0ec1b0fadd..7525765fee 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -1066,12 +1066,12 @@ KNOWN_RESOURCES = ( def _check_resource_config(listeners): - resource_names = set( + resource_names = { res_name for listener in listeners for res in listener.get("resources", []) for res_name in res.get("names", []) - ) + } for resource in resource_names: if resource not in KNOWN_RESOURCES: diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 97a12d51f6..a65538562b 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -260,7 +260,7 @@ class TlsConfig(Config): crypto.FILETYPE_ASN1, self.tls_certificate ) sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest()) - sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints) + sha256_fingerprints = {f["sha256"] for f in self.tls_fingerprints} if sha256_fingerprint not in sha256_fingerprints: self.tls_fingerprints.append({"sha256": sha256_fingerprint}) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 6fe5a6a26a..983f0ead8c 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -326,9 +326,7 @@ class Keyring(object): verify_requests (list[VerifyJsonRequest]): list of verify requests """ - remaining_requests = set( - (rq for rq in verify_requests if not rq.key_ready.called) - ) + remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called} @defer.inlineCallbacks def do_iterations(): @@ -396,7 +394,7 @@ class Keyring(object): results = yield fetcher.get_keys(missing_keys) - completed = list() + completed = [] for verify_request in remaining_requests: server_name = verify_request.server_name diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 001bb304ae..876fb0e245 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -129,9 +129,9 @@ class FederationRemoteSendQueue(object): for key in keys[:i]: del self.presence_changed[key] - user_ids = set( + user_ids = { user_id for uids in self.presence_changed.values() for user_id in uids - ) + } keys = self.presence_destinations.keys() i = self.presence_destinations.bisect_left(position_to_delete) diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index c106abae21..4f0dc0a209 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -608,7 +608,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): user_results = yield self.store.get_users_in_group( group_id, include_private=True ) - if user_id in [user_result["user_id"] for user_result in user_results]: + if user_id in (user_result["user_id"] for user_result in user_results): raise SynapseError(400, "User already in group") content = { diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 50cea3f378..a514c30714 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -742,6 +742,6 @@ class DeviceListUpdater(object): # We clobber the seen updates since we've re-synced from a given # point. - self._seen_updates[user_id] = set([stream_id]) + self._seen_updates[user_id] = {stream_id} defer.returnValue(result) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 921d887b24..0b23ca919a 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -72,7 +72,7 @@ class DirectoryHandler(BaseHandler): # TODO(erikj): Check if there is a current association. if not servers: users = yield self.state.get_current_users_in_room(room_id) - servers = set(get_domain_from_id(u) for u in users) + servers = {get_domain_from_id(u) for u in users} if not servers: raise SynapseError(400, "Failed to get server list") @@ -255,7 +255,7 @@ class DirectoryHandler(BaseHandler): ) users = yield self.state.get_current_users_in_room(room_id) - extra_servers = set(get_domain_from_id(u) for u in users) + extra_servers = {get_domain_from_id(u) for u in users} servers = set(extra_servers) | set(servers) # If this server is in the list of servers, return it first. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index eb20ef4aec..a689065f89 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -659,11 +659,11 @@ class FederationHandler(BaseHandler): # this can happen if a remote server claims that the state or # auth_events at an event in room A are actually events in room B - bad_events = list( + bad_events = [ (event_id, event.room_id) for event_id, event in fetched_events.items() if event.room_id != room_id - ) + ] for bad_event_id, bad_room_id in bad_events: # This is a bogus situation, but since we may only discover it a long time @@ -856,7 +856,7 @@ class FederationHandler(BaseHandler): # Don't bother processing events we already have. seen_events = await self.store.have_events_in_timeline( - set(e.event_id for e in events) + {e.event_id for e in events} ) events = [e for e in events if e.event_id not in seen_events] @@ -866,7 +866,7 @@ class FederationHandler(BaseHandler): event_map = {e.event_id: e for e in events} - event_ids = set(e.event_id for e in events) + event_ids = {e.event_id for e in events} # build a list of events whose prev_events weren't in the batch. # (XXX: this will include events whose prev_events we already have; that doesn't @@ -892,13 +892,13 @@ class FederationHandler(BaseHandler): state_events.update({s.event_id: s for s in state}) events_to_state[e_id] = state - required_auth = set( + required_auth = { a_id for event in events + list(state_events.values()) + list(auth_events.values()) for a_id in event.auth_event_ids() - ) + } auth_events.update( {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map} ) @@ -1247,7 +1247,7 @@ class FederationHandler(BaseHandler): async def on_event_auth(self, event_id: str) -> List[EventBase]: event = await self.store.get_event(event_id) auth = await self.store.get_auth_chain( - [auth_id for auth_id in event.auth_event_ids()], include_given=True + list(event.auth_event_ids()), include_given=True ) return list(auth) @@ -2152,7 +2152,7 @@ class FederationHandler(BaseHandler): # Now get the current auth_chain for the event. local_auth_chain = await self.store.get_auth_chain( - [auth_id for auth_id in event.auth_event_ids()], include_given=True + list(event.auth_event_ids()), include_given=True ) # TODO: Check if we would now reject event_id. If so we need to tell @@ -2654,7 +2654,7 @@ class FederationHandler(BaseHandler): member_handler = self.hs.get_room_member_handler() yield member_handler.send_membership_event(None, event, context) else: - destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) + destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)} yield self.federation_client.forward_third_party_invite( destinations, room_id, event_dict ) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 202aa9294f..0d6cf2b008 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -313,7 +313,7 @@ class PresenceHandler(object): notified_presence_counter.inc(len(to_notify)) yield self._persist_and_notify(list(to_notify.values())) - self.unpersisted_users_changes |= set(s.user_id for s in new_states) + self.unpersisted_users_changes |= {s.user_id for s in new_states} self.unpersisted_users_changes -= set(to_notify.keys()) to_federation_ping = { @@ -698,7 +698,7 @@ class PresenceHandler(object): updates = yield self.current_state_for_users(target_user_ids) updates = list(updates.values()) - for user_id in set(target_user_ids) - set(u.user_id for u in updates): + for user_id in set(target_user_ids) - {u.user_id for u in updates}: updates.append(UserPresenceState.default(user_id)) now = self.clock.time_msec() @@ -886,7 +886,7 @@ class PresenceHandler(object): hosts = yield self.state.get_current_hosts_in_room(room_id) # Filter out ourselves. - hosts = set(host for host in hosts if host != self.server_name) + hosts = {host for host in hosts if host != self.server_name} self.federation.send_presence_to_destinations( states=[state], destinations=hosts diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 9283c039e3..8bc100db42 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -94,7 +94,7 @@ class ReceiptsHandler(BaseHandler): # no new receipts return False - affected_room_ids = list(set([r.room_id for r in receipts])) + affected_room_ids = list({r.room_id for r in receipts}) self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids) # Note that the min here shouldn't be relied upon to be accurate. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 76e8f61b74..8ee870f0bb 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -355,7 +355,7 @@ class RoomCreationHandler(BaseHandler): # If so, mark the new room as non-federatable as well creation_content["m.federate"] = False - initial_state = dict() + initial_state = {} # Replicate relevant room events types_to_copy = ( diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 110097eab9..ec1542d416 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -184,7 +184,7 @@ class SearchHandler(BaseHandler): membership_list=[Membership.JOIN], # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban], ) - room_ids = set(r.room_id for r in rooms) + room_ids = {r.room_id for r in rooms} # If doing a subset of all rooms seearch, check if any of the rooms # are from an upgraded room, and search their contents as well @@ -374,12 +374,12 @@ class SearchHandler(BaseHandler): ).to_string() if include_profile: - senders = set( + senders = { ev.sender for ev in itertools.chain( res["events_before"], [event], res["events_after"] ) - ) + } if res["events_after"]: last_event_id = res["events_after"][-1].event_id @@ -421,7 +421,7 @@ class SearchHandler(BaseHandler): state_results = {} if include_state: - rooms = set(e.room_id for e in allowed_events) + rooms = {e.room_id for e in allowed_events} for room_id in rooms: state = yield self.state_handler.get_current_state(room_id) state_results[room_id] = list(state.values()) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 4324bc702e..669dbc8a48 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -682,11 +682,9 @@ class SyncHandler(object): # FIXME: order by stream ordering rather than as returned by SQL if joined_user_ids or invited_user_ids: - summary["m.heroes"] = sorted( - [user_id for user_id in (joined_user_ids + invited_user_ids)] - )[0:5] + summary["m.heroes"] = sorted(joined_user_ids + invited_user_ids)[0:5] else: - summary["m.heroes"] = sorted([user_id for user_id in gone_user_ids])[0:5] + summary["m.heroes"] = sorted(gone_user_ids)[0:5] if not sync_config.filter_collection.lazy_load_members(): return summary @@ -697,9 +695,9 @@ class SyncHandler(object): # track which members the client should already know about via LL: # Ones which are already in state... - existing_members = set( + existing_members = { user_id for (typ, user_id) in state.keys() if typ == EventTypes.Member - ) + } # ...or ones which are in the timeline... for ev in batch.events: @@ -773,10 +771,10 @@ class SyncHandler(object): # We only request state for the members needed to display the # timeline: - members_to_fetch = set( + members_to_fetch = { event.sender # FIXME: we also care about invite targets etc. for event in batch.events - ) + } if full_state: # always make sure we LL ourselves so we know we're in the room @@ -1993,10 +1991,10 @@ def _calculate_state( ) } - c_ids = set(e for e in itervalues(current)) - ts_ids = set(e for e in itervalues(timeline_start)) - p_ids = set(e for e in itervalues(previous)) - tc_ids = set(e for e in itervalues(timeline_contains)) + c_ids = set(itervalues(current)) + ts_ids = set(itervalues(timeline_start)) + p_ids = set(itervalues(previous)) + tc_ids = set(itervalues(timeline_contains)) # If we are lazyloading room members, we explicitly add the membership events # for the senders in the timeline into the state block returned by /sync, diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 5406618431..391bceb0c4 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -198,7 +198,7 @@ class TypingHandler(object): now=now, obj=member, then=now + FEDERATION_PING_INTERVAL ) - for domain in set(get_domain_from_id(u) for u in users): + for domain in {get_domain_from_id(u) for u in users}: if domain != self.server_name: logger.debug("sending typing update to %s", domain) self.federation.build_and_send_edu( @@ -231,7 +231,7 @@ class TypingHandler(object): return users = yield self.state.get_current_users_in_room(room_id) - domains = set(get_domain_from_id(u) for u in users) + domains = {get_domain_from_id(u) for u in users} if self.server_name in domains: logger.info("Got typing update from %s: %r", user_id, content) diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py index 6073fc2725..0c2527bd86 100644 --- a/synapse/logging/utils.py +++ b/synapse/logging/utils.py @@ -148,7 +148,7 @@ def trace_function(f): pathname=pathname, lineno=lineno, msg=msg, - args=tuple(), + args=(), exc_info=None, ) diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 0b45e1f52a..0dba997a23 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -240,7 +240,7 @@ class BucketCollector(object): res.append(["+Inf", sum(data.values())]) metric = HistogramMetricFamily( - self.name, "", buckets=res, sum_value=sum([x * y for x, y in data.items()]) + self.name, "", buckets=res, sum_value=sum(x * y for x, y in data.items()) ) yield metric diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index c53d2a0d40..b65bcd8806 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -80,13 +80,13 @@ _background_process_db_sched_duration = Counter( # map from description to a counter, so that we can name our logcontexts # incrementally. (It actually duplicates _background_process_start_count, but # it's much simpler to do so than to try to combine them.) -_background_process_counts = dict() # type: dict[str, int] +_background_process_counts = {} # type: dict[str, int] # map from description to the currently running background processes. # # it's kept as a dict of sets rather than a big set so that we can keep track # of process descriptions that no longer have any active processes. -_background_processes = dict() # type: dict[str, set[_BackgroundProcess]] +_background_processes = {} # type: dict[str, set[_BackgroundProcess]] # A lock that covers the above dicts _bg_metrics_lock = threading.Lock() diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 7d9f5a38d9..433ca2f416 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -400,11 +400,11 @@ class RulesForRoom(object): if logger.isEnabledFor(logging.DEBUG): logger.debug("Found members %r: %r", self.room_id, members.values()) - interested_in_user_ids = set( + interested_in_user_ids = { user_id for user_id, membership in itervalues(members) if membership == Membership.JOIN - ) + } logger.debug("Joined: %r", interested_in_user_ids) @@ -412,9 +412,9 @@ class RulesForRoom(object): interested_in_user_ids, on_invalidate=self.invalidate_all_cb ) - user_ids = set( + user_ids = { uid for uid, have_pusher in iteritems(if_users_with_pushers) if have_pusher - ) + } logger.debug("With pushers: %r", user_ids) diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 8c818a86bf..ba4551d619 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -204,7 +204,7 @@ class EmailPusher(object): yield self.send_notification(unprocessed, reason) yield self.save_last_stream_ordering_and_success( - max([ea["stream_ordering"] for ea in unprocessed]) + max(ea["stream_ordering"] for ea in unprocessed) ) # we update the throttle on all the possible unprocessed push actions diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index b13b646bfd..4ccaf178ce 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -526,12 +526,10 @@ class Mailer(object): # If the room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" sender_ids = list( - set( - [ - notif_events[n["event_id"]].sender - for n in notifs_by_room[room_id] - ] - ) + { + notif_events[n["event_id"]].sender + for n in notifs_by_room[room_id] + } ) member_events = yield self.store.get_events( @@ -558,12 +556,10 @@ class Mailer(object): # If the reason room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" sender_ids = list( - set( - [ - notif_events[n["event_id"]].sender - for n in notifs_by_room[reason["room_id"]] - ] - ) + { + notif_events[n["event_id"]].sender + for n in notifs_by_room[reason["room_id"]] + } ) member_events = yield self.store.get_events( diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index b9dca5bc63..01789a9fb4 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -191,7 +191,7 @@ class PusherPool: min_stream_id - 1, max_stream_id ) # This returns a tuple, user_id is at index 3 - users_affected = set([r[3] for r in updated_receipts]) + users_affected = {r[3] for r in updated_receipts} for u in users_affected: if u in self.pushers: diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py index 459482eb6d..a96f75ce26 100644 --- a/synapse/rest/admin/_base.py +++ b/synapse/rest/admin/_base.py @@ -29,7 +29,7 @@ def historical_admin_path_patterns(path_regex): Note that this should only be used for existing endpoints: new ones should just register for the /_synapse/admin path. """ - return list( + return [ re.compile(prefix + path_regex) for prefix in ( "^/_synapse/admin/v1", @@ -37,7 +37,7 @@ def historical_admin_path_patterns(path_regex): "^/_matrix/client/unstable/admin", "^/_matrix/client/r0/admin", ) - ) + ] def admin_patterns(path_regex: str): diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 4f74600239..9fd4908136 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -49,7 +49,7 @@ class PushRuleRestServlet(RestServlet): if self._is_worker: raise Exception("Cannot handle PUT /push_rules on worker") - spec = _rule_spec_from_path([x for x in path.split("/")]) + spec = _rule_spec_from_path(path.split("/")) try: priority_class = _priority_class_from_spec(spec) except InvalidRuleException as e: @@ -110,7 +110,7 @@ class PushRuleRestServlet(RestServlet): if self._is_worker: raise Exception("Cannot handle DELETE /push_rules on worker") - spec = _rule_spec_from_path([x for x in path.split("/")]) + spec = _rule_spec_from_path(path.split("/")) requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -138,7 +138,7 @@ class PushRuleRestServlet(RestServlet): rules = format_push_rules_for_user(requester.user, rules) - path = [x for x in path.split("/")][1:] + path = path.split("/")[1:] if path == []: # we're a reference impl: pedantry is our job. diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 6f6b7aed6e..550a2f1b44 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -54,9 +54,9 @@ class PushersRestServlet(RestServlet): pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) - filtered_pushers = list( + filtered_pushers = [ {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers - ) + ] return 200, {"pushers": filtered_pushers} diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index d8292ce29f..8fa68dd37f 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -72,7 +72,7 @@ class SyncRestServlet(RestServlet): """ PATTERNS = client_patterns("/sync$") - ALLOWED_PRESENCE = set(["online", "offline", "unavailable"]) + ALLOWED_PRESENCE = {"online", "offline", "unavailable"} def __init__(self, hs): super(SyncRestServlet, self).__init__() diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 9d6813a047..4b6d030a57 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -149,7 +149,7 @@ class RemoteKey(DirectServeResource): time_now_ms = self.clock.time_msec() - cache_misses = dict() # type: Dict[str, Set[str]] + cache_misses = {} # type: Dict[str, Set[str]] for (server_name, key_id, from_server), results in cached.items(): results = [(result["ts_added_ms"], result) for result in results] diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 65bbf00073..ba28dd089d 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -135,27 +135,25 @@ def add_file_headers(request, media_type, file_size, upload_name): # separators as defined in RFC2616. SP and HT are handled separately. # see _can_encode_filename_as_token. -_FILENAME_SEPARATOR_CHARS = set( - ( - "(", - ")", - "<", - ">", - "@", - ",", - ";", - ":", - "\\", - '"', - "/", - "[", - "]", - "?", - "=", - "{", - "}", - ) -) +_FILENAME_SEPARATOR_CHARS = { + "(", + ")", + "<", + ">", + "@", + ",", + ";", + ":", + "\\", + '"', + "/", + "[", + "]", + "?", + "=", + "{", + "}", +} def _can_encode_filename_as_token(x): diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 24b7c0faef..9bf98d06f2 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -69,9 +69,9 @@ def resolve_events_with_store( unconflicted_state, conflicted_state = _seperate(state_sets) - needed_events = set( + needed_events = { event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids - ) + } needed_event_count = len(needed_events) if event_map is not None: needed_events -= set(iterkeys(event_map)) @@ -261,11 +261,11 @@ def _resolve_state_events(conflicted_state, auth_events): def _resolve_auth_events(events, auth_events): - reverse = [i for i in reversed(_ordered_events(events))] + reverse = list(reversed(_ordered_events(events))) - auth_keys = set( + auth_keys = { key for event in events for key in event_auth.auth_types_for_event(event) - ) + } new_auth_events = {} for key in auth_keys: diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 75fe58305a..0ffe6d8c14 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -105,7 +105,7 @@ def resolve_events_with_store( % (room_id, event.event_id, event.room_id,) ) - full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map) + full_conflicted_set = {eid for eid in full_conflicted_set if eid in event_map} logger.debug("%d full_conflicted_set entries", len(full_conflicted_set)) @@ -233,7 +233,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): auth_sets = [] for state_set in state_sets: - auth_ids = set( + auth_ids = { eid for key, eid in iteritems(state_set) if ( @@ -246,7 +246,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): ) ) and eid not in common - ) + } auth_chain = yield state_res_store.get_auth_chain(auth_ids, common) auth_ids.update(auth_chain) @@ -275,7 +275,7 @@ def _seperate(state_sets): conflicted_state = {} for key in set(itertools.chain.from_iterable(state_sets)): - event_ids = set(state_set.get(key) for state_set in state_sets) + event_ids = {state_set.get(key) for state_set in state_sets} if len(event_ids) == 1: unconflicted_state[key] = event_ids.pop() else: diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index da3b99f93d..13de5f1f62 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -56,7 +56,7 @@ class SQLBaseStore(metaclass=ABCMeta): members_changed (iterable[str]): The user_ids of members that have changed """ - for host in set(get_domain_from_id(u) for u in members_changed): + for host in {get_domain_from_id(u) for u in members_changed}: self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) self._attempt_to_invalidate_cache("was_host_joined", (room_id, host)) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index bd547f35cf..eb1a7e5002 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -189,7 +189,7 @@ class BackgroundUpdater(object): keyvalues=None, retcols=("update_name", "depends_on"), ) - in_flight = set(update["update_name"] for update in updates) + in_flight = {update["update_name"] for update in updates} for update in updates: if update["depends_on"] not in in_flight: self._background_update_queue.append(update["update_name"]) diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py index b2f39649fd..efbc06c796 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -135,7 +135,7 @@ class ApplicationServiceTransactionWorkerStore( may be empty. """ results = yield self.db.simple_select_list( - "application_services_state", dict(state=state), ["as_id"] + "application_services_state", {"state": state}, ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore as_list = self.get_app_services() @@ -158,7 +158,7 @@ class ApplicationServiceTransactionWorkerStore( """ result = yield self.db.simple_select_one( "application_services_state", - dict(as_id=service.id), + {"as_id": service.id}, ["state"], allow_none=True, desc="get_appservice_state", @@ -177,7 +177,7 @@ class ApplicationServiceTransactionWorkerStore( A Deferred which resolves when the state was set successfully. """ return self.db.simple_upsert( - "application_services_state", dict(as_id=service.id), dict(state=state) + "application_services_state", {"as_id": service.id}, {"state": state} ) def create_appservice_txn(self, service, events): @@ -253,13 +253,15 @@ class ApplicationServiceTransactionWorkerStore( self.db.simple_upsert_txn( txn, "application_services_state", - dict(as_id=service.id), - dict(last_txn=txn_id), + {"as_id": service.id}, + {"last_txn": txn_id}, ) # Delete txn self.db.simple_delete_txn( - txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id) + txn, + "application_services_txns", + {"txn_id": txn_id, "as_id": service.id}, ) return self.db.runInteraction( diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 13f4c9c72e..e1ccb27142 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -530,7 +530,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"])) for row in rows ) - return list( + return [ { "access_token": access_token, "ip": ip, @@ -538,7 +538,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): "last_seen": last_seen, } for (access_token, ip), (user_agent, last_seen) in iteritems(results) - ) + ] @wrap_as_background_process("prune_old_user_ips") async def _prune_old_user_ips(self): diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index b7617efb80..d55733a4cd 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -137,7 +137,7 @@ class DeviceWorkerStore(SQLBaseStore): # get the cross-signing keys of the users in the list, so that we can # determine which of the device changes were cross-signing keys - users = set(r[0] for r in updates) + users = {r[0] for r in updates} master_key_by_user = {} self_signing_key_by_user = {} for user in users: @@ -446,7 +446,7 @@ class DeviceWorkerStore(SQLBaseStore): a set of user_ids and results_map is a mapping of user_id -> device_id -> device_info """ - user_ids = set(user_id for user_id, _ in query_list) + user_ids = {user_id for user_id, _ in query_list} user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) # We go and check if any of the users need to have their device lists @@ -454,10 +454,9 @@ class DeviceWorkerStore(SQLBaseStore): users_needing_resync = yield self.get_user_ids_requiring_device_list_resync( user_ids ) - user_ids_in_cache = ( - set(user_id for user_id, stream_id in user_map.items() if stream_id) - - users_needing_resync - ) + user_ids_in_cache = { + user_id for user_id, stream_id in user_map.items() if stream_id + } - users_needing_resync user_ids_not_in_cache = user_ids - user_ids_in_cache results = {} @@ -604,7 +603,7 @@ class DeviceWorkerStore(SQLBaseStore): rows = yield self.db.execute( "get_users_whose_signatures_changed", None, sql, user_id, from_key ) - return set(user for row in rows for user in json.loads(row[0])) + return {user for row in rows for user in json.loads(row[0])} else: return set() diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 750ec1b70d..49a7b8b433 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -426,7 +426,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas query, (room_id, event_id, False, limit - len(event_results)) ) - new_results = set(t[0] for t in txn) - seen_events + new_results = {t[0] for t in txn} - seen_events new_front |= new_results seen_events |= new_results diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index c9d0d68c3a..8ae23df00a 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -145,7 +145,7 @@ class EventsStore( return txn.fetchall() res = yield self.db.runInteraction("read_forward_extremities", fetch) - self._current_forward_extremities_amount = c_counter(list(x[0] for x in res)) + self._current_forward_extremities_amount = c_counter([x[0] for x in res]) @_retry_on_integrity_error @defer.inlineCallbacks @@ -598,11 +598,11 @@ class EventsStore( # We find out which membership events we may have deleted # and which we have added, then we invlidate the caches for all # those users. - members_changed = set( + members_changed = { state_key for ev_type, state_key in itertools.chain(to_delete, to_insert) if ev_type == EventTypes.Member - ) + } for member in members_changed: txn.call_after( @@ -1615,7 +1615,7 @@ class EventsStore( """ ) - referenced_state_groups = set(sg for sg, in txn) + referenced_state_groups = {sg for sg, in txn} logger.info( "[purge] found %i referenced state groups", len(referenced_state_groups) ) diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index 5177b71016..f54c8b1ee0 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -402,7 +402,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): keyvalues={}, retcols=("room_id",), ) - room_ids = set(row["room_id"] for row in rows) + room_ids = {row["room_id"] for row in rows} for room_id in room_ids: txn.call_after( self.get_latest_event_ids_in_room.invalidate, (room_id,) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 7251e819f5..47a3a26072 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -494,9 +494,9 @@ class EventsWorkerStore(SQLBaseStore): """ with Measure(self._clock, "_fetch_event_list"): try: - events_to_fetch = set( + events_to_fetch = { event_id for events, _ in event_list for event_id in events - ) + } row_dict = self.db.new_transaction( conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch @@ -804,7 +804,7 @@ class EventsWorkerStore(SQLBaseStore): desc="have_events_in_timeline", ) - return set(r["event_id"] for r in rows) + return {r["event_id"] for r in rows} @defer.inlineCallbacks def have_seen_events(self, event_ids): diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index e2673ae073..62ac88d9f2 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -276,21 +276,21 @@ class PushRulesWorkerStore( # We ignore app service users for now. This is so that we don't fill # up the `get_if_users_have_pushers` cache with AS entries that we # know don't have pushers, nor even read receipts. - local_users_in_room = set( + local_users_in_room = { u for u in users_in_room if self.hs.is_mine_id(u) and not self.get_if_app_services_interested_in_user(u) - ) + } # users in the room who have pushers need to get push rules run because # that's how their pushers work if_users_with_pushers = yield self.get_if_users_have_pushers( local_users_in_room, on_invalidate=cache_context.invalidate ) - user_ids = set( + user_ids = { uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher - ) + } users_with_receipts = yield self.get_users_with_read_receipts_in_room( room_id, on_invalidate=cache_context.invalidate diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py index 96e54d145e..0d932a0672 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py @@ -58,7 +58,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cachedInlineCallbacks() def get_users_with_read_receipts_in_room(self, room_id): receipts = yield self.get_receipts_for_room(room_id, "m.read") - return set(r["user_id"] for r in receipts) + return {r["user_id"] for r in receipts} @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): @@ -283,7 +283,7 @@ class ReceiptsWorkerStore(SQLBaseStore): args.append(limit) txn.execute(sql, args) - return list(r[0:5] + (json.loads(r[5]),) for r in txn) + return [r[0:5] + (json.loads(r[5]),) for r in txn] return self.db.runInteraction( "get_all_updated_receipts", get_all_updated_receipts_txn diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index d5ced05701..d5bd0cb5cf 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -465,7 +465,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql % (clause,), args) - return set(row[0] for row in txn) + return {row[0] for row in txn} return await self.db.runInteraction( "get_users_server_still_shares_room_with", @@ -826,7 +826,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): GROUP BY room_id, user_id; """ txn.execute(sql, (user_id,)) - return set(row[0] for row in txn if row[1] == 0) + return {row[0] for row in txn if row[1] == 0} return self.db.runInteraction( "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 3d34103e67..3a3b9a8e72 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -321,7 +321,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): desc="get_referenced_state_groups", ) - return set(row["state_group"] for row in rows) + return {row["state_group"] for row in rows} class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): @@ -367,7 +367,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): """ txn.execute(sql, (last_room_id, batch_size)) - room_ids = list(row[0] for row in txn) + room_ids = [row[0] for row in txn] if not room_ids: return True, set() @@ -384,7 +384,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name)) - joined_room_ids = set(row[0] for row in txn) + joined_room_ids = {row[0] for row in txn} left_rooms = set(room_ids) - joined_room_ids @@ -404,7 +404,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): retcols=("state_key",), ) - potentially_left_users = set(row["state_key"] for row in rows) + potentially_left_users = {row["state_key"] for row in rows} # Now lets actually delete the rooms from the DB. self.db.simple_delete_many_txn( diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 056b25b13a..ada5cce6c2 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -346,11 +346,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): from_key (str): The room_key portion of a StreamToken """ from_key = RoomStreamToken.parse_stream_token(from_key).stream - return set( + return { room_id for room_id in room_ids if self._events_stream_cache.has_entity_changed(room_id, from_key) - ) + } @defer.inlineCallbacks def get_room_events_stream_for_room( @@ -679,11 +679,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) events_before = yield self.get_events_as_list( - [e for e in results["before"]["event_ids"]], get_prev_content=True + list(results["before"]["event_ids"]), get_prev_content=True ) events_after = yield self.get_events_as_list( - [e for e in results["after"]["event_ids"]], get_prev_content=True + list(results["after"]["event_ids"]), get_prev_content=True ) return { diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py index af8025bc17..ec6b8a4ffd 100644 --- a/synapse/storage/data_stores/main/user_erasure_store.py +++ b/synapse/storage/data_stores/main/user_erasure_store.py @@ -63,9 +63,9 @@ class UserErasureWorkerStore(SQLBaseStore): retcols=("user_id",), desc="are_users_erased", ) - erased_users = set(row["user_id"] for row in rows) + erased_users = {row["user_id"] for row in rows} - res = dict((u, u in erased_users) for u in user_ids) + res = {u: u in erased_users for u in user_ids} return res diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py index c4ee9b7ccb..57a5267663 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py @@ -520,11 +520,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): retcols=("state_group",), ) - remaining_state_groups = set( + remaining_state_groups = { row["state_group"] for row in rows if row["state_group"] not in state_groups_to_delete - ) + } logger.info( "[purge] de-delta-ing %i remaining state groups", diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 6dcb5c04da..1953614401 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -554,8 +554,8 @@ class Database(object): Returns: A list of dicts where the key is the column header. """ - col_headers = list(intern(str(column[0])) for column in cursor.description) - results = list(dict(zip(col_headers, row)) for row in cursor) + col_headers = [intern(str(column[0])) for column in cursor.description] + results = [dict(zip(col_headers, row)) for row in cursor] return results def execute(self, desc, decoder, query, *args): diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index b950550f23..0f9ac1cf09 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -602,14 +602,14 @@ class EventsPersistenceStorage(object): event_id_to_state_group.update(event_to_groups) # State groups of old_latest_event_ids - old_state_groups = set( + old_state_groups = { event_id_to_state_group[evid] for evid in old_latest_event_ids - ) + } # State groups of new_latest_event_ids - new_state_groups = set( + new_state_groups = { event_id_to_state_group[evid] for evid in new_latest_event_ids - ) + } # If they old and new groups are the same then we don't need to do # anything. diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index c285ef52a0..fc69c32a0a 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -345,9 +345,9 @@ def _upgrade_existing_database( "Could not open delta dir for version %d: %s" % (v, directory) ) - duplicates = set( + duplicates = { file_name for file_name, count in file_name_counter.items() if count > 1 - ) + } if duplicates: # We don't support using the same file name in the same delta version. raise PrepareDatabaseException( @@ -454,7 +454,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) ), (modname,), ) - applied_deltas = set(d for d, in cur) + applied_deltas = {d for d, in cur} for (name, stream) in names_and_streams: if name in applied_deltas: continue diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 635b897d6c..f2ccd5e7c6 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -30,7 +30,7 @@ def freeze(o): return o try: - return tuple([freeze(i) for i in o]) + return tuple(freeze(i) for i in o) except TypeError: pass diff --git a/synapse/visibility.py b/synapse/visibility.py index d0abd8f04f..e60d9756b7 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -75,7 +75,7 @@ def filter_events_for_client( """ # Filter out events that have been soft failed so that we don't relay them # to clients. - events = list(e for e in events if not e.internal_metadata.is_soft_failed()) + events = [e for e in events if not e.internal_metadata.is_soft_failed()] types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) event_id_to_state = yield storage.state.get_state_for_events( @@ -97,7 +97,7 @@ def filter_events_for_client( erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) if apply_retention_policies: - room_ids = set(e.room_id for e in events) + room_ids = {e.room_id for e in events} retention_policies = {} for room_id in room_ids: diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py index 2684e662de..463855ecc8 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py @@ -48,7 +48,7 @@ class ConfigGenerationTestCase(unittest.TestCase): ) self.assertSetEqual( - set(["homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"]), + {"homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"}, set(os.listdir(self.dir)), ) diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index e7d8699040..296dc887be 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -83,7 +83,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): ) ) - self.assertEqual(members, set(["@user:other.example.com", u1])) + self.assertEqual(members, {"@user:other.example.com", u1}) self.assertEqual(len(channel.json_body["pdus"]), 6) def test_needs_to_be_in_room(self): diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index c171038df8..64915bafcd 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -338,7 +338,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) new_state = handle_timeout( - state, is_mine=True, syncing_user_ids=set([user_id]), now=now + state, is_mine=True, syncing_user_ids={user_id}, now=now ) self.assertIsNotNone(new_state) @@ -579,7 +579,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): ) self.assertEqual(expected_state.state, PresenceState.ONLINE) self.federation_sender.send_presence_to_destinations.assert_called_once_with( - destinations=set(("server2", "server3")), states=[expected_state] + destinations={"server2", "server3"}, states=[expected_state] ) def _add_new_user(self, room_id, user_id): diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 140cc0a3c2..07b204666e 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -129,12 +129,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): hs.get_auth().check_user_in_room = check_user_in_room def get_joined_hosts_for_room(room_id): - return set(member.domain for member in self.room_members) + return {member.domain for member in self.room_members} self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room def get_current_users_in_room(room_id): - return set(str(u) for u in self.room_members) + return {str(u) for u in self.room_members} hs.get_state_handler().get_current_users_in_room = get_current_users_in_room @@ -257,7 +257,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): member = RoomMember(ROOM_ID, U_APPLE.to_string()) self.handler._member_typing_until[member] = 1002000 - self.handler._room_typing[ROOM_ID] = set([U_APPLE.to_string()]) + self.handler._room_typing[ROOM_ID] = {U_APPLE.to_string()} self.assertEquals(self.event_source.get_current_key(), 0) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 0a4765fff4..7b92bdbc47 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -114,7 +114,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): public_users = self.get_users_in_public_rooms() self.assertEqual( - self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)]) + self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} ) self.assertEqual(public_users, []) @@ -169,7 +169,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): public_users = self.get_users_in_public_rooms() self.assertEqual( - self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)]) + self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} ) self.assertEqual(public_users, []) @@ -226,7 +226,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): public_users = self.get_users_in_public_rooms() self.assertEqual( - self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)]) + self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} ) self.assertEqual(public_users, []) @@ -358,12 +358,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): public_users = self.get_users_in_public_rooms() # User 1 and User 2 are in the same public room - self.assertEqual(set(public_users), set([(u1, room), (u2, room)])) + self.assertEqual(set(public_users), {(u1, room), (u2, room)}) # User 1 and User 3 share private rooms self.assertEqual( self._compress_shared(shares_private), - set([(u1, u3, private_room), (u3, u1, private_room)]), + {(u1, u3, private_room), (u3, u1, private_room)}, ) def test_initial_share_all_users(self): @@ -398,7 +398,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # No users share rooms self.assertEqual(public_users, []) - self.assertEqual(self._compress_shared(shares_private), set([])) + self.assertEqual(self._compress_shared(shares_private), set()) # Despite not sharing a room, search_all_users means we get a search # result. diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 80187406bc..83032cc9ea 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -163,7 +163,7 @@ class EmailPusherTests(HomeserverTestCase): # Get the stream ordering before it gets sent pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -174,7 +174,7 @@ class EmailPusherTests(HomeserverTestCase): # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -192,7 +192,7 @@ class EmailPusherTests(HomeserverTestCase): # The stream ordering has increased pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index fe3441f081..baf9c785f4 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -102,7 +102,7 @@ class HTTPPusherTests(HomeserverTestCase): # Get the stream ordering before it gets sent pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -113,7 +113,7 @@ class HTTPPusherTests(HomeserverTestCase): # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -132,7 +132,7 @@ class HTTPPusherTests(HomeserverTestCase): # The stream ordering has increased pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -152,7 +152,7 @@ class HTTPPusherTests(HomeserverTestCase): # The stream ordering has increased, again pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 9c13a13786..fa3a3ec1bd 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -40,16 +40,14 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertTrue( - set( - [ - "next_batch", - "rooms", - "presence", - "account_data", - "to_device", - "device_lists", - ] - ).issubset(set(channel.json_body.keys())) + { + "next_batch", + "rooms", + "presence", + "account_data", + "to_device", + "device_lists", + }.issubset(set(channel.json_body.keys())) ) def test_sync_presence_disabled(self): @@ -63,9 +61,13 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertTrue( - set( - ["next_batch", "rooms", "account_data", "to_device", "device_lists"] - ).issubset(set(channel.json_body.keys())) + { + "next_batch", + "rooms", + "account_data", + "to_device", + "device_lists", + }.issubset(set(channel.json_body.keys())) ) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index d491ea2924..e37260a820 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -373,7 +373,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) self.assertEqual( set(self._dump_to_tuple(res)), - set([(1, "user1", "hello"), (2, "user2", "there")]), + {(1, "user1", "hello"), (2, "user2", "there")}, ) # Update only user2 @@ -400,5 +400,5 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) self.assertEqual( set(self._dump_to_tuple(res)), - set([(1, "user1", "hello"), (2, "user2", "bleb")]), + {(1, "user1", "hello"), (2, "user2", "bleb")}, ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index fd52512696..31710949a8 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -69,14 +69,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): pass def _add_appservice(self, as_token, id, url, hs_token, sender): - as_yaml = dict( - url=url, - as_token=as_token, - hs_token=hs_token, - id=id, - sender_localpart=sender, - namespaces={}, - ) + as_yaml = { + "url": url, + "as_token": as_token, + "hs_token": hs_token, + "id": id, + "sender_localpart": sender, + "namespaces": {}, + } # use the token as the filename with open(as_token, "w") as outfile: outfile.write(yaml.dump(as_yaml)) @@ -135,14 +135,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): ) def _add_service(self, url, as_token, id): - as_yaml = dict( - url=url, - as_token=as_token, - hs_token="something", - id=id, - sender_localpart="a_sender", - namespaces={}, - ) + as_yaml = { + "url": url, + "as_token": as_token, + "hs_token": "something", + "id": id, + "sender_localpart": "a_sender", + "namespaces": {}, + } # use the token as the filename with open(as_token, "w") as outfile: outfile.write(yaml.dump(as_yaml)) @@ -384,8 +384,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): ) self.assertEquals(2, len(services)) self.assertEquals( - set([self.as_list[2]["id"], self.as_list[0]["id"]]), - set([services[0].id, services[1].id]), + {self.as_list[2]["id"], self.as_list[0]["id"]}, + {services[0].id, services[1].id}, ) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 029ac26454..0e04b2cf92 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -134,7 +134,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b))) + self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) # Run the background update and check it did the right thing self.run_background_update() @@ -172,7 +172,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b))) + self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) # Run the background update and check it did the right thing self.run_background_update() @@ -227,9 +227,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual( - set(latest_event_ids), set((event_id_a, event_id_b, event_id_c)) - ) + self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b, event_id_c}) # Run the background update and check it did the right thing self.run_background_update() @@ -237,7 +235,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), set([event_id_b, event_id_c])) + self.assertEqual(set(latest_event_ids), {event_id_b, event_id_c}) class CleanupExtremDummyEventsTestCase(HomeserverTestCase): diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index f26ff57a18..a7b7fd36d3 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -59,24 +59,22 @@ class ExtremStatisticsTestCase(HomeserverTestCase): ) ) - expected = set( - [ - b'synapse_forward_extremities_bucket{le="1.0"} 0.0', - b'synapse_forward_extremities_bucket{le="2.0"} 2.0', - b'synapse_forward_extremities_bucket{le="3.0"} 2.0', - b'synapse_forward_extremities_bucket{le="5.0"} 2.0', - b'synapse_forward_extremities_bucket{le="7.0"} 3.0', - b'synapse_forward_extremities_bucket{le="10.0"} 3.0', - b'synapse_forward_extremities_bucket{le="15.0"} 3.0', - b'synapse_forward_extremities_bucket{le="20.0"} 3.0', - b'synapse_forward_extremities_bucket{le="50.0"} 3.0', - b'synapse_forward_extremities_bucket{le="100.0"} 3.0', - b'synapse_forward_extremities_bucket{le="200.0"} 3.0', - b'synapse_forward_extremities_bucket{le="500.0"} 3.0', - b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', - b"synapse_forward_extremities_count 3.0", - b"synapse_forward_extremities_sum 10.0", - ] - ) + expected = { + b'synapse_forward_extremities_bucket{le="1.0"} 0.0', + b'synapse_forward_extremities_bucket{le="2.0"} 2.0', + b'synapse_forward_extremities_bucket{le="3.0"} 2.0', + b'synapse_forward_extremities_bucket{le="5.0"} 2.0', + b'synapse_forward_extremities_bucket{le="7.0"} 3.0', + b'synapse_forward_extremities_bucket{le="10.0"} 3.0', + b'synapse_forward_extremities_bucket{le="15.0"} 3.0', + b'synapse_forward_extremities_bucket{le="20.0"} 3.0', + b'synapse_forward_extremities_bucket{le="50.0"} 3.0', + b'synapse_forward_extremities_bucket{le="100.0"} 3.0', + b'synapse_forward_extremities_bucket{le="200.0"} 3.0', + b'synapse_forward_extremities_bucket{le="500.0"} 3.0', + b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', + b"synapse_forward_extremities_count 3.0", + b"synapse_forward_extremities_sum 10.0", + } self.assertEqual(items, expected) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 04d58fbf24..0b88308ff4 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -394,7 +394,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) = self.state_datastore._state_group_cache.get(group) self.assertEqual(is_all, False) - self.assertEqual(known_absent, set([(e1.type, e1.state_key)])) + self.assertEqual(known_absent, {(e1.type, e1.state_key)}) self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id}) ############################################ diff --git a/tests/test_state.py b/tests/test_state.py index d1578fe581..66f22f6813 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -254,9 +254,7 @@ class StateTestCase(unittest.TestCase): ctx_d = context_store["D"] prev_state_ids = yield ctx_d.get_prev_state_ids() - self.assertSetEqual( - {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()} - ) + self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values())) self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group) @@ -313,9 +311,7 @@ class StateTestCase(unittest.TestCase): ctx_e = context_store["E"] prev_state_ids = yield ctx_e.get_prev_state_ids() - self.assertSetEqual( - {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()} - ) + self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values())) self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event) self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group) @@ -388,9 +384,7 @@ class StateTestCase(unittest.TestCase): ctx_d = context_store["D"] prev_state_ids = yield ctx_d.get_prev_state_ids() - self.assertSetEqual( - {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()} - ) + self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values())) self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event) self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group) @@ -482,7 +476,7 @@ class StateTestCase(unittest.TestCase): current_state_ids = yield context.get_current_state_ids() self.assertEqual( - set([e.event_id for e in old_state]), set(current_state_ids.values()) + {e.event_id for e in old_state}, set(current_state_ids.values()) ) self.assertEqual(group_name, context.state_group) @@ -513,9 +507,7 @@ class StateTestCase(unittest.TestCase): prev_state_ids = yield context.get_prev_state_ids() - self.assertEqual( - set([e.event_id for e in old_state]), set(prev_state_ids.values()) - ) + self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values())) self.assertIsNotNone(context.state_group) diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index f2be63706b..72a9de5370 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -67,7 +67,7 @@ class StreamChangeCacheTests(unittest.TestCase): # If we update an existing entity, it keeps the two existing entities cache.entity_has_changed("bar@baz.net", 5) self.assertEqual( - set(["bar@baz.net", "user@elsewhere.org"]), set(cache._entity_to_key) + {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key) ) def test_get_all_entities_changed(self): @@ -137,7 +137,7 @@ class StreamChangeCacheTests(unittest.TestCase): cache.get_entities_changed( ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2 ), - set(["bar@baz.net", "user@elsewhere.org"]), + {"bar@baz.net", "user@elsewhere.org"}, ) # Query all the entries mid-way through the stream, but include one @@ -153,7 +153,7 @@ class StreamChangeCacheTests(unittest.TestCase): ], stream_pos=2, ), - set(["bar@baz.net", "user@elsewhere.org"]), + {"bar@baz.net", "user@elsewhere.org"}, ) # Query all the entries, but before the first known point. We will get @@ -168,21 +168,13 @@ class StreamChangeCacheTests(unittest.TestCase): ], stream_pos=0, ), - set( - [ - "user@foo.com", - "bar@baz.net", - "user@elsewhere.org", - "not@here.website", - ] - ), + {"user@foo.com", "bar@baz.net", "user@elsewhere.org", "not@here.website"}, ) # Query a subset of the entries mid-way through the stream. We should # only get back the subset. self.assertEqual( - cache.get_entities_changed(["bar@baz.net"], stream_pos=2), - set(["bar@baz.net"]), + cache.get_entities_changed(["bar@baz.net"], stream_pos=2), {"bar@baz.net"}, ) def test_max_pos(self): diff --git a/tox.ini b/tox.ini index b9132a3177..b715ea0bff 100644 --- a/tox.ini +++ b/tox.ini @@ -123,6 +123,7 @@ skip_install = True basepython = python3.6 deps = flake8 + flake8-comprehensions black==19.10b0 # We pin so that our tests don't start failing on new releases of black. commands = python -m black --check --diff . -- cgit 1.5.1 From 7cb8b4bc67042a39bd1b0e05df46089a2fce1955 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 12 May 2020 03:45:23 +1000 Subject: Allow configuration of Synapse's cache without using synctl or environment variables (#6391) --- changelog.d/6391.feature | 1 + docs/sample_config.yaml | 43 +++++- synapse/api/auth.py | 4 +- synapse/app/homeserver.py | 5 +- synapse/config/cache.py | 164 ++++++++++++++++++++++ synapse/config/database.py | 6 - synapse/config/homeserver.py | 2 + synapse/http/client.py | 6 +- synapse/metrics/_exposition.py | 12 +- synapse/push/bulk_push_rule_evaluator.py | 4 +- synapse/push/push_rule_evaluator.py | 4 +- synapse/replication/slave/storage/client_ips.py | 3 +- synapse/state/__init__.py | 4 +- synapse/storage/data_stores/main/client_ips.py | 3 +- synapse/storage/data_stores/main/events_worker.py | 5 +- synapse/storage/data_stores/state/store.py | 6 +- synapse/util/caches/__init__.py | 144 ++++++++++--------- synapse/util/caches/descriptors.py | 36 ++++- synapse/util/caches/expiringcache.py | 29 +++- synapse/util/caches/lrucache.py | 52 +++++-- synapse/util/caches/response_cache.py | 2 +- synapse/util/caches/stream_change_cache.py | 33 ++++- synapse/util/caches/ttlcache.py | 2 +- tests/config/test_cache.py | 127 +++++++++++++++++ tests/storage/test__base.py | 8 +- tests/storage/test_appservice.py | 10 +- tests/storage/test_base.py | 3 +- tests/test_metrics.py | 34 +++++ tests/util/test_expiring_cache.py | 2 +- tests/util/test_lrucache.py | 6 +- tests/util/test_stream_change_cache.py | 5 +- tests/utils.py | 1 + 32 files changed, 620 insertions(+), 146 deletions(-) create mode 100644 changelog.d/6391.feature create mode 100644 synapse/config/cache.py create mode 100644 tests/config/test_cache.py (limited to 'tests/storage/test_appservice.py') diff --git a/changelog.d/6391.feature b/changelog.d/6391.feature new file mode 100644 index 0000000000..f123426e23 --- /dev/null +++ b/changelog.d/6391.feature @@ -0,0 +1 @@ +Synapse's cache factor can now be configured in `homeserver.yaml` by the `caches.global_factor` setting. Additionally, `caches.per_cache_factors` controls the cache factors for individual caches. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 5abeaf519b..8a8415b9a2 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -603,6 +603,45 @@ acme: +## Caching ## + +# Caching can be configured through the following options. +# +# A cache 'factor' is a multiplier that can be applied to each of +# Synapse's caches in order to increase or decrease the maximum +# number of entries that can be stored. + +# The number of events to cache in memory. Not affected by +# caches.global_factor. +# +#event_cache_size: 10K + +caches: + # Controls the global cache factor, which is the default cache factor + # for all caches if a specific factor for that cache is not otherwise + # set. + # + # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment + # variable. Setting by environment variable takes priority over + # setting through the config file. + # + # Defaults to 0.5, which will half the size of all caches. + # + #global_factor: 1.0 + + # A dictionary of cache name to cache factor for that individual + # cache. Overrides the global cache factor for a given cache. + # + # These can also be set through environment variables comprised + # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital + # letters and underscores. Setting by environment variable + # takes priority over setting through the config file. + # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 + # + per_cache_factors: + #get_users_who_share_room_with_user: 2.0 + + ## Database ## # The 'database' setting defines the database that synapse uses to store all of @@ -646,10 +685,6 @@ database: args: database: DATADIR/homeserver.db -# Number of events to cache in memory. -# -#event_cache_size: 10K - ## Logging ## diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 1ad5ff9410..e009b1a760 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -37,7 +37,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.types import StateMap, UserID -from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache +from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache from synapse.util.metrics import Measure @@ -73,7 +73,7 @@ class Auth(object): self.store = hs.get_datastore() self.state = hs.get_state_handler() - self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000) + self.token_cache = LruCache(10000) register_cache("cache", "token_cache", self.token_cache) self._auth_blocking = AuthBlocking(self.hs) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index bc8695d8dd..d7f337e586 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -69,7 +69,6 @@ from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.prepare_database import UpgradeDatabaseException -from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.module_loader import load_module @@ -516,8 +515,8 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() stats["daily_sent_messages"] = daily_sent_messages - stats["cache_factor"] = CACHE_SIZE_FACTOR - stats["event_cache_size"] = hs.config.event_cache_size + stats["cache_factor"] = hs.config.caches.global_factor + stats["event_cache_size"] = hs.config.caches.event_cache_size # # Performance statistics diff --git a/synapse/config/cache.py b/synapse/config/cache.py new file mode 100644 index 0000000000..91036a012e --- /dev/null +++ b/synapse/config/cache.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Callable, Dict + +from ._base import Config, ConfigError + +# The prefix for all cache factor-related environment variables +_CACHES = {} +_CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR" +_DEFAULT_FACTOR_SIZE = 0.5 +_DEFAULT_EVENT_CACHE_SIZE = "10K" + + +class CacheProperties(object): + def __init__(self): + # The default factor size for all caches + self.default_factor_size = float( + os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) + ) + self.resize_all_caches_func = None + + +properties = CacheProperties() + + +def add_resizable_cache(cache_name: str, cache_resize_callback: Callable): + """Register a cache that's size can dynamically change + + Args: + cache_name: A reference to the cache + cache_resize_callback: A callback function that will be ran whenever + the cache needs to be resized + """ + _CACHES[cache_name.lower()] = cache_resize_callback + + # Ensure all loaded caches are sized appropriately + # + # This method should only run once the config has been read, + # as it uses values read from it + if properties.resize_all_caches_func: + properties.resize_all_caches_func() + + +class CacheConfig(Config): + section = "caches" + _environ = os.environ + + @staticmethod + def reset(): + """Resets the caches to their defaults. Used for tests.""" + properties.default_factor_size = float( + os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) + ) + properties.resize_all_caches_func = None + _CACHES.clear() + + def generate_config_section(self, **kwargs): + return """\ + ## Caching ## + + # Caching can be configured through the following options. + # + # A cache 'factor' is a multiplier that can be applied to each of + # Synapse's caches in order to increase or decrease the maximum + # number of entries that can be stored. + + # The number of events to cache in memory. Not affected by + # caches.global_factor. + # + #event_cache_size: 10K + + caches: + # Controls the global cache factor, which is the default cache factor + # for all caches if a specific factor for that cache is not otherwise + # set. + # + # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment + # variable. Setting by environment variable takes priority over + # setting through the config file. + # + # Defaults to 0.5, which will half the size of all caches. + # + #global_factor: 1.0 + + # A dictionary of cache name to cache factor for that individual + # cache. Overrides the global cache factor for a given cache. + # + # These can also be set through environment variables comprised + # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital + # letters and underscores. Setting by environment variable + # takes priority over setting through the config file. + # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 + # + per_cache_factors: + #get_users_who_share_room_with_user: 2.0 + """ + + def read_config(self, config, **kwargs): + self.event_cache_size = self.parse_size( + config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE) + ) + self.cache_factors = {} # type: Dict[str, float] + + cache_config = config.get("caches") or {} + self.global_factor = cache_config.get( + "global_factor", properties.default_factor_size + ) + if not isinstance(self.global_factor, (int, float)): + raise ConfigError("caches.global_factor must be a number.") + + # Set the global one so that it's reflected in new caches + properties.default_factor_size = self.global_factor + + # Load cache factors from the config + individual_factors = cache_config.get("per_cache_factors") or {} + if not isinstance(individual_factors, dict): + raise ConfigError("caches.per_cache_factors must be a dictionary") + + # Override factors from environment if necessary + individual_factors.update( + { + key[len(_CACHE_PREFIX) + 1 :].lower(): float(val) + for key, val in self._environ.items() + if key.startswith(_CACHE_PREFIX + "_") + } + ) + + for cache, factor in individual_factors.items(): + if not isinstance(factor, (int, float)): + raise ConfigError( + "caches.per_cache_factors.%s must be a number" % (cache.lower(),) + ) + self.cache_factors[cache.lower()] = factor + + # Resize all caches (if necessary) with the new factors we've loaded + self.resize_all_caches() + + # Store this function so that it can be called from other classes without + # needing an instance of Config + properties.resize_all_caches_func = self.resize_all_caches + + def resize_all_caches(self): + """Ensure all cache sizes are up to date + + For each cache, run the mapped callback function with either + a specific cache factor or the default, global one. + """ + for cache_name, callback in _CACHES.items(): + new_factor = self.cache_factors.get(cache_name, self.global_factor) + callback(new_factor) diff --git a/synapse/config/database.py b/synapse/config/database.py index 5b662d1b01..1064c2697b 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -68,10 +68,6 @@ database: name: sqlite3 args: database: %(database_path)s - -# Number of events to cache in memory. -# -#event_cache_size: 10K """ @@ -116,8 +112,6 @@ class DatabaseConfig(Config): self.databases = [] def read_config(self, config, **kwargs): - self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) - # We *experimentally* support specifying multiple databases via the # `databases` key. This is a map from a label to database config in the # same format as the `database` config option, plus an extra diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 996d3e6bf7..2c7b3a699f 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -17,6 +17,7 @@ from ._base import RootConfig from .api import ApiConfig from .appservice import AppServiceConfig +from .cache import CacheConfig from .captcha import CaptchaConfig from .cas import CasConfig from .consent_config import ConsentConfig @@ -55,6 +56,7 @@ class HomeServerConfig(RootConfig): config_classes = [ ServerConfig, TlsConfig, + CacheConfig, DatabaseConfig, LoggingConfig, RatelimitConfig, diff --git a/synapse/http/client.py b/synapse/http/client.py index 58eb47c69c..3cef747a4d 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -49,7 +49,6 @@ from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.util.async_helpers import timeout_deferred -from synapse.util.caches import CACHE_SIZE_FACTOR logger = logging.getLogger(__name__) @@ -241,7 +240,10 @@ class SimpleHttpClient(object): # tends to do so in batches, so we need to allow the pool to keep # lots of idle connections around. pool = HTTPConnectionPool(self.reactor) - pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5)) + # XXX: The justification for using the cache factor here is that larger instances + # will need both more cache and more connections. + # Still, this should probably be a separate dial + pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5)) pool.cachedConnectionTimeout = 2 * 60 self.agent = ProxyAgent( diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py index a248103191..ab7f948ed4 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py @@ -33,6 +33,8 @@ from prometheus_client import REGISTRY from twisted.web.resource import Resource +from synapse.util import caches + try: from prometheus_client.samples import Sample except ImportError: @@ -103,13 +105,15 @@ def nameify_sample(sample): def generate_latest(registry, emit_help=False): - output = [] - for metric in registry.collect(): + # Trigger the cache metrics to be rescraped, which updates the common + # metrics but do not produce metrics themselves + for collector in caches.collectors_by_name.values(): + collector.collect() - if metric.name.startswith("__unused"): - continue + output = [] + for metric in registry.collect(): if not metric.samples: # No samples, don't bother. continue diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 433ca2f416..e75d964ac8 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -51,6 +51,7 @@ push_rules_delta_state_cache_metric = register_cache( "cache", "push_rules_delta_state_cache_metric", cache=[], # Meaningless size, as this isn't a cache that stores values + resizable=False, ) @@ -67,7 +68,8 @@ class BulkPushRuleEvaluator(object): self.room_push_rule_cache_metrics = register_cache( "cache", "room_push_rule_cache", - cache=[], # Meaningless size, as this isn't a cache that stores values + cache=[], # Meaningless size, as this isn't a cache that stores values, + resizable=False, ) @defer.inlineCallbacks diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 4cd702b5fa..11032491af 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -22,7 +22,7 @@ from six import string_types from synapse.events import EventBase from synapse.types import UserID -from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache +from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -165,7 +165,7 @@ class PushRuleEvaluatorForEvent(object): # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches -regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR) +regex_cache = LruCache(50000) register_cache("cache", "regex_push_cache", regex_cache) diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index fbf996e33a..1a38f53dfb 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -15,7 +15,6 @@ from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY from synapse.storage.database import Database -from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache from ._base import BaseSlavedStore @@ -26,7 +25,7 @@ class SlavedClientIpStore(BaseSlavedStore): super(SlavedClientIpStore, self).__init__(database, db_conn, hs) self.client_ip_last_seen = Cache( - name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR + name="client_ip_last_seen", keylen=4, max_entries=50000 ) def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 4afefc6b1d..2fa529fcd0 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -35,7 +35,6 @@ from synapse.state import v1, v2 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.types import StateMap from synapse.util.async_helpers import Linearizer -from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure, measure_func @@ -53,7 +52,6 @@ state_groups_histogram = Histogram( KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) -SIZE_OF_CACHE = 100000 * get_cache_factor_for("state_cache") EVICTION_TIMEOUT_SECONDS = 60 * 60 @@ -447,7 +445,7 @@ class StateResolutionHandler(object): self._state_cache = ExpiringCache( cache_name="state_cache", clock=self.clock, - max_len=SIZE_OF_CACHE, + max_len=100000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, iterable=True, reset_expiry_on_get=True, diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 92bc06919b..71f8d43a76 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -22,7 +22,6 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import Database, make_tuple_comparison_clause -from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache logger = logging.getLogger(__name__) @@ -361,7 +360,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): self.client_ip_last_seen = Cache( - name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR + name="client_ip_last_seen", keylen=4, max_entries=50000 ) super(ClientIpStore, self).__init__(database, db_conn, hs) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 73df6b33ba..b8c1bbdf99 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -75,7 +75,10 @@ class EventsWorkerStore(SQLBaseStore): super(EventsWorkerStore, self).__init__(database, db_conn, hs) self._get_event_cache = Cache( - "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size + "*getEvent*", + keylen=3, + max_entries=hs.config.caches.event_cache_size, + apply_cache_factor_from_config=False, ) self._event_fetch_lock = threading.Condition() diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py index 57a5267663..f3ad1e4369 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py @@ -28,7 +28,6 @@ from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateSt from synapse.storage.database import Database from synapse.storage.state import StateFilter from synapse.types import StateMap -from synapse.util.caches import get_cache_factor_for from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache @@ -90,11 +89,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): self._state_group_cache = DictionaryCache( "*stateGroupCache*", # TODO: this hasn't been tuned yet - 50000 * get_cache_factor_for("stateGroupCache"), + 50000, ) self._state_group_members_cache = DictionaryCache( - "*stateGroupMembersCache*", - 500000 * get_cache_factor_for("stateGroupMembersCache"), + "*stateGroupMembersCache*", 500000, ) @cached(max_entries=10000, iterable=True) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index da5077b471..4b8a0c7a8f 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019, 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,27 +15,17 @@ # limitations under the License. import logging -import os -from typing import Dict +from typing import Callable, Dict, Optional import six from six.moves import intern -from prometheus_client.core import REGISTRY, Gauge, GaugeMetricFamily - -logger = logging.getLogger(__name__) - -CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5)) +import attr +from prometheus_client.core import Gauge +from synapse.config.cache import add_resizable_cache -def get_cache_factor_for(cache_name): - env_var = "SYNAPSE_CACHE_FACTOR_" + cache_name.upper() - factor = os.environ.get(env_var) - if factor: - return float(factor) - - return CACHE_SIZE_FACTOR - +logger = logging.getLogger(__name__) caches_by_name = {} collectors_by_name = {} # type: Dict @@ -44,6 +34,7 @@ cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"]) cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name"]) cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"]) +cache_max_size = Gauge("synapse_util_caches_cache_max_size", "", ["name"]) response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"]) response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"]) @@ -53,67 +44,82 @@ response_cache_evicted = Gauge( response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"]) -def register_cache(cache_type, cache_name, cache, collect_callback=None): - """Register a cache object for metric collection. +@attr.s +class CacheMetric(object): + + _cache = attr.ib() + _cache_type = attr.ib(type=str) + _cache_name = attr.ib(type=str) + _collect_callback = attr.ib(type=Optional[Callable]) + + hits = attr.ib(default=0) + misses = attr.ib(default=0) + evicted_size = attr.ib(default=0) + + def inc_hits(self): + self.hits += 1 + + def inc_misses(self): + self.misses += 1 + + def inc_evictions(self, size=1): + self.evicted_size += size + + def describe(self): + return [] + + def collect(self): + try: + if self._cache_type == "response_cache": + response_cache_size.labels(self._cache_name).set(len(self._cache)) + response_cache_hits.labels(self._cache_name).set(self.hits) + response_cache_evicted.labels(self._cache_name).set(self.evicted_size) + response_cache_total.labels(self._cache_name).set( + self.hits + self.misses + ) + else: + cache_size.labels(self._cache_name).set(len(self._cache)) + cache_hits.labels(self._cache_name).set(self.hits) + cache_evicted.labels(self._cache_name).set(self.evicted_size) + cache_total.labels(self._cache_name).set(self.hits + self.misses) + if getattr(self._cache, "max_size", None): + cache_max_size.labels(self._cache_name).set(self._cache.max_size) + if self._collect_callback: + self._collect_callback() + except Exception as e: + logger.warning("Error calculating metrics for %s: %s", self._cache_name, e) + raise + + +def register_cache( + cache_type: str, + cache_name: str, + cache, + collect_callback: Optional[Callable] = None, + resizable: bool = True, + resize_callback: Optional[Callable] = None, +) -> CacheMetric: + """Register a cache object for metric collection and resizing. Args: - cache_type (str): - cache_name (str): name of the cache - cache (object): cache itself - collect_callback (callable|None): if not None, a function which is called during - metric collection to update additional metrics. + cache_type + cache_name: name of the cache + cache: cache itself + collect_callback: If given, a function which is called during metric + collection to update additional metrics. + resizable: Whether this cache supports being resized. + resize_callback: A function which can be called to resize the cache. Returns: CacheMetric: an object which provides inc_{hits,misses,evictions} methods """ + if resizable: + if not resize_callback: + resize_callback = getattr(cache, "set_cache_factor") + add_resizable_cache(cache_name, resize_callback) - # Check if the metric is already registered. Unregister it, if so. - # This usually happens during tests, as at runtime these caches are - # effectively singletons. + metric = CacheMetric(cache, cache_type, cache_name, collect_callback) metric_name = "cache_%s_%s" % (cache_type, cache_name) - if metric_name in collectors_by_name.keys(): - REGISTRY.unregister(collectors_by_name[metric_name]) - - class CacheMetric(object): - - hits = 0 - misses = 0 - evicted_size = 0 - - def inc_hits(self): - self.hits += 1 - - def inc_misses(self): - self.misses += 1 - - def inc_evictions(self, size=1): - self.evicted_size += size - - def describe(self): - return [] - - def collect(self): - try: - if cache_type == "response_cache": - response_cache_size.labels(cache_name).set(len(cache)) - response_cache_hits.labels(cache_name).set(self.hits) - response_cache_evicted.labels(cache_name).set(self.evicted_size) - response_cache_total.labels(cache_name).set(self.hits + self.misses) - else: - cache_size.labels(cache_name).set(len(cache)) - cache_hits.labels(cache_name).set(self.hits) - cache_evicted.labels(cache_name).set(self.evicted_size) - cache_total.labels(cache_name).set(self.hits + self.misses) - if collect_callback: - collect_callback() - except Exception as e: - logger.warning("Error calculating metrics for %s: %s", cache_name, e) - raise - - yield GaugeMetricFamily("__unused", "") - - metric = CacheMetric() - REGISTRY.register(metric) caches_by_name[cache_name] = cache collectors_by_name[metric_name] = metric return metric diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 2e8f6543e5..cd48262420 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import functools import inspect import logging @@ -30,7 +31,6 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches import get_cache_factor_for from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry @@ -81,7 +81,6 @@ class CacheEntry(object): class Cache(object): __slots__ = ( "cache", - "max_entries", "name", "keylen", "thread", @@ -89,7 +88,29 @@ class Cache(object): "_pending_deferred_cache", ) - def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False): + def __init__( + self, + name: str, + max_entries: int = 1000, + keylen: int = 1, + tree: bool = False, + iterable: bool = False, + apply_cache_factor_from_config: bool = True, + ): + """ + Args: + name: The name of the cache + max_entries: Maximum amount of entries that the cache will hold + keylen: The length of the tuple used as the cache key + tree: Use a TreeCache instead of a dict as the underlying cache type + iterable: If True, count each item in the cached object as an entry, + rather than each cached object + apply_cache_factor_from_config: Whether cache factors specified in the + config file affect `max_entries` + + Returns: + Cache + """ cache_type = TreeCache if tree else dict self._pending_deferred_cache = cache_type() @@ -99,6 +120,7 @@ class Cache(object): cache_type=cache_type, size_callback=(lambda d: len(d)) if iterable else None, evicted_callback=self._on_evicted, + apply_cache_factor_from_config=apply_cache_factor_from_config, ) self.name = name @@ -111,6 +133,10 @@ class Cache(object): collect_callback=self._metrics_collection_callback, ) + @property + def max_entries(self): + return self.cache.max_size + def _on_evicted(self, evicted_count): self.metrics.inc_evictions(evicted_count) @@ -370,13 +396,11 @@ class CacheDescriptor(_CacheDescriptorBase): cache_context=cache_context, ) - max_entries = int(max_entries * get_cache_factor_for(orig.__name__)) - self.max_entries = max_entries self.tree = tree self.iterable = iterable - def __get__(self, obj, objtype=None): + def __get__(self, obj, owner): cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index cddf1ed515..2726b67b6d 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -18,6 +18,7 @@ from collections import OrderedDict from six import iteritems, itervalues +from synapse.config import cache as cache_config from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches import register_cache @@ -51,15 +52,16 @@ class ExpiringCache(object): an item on access. Defaults to False. iterable (bool): If true, the size is calculated by summing the sizes of all entries, rather than the number of entries. - """ self._cache_name = cache_name + self._original_max_size = max_len + + self._max_size = int(max_len * cache_config.properties.default_factor_size) + self._clock = clock - self._max_len = max_len self._expiry_ms = expiry_ms - self._reset_expiry_on_get = reset_expiry_on_get self._cache = OrderedDict() @@ -82,9 +84,11 @@ class ExpiringCache(object): def __setitem__(self, key, value): now = self._clock.time_msec() self._cache[key] = _CacheEntry(now, value) + self.evict() + def evict(self): # Evict if there are now too many items - while self._max_len and len(self) > self._max_len: + while self._max_size and len(self) > self._max_size: _key, value = self._cache.popitem(last=False) if self.iterable: self.metrics.inc_evictions(len(value.value)) @@ -170,6 +174,23 @@ class ExpiringCache(object): else: return len(self._cache) + def set_cache_factor(self, factor: float) -> bool: + """ + Set the cache factor for this individual cache. + + This will trigger a resize if it changes, which may require evicting + items from the cache. + + Returns: + bool: Whether the cache changed size or not. + """ + new_size = int(self._original_max_size * factor) + if new_size != self._max_size: + self._max_size = new_size + self.evict() + return True + return False + class _CacheEntry(object): __slots__ = ["time", "value"] diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 1536cb64f3..29fabac3cd 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. - import threading from functools import wraps +from typing import Callable, Optional, Type, Union +from synapse.config import cache as cache_config from synapse.util.caches.treecache import TreeCache @@ -52,17 +53,18 @@ class LruCache(object): def __init__( self, - max_size, - keylen=1, - cache_type=dict, - size_callback=None, - evicted_callback=None, + max_size: int, + keylen: int = 1, + cache_type: Type[Union[dict, TreeCache]] = dict, + size_callback: Optional[Callable] = None, + evicted_callback: Optional[Callable] = None, + apply_cache_factor_from_config: bool = True, ): """ Args: - max_size (int): + max_size: The maximum amount of entries the cache can hold - keylen (int): + keylen: The length of the tuple used as the cache key cache_type (type): type of underlying cache to be used. Typically one of dict @@ -73,9 +75,23 @@ class LruCache(object): evicted_callback (func(int)|None): if not None, called on eviction with the size of the evicted entry + + apply_cache_factor_from_config (bool): If true, `max_size` will be + multiplied by a cache factor derived from the homeserver config """ cache = cache_type() self.cache = cache # Used for introspection. + + # Save the original max size, and apply the default size factor. + self._original_max_size = max_size + # We previously didn't apply the cache factor here, and as such some caches were + # not affected by the global cache factor. Add an option here to disable applying + # the cache factor when a cache is created + if apply_cache_factor_from_config: + self.max_size = int(max_size * cache_config.properties.default_factor_size) + else: + self.max_size = int(max_size) + list_root = _Node(None, None, None, None) list_root.next_node = list_root list_root.prev_node = list_root @@ -83,7 +99,7 @@ class LruCache(object): lock = threading.Lock() def evict(): - while cache_len() > max_size: + while cache_len() > self.max_size: todelete = list_root.prev_node evicted_len = delete_node(todelete) cache.pop(todelete.key, None) @@ -236,6 +252,7 @@ class LruCache(object): return key in cache self.sentinel = object() + self._on_resize = evict self.get = cache_get self.set = cache_set self.setdefault = cache_set_default @@ -266,3 +283,20 @@ class LruCache(object): def __contains__(self, key): return self.contains(key) + + def set_cache_factor(self, factor: float) -> bool: + """ + Set the cache factor for this individual cache. + + This will trigger a resize if it changes, which may require evicting + items from the cache. + + Returns: + bool: Whether the cache changed size or not. + """ + new_size = int(self._original_max_size * factor) + if new_size != self.max_size: + self.max_size = new_size + self._on_resize() + return True + return False diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index b68f9fe0d4..a6c60888e5 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -38,7 +38,7 @@ class ResponseCache(object): self.timeout_sec = timeout_ms / 1000.0 self._name = name - self._metrics = register_cache("response_cache", name, self) + self._metrics = register_cache("response_cache", name, self, resizable=False) def size(self): return len(self.pending_result_cache) diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index e54f80d76e..2a161bf244 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +import math from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union from six import integer_types @@ -46,7 +47,8 @@ class StreamChangeCache: max_size=10000, prefilled_cache: Optional[Mapping[EntityType, int]] = None, ): - self._max_size = int(max_size * caches.CACHE_SIZE_FACTOR) + self._original_max_size = max_size + self._max_size = math.floor(max_size) self._entity_to_key = {} # type: Dict[EntityType, int] # map from stream id to the a set of entities which changed at that stream id. @@ -58,12 +60,31 @@ class StreamChangeCache: # self._earliest_known_stream_pos = current_stream_pos self.name = name - self.metrics = caches.register_cache("cache", self.name, self._cache) + self.metrics = caches.register_cache( + "cache", self.name, self._cache, resize_callback=self.set_cache_factor + ) if prefilled_cache: for entity, stream_pos in prefilled_cache.items(): self.entity_has_changed(entity, stream_pos) + def set_cache_factor(self, factor: float) -> bool: + """ + Set the cache factor for this individual cache. + + This will trigger a resize if it changes, which may require evicting + items from the cache. + + Returns: + bool: Whether the cache changed size or not. + """ + new_size = math.floor(self._original_max_size * factor) + if new_size != self._max_size: + self.max_size = new_size + self._evict() + return True + return False + def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool: """Returns True if the entity may have been updated since stream_pos """ @@ -171,6 +192,7 @@ class StreamChangeCache: e1 = self._cache[stream_pos] = set() e1.add(entity) self._entity_to_key[entity] = stream_pos + self._evict() # if the cache is too big, remove entries while len(self._cache) > self._max_size: @@ -179,6 +201,13 @@ class StreamChangeCache: for entity in r: del self._entity_to_key[entity] + def _evict(self): + while len(self._cache) > self._max_size: + k, r = self._cache.popitem(0) + self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) + for entity in r: + self._entity_to_key.pop(entity, None) + def get_max_pos_of_last_change(self, entity: EntityType) -> int: """Returns an upper bound of the stream id of the last change to an diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index 99646c7cf0..6437aa907e 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -38,7 +38,7 @@ class TTLCache(object): self._timer = timer - self._metrics = register_cache("ttl", cache_name, self) + self._metrics = register_cache("ttl", cache_name, self, resizable=False) def set(self, key, value, ttl): """Add/update an entry in the cache diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py new file mode 100644 index 0000000000..2920279125 --- /dev/null +++ b/tests/config/test_cache.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config._base import Config, RootConfig +from synapse.config.cache import CacheConfig, add_resizable_cache +from synapse.util.caches.lrucache import LruCache + +from tests.unittest import TestCase + + +class FakeServer(Config): + section = "server" + + +class TestConfig(RootConfig): + config_classes = [FakeServer, CacheConfig] + + +class CacheConfigTests(TestCase): + def setUp(self): + # Reset caches before each test + TestConfig().caches.reset() + + def test_individual_caches_from_environ(self): + """ + Individual cache factors will be loaded from the environment. + """ + config = {} + t = TestConfig() + t.caches._environ = { + "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", + "SYNAPSE_NOT_CACHE": "BLAH", + } + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(dict(t.caches.cache_factors), {"something_or_other": 2.0}) + + def test_config_overrides_environ(self): + """ + Individual cache factors defined in the environment will take precedence + over those in the config. + """ + config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}} + t = TestConfig() + t.caches._environ = { + "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", + "SYNAPSE_CACHE_FACTOR_FOO": 1, + } + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual( + dict(t.caches.cache_factors), + {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0}, + ) + + def test_individual_instantiated_before_config_load(self): + """ + If a cache is instantiated before the config is read, it will be given + the default cache size in the interim, and then resized once the config + is loaded. + """ + cache = LruCache(100) + + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 50) + + config = {"caches": {"per_cache_factors": {"foo": 3}}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(cache.max_size, 300) + + def test_individual_instantiated_after_config_load(self): + """ + If a cache is instantiated after the config is read, it will be + immediately resized to the correct size given the per_cache_factor if + there is one. + """ + config = {"caches": {"per_cache_factors": {"foo": 2}}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + cache = LruCache(100) + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 200) + + def test_global_instantiated_before_config_load(self): + """ + If a cache is instantiated before the config is read, it will be given + the default cache size in the interim, and then resized to the new + default cache size once the config is loaded. + """ + cache = LruCache(100) + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 50) + + config = {"caches": {"global_factor": 4}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(cache.max_size, 400) + + def test_global_instantiated_after_config_load(self): + """ + If a cache is instantiated after the config is read, it will be + immediately resized to the correct size given the global factor if there + is no per-cache factor. + """ + config = {"caches": {"global_factor": 1.5}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + cache = LruCache(100) + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 150) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index e37260a820..5a50e4fdd4 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -25,8 +25,8 @@ from synapse.util.caches.descriptors import Cache, cached from tests import unittest -class CacheTestCase(unittest.TestCase): - def setUp(self): +class CacheTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): self.cache = Cache("test") def test_empty(self): @@ -96,7 +96,7 @@ class CacheTestCase(unittest.TestCase): cache.get(3) -class CacheDecoratorTestCase(unittest.TestCase): +class CacheDecoratorTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def test_passthrough(self): class A(object): @@ -239,7 +239,7 @@ class CacheDecoratorTestCase(unittest.TestCase): callcount2 = [0] class A(object): - @cached(max_entries=4) # HACK: This makes it 2 due to cache factor + @cached(max_entries=2) def func(self, key): callcount[0] += 1 return key diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 31710949a8..ef296e7dab 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -43,7 +43,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): ) hs.config.app_service_config_files = self.as_yaml_files - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] self.as_token = "token1" @@ -110,7 +110,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): ) hs.config.app_service_config_files = self.as_yaml_files - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] self.as_list = [ @@ -422,7 +422,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] database = hs.get_datastores().databases[0] @@ -440,7 +440,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: @@ -464,7 +464,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index cdee0a9e60..278961c331 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -51,7 +51,8 @@ class SQLBaseStoreTestCase(unittest.TestCase): config = Mock() config._disable_native_upserts = True - config.event_cache_size = 1 + config.caches = Mock() + config.caches.event_cache_size = 1 hs = TestHomeServer("test", config=config) sqlite_config = {"name": "sqlite3"} diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 270f853d60..f5f63d8ed6 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -15,6 +15,7 @@ # limitations under the License. from synapse.metrics import REGISTRY, InFlightGauge, generate_latest +from synapse.util.caches.descriptors import Cache from tests import unittest @@ -129,3 +130,36 @@ class BuildInfoTests(unittest.TestCase): self.assertTrue(b"osversion=" in items[0]) self.assertTrue(b"pythonversion=" in items[0]) self.assertTrue(b"version=" in items[0]) + + +class CacheMetricsTests(unittest.HomeserverTestCase): + def test_cache_metric(self): + """ + Caches produce metrics reflecting their state when scraped. + """ + CACHE_NAME = "cache_metrics_test_fgjkbdfg" + cache = Cache(CACHE_NAME, max_entries=777) + + items = { + x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") + for x in filter( + lambda x: b"cache_metrics_test_fgjkbdfg" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + } + + self.assertEqual(items["synapse_util_caches_cache_size"], "0.0") + self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") + + cache.prefill("1", "hi") + + items = { + x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") + for x in filter( + lambda x: b"cache_metrics_test_fgjkbdfg" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + } + + self.assertEqual(items["synapse_util_caches_cache_size"], "1.0") + self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index 50bc7702d2..49ffeebd0e 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -21,7 +21,7 @@ from tests.utils import MockClock from .. import unittest -class ExpiringCacheTestCase(unittest.TestCase): +class ExpiringCacheTestCase(unittest.HomeserverTestCase): def test_get_set(self): clock = MockClock() cache = ExpiringCache("test", clock, max_len=1) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 786947375d..0adb2174af 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -22,7 +22,7 @@ from synapse.util.caches.treecache import TreeCache from .. import unittest -class LruCacheTestCase(unittest.TestCase): +class LruCacheTestCase(unittest.HomeserverTestCase): def test_get_set(self): cache = LruCache(1) cache["key"] = "value" @@ -84,7 +84,7 @@ class LruCacheTestCase(unittest.TestCase): self.assertEquals(len(cache), 0) -class LruCacheCallbacksTestCase(unittest.TestCase): +class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): def test_get(self): m = Mock() cache = LruCache(1) @@ -233,7 +233,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): self.assertEquals(m3.call_count, 1) -class LruCacheSizedTestCase(unittest.TestCase): +class LruCacheSizedTestCase(unittest.HomeserverTestCase): def test_evict(self): cache = LruCache(5, size_callback=len) cache["key1"] = [0] diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 6857933540..13b753e367 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -1,11 +1,9 @@ -from mock import patch - from synapse.util.caches.stream_change_cache import StreamChangeCache from tests import unittest -class StreamChangeCacheTests(unittest.TestCase): +class StreamChangeCacheTests(unittest.HomeserverTestCase): """ Tests for StreamChangeCache. """ @@ -54,7 +52,6 @@ class StreamChangeCacheTests(unittest.TestCase): self.assertTrue(cache.has_entity_changed("user@foo.com", 0)) self.assertTrue(cache.has_entity_changed("not@here.website", 0)) - @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0) def test_entity_has_changed_pops_off_start(self): """ StreamChangeCache.entity_has_changed will respect the max size and diff --git a/tests/utils.py b/tests/utils.py index f9be62b499..59c020a051 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -167,6 +167,7 @@ def default_config(name, parse=False): # disable user directory updates, because they get done in the # background, which upsets the test runner. "update_user_directory": False, + "caches": {"global_factor": 1}, } if parse: -- cgit 1.5.1