diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index f9e2533e96..60f2e1245f 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -16,8 +16,8 @@
import logging
from typing import Optional
-from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
class BaseSlavedStore(CacheInvalidationWorkerStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = MultiWriterIdGenerator(
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 9d1d173b2f..d43eaf3a29 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -33,3 +33,11 @@ class SlavedIdTracker(object):
int
"""
return self._current
+
+ def get_current_token_for_writer(self, instance_name: str) -> int:
+ """Returns the position of the given writer.
+
+ For streams with single writers this is equivalent to
+ `get_current_token`.
+ """
+ return self.get_current_token()
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 9db6c62bc7..154f0e687c 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -16,13 +16,14 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
-from synapse.storage.data_stores.main.tags import TagsWorkerStore
-from synapse.storage.database import Database
+from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.account_data import AccountDataWorkerStore
+from synapse.storage.databases.main.tags import TagsWorkerStore
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
self._account_data_id_gen = SlavedIdTracker(
db_conn,
"account_data",
@@ -39,12 +40,12 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
return self._account_data_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "tag_account_data":
+ if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
- elif stream_name == "account_data":
+ elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(token)
for row in rows:
if not row.room_id:
diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py
index a67fbeffb7..0f8d7037bd 100644
--- a/synapse/replication/slave/storage/appservice.py
+++ b/synapse/replication/slave/storage/appservice.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.appservice import (
+from synapse.storage.databases.main.appservice import (
ApplicationServiceTransactionWorkerStore,
ApplicationServiceWorkerStore,
)
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 1a38f53dfb..a6fdedde63 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -13,22 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches.descriptors import Cache
from ._base import BaseSlavedStore
class SlavedClientIpStore(BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000
)
- def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
+ async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())
key = (user_id, access_token, ip)
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 6e7fd259d4..ee7f69a918 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -15,17 +15,18 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
-from synapse.storage.database import Database
+from synapse.replication.tcp.streams import ToDeviceStream
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
- db_conn, "device_max_stream_id", "stream_id"
+ db_conn, "device_inbox", "stream_id"
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache",
@@ -44,7 +45,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "to_device":
+ if stream_name == ToDeviceStream.NAME:
self._device_inbox_id_gen.advance(token)
for row in rows:
if row.entity.startswith("@"):
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 9d8067342f..722f3745e9 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -16,14 +16,14 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
-from synapse.storage.data_stores.main.devices import DeviceWorkerStore
-from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.devices import DeviceWorkerStore
+from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedDeviceStore, self).__init__(database, db_conn, hs)
self.hs = hs
diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py
index 8b9717c46f..1945bcf9a8 100644
--- a/synapse/replication/slave/storage/directory.py
+++ b/synapse/replication/slave/storage/directory.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.directory import DirectoryWorkerStore
+from synapse.storage.databases.main.directory import DirectoryWorkerStore
from ._base import BaseSlavedStore
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 1a1a50a24f..da1cc836cf 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -15,18 +15,18 @@
# limitations under the License.
import logging
-from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore
-from synapse.storage.data_stores.main.event_push_actions import (
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
+from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.data_stores.main.relations import RelationsWorkerStore
-from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
-from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
-from synapse.storage.data_stores.main.state import StateGroupWorkerStore
-from synapse.storage.data_stores.main.stream import StreamWorkerStore
-from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.databases.main.relations import RelationsWorkerStore
+from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
+from synapse.storage.databases.main.signatures import SignatureWorkerStore
+from synapse.storage.databases.main.state import StateGroupWorkerStore
+from synapse.storage.databases.main.stream import StreamWorkerStore
+from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore
@@ -55,11 +55,11 @@ class SlavedEventStore(
RelationsWorkerStore,
BaseSlavedStore,
):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedEventStore, self).__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token()
- curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
+ curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
db_conn,
"current_state_delta_stream",
entity_column="room_id",
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index bcb0688954..2562b6fc38 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.filtering import FilteringStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.filtering import FilteringStore
from ._base import BaseSlavedStore
class SlavedFilteringStore(BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedFilteringStore, self).__init__(database, db_conn, hs)
# Filters are immutable so this cache doesn't need to be expired
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 1851e7d525..3291558c7a 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -15,13 +15,14 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore
-from synapse.storage.database import Database
+from synapse.replication.tcp.streams import GroupServerStream
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.group_server import GroupServerWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedGroupServerStore, self).__init__(database, db_conn, hs)
self.hs = hs
@@ -38,7 +39,7 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
return self._group_updates_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "groups":
+ if stream_name == GroupServerStream.NAME:
self._group_updates_id_gen.advance(token)
for row in rows:
self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py
index 3def367ae9..961579751c 100644
--- a/synapse/replication/slave/storage/keys.py
+++ b/synapse/replication/slave/storage/keys.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.keys import KeyStore
+from synapse.storage.databases.main.keys import KeyStore
# KeyStore isn't really safe to use from a worker, but for now we do so and hope that
# the races it creates aren't too bad.
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 4e0124842d..a912c04360 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.replication.tcp.streams import PresenceStream
from synapse.storage import DataStore
-from synapse.storage.data_stores.main.presence import PresenceStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.presence import PresenceStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore
@@ -23,7 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker
class SlavedPresenceStore(BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
@@ -42,7 +43,7 @@ class SlavedPresenceStore(BaseSlavedStore):
return self._presence_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "presence":
+ if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(token)
for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token)
diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py
index 28c508aad3..f85b20a071 100644
--- a/synapse/replication/slave/storage/profile.py
+++ b/synapse/replication/slave/storage/profile.py
@@ -14,7 +14,7 @@
# limitations under the License.
from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.storage.data_stores.main.profile import ProfileWorkerStore
+from synapse.storage.databases.main.profile import ProfileWorkerStore
class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore):
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 6adb19463a..90d90833f9 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -14,23 +14,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import PushRulesStream
+from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
- def get_push_rules_stream_token(self):
- return (
- self._push_rules_stream_id_gen.get_current_token(),
- self._stream_id_gen.get_current_token(),
- )
-
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "push_rules":
+ # We assert this for the benefit of mypy
+ assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
+
+ if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(token)
for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,))
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index cb78b49acb..63300e5da6 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -14,15 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.pusher import PusherWorkerStore
-from synapse.storage.database import Database
+from synapse.replication.tcp.streams import PushersStream
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.pusher import PusherWorkerStore
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedPusherStore, self).__init__(database, db_conn, hs)
self._pushers_id_gen = SlavedIdTracker(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
@@ -32,6 +33,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
return self._pushers_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "pushers":
+ if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index be716cc558..17ba1f22ac 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -14,23 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
-from synapse.storage.database import Database
+from synapse.replication.tcp.streams import ReceiptsStream
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
-# So, um, we want to borrow a load of functions intended for reading from
-# a DataStore, but we don't want to take functions that either write to the
-# DataStore or are cached and don't have cache invalidation logic.
-#
-# Rather than write duplicate versions of those functions, or lift them to
-# a common base class, we going to grab the underlying __func__ object from
-# the method descriptor on the DataStore and chuck them into our class.
-
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = SlavedIdTracker(
@@ -52,7 +45,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
self.get_receipts_for_room.invalidate((room_id, receipt_type))
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "receipts":
+ if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(token)
for row in rows:
self.invalidate_caches_for_receipt(
diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py
index 4b8553e250..a40f064e2b 100644
--- a/synapse/replication/slave/storage/registration.py
+++ b/synapse/replication/slave/storage/registration.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.registration import RegistrationWorkerStore
+from synapse.storage.databases.main.registration import RegistrationWorkerStore
from ._base import BaseSlavedStore
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 8873bf37e5..427c81772b 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -13,15 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.room import RoomWorkerStore
-from synapse.storage.database import Database
+from synapse.replication.tcp.streams import PublicRoomsStream
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.room import RoomWorkerStore
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
class RoomStore(RoomWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomStore, self).__init__(database, db_conn, hs)
self._public_room_id_gen = SlavedIdTracker(
db_conn, "public_room_list_stream", "stream_id"
@@ -31,7 +32,7 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
return self._public_room_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "public_rooms":
+ if stream_name == PublicRoomsStream.NAME:
self._public_room_id_gen.advance(token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py
index ac88e6b8c3..2091ac0df6 100644
--- a/synapse/replication/slave/storage/transactions.py
+++ b/synapse/replication/slave/storage/transactions.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.transactions import TransactionStore
+from synapse.storage.databases.main.transactions import TransactionStore
from ._base import BaseSlavedStore
|