From c16e192e2f9970cc62adfd758034244631968102 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 5 Nov 2019 15:49:43 +0000 Subject: Fix caching devices for remote servers in worker. When the `/keys/query` API is hit on client_reader worker Synapse may decide that it needs to resync some remote deivces. Usually this happens on master, and then gets cached. However, that fails on workers and so it falls back to fetching devices from remotes directly, which may in turn fail if the remote is down. --- synapse/replication/http/__init__.py | 10 +++++- synapse/replication/http/devices.py | 69 ++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 synapse/replication/http/devices.py (limited to 'synapse/replication') diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index 81b85352b1..28dbc6fcba 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -14,7 +14,14 @@ # limitations under the License. from synapse.http.server import JsonResource -from synapse.replication.http import federation, login, membership, register, send_event +from synapse.replication.http import ( + devices, + federation, + login, + membership, + register, + send_event, +) REPLICATION_PREFIX = "/_synapse/replication" @@ -30,3 +37,4 @@ class ReplicationRestResource(JsonResource): federation.register_servlets(hs, self) login.register_servlets(hs, self) register.register_servlets(hs, self) + devices.register_servlets(hs, self) diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py new file mode 100644 index 0000000000..795ca7b65e --- /dev/null +++ b/synapse/replication/http/devices.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.replication.http._base import ReplicationEndpoint + +logger = logging.getLogger(__name__) + + +class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): + """Notifies that a user has joined or left the room + + Request format: + + POST /_synapse/replication/user_device_resync/:user_id + + {} + + Response is equivalent to ` /_matrix/federation/v1/user/devices/:user_id` + response, e.g.: + + { + "user_id": "@alice:example.org", + "devices": [ + { + "device_id": "JLAFKJWSCS", + "keys": { ... }, + "device_display_name": "Alice's Mobile Phone" + } + ] + } + """ + + NAME = "user_device_resync" + PATH_ARGS = ("user_id",) + CACHE = False + + def __init__(self, hs): + super(ReplicationUserDevicesResyncRestServlet, self).__init__(hs) + + self.device_list_updater = hs.get_device_handler().device_list_updater + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + @staticmethod + def _serialize_payload(user_id): + return {} + + async def _handle_request(self, request, user_id): + user_devices = await self.device_list_updater.user_device_resync(user_id) + + return 200, user_devices + + +def register_servlets(hs, http_server): + ReplicationUserDevicesResyncRestServlet(hs).register(http_server) -- cgit 1.5.1 From c4bdf2d7855dcddb7e294e9dba458884db2be7e0 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Fri, 8 Nov 2019 11:42:55 +0000 Subject: Remove content from being sent for account data rdata stream --- synapse/replication/tcp/streams/_base.py | 6 +++--- synapse/storage/data_stores/main/account_data.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'synapse/replication') diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 9e45429d49..dd87733842 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -89,7 +89,7 @@ TagAccountDataStreamRow = namedtuple( ) AccountDataStreamRow = namedtuple( "AccountDataStream", - ("user_id", "room_id", "data_type", "data"), # str # str # str # dict + ("user_id", "room_id", "data_type"), # str # str # str ) GroupsStreamRow = namedtuple( "GroupsStreamRow", @@ -421,8 +421,8 @@ class AccountDataStream(Stream): results = list(room_results) results.extend( - (stream_id, user_id, None, account_data_type, content) - for stream_id, user_id, account_data_type, content in global_results + (stream_id, user_id, None, account_data_type) + for stream_id, user_id, account_data_type in global_results ) return results diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index 6afbfc0d74..22093484ed 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -184,14 +184,14 @@ class AccountDataWorkerStore(SQLBaseStore): current_id(int): The position to fetch up to. Returns: A deferred pair of lists of tuples of stream_id int, user_id string, - room_id string, type string, and content string. + room_id string, and type string. """ if last_room_id == current_id and last_global_id == current_id: return defer.succeed(([], [])) def get_updated_account_data_txn(txn): sql = ( - "SELECT stream_id, user_id, account_data_type, content" + "SELECT stream_id, user_id, account_data_type" " FROM account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) @@ -199,7 +199,7 @@ class AccountDataWorkerStore(SQLBaseStore): global_results = txn.fetchall() sql = ( - "SELECT stream_id, user_id, room_id, account_data_type, content" + "SELECT stream_id, user_id, room_id, account_data_type" " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) -- cgit 1.5.1 From cd96b4586f8dfa8a6857e1bee7de2bd9660238d5 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Fri, 8 Nov 2019 15:45:45 +0000 Subject: lint --- synapse/replication/tcp/streams/_base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'synapse/replication') diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index dd87733842..8512923eae 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -88,8 +88,7 @@ TagAccountDataStreamRow = namedtuple( "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict ) AccountDataStreamRow = namedtuple( - "AccountDataStream", - ("user_id", "room_id", "data_type"), # str # str # str + "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str ) GroupsStreamRow = namedtuple( "GroupsStreamRow", -- cgit 1.5.1 From 35f9165e96f6261e15aadb439a5d2199bede3c99 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 26 Nov 2019 12:04:48 +0000 Subject: Fixup docs --- changelog.d/6332.bugfix | 2 +- synapse/replication/http/devices.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) (limited to 'synapse/replication') diff --git a/changelog.d/6332.bugfix b/changelog.d/6332.bugfix index b14bd7e43c..67d5170ba0 100644 --- a/changelog.d/6332.bugfix +++ b/changelog.d/6332.bugfix @@ -1 +1 @@ -Fix caching devices for remote users when using workers. +Fix caching devices for remote users when using workers, so that we don't attempt to refetch (and potentially fail) each time a user requests devices. diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index 795ca7b65e..e32aac0a25 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -21,7 +21,11 @@ logger = logging.getLogger(__name__) class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): - """Notifies that a user has joined or left the room + """Ask master to resync the device list for a user by contacting their + server. + + This must happen on master so that the results can be correctly cached in + the database and streamed to workers. Request format: -- cgit 1.5.1 From 2173785f0d9124037ca841b568349ad0424b39cd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Nov 2019 11:31:56 +0000 Subject: Propagate reason in remotely rejected invites --- synapse/handlers/federation.py | 4 ++-- synapse/handlers/room_member.py | 13 +++++++++---- synapse/handlers/room_member_worker.py | 5 ++++- synapse/replication/http/membership.py | 7 +++++-- 4 files changed, 20 insertions(+), 9 deletions(-) (limited to 'synapse/replication') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index a5ae7b77d1..d3267734f7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1428,9 +1428,9 @@ class FederationHandler(BaseHandler): return event @defer.inlineCallbacks - def do_remotely_reject_invite(self, target_hosts, room_id, user_id): + def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content): origin, event, event_format_version = yield self._make_and_verify_event( - target_hosts, room_id, user_id, "leave" + target_hosts, room_id, user_id, "leave", content=content, ) # Mark as outlier as we don't have any state for this event; we're not # even in the room. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 6cfee4b361..7b7270fc61 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -94,7 +94,9 @@ class RoomMemberHandler(object): raise NotImplementedError() @abc.abstractmethod - def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): + def _remote_reject_invite( + self, requester, remote_room_hosts, room_id, target, content + ): """Attempt to reject an invite for a room this server is not in. If we fail to do so we locally mark the invite as rejected. @@ -104,6 +106,7 @@ class RoomMemberHandler(object): reject invite room_id (str) target (UserID): The user rejecting the invite + content (dict): The content for the rejection event Returns: Deferred[dict]: A dictionary to be returned to the client, may @@ -471,7 +474,7 @@ class RoomMemberHandler(object): # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] res = yield self._remote_reject_invite( - requester, remote_room_hosts, room_id, target + requester, remote_room_hosts, room_id, target, content, ) return res @@ -971,13 +974,15 @@ class RoomMemberMasterHandler(RoomMemberHandler): ) @defer.inlineCallbacks - def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): + def _remote_reject_invite( + self, requester, remote_room_hosts, room_id, target, content + ): """Implements RoomMemberHandler._remote_reject_invite """ fed_handler = self.federation_handler try: ret = yield fed_handler.do_remotely_reject_invite( - remote_room_hosts, room_id, target.to_string() + remote_room_hosts, room_id, target.to_string(), content=content, ) return ret except Exception as e: diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 75e96ae1a2..69be86893b 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -55,7 +55,9 @@ class RoomMemberWorkerHandler(RoomMemberHandler): return ret - def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): + def _remote_reject_invite( + self, requester, remote_room_hosts, room_id, target, content + ): """Implements RoomMemberHandler._remote_reject_invite """ return self._remote_reject_client( @@ -63,6 +65,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): remote_room_hosts=remote_room_hosts, room_id=room_id, user_id=target.to_string(), + content=content, ) def _user_joined_room(self, target, room_id): diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index cc1f249740..3577611fd7 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -93,6 +93,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): { "requester": ..., "remote_room_hosts": [...], + "content": { ... } } """ @@ -107,7 +108,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - def _serialize_payload(requester, room_id, user_id, remote_room_hosts): + def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content): """ Args: requester(Requester) @@ -118,12 +119,14 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): return { "requester": requester.serialize(), "remote_room_hosts": remote_room_hosts, + "content": content, } async def _handle_request(self, request, room_id, user_id): content = parse_json_object_from_request(request) remote_room_hosts = content["remote_room_hosts"] + event_content = content["content"] requester = Requester.deserialize(self.store, content["requester"]) @@ -134,7 +137,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): try: event = await self.federation_handler.do_remotely_reject_invite( - remote_room_hosts, room_id, user_id + remote_room_hosts, room_id, user_id, event_content, ) ret = event.get_pdu_json() except Exception as e: -- cgit 1.5.1 From 1056d6885a7b96be85c5ff19e26eba2ed3f90dd4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 3 Dec 2019 14:28:46 +0000 Subject: Move cache invalidation to main data store --- synapse/replication/slave/storage/_base.py | 3 +- synapse/storage/_base.py | 104 -------------------- synapse/storage/data_stores/main/__init__.py | 2 + synapse/storage/data_stores/main/cache.py | 131 +++++++++++++++++++++++++ synapse/storage/data_stores/main/client_ips.py | 2 +- synapse/storage/data_stores/main/devices.py | 14 +-- 6 files changed, 143 insertions(+), 113 deletions(-) create mode 100644 synapse/storage/data_stores/main/cache.py (limited to 'synapse/replication') diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 456bc005a0..71e5877aca 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -18,7 +18,8 @@ from typing import Dict import six -from synapse.storage._base import _CURRENT_STATE_CACHE_NAME, SQLBaseStore +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.cache import _CURRENT_STATE_CACHE_NAME from synapse.storage.engines import PostgresEngine from ._slaved_id_tracker import SlavedIdTracker diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 6b8120a608..c02248cfe9 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging import random import sys @@ -34,7 +33,6 @@ from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import get_domain_from_id -from synapse.util import batch_iter from synapse.util.stringutils import exception_to_unicode # import a function which will return a monotonic time, in seconds @@ -77,10 +75,6 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { "event_search": "event_search_event_id_idx", } -# This is a special cache name we use to batch multiple invalidations of caches -# based on the current state when notifying workers over replication. -_CURRENT_STATE_CACHE_NAME = "cs_cache_fake" - class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object @@ -1322,47 +1316,6 @@ class SQLBaseStore(object): return cache, min_val - def _invalidate_cache_and_stream(self, txn, cache_func, keys): - """Invalidates the cache and adds it to the cache stream so slaves - will know to invalidate their caches. - - This should only be used to invalidate caches where slaves won't - otherwise know from other replication streams that the cache should - be invalidated. - """ - txn.call_after(cache_func.invalidate, keys) - self._send_invalidation_to_replication(txn, cache_func.__name__, keys) - - def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): - """Special case invalidation of caches based on current state. - - We special case this so that we can batch the cache invalidations into a - single replication poke. - - Args: - txn - room_id (str): Room where state changed - members_changed (iterable[str]): The user_ids of members that have changed - """ - txn.call_after(self._invalidate_state_caches, room_id, members_changed) - - if members_changed: - # We need to be careful that the size of the `members_changed` list - # isn't so large that it causes problems sending over replication, so we - # send them in chunks. - # Max line length is 16K, and max user ID length is 255, so 50 should - # be safe. - for chunk in batch_iter(members_changed, 50): - keys = itertools.chain([room_id], chunk) - self._send_invalidation_to_replication( - txn, _CURRENT_STATE_CACHE_NAME, keys - ) - else: - # if no members changed, we still need to invalidate the other caches. - self._send_invalidation_to_replication( - txn, _CURRENT_STATE_CACHE_NAME, [room_id] - ) - def _invalidate_state_caches(self, room_id, members_changed): """Invalidates caches that are based on the current state, but does not stream invalidations down replication. @@ -1396,63 +1349,6 @@ class SQLBaseStore(object): # which is fine. pass - def _send_invalidation_to_replication(self, txn, cache_name, keys): - """Notifies replication that given cache has been invalidated. - - Note that this does *not* invalidate the cache locally. - - Args: - txn - cache_name (str) - keys (iterable[str]) - """ - - if isinstance(self.database_engine, PostgresEngine): - # get_next() returns a context manager which is designed to wrap - # the transaction. However, we want to only get an ID when we want - # to use it, here, so we need to call __enter__ manually, and have - # __exit__ called after the transaction finishes. - ctx = self._cache_id_gen.get_next() - stream_id = ctx.__enter__() - txn.call_on_exception(ctx.__exit__, None, None, None) - txn.call_after(ctx.__exit__, None, None, None) - txn.call_after(self.hs.get_notifier().on_new_replication_data) - - self._simple_insert_txn( - txn, - table="cache_invalidation_stream", - values={ - "stream_id": stream_id, - "cache_func": cache_name, - "keys": list(keys), - "invalidation_ts": self.clock.time_msec(), - }, - ) - - def get_all_updated_caches(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_updated_caches_txn(txn): - # We purposefully don't bound by the current token, as we want to - # send across cache invalidations as quickly as possible. Cache - # invalidations are idempotent, so duplicates are fine. - sql = ( - "SELECT stream_id, cache_func, keys, invalidation_ts" - " FROM cache_invalidation_stream" - " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, limit)) - return txn.fetchall() - - return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn) - - def get_cache_stream_token(self): - if self._cache_id_gen: - return self._cache_id_gen.get_current_token() - else: - return 0 - def _simple_select_list_paginate( self, table, diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index 10c940df1e..474924c68f 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -32,6 +32,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache from .account_data import AccountDataStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore +from .cache import CacheInvalidationStore from .client_ips import ClientIpStore from .deviceinbox import DeviceInboxStore from .devices import DeviceStore @@ -110,6 +111,7 @@ class DataStore( MonthlyActiveUsersStore, StatsStore, RelationsStore, + CacheInvalidationStore, ): def __init__(self, db_conn, hs): self.hs = hs diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py new file mode 100644 index 0000000000..6efcc5f3b0 --- /dev/null +++ b/synapse/storage/data_stores/main/cache.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore +from synapse.storage.engines import PostgresEngine +from synapse.util import batch_iter + +logger = logging.getLogger(__name__) + + +# This is a special cache name we use to batch multiple invalidations of caches +# based on the current state when notifying workers over replication. +_CURRENT_STATE_CACHE_NAME = "cs_cache_fake" + + +class CacheInvalidationStore(SQLBaseStore): + def _invalidate_cache_and_stream(self, txn, cache_func, keys): + """Invalidates the cache and adds it to the cache stream so slaves + will know to invalidate their caches. + + This should only be used to invalidate caches where slaves won't + otherwise know from other replication streams that the cache should + be invalidated. + """ + txn.call_after(cache_func.invalidate, keys) + self._send_invalidation_to_replication(txn, cache_func.__name__, keys) + + def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): + """Special case invalidation of caches based on current state. + + We special case this so that we can batch the cache invalidations into a + single replication poke. + + Args: + txn + room_id (str): Room where state changed + members_changed (iterable[str]): The user_ids of members that have changed + """ + txn.call_after(self._invalidate_state_caches, room_id, members_changed) + + if members_changed: + # We need to be careful that the size of the `members_changed` list + # isn't so large that it causes problems sending over replication, so we + # send them in chunks. + # Max line length is 16K, and max user ID length is 255, so 50 should + # be safe. + for chunk in batch_iter(members_changed, 50): + keys = itertools.chain([room_id], chunk) + self._send_invalidation_to_replication( + txn, _CURRENT_STATE_CACHE_NAME, keys + ) + else: + # if no members changed, we still need to invalidate the other caches. + self._send_invalidation_to_replication( + txn, _CURRENT_STATE_CACHE_NAME, [room_id] + ) + + def _send_invalidation_to_replication(self, txn, cache_name, keys): + """Notifies replication that given cache has been invalidated. + + Note that this does *not* invalidate the cache locally. + + Args: + txn + cache_name (str) + keys (iterable[str]) + """ + + if isinstance(self.database_engine, PostgresEngine): + # get_next() returns a context manager which is designed to wrap + # the transaction. However, we want to only get an ID when we want + # to use it, here, so we need to call __enter__ manually, and have + # __exit__ called after the transaction finishes. + ctx = self._cache_id_gen.get_next() + stream_id = ctx.__enter__() + txn.call_on_exception(ctx.__exit__, None, None, None) + txn.call_after(ctx.__exit__, None, None, None) + txn.call_after(self.hs.get_notifier().on_new_replication_data) + + self._simple_insert_txn( + txn, + table="cache_invalidation_stream", + values={ + "stream_id": stream_id, + "cache_func": cache_name, + "keys": list(keys), + "invalidation_ts": self.clock.time_msec(), + }, + ) + + def get_all_updated_caches(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_caches_txn(txn): + # We purposefully don't bound by the current token, as we want to + # send across cache invalidations as quickly as possible. Cache + # invalidations are idempotent, so duplicates are fine. + sql = ( + "SELECT stream_id, cache_func, keys, invalidation_ts" + " FROM cache_invalidation_stream" + " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, limit)) + return txn.fetchall() + + return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn) + + def get_cache_stream_token(self): + if self._cache_id_gen: + return self._cache_id_gen.get_current_token() + else: + return 0 diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 7931b876ce..cae93b0e22 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -21,8 +21,8 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage import background_updates -from synapse.util.caches.descriptors import Cache from synapse.util.caches import CACHE_SIZE_FACTOR +from synapse.util.caches.descriptors import Cache logger = logging.getLogger(__name__) diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index b50ee026a2..a3ad23e783 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -30,16 +30,16 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import ( - SQLBaseStore, - db_to_json, - make_in_list_sql_clause, -) -from synapse.util.caches.descriptors import Cache +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.background_updates import BackgroundUpdateStore from synapse.types import get_verify_key_from_cross_signing_key from synapse.util import batch_iter -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.caches.descriptors import ( + Cache, + cached, + cachedInlineCallbacks, + cachedList, +) logger = logging.getLogger(__name__) -- cgit 1.5.1 From a7f20500ff39399634d4623e284fb2f9892776ae Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 4 Dec 2019 15:45:42 +0000 Subject: _CURRENT_STATE_CACHE_NAME is public --- synapse/replication/slave/storage/_base.py | 4 ++-- synapse/storage/data_stores/main/cache.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) (limited to 'synapse/replication') diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 71e5877aca..6ece1d6745 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -19,7 +19,7 @@ from typing import Dict import six from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.cache import _CURRENT_STATE_CACHE_NAME +from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME from synapse.storage.engines import PostgresEngine from ._slaved_id_tracker import SlavedIdTracker @@ -63,7 +63,7 @@ class BaseSlavedStore(SQLBaseStore): if stream_name == "caches": self._cache_id_gen.advance(token) for row in rows: - if row.cache_func == _CURRENT_STATE_CACHE_NAME: + if row.cache_func == CURRENT_STATE_CACHE_NAME: room_id = row.keys[0] members_changed = set(row.keys[1:]) self._invalidate_state_caches(room_id, members_changed) diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index 6efcc5f3b0..258c08722a 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) # This is a special cache name we use to batch multiple invalidations of caches # based on the current state when notifying workers over replication. -_CURRENT_STATE_CACHE_NAME = "cs_cache_fake" +CURRENT_STATE_CACHE_NAME = "cs_cache_fake" class CacheInvalidationStore(SQLBaseStore): @@ -65,12 +65,12 @@ class CacheInvalidationStore(SQLBaseStore): for chunk in batch_iter(members_changed, 50): keys = itertools.chain([room_id], chunk) self._send_invalidation_to_replication( - txn, _CURRENT_STATE_CACHE_NAME, keys + txn, CURRENT_STATE_CACHE_NAME, keys ) else: # if no members changed, we still need to invalidate the other caches. self._send_invalidation_to_replication( - txn, _CURRENT_STATE_CACHE_NAME, [room_id] + txn, CURRENT_STATE_CACHE_NAME, [room_id] ) def _send_invalidation_to_replication(self, txn, cache_name, keys): -- cgit 1.5.1 From 9a4fb457cf5918c85068ea249cd2d58b3e2e3cfc Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 6 Dec 2019 13:08:40 +0000 Subject: Change DataStores to accept 'database' param. --- synapse/app/federation_sender.py | 5 +++-- synapse/app/user_dir.py | 5 +++-- synapse/replication/slave/storage/_base.py | 5 +++-- synapse/replication/slave/storage/account_data.py | 5 +++-- synapse/replication/slave/storage/client_ips.py | 5 +++-- synapse/replication/slave/storage/deviceinbox.py | 5 +++-- synapse/replication/slave/storage/devices.py | 5 +++-- synapse/replication/slave/storage/events.py | 5 +++-- synapse/replication/slave/storage/filtering.py | 5 +++-- synapse/replication/slave/storage/groups.py | 5 +++-- synapse/replication/slave/storage/presence.py | 5 +++-- synapse/replication/slave/storage/push_rule.py | 5 +++-- synapse/replication/slave/storage/pushers.py | 5 +++-- synapse/replication/slave/storage/receipts.py | 5 +++-- synapse/replication/slave/storage/room.py | 5 +++-- synapse/storage/_base.py | 2 +- synapse/storage/data_stores/main/__init__.py | 5 +++-- synapse/storage/data_stores/main/account_data.py | 9 +++++---- synapse/storage/data_stores/main/appservice.py | 5 +++-- synapse/storage/data_stores/main/client_ips.py | 9 +++++---- synapse/storage/data_stores/main/deviceinbox.py | 9 +++++---- synapse/storage/data_stores/main/devices.py | 9 +++++---- synapse/storage/data_stores/main/event_federation.py | 5 +++-- synapse/storage/data_stores/main/event_push_actions.py | 9 +++++---- synapse/storage/data_stores/main/events.py | 5 +++-- synapse/storage/data_stores/main/events_bg_updates.py | 5 +++-- synapse/storage/data_stores/main/events_worker.py | 5 +++-- synapse/storage/data_stores/main/media_repository.py | 11 +++++++---- synapse/storage/data_stores/main/monthly_active_users.py | 7 ++++--- synapse/storage/data_stores/main/push_rule.py | 5 +++-- synapse/storage/data_stores/main/receipts.py | 9 +++++---- synapse/storage/data_stores/main/registration.py | 13 +++++++------ synapse/storage/data_stores/main/room.py | 9 +++++---- synapse/storage/data_stores/main/roommember.py | 13 +++++++------ synapse/storage/data_stores/main/search.py | 9 +++++---- synapse/storage/data_stores/main/state.py | 13 +++++++------ synapse/storage/data_stores/main/stats.py | 5 +++-- synapse/storage/data_stores/main/stream.py | 5 +++-- synapse/storage/data_stores/main/transactions.py | 5 +++-- synapse/storage/data_stores/main/user_directory.py | 9 +++++---- tests/storage/test_appservice.py | 5 +++-- 41 files changed, 156 insertions(+), 114 deletions(-) (limited to 'synapse/replication') diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 448e45e00f..f24920a7d6 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -40,6 +40,7 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.streams._base import ReceiptsStream from synapse.server import HomeServer +from synapse.storage.database import Database from synapse.storage.engines import create_engine from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer @@ -59,8 +60,8 @@ class FederationSenderSlaveStore( SlavedDeviceStore, SlavedPresenceStore, ): - def __init__(self, db_conn, hs): - super(FederationSenderSlaveStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(FederationSenderSlaveStore, self).__init__(database, db_conn, hs) # We pull out the current federation stream position now so that we # always have a known value for the federation position in memory so diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index b6d4481725..c01fb34a9b 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -43,6 +43,7 @@ from synapse.replication.tcp.streams.events import ( from synapse.rest.client.v2_alpha import user_directory from synapse.server import HomeServer from synapse.storage.data_stores.main.user_directory import UserDirectoryStore +from synapse.storage.database import Database from synapse.storage.engines import create_engine from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.httpresourcetree import create_resource_tree @@ -60,8 +61,8 @@ class UserDirectorySlaveStore( UserDirectoryStore, BaseSlavedStore, ): - def __init__(self, db_conn, hs): - super(UserDirectorySlaveStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(UserDirectorySlaveStore, self).__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 6ece1d6745..b91a528245 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -20,6 +20,7 @@ import six from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from ._slaved_id_tracker import SlavedIdTracker @@ -35,8 +36,8 @@ def __func__(inp): class BaseSlavedStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(BaseSlavedStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(BaseSlavedStore, self).__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = SlavedIdTracker( db_conn, "cache_invalidation_stream", "stream_id" diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index bc2f6a12ae..ebe94909cb 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -18,15 +18,16 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore from synapse.storage.data_stores.main.tags import TagsWorkerStore +from synapse.storage.database import Database class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._account_data_id_gen = SlavedIdTracker( db_conn, "account_data_max_stream_id", "stream_id" ) - super(SlavedAccountDataStore, self).__init__(db_conn, hs) + super(SlavedAccountDataStore, self).__init__(database, db_conn, hs) def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index b4f58cea19..fbf996e33a 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -14,6 +14,7 @@ # limitations under the License. from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY +from synapse.storage.database import Database from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache @@ -21,8 +22,8 @@ from ._base import BaseSlavedStore class SlavedClientIpStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedClientIpStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedClientIpStore, self).__init__(database, db_conn, hs) self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 9fb6c5c6ff..0c237c6e0f 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -16,13 +16,14 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore +from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( db_conn, "device_max_stream_id", "stream_id" ) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index de50748c30..dc625e0d7a 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -18,12 +18,13 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage.data_stores.main.devices import DeviceWorkerStore from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedDeviceStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedDeviceStore, self).__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index d0a0eaf75b..29f35b9915 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -31,6 +31,7 @@ from synapse.storage.data_stores.main.signatures import SignatureWorkerStore from synapse.storage.data_stores.main.state import StateGroupWorkerStore from synapse.storage.data_stores.main.stream import StreamWorkerStore from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker @@ -59,13 +60,13 @@ class SlavedEventStore( RelationsWorkerStore, BaseSlavedStore, ): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") self._backfill_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", step=-1 ) - super(SlavedEventStore, self).__init__(db_conn, hs) + super(SlavedEventStore, self).__init__(database, db_conn, hs) # Cached functions can't be accessed through a class instance so we need # to reach inside the __dict__ to extract them. diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index 5c84ebd125..bcb0688954 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -14,13 +14,14 @@ # limitations under the License. from synapse.storage.data_stores.main.filtering import FilteringStore +from synapse.storage.database import Database from ._base import BaseSlavedStore class SlavedFilteringStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedFilteringStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedFilteringStore, self).__init__(database, db_conn, hs) # Filters are immutable so this cache doesn't need to be expired get_user_filter = FilteringStore.__dict__["get_user_filter"] diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 28a46edd28..69a4ae42f9 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -14,6 +14,7 @@ # limitations under the License. from synapse.storage import DataStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore, __func__ @@ -21,8 +22,8 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedGroupServerStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedGroupServerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedGroupServerStore, self).__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 747ced0c84..f552e7c972 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -15,6 +15,7 @@ from synapse.storage import DataStore from synapse.storage.data_stores.main.presence import PresenceStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore, __func__ @@ -22,8 +23,8 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedPresenceStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedPresenceStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedPresenceStore, self).__init__(database, db_conn, hs) self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") self._presence_on_startup = self._get_active_presence(db_conn) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 3655f05e54..eebd5a1fb6 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -15,17 +15,18 @@ # limitations under the License. from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore +from synapse.storage.database import Database from ._slaved_id_tracker import SlavedIdTracker from .events import SlavedEventStore class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id" ) - super(SlavedPushRuleStore, self).__init__(db_conn, hs) + super(SlavedPushRuleStore, self).__init__(database, db_conn, hs) def get_push_rules_stream_token(self): return ( diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index b4331d0799..f22c2d44a3 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -15,14 +15,15 @@ # limitations under the License. from synapse.storage.data_stores.main.pusher import PusherWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedPusherStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedPusherStore, self).__init__(database, db_conn, hs) self._pushers_id_gen = SlavedIdTracker( db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 43d823c601..d40dc6e1f5 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -15,6 +15,7 @@ # limitations under the License. from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker @@ -29,14 +30,14 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = SlavedIdTracker( db_conn, "receipts_linearized", "stream_id" ) - super(SlavedReceiptsStore, self).__init__(db_conn, hs) + super(SlavedReceiptsStore, self).__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index d9ad386b28..3a20f45316 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -14,14 +14,15 @@ # limitations under the License. from synapse.storage.data_stores.main.room import RoomWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class RoomStore(RoomWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(RoomStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomStore, self).__init__(database, db_conn, hs) self._public_room_id_gen = SlavedIdTracker( db_conn, "public_room_list_stream", "stream_id" ) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b7e27d4e97..f9e7f9a71e 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -37,7 +37,7 @@ class SQLBaseStore(object): per data store (and not one per physical database). """ - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self.database_engine = hs.database_engine diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index 6adb8adb04..7f5fd81bcf 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -20,6 +20,7 @@ import logging import time from synapse.api.constants import PresenceState +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( ChainedIdGenerator, @@ -111,7 +112,7 @@ class DataStore( RelationsStore, CacheInvalidationStore, ): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self.database_engine = hs.database_engine @@ -169,7 +170,7 @@ class DataStore( else: self._cache_id_gen = None - super(DataStore, self).__init__(db_conn, hs) + super(DataStore, self).__init__(database, db_conn, hs) self._presence_on_startup = self._get_active_presence(db_conn) diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index a96fe9485c..44d20c19bf 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -22,6 +22,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -38,13 +39,13 @@ class AccountDataWorkerStore(SQLBaseStore): # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max ) - super(AccountDataWorkerStore, self).__init__(db_conn, hs) + super(AccountDataWorkerStore, self).__init__(database, db_conn, hs) @abc.abstractmethod def get_max_account_data_stream_id(self): @@ -270,12 +271,12 @@ class AccountDataWorkerStore(SQLBaseStore): class AccountDataStore(AccountDataWorkerStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._account_data_id_gen = StreamIdGenerator( db_conn, "account_data_max_stream_id", "stream_id" ) - super(AccountDataStore, self).__init__(db_conn, hs) + super(AccountDataStore, self).__init__(database, db_conn, hs) def get_max_account_data_stream_id(self): """Get the current max stream id for the private user data stream diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py index 6b2e12719c..b2f39649fd 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -24,6 +24,7 @@ from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database logger = logging.getLogger(__name__) @@ -48,13 +49,13 @@ def _make_exclusive_regex(services_cache): class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.services_cache = load_appservices( hs.hostname, hs.config.app_service_config_files ) self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) - super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs) + super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs) def get_app_services(self): return self.services_cache diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 7b470a58f1..320c5b0f07 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -21,6 +21,7 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache @@ -33,8 +34,8 @@ LAST_SEEN_GRANULARITY = 120 * 1000 class ClientIpBackgroundUpdateStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_index_update( "user_ips_device_index", @@ -363,13 +364,13 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): class ClientIpStore(ClientIpBackgroundUpdateStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR ) - super(ClientIpStore, self).__init__(db_conn, hs) + super(ClientIpStore, self).__init__(database, db_conn, hs) self.user_ips_max_age = hs.config.user_ips_max_age diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index 3c9f09301a..85cfa16850 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -21,6 +21,7 @@ from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache logger = logging.getLogger(__name__) @@ -210,8 +211,8 @@ class DeviceInboxWorkerStore(SQLBaseStore): class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, db_conn, hs): - super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_index_update( "device_inbox_stream_index", @@ -241,8 +242,8 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, db_conn, hs): - super(DeviceInboxStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceInboxStore, self).__init__(database, db_conn, hs) # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 91ddaf137e..9a828231c4 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -31,6 +31,7 @@ from synapse.logging.opentracing import ( ) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.types import get_verify_key_from_cross_signing_key from synapse.util import batch_iter from synapse.util.caches.descriptors import ( @@ -642,8 +643,8 @@ class DeviceWorkerStore(SQLBaseStore): class DeviceBackgroundUpdateStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_index_update( "device_lists_stream_idx", @@ -692,8 +693,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(DeviceStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceStore, self).__init__(database, db_conn, hs) # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 31d2e8eb28..1f517e8fad 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -28,6 +28,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.signatures import SignatureWorkerStore +from synapse.storage.database import Database from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -491,8 +492,8 @@ class EventFederationStore(EventFederationWorkerStore): EVENT_AUTH_STATE_ONLY = "event_auth_state_only" - def __init__(self, db_conn, hs): - super(EventFederationStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventFederationStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_update_handler( self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index eec054cd48..9988a6d3fc 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -24,6 +24,7 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore +from synapse.storage.database import Database from synapse.util.caches.descriptors import cachedInlineCallbacks logger = logging.getLogger(__name__) @@ -68,8 +69,8 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(EventPushActionsWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs) # These get correctly set by _find_stream_orderings_for_times_txn self.stream_ordering_month_ago = None @@ -611,8 +612,8 @@ class EventPushActionsWorkerStore(SQLBaseStore): class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - def __init__(self, db_conn, hs): - super(EventPushActionsStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventPushActionsStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_index_update( self.EPA_HIGHLIGHT_INDEX, diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index d644c82784..da1529f6ea 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -41,6 +41,7 @@ from synapse.storage._base import make_in_list_sql_clause from synapse.storage.data_stores.main.event_federation import EventFederationStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.state import StateGroupWorkerStore +from synapse.storage.database import Database from synapse.types import RoomStreamToken, get_domain_from_id from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -95,8 +96,8 @@ def _retry_on_integrity_error(func): class EventsStore( StateGroupWorkerStore, EventFederationStore, EventsWorkerStore, ): - def __init__(self, db_conn, hs): - super(EventsStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventsStore, self).__init__(database, db_conn, hs) # Collect metrics on the number of forward extremities that exist. # Counter of number of extremities to count diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index cb1fc30c31..efee17b929 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -23,6 +23,7 @@ from twisted.internet import defer from synapse.api.constants import EventContentFields from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database logger = logging.getLogger(__name__) @@ -33,8 +34,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" - def __init__(self, db_conn, hs): - super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index e041fc5eac..9ee117ce0f 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -33,6 +33,7 @@ from synapse.events.utils import prune_event from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.types import get_domain_from_id from synapse.util import batch_iter from synapse.util.caches.descriptors import Cache @@ -55,8 +56,8 @@ _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) class EventsWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(EventsWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventsWorkerStore, self).__init__(database, db_conn, hs) self._get_event_cache = Cache( "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py index 03c9c6f8ae..80ca36dedf 100644 --- a/synapse/storage/data_stores/main/media_repository.py +++ b/synapse/storage/data_stores/main/media_repository.py @@ -13,11 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(MediaRepositoryBackgroundUpdateStore, self).__init__( + database, db_conn, hs + ) self.db.updates.register_background_index_update( update_name="local_media_repository_url_idx", @@ -31,8 +34,8 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): """Persistence for attachments and avatars""" - def __init__(self, db_conn, hs): - super(MediaRepositoryStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(MediaRepositoryStore, self).__init__(database, db_conn, hs) def get_local_media(self, media_id): """Get the metadata for a local piece of media diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index 34bf3a1880..27158534cb 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -17,6 +17,7 @@ import logging from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -27,13 +28,13 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000 class MonthlyActiveUsersStore(SQLBaseStore): - def __init__(self, dbconn, hs): - super(MonthlyActiveUsersStore, self).__init__(None, hs) + def __init__(self, database: Database, db_conn, hs): + super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) self._clock = hs.get_clock() self.hs = hs # Do not add more reserved users than the total allowable number self.db.new_transaction( - dbconn, + db_conn, "initialise_mau_threepids", [], [], diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index de682cc63a..5ba13aa973 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -27,6 +27,7 @@ from synapse.storage.data_stores.main.appservice import ApplicationServiceWorker from synapse.storage.data_stores.main.pusher import PusherWorkerStore from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore +from synapse.storage.database import Database from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -72,8 +73,8 @@ class PushRulesWorkerStore( # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): - super(PushRulesWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) push_rules_prefill, push_rules_id = self.db.get_cache_dict( db_conn, diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py index ac2d45bd5c..96e54d145e 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py @@ -22,6 +22,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -38,8 +39,8 @@ class ReceiptsWorkerStore(SQLBaseStore): # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): - super(ReceiptsWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() @@ -315,14 +316,14 @@ class ReceiptsWorkerStore(SQLBaseStore): class ReceiptsStore(ReceiptsWorkerStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" ) - super(ReceiptsStore, self).__init__(db_conn, hs) + super(ReceiptsStore, self).__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 1ef143c6d8..5e8ecac0ea 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -27,6 +27,7 @@ from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.types import UserID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -36,8 +37,8 @@ logger = logging.getLogger(__name__) class RegistrationWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(RegistrationWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RegistrationWorkerStore, self).__init__(database, db_conn, hs) self.config = hs.config self.clock = hs.get_clock() @@ -794,8 +795,8 @@ class RegistrationWorkerStore(SQLBaseStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): - def __init__(self, db_conn, hs): - super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.clock = hs.get_clock() self.config = hs.config @@ -920,8 +921,8 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): class RegistrationStore(RegistrationBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(RegistrationStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RegistrationStore, self).__init__(database, db_conn, hs) self._account_validity = hs.config.account_validity diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index da42dae243..0148be20d3 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -29,6 +29,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.search import SearchStore +from synapse.storage.database import Database from synapse.types import ThirdPartyInstanceID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -361,8 +362,8 @@ class RoomWorkerStore(SQLBaseStore): class RoomBackgroundUpdateStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(RoomBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.config = hs.config @@ -440,8 +441,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore): class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): - def __init__(self, db_conn, hs): - super(RoomStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomStore, self).__init__(database, db_conn, hs) self.config = hs.config diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 929f6b0d39..92e3b9c512 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -32,6 +32,7 @@ from synapse.storage._base import ( make_in_list_sql_clause, ) from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database from synapse.storage.engines import Sqlite3Engine from synapse.storage.roommember import ( GetRoomsForUserWithStreamOrdering, @@ -54,8 +55,8 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" class RoomMemberWorkerStore(EventsWorkerStore): - def __init__(self, db_conn, hs): - super(RoomMemberWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs) # Is the current_state_events.membership up to date? Or is the # background update still running? @@ -835,8 +836,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): class RoomMemberBackgroundUpdateStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile ) @@ -991,8 +992,8 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(RoomMemberStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomMemberStore, self).__init__(database, db_conn, hs) def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index ffa1817e64..4eec2fae5e 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -25,6 +25,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine logger = logging.getLogger(__name__) @@ -42,8 +43,8 @@ class SearchBackgroundUpdateStore(SQLBaseStore): EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" - def __init__(self, db_conn, hs): - super(SearchBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs) if not hs.config.enable_search: return @@ -342,8 +343,8 @@ class SearchBackgroundUpdateStore(SQLBaseStore): class SearchStore(SearchBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(SearchStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SearchStore, self).__init__(database, db_conn, hs) def store_event_search_txn(self, txn, event, key, value): """Add event to the search table diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 7d5a9f8128..9ef7b48c74 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -28,6 +28,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter from synapse.util.caches import get_cache_factor_for, intern_string @@ -213,8 +214,8 @@ class StateGroupWorkerStore( STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - def __init__(self, db_conn, hs): - super(StateGroupWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) # Originally the state store used a single DictionaryCache to cache the # event IDs for the state types in a given state group to avoid hammering @@ -1029,8 +1030,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" - def __init__(self, db_conn, hs): - super(StateBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_update_handler( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self._background_deduplicate_state, @@ -1245,8 +1246,8 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): * `state_groups_state`: Maps state group to state events. """ - def __init__(self, db_conn, hs): - super(StateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StateStore, self).__init__(database, db_conn, hs) def _store_event_state_mappings_txn( self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]] diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 40579bf965..7bc186e9a1 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -22,6 +22,7 @@ from twisted.internet.defer import DeferredLock from synapse.api.constants import EventTypes, Membership from synapse.storage.data_stores.main.state_deltas import StateDeltasStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.util.caches.descriptors import cached @@ -58,8 +59,8 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")} class StatsStore(StateDeltasStore): - def __init__(self, db_conn, hs): - super(StatsStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StatsStore, self).__init__(database, db_conn, hs) self.server_name = hs.hostname self.clock = self.hs.get_clock() diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 2ff8c57109..140da8dad6 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -47,6 +47,7 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -251,8 +252,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): - super(StreamWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StreamWorkerStore, self).__init__(database, db_conn, hs) events_max = self.get_room_max_stream_ordering() event_cache_prefill, min_event_val = self.db.get_cache_dict( diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py index c0d155a43c..5b07c2fbc0 100644 --- a/synapse/storage/data_stores/main/transactions.py +++ b/synapse/storage/data_stores/main/transactions.py @@ -24,6 +24,7 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache # py2 sqlite has buffer hardcoded as only binary type, so we must use it, @@ -52,8 +53,8 @@ class TransactionStore(SQLBaseStore): """A collection of queries for handling PDUs. """ - def __init__(self, db_conn, hs): - super(TransactionStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(TransactionStore, self).__init__(database, db_conn, hs) self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py index 62ffb34b29..90c180ec6d 100644 --- a/synapse/storage/data_stores/main/user_directory.py +++ b/synapse/storage/data_stores/main/user_directory.py @@ -21,6 +21,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules from synapse.storage.data_stores.main.state import StateFilter from synapse.storage.data_stores.main.state_deltas import StateDeltasStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import get_domain_from_id, get_localpart_from_id from synapse.util.caches.descriptors import cached @@ -37,8 +38,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 - def __init__(self, db_conn, hs): - super(UserDirectoryBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -549,8 +550,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 - def __init__(self, db_conn, hs): - super(UserDirectoryStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(UserDirectoryStore, self).__init__(database, db_conn, hs) def remove_from_user_dir(self, user_id): def _remove_from_user_dir_txn(txn): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index dfeea24599..1679112d82 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -28,6 +28,7 @@ from synapse.storage.data_stores.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) +from synapse.storage.database import Database from tests import unittest from tests.utils import setup_test_homeserver @@ -382,8 +383,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): - def __init__(self, db_conn, hs): - super(TestTransactionStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(TestTransactionStore, self).__init__(database, db_conn, hs) class ApplicationServiceStoreConfigTestCase(unittest.TestCase): -- cgit 1.5.1