diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index f45cbd37a0..f9e2533e96 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -14,72 +14,30 @@
# limitations under the License.
import logging
-from typing import Dict, Optional
+from typing import Optional
-import six
-
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
+from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
-
-from ._slaved_id_tracker import SlavedIdTracker
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
logger = logging.getLogger(__name__)
-def __func__(inp):
- if six.PY3:
- return inp
- else:
- return inp.__func__
-
-
-class BaseSlavedStore(SQLBaseStore):
+class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: Database, db_conn, hs):
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
- self._cache_id_gen = SlavedIdTracker(
- db_conn, "cache_invalidation_stream", "stream_id"
- ) # type: Optional[SlavedIdTracker]
+ self._cache_id_gen = MultiWriterIdGenerator(
+ db_conn,
+ database,
+ instance_name=hs.get_instance_name(),
+ table="cache_invalidation_stream_by_instance",
+ instance_column="instance_name",
+ id_column="stream_id",
+ sequence_name="cache_invalidation_stream_seq",
+ ) # type: Optional[MultiWriterIdGenerator]
else:
self._cache_id_gen = None
self.hs = hs
-
- def stream_positions(self) -> Dict[str, int]:
- """
- Get the current positions of all the streams this store wants to subscribe to
-
- Returns:
- map from stream name to the most recent update we have for
- that stream (ie, the point we want to start replicating from)
- """
- pos = {}
- if self._cache_id_gen:
- pos["caches"] = self._cache_id_gen.get_current_token()
- return pos
-
- def process_replication_rows(self, stream_name, token, rows):
- if stream_name == "caches":
- if self._cache_id_gen:
- self._cache_id_gen.advance(token)
- for row in rows:
- if row.cache_func == CURRENT_STATE_CACHE_NAME:
- if row.keys is None:
- raise Exception(
- "Can't send an 'invalidate all' for current state cache"
- )
-
- room_id = row.keys[0]
- members_changed = set(row.keys[1:])
- self._invalidate_state_caches(room_id, members_changed)
- else:
- self._attempt_to_invalidate_cache(row.cache_func, row.keys)
-
- def _invalidate_cache_and_stream(self, txn, cache_func, keys):
- txn.call_after(cache_func.invalidate, keys)
- txn.call_after(self._send_invalidation_poke, cache_func, keys)
-
- def _send_invalidation_poke(self, cache_func, keys):
- self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index ebe94909cb..9db6c62bc7 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -24,7 +24,13 @@ from synapse.storage.database import Database
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs):
self._account_data_id_gen = SlavedIdTracker(
- db_conn, "account_data_max_stream_id", "stream_id"
+ db_conn,
+ "account_data",
+ "stream_id",
+ extra_tables=[
+ ("room_account_data", "stream_id"),
+ ("room_tags_revisions", "stream_id"),
+ ],
)
super(SlavedAccountDataStore, self).__init__(database, db_conn, hs)
@@ -32,15 +38,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
- def stream_positions(self):
- result = super(SlavedAccountDataStore, self).stream_positions()
- position = self._account_data_id_gen.get_current_token()
- result["user_account_data"] = position
- result["room_account_data"] = position
- result["tag_account_data"] = position
- return result
-
- def process_replication_rows(self, stream_name, token, rows):
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "tag_account_data":
self._account_data_id_gen.advance(token)
for row in rows:
@@ -59,6 +57,4 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
(row.user_id, row.room_id, row.data_type)
)
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
- return super(SlavedAccountDataStore, self).process_replication_rows(
- stream_name, token, rows
- )
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index fbf996e33a..1a38f53dfb 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -15,7 +15,6 @@
from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.storage.database import Database
-from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.descriptors import Cache
from ._base import BaseSlavedStore
@@ -26,7 +25,7 @@ class SlavedClientIpStore(BaseSlavedStore):
super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
self.client_ip_last_seen = Cache(
- name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
+ name="client_ip_last_seen", keylen=4, max_entries=50000
)
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 0c237c6e0f..6e7fd259d4 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -43,12 +43,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
expiry_ms=30 * 60 * 1000,
)
- def stream_positions(self):
- result = super(SlavedDeviceInboxStore, self).stream_positions()
- result["to_device"] = self._device_inbox_id_gen.get_current_token()
- return result
-
- def process_replication_rows(self, stream_name, token, rows):
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "to_device":
self._device_inbox_id_gen.advance(token)
for row in rows:
@@ -60,6 +55,4 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
self._device_federation_outbox_stream_cache.entity_has_changed(
row.entity, token
)
- return super(SlavedDeviceInboxStore, self).process_replication_rows(
- stream_name, token, rows
- )
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 1c77687eea..9d8067342f 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -29,7 +29,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
self.hs = hs
self._device_list_id_gen = SlavedIdTracker(
- db_conn, "device_lists_stream", "stream_id"
+ db_conn,
+ "device_lists_stream",
+ "stream_id",
+ extra_tables=[
+ ("user_signature_stream", "stream_id"),
+ ("device_lists_outbound_pokes", "stream_id"),
+ ],
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
@@ -42,36 +48,28 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
"DeviceListFederationStreamChangeCache", device_list_max
)
- def stream_positions(self):
- result = super(SlavedDeviceStore, self).stream_positions()
- # The user signature stream uses the same stream ID generator as the
- # device list stream, so set them both to the device list ID
- # generator's current token.
- current_token = self._device_list_id_gen.get_current_token()
- result[DeviceListsStream.NAME] = current_token
- result[UserSignatureStream.NAME] = current_token
- return result
-
- def process_replication_rows(self, stream_name, token, rows):
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
- for row in rows:
- self._invalidate_caches_for_devices(token, row.user_id, row.destination)
+ self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME:
+ self._device_list_id_gen.advance(token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
- return super(SlavedDeviceStore, self).process_replication_rows(
- stream_name, token, rows
- )
-
- def _invalidate_caches_for_devices(self, token, user_id, destination):
- self._device_list_stream_cache.entity_has_changed(user_id, token)
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
- if destination:
- self._device_list_federation_stream_cache.entity_has_changed(
- destination, token
- )
+ def _invalidate_caches_for_devices(self, token, rows):
+ for row in rows:
+ # The entities are either user IDs (starting with '@') whose devices
+ # have changed, or remote servers that we need to tell about
+ # changes.
+ if row.entity.startswith("@"):
+ self._device_list_stream_cache.entity_has_changed(row.entity, token)
+ self.get_cached_devices_for_user.invalidate((row.entity,))
+ self._get_cached_user_device.invalidate_many((row.entity,))
+ self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
- self.get_cached_devices_for_user.invalidate((user_id,))
- self._get_cached_user_device.invalidate_many((user_id,))
- self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
+ else:
+ self._device_list_federation_stream_cache.entity_has_changed(
+ row.entity, token
+ )
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index e73342c657..1a1a50a24f 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -15,11 +15,6 @@
# limitations under the License.
import logging
-from synapse.api.constants import EventTypes
-from synapse.replication.tcp.streams.events import (
- EventsStreamCurrentStateRow,
- EventsStreamEventRow,
-)
from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore
from synapse.storage.data_stores.main.event_push_actions import (
EventPushActionsWorkerStore,
@@ -35,7 +30,6 @@ from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
logger = logging.getLogger(__name__)
@@ -62,11 +56,6 @@ class SlavedEventStore(
BaseSlavedStore,
):
def __init__(self, database: Database, db_conn, hs):
- self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
- self._backfill_id_gen = SlavedIdTracker(
- db_conn, "events", "stream_ordering", step=-1
- )
-
super(SlavedEventStore, self).__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token()
@@ -92,89 +81,3 @@ class SlavedEventStore(
def get_room_min_stream_ordering(self):
return self._backfill_id_gen.get_current_token()
-
- def stream_positions(self):
- result = super(SlavedEventStore, self).stream_positions()
- result["events"] = self._stream_id_gen.get_current_token()
- result["backfill"] = -self._backfill_id_gen.get_current_token()
- return result
-
- def process_replication_rows(self, stream_name, token, rows):
- if stream_name == "events":
- self._stream_id_gen.advance(token)
- for row in rows:
- self._process_event_stream_row(token, row)
- elif stream_name == "backfill":
- self._backfill_id_gen.advance(-token)
- for row in rows:
- self.invalidate_caches_for_event(
- -token,
- row.event_id,
- row.room_id,
- row.type,
- row.state_key,
- row.redacts,
- row.relates_to,
- backfilled=True,
- )
- return super(SlavedEventStore, self).process_replication_rows(
- stream_name, token, rows
- )
-
- def _process_event_stream_row(self, token, row):
- data = row.data
-
- if row.type == EventsStreamEventRow.TypeId:
- self.invalidate_caches_for_event(
- token,
- data.event_id,
- data.room_id,
- data.type,
- data.state_key,
- data.redacts,
- data.relates_to,
- backfilled=False,
- )
- elif row.type == EventsStreamCurrentStateRow.TypeId:
- self._curr_state_delta_stream_cache.entity_has_changed(
- row.data.room_id, token
- )
-
- if data.type == EventTypes.Member:
- self.get_rooms_for_user_with_stream_ordering.invalidate(
- (data.state_key,)
- )
- else:
- raise Exception("Unknown events stream row type %s" % (row.type,))
-
- def invalidate_caches_for_event(
- self,
- stream_ordering,
- event_id,
- room_id,
- etype,
- state_key,
- redacts,
- relates_to,
- backfilled,
- ):
- self._invalidate_get_event_cache(event_id)
-
- self.get_latest_event_ids_in_room.invalidate((room_id,))
-
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
-
- if not backfilled:
- self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
-
- if redacts:
- self._invalidate_get_event_cache(redacts)
-
- if etype == EventTypes.Member:
- self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
- self.get_invited_rooms_for_local_user.invalidate((state_key,))
-
- if relates_to:
- self.get_relations_for_event.invalidate_many((relates_to,))
- self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
- self.get_applicable_edit.invalidate((relates_to,))
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 2d4fd08cf5..1851e7d525 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -37,17 +37,10 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token()
- def stream_positions(self):
- result = super(SlavedGroupServerStore, self).stream_positions()
- result["groups"] = self._group_updates_id_gen.get_current_token()
- return result
-
- def process_replication_rows(self, stream_name, token, rows):
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "groups":
self._group_updates_id_gen.advance(token)
for row in rows:
self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
- return super(SlavedGroupServerStore, self).process_replication_rows(
- stream_name, token, rows
- )
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index ad8f0c15a9..4e0124842d 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -18,7 +18,7 @@ from synapse.storage.data_stores.main.presence import PresenceStore
from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache
-from ._base import BaseSlavedStore, __func__
+from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
@@ -27,35 +27,24 @@ class SlavedPresenceStore(BaseSlavedStore):
super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
- self._presence_on_startup = self._get_active_presence(db_conn)
+ self._presence_on_startup = self._get_active_presence(db_conn) # type: ignore
self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
)
- _get_active_presence = __func__(DataStore._get_active_presence)
- take_presence_startup_info = __func__(DataStore.take_presence_startup_info)
+ _get_active_presence = DataStore._get_active_presence
+ take_presence_startup_info = DataStore.take_presence_startup_info
_get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"]
get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"]
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
- def stream_positions(self):
- result = super(SlavedPresenceStore, self).stream_positions()
-
- if self.hs.config.use_presence:
- position = self._presence_id_gen.get_current_token()
- result["presence"] = position
-
- return result
-
- def process_replication_rows(self, stream_name, token, rows):
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "presence":
self._presence_id_gen.advance(token)
for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token)
self._get_presence_for_user.invalidate((row.user_id,))
- return super(SlavedPresenceStore, self).process_replication_rows(
- stream_name, token, rows
- )
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index eebd5a1fb6..6adb19463a 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -15,19 +15,11 @@
# limitations under the License.
from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
-from synapse.storage.database import Database
-from ._slaved_id_tracker import SlavedIdTracker
from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
- def __init__(self, database: Database, db_conn, hs):
- self._push_rules_stream_id_gen = SlavedIdTracker(
- db_conn, "push_rules_stream", "stream_id"
- )
- super(SlavedPushRuleStore, self).__init__(database, db_conn, hs)
-
def get_push_rules_stream_token(self):
return (
self._push_rules_stream_id_gen.get_current_token(),
@@ -37,18 +29,11 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()
- def stream_positions(self):
- result = super(SlavedPushRuleStore, self).stream_positions()
- result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
- return result
-
- def process_replication_rows(self, stream_name, token, rows):
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "push_rules":
self._push_rules_stream_id_gen.advance(token)
for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,))
self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
- return super(SlavedPushRuleStore, self).process_replication_rows(
- stream_name, token, rows
- )
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index f22c2d44a3..cb78b49acb 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -28,14 +28,10 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
- def stream_positions(self):
- result = super(SlavedPusherStore, self).stream_positions()
- result["pushers"] = self._pushers_id_gen.get_current_token()
- return result
+ def get_pushers_stream_token(self):
+ return self._pushers_id_gen.get_current_token()
- def process_replication_rows(self, stream_name, token, rows):
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "pushers":
self._pushers_id_gen.advance(token)
- return super(SlavedPusherStore, self).process_replication_rows(
- stream_name, token, rows
- )
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index d40dc6e1f5..be716cc558 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -42,11 +42,6 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
- def stream_positions(self):
- result = super(SlavedReceiptsStore, self).stream_positions()
- result["receipts"] = self._receipts_id_gen.get_current_token()
- return result
-
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate_many((room_id,))
@@ -56,7 +51,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type))
- def process_replication_rows(self, stream_name, token, rows):
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "receipts":
self._receipts_id_gen.advance(token)
for row in rows:
@@ -65,6 +60,4 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
)
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
- return super(SlavedReceiptsStore, self).process_replication_rows(
- stream_name, token, rows
- )
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 3a20f45316..8873bf37e5 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -30,13 +30,8 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
- def stream_positions(self):
- result = super(RoomStore, self).stream_positions()
- result["public_rooms"] = self._public_room_id_gen.get_current_token()
- return result
-
- def process_replication_rows(self, stream_name, token, rows):
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "public_rooms":
self._public_room_id_gen.advance(token)
- return super(RoomStore, self).process_replication_rows(stream_name, token, rows)
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
|