diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index cb202bda44..e60988fa4a 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -100,7 +100,16 @@ from synapse.rest.client.v1.profile import (
)
from synapse.rest.client.v1.push_rule import PushRuleRestServlet
from synapse.rest.client.v1.voip import VoipRestServlet
-from synapse.rest.client.v2_alpha import groups, room_keys, sync, user_directory
+from synapse.rest.client.v2_alpha import (
+ account_data,
+ groups,
+ read_marker,
+ receipts,
+ room_keys,
+ sync,
+ tags,
+ user_directory,
+)
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
from synapse.rest.client.v2_alpha.account_data import (
@@ -531,6 +540,10 @@ class GenericWorkerServer(HomeServer):
room.register_deprecated_servlets(self, resource)
InitialSyncRestServlet(self).register(resource)
room_keys.register_servlets(self, resource)
+ tags.register_servlets(self, resource)
+ account_data.register_servlets(self, resource)
+ receipts.register_servlets(self, resource)
+ read_marker.register_servlets(self, resource)
SendToDeviceRestServlet(self).register(resource)
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 364583f48b..f10e33f7b8 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -56,6 +56,12 @@ class WriterLocations:
to_device = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter,
)
+ account_data = attr.ib(
+ default=["master"], type=List[str], converter=_instance_to_list_converter,
+ )
+ receipts = attr.ib(
+ default=["master"], type=List[str], converter=_instance_to_list_converter,
+ )
class WorkerConfig(Config):
@@ -127,7 +133,7 @@ class WorkerConfig(Config):
# Check that the configured writers for events and typing also appears in
# `instance_map`.
- for stream in ("events", "typing", "to_device"):
+ for stream in ("events", "typing", "to_device", "account_data", "receipts"):
instances = _instance_to_list_converter(getattr(self.writers, stream))
for instance in instances:
if instance != "master" and instance not in self.instance_map:
@@ -141,6 +147,16 @@ class WorkerConfig(Config):
"Must only specify one instance to handle `to_device` messages."
)
+ if len(self.writers.account_data) != 1:
+ raise ConfigError(
+ "Must only specify one instance to handle `account_data` messages."
+ )
+
+ if len(self.writers.receipts) != 1:
+ raise ConfigError(
+ "Must only specify one instance to handle `receipts` messages."
+ )
+
self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
# Whether this worker should run background tasks or not.
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 341135822e..b1a5df9638 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,14 +13,157 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import random
from typing import TYPE_CHECKING, List, Tuple
+from synapse.replication.http.account_data import (
+ ReplicationAddTagRestServlet,
+ ReplicationRemoveTagRestServlet,
+ ReplicationRoomAccountDataRestServlet,
+ ReplicationUserAccountDataRestServlet,
+)
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
+class AccountDataHandler:
+ def __init__(self, hs: "HomeServer"):
+ self._store = hs.get_datastore()
+ self._instance_name = hs.get_instance_name()
+ self._notifier = hs.get_notifier()
+
+ self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs)
+ self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs)
+ self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs)
+ self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
+ self._account_data_writers = hs.config.worker.writers.account_data
+
+ async def add_account_data_to_room(
+ self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
+ ) -> int:
+ """Add some account_data to a room for a user.
+
+ Args:
+ user_id: The user to add a tag for.
+ room_id: The room to add a tag for.
+ account_data_type: The type of account_data to add.
+ content: A json object to associate with the tag.
+
+ Returns:
+ The maximum stream ID.
+ """
+ if self._instance_name in self._account_data_writers:
+ max_stream_id = await self._store.add_account_data_to_room(
+ user_id, room_id, account_data_type, content
+ )
+
+ self._notifier.on_new_event(
+ "account_data_key", max_stream_id, users=[user_id]
+ )
+
+ return max_stream_id
+ else:
+ response = await self._room_data_client(
+ instance_name=random.choice(self._account_data_writers),
+ user_id=user_id,
+ room_id=room_id,
+ account_data_type=account_data_type,
+ content=content,
+ )
+ return response["max_stream_id"]
+
+ async def add_account_data_for_user(
+ self, user_id: str, account_data_type: str, content: JsonDict
+ ) -> int:
+ """Add some account_data to a room for a user.
+
+ Args:
+ user_id: The user to add a tag for.
+ account_data_type: The type of account_data to add.
+ content: A json object to associate with the tag.
+
+ Returns:
+ The maximum stream ID.
+ """
+
+ if self._instance_name in self._account_data_writers:
+ max_stream_id = await self._store.add_account_data_for_user(
+ user_id, account_data_type, content
+ )
+
+ self._notifier.on_new_event(
+ "account_data_key", max_stream_id, users=[user_id]
+ )
+ return max_stream_id
+ else:
+ response = await self._user_data_client(
+ instance_name=random.choice(self._account_data_writers),
+ user_id=user_id,
+ account_data_type=account_data_type,
+ content=content,
+ )
+ return response["max_stream_id"]
+
+ async def add_tag_to_room(
+ self, user_id: str, room_id: str, tag: str, content: JsonDict
+ ) -> int:
+ """Add a tag to a room for a user.
+
+ Args:
+ user_id: The user to add a tag for.
+ room_id: The room to add a tag for.
+ tag: The tag name to add.
+ content: A json object to associate with the tag.
+
+ Returns:
+ The next account data ID.
+ """
+ if self._instance_name in self._account_data_writers:
+ max_stream_id = await self._store.add_tag_to_room(
+ user_id, room_id, tag, content
+ )
+
+ self._notifier.on_new_event(
+ "account_data_key", max_stream_id, users=[user_id]
+ )
+ return max_stream_id
+ else:
+ response = await self._add_tag_client(
+ instance_name=random.choice(self._account_data_writers),
+ user_id=user_id,
+ room_id=room_id,
+ tag=tag,
+ content=content,
+ )
+ return response["max_stream_id"]
+
+ async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
+ """Remove a tag from a room for a user.
+
+ Returns:
+ The next account data ID.
+ """
+ if self._instance_name in self._account_data_writers:
+ max_stream_id = await self._store.remove_tag_from_room(
+ user_id, room_id, tag
+ )
+
+ self._notifier.on_new_event(
+ "account_data_key", max_stream_id, users=[user_id]
+ )
+ return max_stream_id
+ else:
+ response = await self._remove_tag_client(
+ instance_name=random.choice(self._account_data_writers),
+ user_id=user_id,
+ room_id=room_id,
+ tag=tag,
+ )
+ return response["max_stream_id"]
+
+
class AccountDataEventSource:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index a7550806e6..6bb2fd936b 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -31,8 +31,8 @@ class ReadMarkerHandler(BaseHandler):
super().__init__(hs)
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
+ self.account_data_handler = hs.get_account_data_handler()
self.read_marker_linearizer = Linearizer(name="read_marker")
- self.notifier = hs.get_notifier()
async def received_client_read_marker(
self, room_id: str, user_id: str, event_id: str
@@ -59,7 +59,6 @@ class ReadMarkerHandler(BaseHandler):
if should_update:
content = {"event_id": event_id}
- max_id = await self.store.add_account_data_to_room(
+ await self.account_data_handler.add_account_data_to_room(
user_id, room_id, "m.fully_read", content
)
- self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index a9abdf42e0..cc21fc2284 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -32,10 +32,26 @@ class ReceiptsHandler(BaseHandler):
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
self.hs = hs
- self.federation = hs.get_federation_sender()
- hs.get_federation_registry().register_edu_handler(
- "m.receipt", self._received_remote_receipt
- )
+
+ # We only need to poke the federation sender explicitly if its on the
+ # same instance. Other federation sender instances will get notified by
+ # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
+ # in the receipts stream.
+ self.federation_sender = None
+ if hs.should_send_federation():
+ self.federation_sender = hs.get_federation_sender()
+
+ # If we can handle the receipt EDUs we do so, otherwise we route them
+ # to the appropriate worker.
+ if hs.get_instance_name() in hs.config.worker.writers.receipts:
+ hs.get_federation_registry().register_edu_handler(
+ "m.receipt", self._received_remote_receipt
+ )
+ else:
+ hs.get_federation_registry().register_instances_for_edu(
+ "m.receipt", hs.config.worker.writers.receipts,
+ )
+
self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()
@@ -125,7 +141,8 @@ class ReceiptsHandler(BaseHandler):
if not is_new:
return
- await self.federation.send_read_receipt(receipt)
+ if self.federation_sender:
+ await self.federation_sender.send_read_receipt(receipt)
class ReceiptEventSource:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index cb5a29bc7e..e001e418f9 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -63,6 +63,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.registration_handler = hs.get_registration_handler()
self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler()
+ self.account_data_handler = hs.get_account_data_handler()
self.member_linearizer = Linearizer(name="member")
@@ -253,7 +254,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
direct_rooms[key].append(new_room_id)
# Save back to user's m.direct account data
- await self.store.add_account_data_for_user(
+ await self.account_data_handler.add_account_data_for_user(
user_id, AccountDataTypes.DIRECT, direct_rooms
)
break
@@ -263,7 +264,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Copy each room tag to the new room
for tag, tag_content in room_tags.items():
- await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
+ await self.account_data_handler.add_tag_to_room(
+ user_id, new_room_id, tag, tag_content
+ )
async def update_membership(
self,
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index a84a064c8d..dd527e807f 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -15,6 +15,7 @@
from synapse.http.server import JsonResource
from synapse.replication.http import (
+ account_data,
devices,
federation,
login,
@@ -40,6 +41,7 @@ class ReplicationRestResource(JsonResource):
presence.register_servlets(hs, self)
membership.register_servlets(hs, self)
streams.register_servlets(hs, self)
+ account_data.register_servlets(hs, self)
# The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None:
diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py
new file mode 100644
index 0000000000..52d32528ee
--- /dev/null
+++ b/synapse/replication/http/account_data.py
@@ -0,0 +1,187 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.http.servlet import parse_json_object_from_request
+from synapse.replication.http._base import ReplicationEndpoint
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
+ """Add user account data on the appropriate account data worker.
+
+ Request format:
+
+ POST /_synapse/replication/add_user_account_data/:user_id/:type
+
+ {
+ "content": { ... },
+ }
+
+ """
+
+ NAME = "add_user_account_data"
+ PATH_ARGS = ("user_id", "account_data_type")
+ CACHE = False
+
+ def __init__(self, hs):
+ super().__init__(hs)
+
+ self.handler = hs.get_account_data_handler()
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ async def _serialize_payload(user_id, account_data_type, content):
+ payload = {
+ "content": content,
+ }
+
+ return payload
+
+ async def _handle_request(self, request, user_id, account_data_type):
+ content = parse_json_object_from_request(request)
+
+ max_stream_id = await self.handler.add_account_data_for_user(
+ user_id, account_data_type, content["content"]
+ )
+
+ return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
+ """Add room account data on the appropriate account data worker.
+
+ Request format:
+
+ POST /_synapse/replication/add_room_account_data/:user_id/:room_id/:account_data_type
+
+ {
+ "content": { ... },
+ }
+
+ """
+
+ NAME = "add_room_account_data"
+ PATH_ARGS = ("user_id", "room_id", "account_data_type")
+ CACHE = False
+
+ def __init__(self, hs):
+ super().__init__(hs)
+
+ self.handler = hs.get_account_data_handler()
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ async def _serialize_payload(user_id, room_id, account_data_type, content):
+ payload = {
+ "content": content,
+ }
+
+ return payload
+
+ async def _handle_request(self, request, user_id, room_id, account_data_type):
+ content = parse_json_object_from_request(request)
+
+ max_stream_id = await self.handler.add_account_data_to_room(
+ user_id, room_id, account_data_type, content["content"]
+ )
+
+ return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationAddTagRestServlet(ReplicationEndpoint):
+ """Add tag on the appropriate account data worker.
+
+ Request format:
+
+ POST /_synapse/replication/add_tag/:user_id/:room_id/:tag
+
+ {
+ "content": { ... },
+ }
+
+ """
+
+ NAME = "add_tag"
+ PATH_ARGS = ("user_id", "room_id", "tag")
+ CACHE = False
+
+ def __init__(self, hs):
+ super().__init__(hs)
+
+ self.handler = hs.get_account_data_handler()
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ async def _serialize_payload(user_id, room_id, tag, content):
+ payload = {
+ "content": content,
+ }
+
+ return payload
+
+ async def _handle_request(self, request, user_id, room_id, tag):
+ content = parse_json_object_from_request(request)
+
+ max_stream_id = await self.handler.add_tag_to_room(
+ user_id, room_id, tag, content["content"]
+ )
+
+ return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
+ """Remove tag on the appropriate account data worker.
+
+ Request format:
+
+ POST /_synapse/replication/remove_tag/:user_id/:room_id/:tag
+
+ {}
+
+ """
+
+ NAME = "remove_tag"
+ PATH_ARGS = (
+ "user_id",
+ "room_id",
+ "tag",
+ )
+ CACHE = False
+
+ def __init__(self, hs):
+ super().__init__(hs)
+
+ self.handler = hs.get_account_data_handler()
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ async def _serialize_payload(user_id, room_id, tag):
+
+ return {}
+
+ async def _handle_request(self, request, user_id, room_id, tag):
+ max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,)
+
+ return 200, {"max_stream_id": max_stream_id}
+
+
+def register_servlets(hs, http_server):
+ ReplicationUserAccountDataRestServlet(hs).register(http_server)
+ ReplicationRoomAccountDataRestServlet(hs).register(http_server)
+ ReplicationAddTagRestServlet(hs).register(http_server)
+ ReplicationRemoveTagRestServlet(hs).register(http_server)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index d0089fe06c..693c9ab901 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -33,9 +33,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
database,
stream_name="caches",
instance_name=hs.get_instance_name(),
- table="cache_invalidation_stream_by_instance",
- instance_column="instance_name",
- id_column="stream_id",
+ tables=[
+ (
+ "cache_invalidation_stream_by_instance",
+ "instance_name",
+ "stream_id",
+ )
+ ],
sequence_name="cache_invalidation_stream_seq",
writers=[],
) # type: Optional[MultiWriterIdGenerator]
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 4268565fc8..21afe5f155 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -15,47 +15,9 @@
# limitations under the License.
from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
-from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.tags import TagsWorkerStore
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
- self._account_data_id_gen = SlavedIdTracker(
- db_conn,
- "account_data",
- "stream_id",
- extra_tables=[
- ("room_account_data", "stream_id"),
- ("room_tags_revisions", "stream_id"),
- ],
- )
-
- super().__init__(database, db_conn, hs)
-
- def get_max_account_data_stream_id(self):
- return self._account_data_id_gen.get_current_token()
-
- def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == TagAccountDataStream.NAME:
- self._account_data_id_gen.advance(instance_name, token)
- for row in rows:
- self.get_tags_for_user.invalidate((row.user_id,))
- self._account_data_stream_cache.entity_has_changed(row.user_id, token)
- elif stream_name == AccountDataStream.NAME:
- self._account_data_id_gen.advance(instance_name, token)
- for row in rows:
- if not row.room_id:
- self.get_global_account_data_by_type_for_user.invalidate(
- (row.data_type, row.user_id)
- )
- self.get_account_data_for_user.invalidate((row.user_id,))
- self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
- self.get_account_data_for_room_and_type.invalidate(
- (row.user_id, row.room_id, row.data_type)
- )
- self._account_data_stream_cache.entity_has_changed(row.user_id, token)
- return super().process_replication_rows(stream_name, instance_name, token, rows)
+ pass
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 6195917376..3dfdd9961d 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -14,43 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.replication.tcp.streams import ReceiptsStream
-from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
- # We instantiate this first as the ReceiptsWorkerStore constructor
- # needs to be able to call get_max_receipt_stream_id
- self._receipts_id_gen = SlavedIdTracker(
- db_conn, "receipts_linearized", "stream_id"
- )
-
- super().__init__(database, db_conn, hs)
-
- def get_max_receipt_stream_id(self):
- return self._receipts_id_gen.get_current_token()
-
- def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
- self.get_receipts_for_user.invalidate((user_id, receipt_type))
- self._get_linearized_receipts_for_room.invalidate_many((room_id,))
- self.get_last_receipt_event_id_for_user.invalidate(
- (user_id, room_id, receipt_type)
- )
- self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
- self.get_receipts_for_room.invalidate((room_id, receipt_type))
-
- def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == ReceiptsStream.NAME:
- self._receipts_id_gen.advance(instance_name, token)
- for row in rows:
- self.invalidate_caches_for_receipt(
- row.room_id, row.receipt_type, row.user_id
- )
- self._receipts_stream_cache.entity_has_changed(row.room_id, token)
-
- return super().process_replication_rows(stream_name, instance_name, token, rows)
+ pass
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 1f89249475..317796d5e0 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -51,11 +51,14 @@ from synapse.replication.tcp.commands import (
from synapse.replication.tcp.protocol import AbstractConnection
from synapse.replication.tcp.streams import (
STREAMS_MAP,
+ AccountDataStream,
BackfillStream,
CachesStream,
EventsStream,
FederationStream,
+ ReceiptsStream,
Stream,
+ TagAccountDataStream,
ToDeviceStream,
TypingStream,
)
@@ -132,6 +135,22 @@ class ReplicationCommandHandler:
continue
+ if isinstance(stream, (AccountDataStream, TagAccountDataStream)):
+ # Only add AccountDataStream and TagAccountDataStream as a source on the
+ # instance in charge of account_data persistence.
+ if hs.get_instance_name() in hs.config.worker.writers.account_data:
+ self._streams_to_replicate.append(stream)
+
+ continue
+
+ if isinstance(stream, ReceiptsStream):
+ # Only add ReceiptsStream as a source on the instance in charge of
+ # receipts.
+ if hs.get_instance_name() in hs.config.worker.writers.receipts:
+ self._streams_to_replicate.append(stream)
+
+ continue
+
# Only add any other streams if we're on master.
if hs.config.worker_app is not None:
continue
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 87a5b1b86b..3f28c0bc3e 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -37,24 +37,16 @@ class AccountDataServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- self.notifier = hs.get_notifier()
- self._is_worker = hs.config.worker_app is not None
+ self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, account_data_type):
- if self._is_worker:
- raise Exception("Cannot handle PUT /account_data on worker")
-
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
body = parse_json_object_from_request(request)
- max_id = await self.store.add_account_data_for_user(
- user_id, account_data_type, body
- )
-
- self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+ await self.handler.add_account_data_for_user(user_id, account_data_type, body)
return 200, {}
@@ -89,13 +81,9 @@ class RoomAccountDataServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- self.notifier = hs.get_notifier()
- self._is_worker = hs.config.worker_app is not None
+ self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, room_id, account_data_type):
- if self._is_worker:
- raise Exception("Cannot handle PUT /account_data on worker")
-
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
@@ -109,12 +97,10 @@ class RoomAccountDataServlet(RestServlet):
" Use /rooms/!roomId:server.name/read_markers",
)
- max_id = await self.store.add_account_data_to_room(
+ await self.handler.add_account_data_to_room(
user_id, room_id, account_data_type, body
)
- self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
-
return 200, {}
async def on_GET(self, request, user_id, room_id, account_data_type):
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index bf3a79db44..a97cd66c52 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -58,8 +58,7 @@ class TagServlet(RestServlet):
def __init__(self, hs):
super().__init__()
self.auth = hs.get_auth()
- self.store = hs.get_datastore()
- self.notifier = hs.get_notifier()
+ self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, room_id, tag):
requester = await self.auth.get_user_by_req(request)
@@ -68,9 +67,7 @@ class TagServlet(RestServlet):
body = parse_json_object_from_request(request)
- max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body)
-
- self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+ await self.handler.add_tag_to_room(user_id, room_id, tag, body)
return 200, {}
@@ -79,9 +76,7 @@ class TagServlet(RestServlet):
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
- max_id = await self.store.remove_tag_from_room(user_id, room_id, tag)
-
- self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+ await self.handler.remove_tag_from_room(user_id, room_id, tag)
return 200, {}
diff --git a/synapse/server.py b/synapse/server.py
index d4c235cda5..9cdda83aa1 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -55,6 +55,7 @@ from synapse.federation.sender import FederationSender
from synapse.federation.transport.client import TransportLayerClient
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler
+from synapse.handlers.account_data import AccountDataHandler
from synapse.handlers.account_validity import AccountValidityHandler
from synapse.handlers.acme import AcmeHandler
from synapse.handlers.admin import AdminHandler
@@ -711,6 +712,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_module_api(self) -> ModuleApi:
return ModuleApi(self, self.get_auth_handler())
+ @cache_in_self
+ def get_account_data_handler(self) -> AccountDataHandler:
+ return AccountDataHandler(self)
+
async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index c4de07a0a8..ae561a2da3 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -160,9 +160,13 @@ class DataStore(
database,
stream_name="caches",
instance_name=hs.get_instance_name(),
- table="cache_invalidation_stream_by_instance",
- instance_column="instance_name",
- id_column="stream_id",
+ tables=[
+ (
+ "cache_invalidation_stream_by_instance",
+ "instance_name",
+ "stream_id",
+ )
+ ],
sequence_name="cache_invalidation_stream_seq",
writers=[],
)
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index bad8260892..68896f34af 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,14 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import abc
import logging
from typing import Dict, List, Optional, Set, Tuple
from synapse.api.constants import AccountDataTypes
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -30,14 +32,57 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
-# The ABCMeta metaclass ensures that it cannot be instantiated without
-# the abstract methods being implemented.
-class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
+class AccountDataWorkerStore(SQLBaseStore):
"""This is an abstract base class where subclasses must implement
`get_max_account_data_stream_id` which can be called in the initializer.
"""
def __init__(self, database: DatabasePool, db_conn, hs):
+ self._instance_name = hs.get_instance_name()
+
+ if isinstance(database.engine, PostgresEngine):
+ self._can_write_to_account_data = (
+ self._instance_name in hs.config.worker.writers.account_data
+ )
+
+ self._account_data_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ stream_name="account_data",
+ instance_name=self._instance_name,
+ tables=[
+ ("room_account_data", "instance_name", "stream_id"),
+ ("room_tags_revisions", "instance_name", "stream_id"),
+ ("account_data", "instance_name", "stream_id"),
+ ],
+ sequence_name="account_data_sequence",
+ writers=hs.config.worker.writers.account_data,
+ )
+ else:
+ self._can_write_to_account_data = True
+
+ # We shouldn't be running in worker mode with SQLite, but its useful
+ # to support it for unit tests.
+ #
+ # If this process is the writer than we need to use
+ # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+ # updated over replication. (Multiple writers are not supported for
+ # SQLite).
+ if hs.get_instance_name() in hs.config.worker.writers.events:
+ self._account_data_id_gen = StreamIdGenerator(
+ db_conn,
+ "room_account_data",
+ "stream_id",
+ extra_tables=[("room_tags_revisions", "stream_id")],
+ )
+ else:
+ self._account_data_id_gen = SlavedIdTracker(
+ db_conn,
+ "room_account_data",
+ "stream_id",
+ extra_tables=[("room_tags_revisions", "stream_id")],
+ )
+
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max
@@ -45,14 +90,13 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
super().__init__(database, db_conn, hs)
- @abc.abstractmethod
- def get_max_account_data_stream_id(self):
+ def get_max_account_data_stream_id(self) -> int:
"""Get the current max stream ID for account data stream
Returns:
int
"""
- raise NotImplementedError()
+ return self._account_data_id_gen.get_current_token()
@cached()
async def get_account_data_for_user(
@@ -307,25 +351,26 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
)
)
-
-class AccountDataStore(AccountDataWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
- self._account_data_id_gen = StreamIdGenerator(
- db_conn,
- "room_account_data",
- "stream_id",
- extra_tables=[("room_tags_revisions", "stream_id")],
- )
-
- super().__init__(database, db_conn, hs)
-
- def get_max_account_data_stream_id(self) -> int:
- """Get the current max stream id for the private user data stream
-
- Returns:
- The maximum stream ID.
- """
- return self._account_data_id_gen.get_current_token()
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
+ if stream_name == TagAccountDataStream.NAME:
+ self._account_data_id_gen.advance(instance_name, token)
+ for row in rows:
+ self.get_tags_for_user.invalidate((row.user_id,))
+ self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+ elif stream_name == AccountDataStream.NAME:
+ self._account_data_id_gen.advance(instance_name, token)
+ for row in rows:
+ if not row.room_id:
+ self.get_global_account_data_by_type_for_user.invalidate(
+ (row.data_type, row.user_id)
+ )
+ self.get_account_data_for_user.invalidate((row.user_id,))
+ self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
+ self.get_account_data_for_room_and_type.invalidate(
+ (row.user_id, row.room_id, row.data_type)
+ )
+ self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
@@ -341,6 +386,8 @@ class AccountDataStore(AccountDataWorkerStore):
Returns:
The maximum stream ID.
"""
+ assert self._can_write_to_account_data
+
content_json = json_encoder.encode(content)
async with self._account_data_id_gen.get_next() as next_id:
@@ -381,6 +428,8 @@ class AccountDataStore(AccountDataWorkerStore):
Returns:
The maximum stream ID.
"""
+ assert self._can_write_to_account_data
+
async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"add_user_account_data",
@@ -463,3 +512,7 @@ class AccountDataStore(AccountDataWorkerStore):
# Invalidate the cache for any ignored users which were added or removed.
for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
+
+
+class AccountDataStore(AccountDataWorkerStore):
+ pass
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 58d3f71e45..31f70ac5ef 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -54,9 +54,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
db=database,
stream_name="to_device",
instance_name=self._instance_name,
- table="device_inbox",
- instance_column="instance_name",
- id_column="stream_id",
+ tables=[("device_inbox", "instance_name", "stream_id")],
sequence_name="device_inbox_sequence",
writers=hs.config.worker.writers.to_device,
)
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index e5c03cc609..1b657191a9 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -835,6 +835,52 @@ class EventPushActionsWorkerStore(SQLBaseStore):
(rotate_to_stream_ordering,),
)
+ def _remove_old_push_actions_before_txn(
+ self, txn, room_id, user_id, stream_ordering
+ ):
+ """
+ Purges old push actions for a user and room before a given
+ stream_ordering.
+
+ We however keep a months worth of highlighted notifications, so that
+ users can still get a list of recent highlights.
+
+ Args:
+ txn: The transcation
+ room_id: Room ID to delete from
+ user_id: user ID to delete for
+ stream_ordering: The lowest stream ordering which will
+ not be deleted.
+ """
+ txn.call_after(
+ self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+ (room_id, user_id),
+ )
+
+ # We need to join on the events table to get the received_ts for
+ # event_push_actions and sqlite won't let us use a join in a delete so
+ # we can't just delete where received_ts < x. Furthermore we can
+ # only identify event_push_actions by a tuple of room_id, event_id
+ # we we can't use a subquery.
+ # Instead, we look up the stream ordering for the last event in that
+ # room received before the threshold time and delete event_push_actions
+ # in the room with a stream_odering before that.
+ txn.execute(
+ "DELETE FROM event_push_actions "
+ " WHERE user_id = ? AND room_id = ? AND "
+ " stream_ordering <= ?"
+ " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
+ (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
+ )
+
+ txn.execute(
+ """
+ DELETE FROM event_push_summary
+ WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
+ """,
+ (room_id, user_id, stream_ordering),
+ )
+
class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
@@ -894,52 +940,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions
- def _remove_old_push_actions_before_txn(
- self, txn, room_id, user_id, stream_ordering
- ):
- """
- Purges old push actions for a user and room before a given
- stream_ordering.
-
- We however keep a months worth of highlighted notifications, so that
- users can still get a list of recent highlights.
-
- Args:
- txn: The transcation
- room_id: Room ID to delete from
- user_id: user ID to delete for
- stream_ordering: The lowest stream ordering which will
- not be deleted.
- """
- txn.call_after(
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
- (room_id, user_id),
- )
-
- # We need to join on the events table to get the received_ts for
- # event_push_actions and sqlite won't let us use a join in a delete so
- # we can't just delete where received_ts < x. Furthermore we can
- # only identify event_push_actions by a tuple of room_id, event_id
- # we we can't use a subquery.
- # Instead, we look up the stream ordering for the last event in that
- # room received before the threshold time and delete event_push_actions
- # in the room with a stream_odering before that.
- txn.execute(
- "DELETE FROM event_push_actions "
- " WHERE user_id = ? AND room_id = ? AND "
- " stream_ordering <= ?"
- " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
- (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
- )
-
- txn.execute(
- """
- DELETE FROM event_push_summary
- WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
- """,
- (room_id, user_id, stream_ordering),
- )
-
def _action_has_highlight(actions):
for action in actions:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 4732685f6e..71d823be72 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -96,9 +96,7 @@ class EventsWorkerStore(SQLBaseStore):
db=database,
stream_name="events",
instance_name=hs.get_instance_name(),
- table="events",
- instance_column="instance_name",
- id_column="stream_ordering",
+ tables=[("events", "instance_name", "stream_ordering")],
sequence_name="events_stream_seq",
writers=hs.config.worker.writers.events,
)
@@ -107,9 +105,7 @@ class EventsWorkerStore(SQLBaseStore):
db=database,
stream_name="backfill",
instance_name=hs.get_instance_name(),
- table="events",
- instance_column="instance_name",
- id_column="stream_ordering",
+ tables=[("events", "instance_name", "stream_ordering")],
sequence_name="events_backfill_stream_seq",
positive=False,
writers=hs.config.worker.writers.events,
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 1e7949a323..e0e57f0578 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,15 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import abc
import logging
from typing import Any, Dict, List, Optional, Tuple
from twisted.internet import defer
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@@ -31,28 +33,56 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
-# The ABCMeta metaclass ensures that it cannot be instantiated without
-# the abstract methods being implemented.
-class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
- """This is an abstract base class where subclasses must implement
- `get_max_receipt_stream_id` which can be called in the initializer.
- """
-
+class ReceiptsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
+ self._instance_name = hs.get_instance_name()
+
+ if isinstance(database.engine, PostgresEngine):
+ self._can_write_to_receipts = (
+ self._instance_name in hs.config.worker.writers.receipts
+ )
+
+ self._receipts_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ stream_name="account_data",
+ instance_name=self._instance_name,
+ tables=[("receipts_linearized", "instance_name", "stream_id")],
+ sequence_name="receipts_sequence",
+ writers=hs.config.worker.writers.receipts,
+ )
+ else:
+ self._can_write_to_receipts = True
+
+ # We shouldn't be running in worker mode with SQLite, but its useful
+ # to support it for unit tests.
+ #
+ # If this process is the writer than we need to use
+ # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+ # updated over replication. (Multiple writers are not supported for
+ # SQLite).
+ if hs.get_instance_name() in hs.config.worker.writers.events:
+ self._receipts_id_gen = StreamIdGenerator(
+ db_conn, "receipts_linearized", "stream_id"
+ )
+ else:
+ self._receipts_id_gen = SlavedIdTracker(
+ db_conn, "receipts_linearized", "stream_id"
+ )
+
super().__init__(database, db_conn, hs)
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
)
- @abc.abstractmethod
def get_max_receipt_stream_id(self):
"""Get the current max stream ID for receipts stream
Returns:
int
"""
- raise NotImplementedError()
+ return self._receipts_id_gen.get_current_token()
@cached()
async def get_users_with_read_receipts_in_room(self, room_id):
@@ -428,19 +458,25 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
-
-class ReceiptsStore(ReceiptsWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
- # We instantiate this first as the ReceiptsWorkerStore constructor
- # needs to be able to call get_max_receipt_stream_id
- self._receipts_id_gen = StreamIdGenerator(
- db_conn, "receipts_linearized", "stream_id"
+ def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
+ self.get_receipts_for_user.invalidate((user_id, receipt_type))
+ self._get_linearized_receipts_for_room.invalidate_many((room_id,))
+ self.get_last_receipt_event_id_for_user.invalidate(
+ (user_id, room_id, receipt_type)
)
+ self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
+ self.get_receipts_for_room.invalidate((room_id, receipt_type))
+
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
+ if stream_name == ReceiptsStream.NAME:
+ self._receipts_id_gen.advance(instance_name, token)
+ for row in rows:
+ self.invalidate_caches_for_receipt(
+ row.room_id, row.receipt_type, row.user_id
+ )
+ self._receipts_stream_cache.entity_has_changed(row.room_id, token)
- super().__init__(database, db_conn, hs)
-
- def get_max_receipt_stream_id(self):
- return self._receipts_id_gen.get_current_token()
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
def insert_linearized_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
@@ -452,6 +488,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown)
"""
+ assert self._can_write_to_receipts
+
res = self.db_pool.simple_select_one_txn(
txn,
table="events",
@@ -483,28 +521,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
)
return None
- txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
- txn.call_after(
- self._invalidate_get_users_with_receipts_in_room,
- room_id,
- receipt_type,
- user_id,
- )
- txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
- # FIXME: This shouldn't invalidate the whole cache
txn.call_after(
- self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
+ self.invalidate_caches_for_receipt, room_id, receipt_type, user_id
)
txn.call_after(
self._receipts_stream_cache.entity_has_changed, room_id, stream_id
)
- txn.call_after(
- self.get_last_receipt_event_id_for_user.invalidate,
- (user_id, room_id, receipt_type),
- )
-
self.db_pool.simple_upsert_txn(
txn,
table="receipts_linearized",
@@ -543,6 +567,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
Automatically does conversion between linearized and graph
representations.
"""
+ assert self._can_write_to_receipts
+
if not event_ids:
return None
@@ -607,6 +633,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
async def insert_graph_receipt(
self, room_id, receipt_type, user_id, event_ids, data
):
+ assert self._can_write_to_receipts
+
return await self.db_pool.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
@@ -620,6 +648,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
def insert_graph_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_ids, data
):
+ assert self._can_write_to_receipts
+
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
txn.call_after(
self._invalidate_get_users_with_receipts_in_room,
@@ -653,3 +683,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"data": json_encoder.encode(data),
},
)
+
+
+class ReceiptsStore(ReceiptsWorkerStore):
+ pass
diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql
new file mode 100644
index 0000000000..46abf8d562
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql
@@ -0,0 +1,20 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE room_account_data ADD COLUMN instance_name TEXT;
+ALTER TABLE room_tags_revisions ADD COLUMN instance_name TEXT;
+ALTER TABLE account_data ADD COLUMN instance_name TEXT;
+
+ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres
new file mode 100644
index 0000000000..4a6e6c74f5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres
@@ -0,0 +1,32 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS account_data_sequence;
+
+-- We need to take the max across all the account_data tables as they share the
+-- ID generator
+SELECT setval('account_data_sequence', (
+ SELECT GREATEST(
+ (SELECT COALESCE(MAX(stream_id), 1) FROM room_account_data),
+ (SELECT COALESCE(MAX(stream_id), 1) FROM room_tags_revisions),
+ (SELECT COALESCE(MAX(stream_id), 1) FROM account_data)
+ )
+));
+
+CREATE SEQUENCE IF NOT EXISTS receipts_sequence;
+
+SELECT setval('receipts_sequence', (
+ SELECT COALESCE(MAX(stream_id), 1) FROM receipts_linearized
+));
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 74da9c49f2..50067eabfc 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -183,8 +183,6 @@ class TagsWorkerStore(AccountDataWorkerStore):
)
return {row["tag"]: db_to_json(row["content"]) for row in rows}
-
-class TagsStore(TagsWorkerStore):
async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict
) -> int:
@@ -199,6 +197,8 @@ class TagsStore(TagsWorkerStore):
Returns:
The next account data ID.
"""
+ assert self._can_write_to_account_data
+
content_json = json_encoder.encode(content)
def add_tag_txn(txn, next_id):
@@ -223,6 +223,7 @@ class TagsStore(TagsWorkerStore):
Returns:
The next account data ID.
"""
+ assert self._can_write_to_account_data
def remove_tag_txn(txn, next_id):
sql = (
@@ -250,6 +251,7 @@ class TagsStore(TagsWorkerStore):
room_id: The ID of the room.
next_id: The the revision to advance to.
"""
+ assert self._can_write_to_account_data
txn.call_after(
self._account_data_stream_cache.entity_has_changed, user_id, next_id
@@ -278,3 +280,7 @@ class TagsStore(TagsWorkerStore):
# which stream_id ends up in the table, as long as it is higher
# than the id that the client has.
pass
+
+
+class TagsStore(TagsWorkerStore):
+ pass
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 133c0e7a28..39a3ab1162 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -17,7 +17,7 @@ import logging
import threading
from collections import deque
from contextlib import contextmanager
-from typing import Dict, List, Optional, Set, Union
+from typing import Dict, List, Optional, Set, Tuple, Union
import attr
from typing_extensions import Deque
@@ -186,11 +186,12 @@ class MultiWriterIdGenerator:
Args:
db_conn
db
- stream_name: A name for the stream.
+ stream_name: A name for the stream, for use in the `stream_positions`
+ table. (Does not need to be the same as the replication stream name)
instance_name: The name of this instance.
- table: Database table associated with stream.
- instance_column: Column that stores the row's writer's instance name
- id_column: Column that stores the stream ID.
+ tables: List of tables associated with the stream. Tuple of table
+ name, column name that stores the writer's instance name, and
+ column name that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new
IDs.
writers: A list of known writers to use to populate current positions
@@ -206,9 +207,7 @@ class MultiWriterIdGenerator:
db: DatabasePool,
stream_name: str,
instance_name: str,
- table: str,
- instance_column: str,
- id_column: str,
+ tables: List[Tuple[str, str, str]],
sequence_name: str,
writers: List[str],
positive: bool = True,
@@ -260,15 +259,16 @@ class MultiWriterIdGenerator:
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
# We check that the table and sequence haven't diverged.
- self._sequence_gen.check_consistency(
- db_conn, table=table, id_column=id_column, positive=positive
- )
+ for table, _, id_column in tables:
+ self._sequence_gen.check_consistency(
+ db_conn, table=table, id_column=id_column, positive=positive
+ )
# This goes and fills out the above state from the database.
- self._load_current_ids(db_conn, table, instance_column, id_column)
+ self._load_current_ids(db_conn, tables)
def _load_current_ids(
- self, db_conn, table: str, instance_column: str, id_column: str
+ self, db_conn, tables: List[Tuple[str, str, str]],
):
cur = db_conn.cursor(txn_name="_load_current_ids")
@@ -306,17 +306,22 @@ class MultiWriterIdGenerator:
# We add a GREATEST here to ensure that the result is always
# positive. (This can be a problem for e.g. backfill streams where
# the server has never backfilled).
- sql = """
- SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
- FROM %(table)s
- """ % {
- "id": id_column,
- "table": table,
- "agg": "MAX" if self._positive else "-MIN",
- }
- cur.execute(sql)
- (stream_id,) = cur.fetchone()
- self._persisted_upto_position = stream_id
+ max_stream_id = 1
+ for table, _, id_column in tables:
+ sql = """
+ SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
+ FROM %(table)s
+ """ % {
+ "id": id_column,
+ "table": table,
+ "agg": "MAX" if self._positive else "-MIN",
+ }
+ cur.execute(sql)
+ (stream_id,) = cur.fetchone()
+
+ max_stream_id = max(max_stream_id, stream_id)
+
+ self._persisted_upto_position = max_stream_id
else:
# If we have a min_stream_id then we pull out everything greater
# than it from the DB so that we can prefill
@@ -329,21 +334,28 @@ class MultiWriterIdGenerator:
# stream positions table before restart (or the stream position
# table otherwise got out of date).
- sql = """
- SELECT %(instance)s, %(id)s FROM %(table)s
- WHERE ? %(cmp)s %(id)s
- """ % {
- "id": id_column,
- "table": table,
- "instance": instance_column,
- "cmp": "<=" if self._positive else ">=",
- }
- cur.execute(sql, (min_stream_id * self._return_factor,))
-
self._persisted_upto_position = min_stream_id
+ rows = []
+ for table, instance_column, id_column in tables:
+ sql = """
+ SELECT %(instance)s, %(id)s FROM %(table)s
+ WHERE ? %(cmp)s %(id)s
+ """ % {
+ "id": id_column,
+ "table": table,
+ "instance": instance_column,
+ "cmp": "<=" if self._positive else ">=",
+ }
+ cur.execute(sql, (min_stream_id * self._return_factor,))
+
+ rows.extend(cur)
+
+ # Sort so that we handle rows in order for each instance.
+ rows.sort()
+
with self._lock:
- for (instance, stream_id,) in cur:
+ for (instance, stream_id,) in rows:
stream_id = self._return_factor * stream_id
self._add_persisted_position(stream_id)
|