diff options
-rw-r--r-- | changelog.d/16432.feature | 1 | ||||
-rw-r--r-- | synapse/config/workers.py | 4 | ||||
-rw-r--r-- | synapse/handlers/appservice.py | 42 | ||||
-rw-r--r-- | synapse/handlers/initial_sync.py | 2 | ||||
-rw-r--r-- | synapse/handlers/receipts.py | 19 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 7 | ||||
-rw-r--r-- | synapse/notifier.py | 45 | ||||
-rw-r--r-- | synapse/replication/tcp/client.py | 3 | ||||
-rw-r--r-- | synapse/storage/databases/main/receipts.py | 148 | ||||
-rw-r--r-- | synapse/storage/databases/main/relations.py | 4 | ||||
-rw-r--r-- | synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite | 17 | ||||
-rw-r--r-- | synapse/streams/events.py | 4 | ||||
-rw-r--r-- | synapse/types/__init__.py | 137 | ||||
-rw-r--r-- | tests/handlers/test_appservice.py | 17 | ||||
-rw-r--r-- | tests/replication/test_sharded_receipts.py | 243 |
15 files changed, 604 insertions, 89 deletions
diff --git a/changelog.d/16432.feature b/changelog.d/16432.feature new file mode 100644 index 0000000000..9a76e85592 --- /dev/null +++ b/changelog.d/16432.feature @@ -0,0 +1 @@ +Allow multiple workers to write to receipts stream. diff --git a/synapse/config/workers.py b/synapse/config/workers.py index f1766088fc..6d67a8cd5c 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -358,9 +358,9 @@ class WorkerConfig(Config): "Must only specify one instance to handle `account_data` messages." ) - if len(self.writers.receipts) != 1: + if len(self.writers.receipts) == 0: raise ConfigError( - "Must only specify one instance to handle `receipts` messages." + "Must specify at least one instance to handle `receipts` messages." ) if len(self.writers.events) == 0: diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index c200a45f3a..873dadc3bd 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -47,6 +47,7 @@ from synapse.types import ( DeviceListUpdates, JsonDict, JsonMapping, + MultiWriterStreamToken, RoomAlias, RoomStreamToken, StreamKeyType, @@ -217,7 +218,7 @@ class ApplicationServicesHandler: def notify_interested_services_ephemeral( self, stream_key: StreamKeyType, - new_token: Union[int, RoomStreamToken], + new_token: Union[int, RoomStreamToken, MultiWriterStreamToken], users: Collection[Union[str, UserID]], ) -> None: """ @@ -259,19 +260,6 @@ class ApplicationServicesHandler: ): return - # Assert that new_token is an integer (and not a RoomStreamToken). - # All of the supported streams that this function handles use an - # integer to track progress (rather than a RoomStreamToken - a - # vector clock implementation) as they don't support multiple - # stream writers. - # - # As a result, we simply assert that new_token is an integer. - # If we do end up needing to pass a RoomStreamToken down here - # in the future, using RoomStreamToken.stream (the minimum stream - # position) to convert to an ascending integer value should work. - # Additional context: https://github.com/matrix-org/synapse/pull/11137 - assert isinstance(new_token, int) - # Ignore to-device messages if the feature flag is not enabled if ( stream_key == StreamKeyType.TO_DEVICE @@ -286,6 +274,9 @@ class ApplicationServicesHandler: ): return + # We know we're not a `RoomStreamToken` at this point. + assert not isinstance(new_token, RoomStreamToken) + # Check whether there are any appservices which have registered to receive # ephemeral events. # @@ -327,7 +318,7 @@ class ApplicationServicesHandler: self, services: List[ApplicationService], stream_key: StreamKeyType, - new_token: int, + new_token: Union[int, MultiWriterStreamToken], users: Collection[Union[str, UserID]], ) -> None: logger.debug("Checking interested services for %s", stream_key) @@ -340,6 +331,7 @@ class ApplicationServicesHandler: # # Instead we simply grab the latest typing updates in _handle_typing # and, if they apply to this application service, send it off. + assert isinstance(new_token, int) events = await self._handle_typing(service, new_token) if events: self.scheduler.enqueue_for_appservice(service, ephemeral=events) @@ -350,15 +342,23 @@ class ApplicationServicesHandler: (service.id, stream_key) ): if stream_key == StreamKeyType.RECEIPT: + assert isinstance(new_token, MultiWriterStreamToken) + + # We store appservice tokens as integers, so we ignore + # the `instance_map` components and instead simply + # follow the base stream position. + new_token = MultiWriterStreamToken(stream=new_token.stream) + events = await self._handle_receipts(service, new_token) self.scheduler.enqueue_for_appservice(service, ephemeral=events) # Persist the latest handled stream token for this appservice await self.store.set_appservice_stream_type_pos( - service, "read_receipt", new_token + service, "read_receipt", new_token.stream ) elif stream_key == StreamKeyType.PRESENCE: + assert isinstance(new_token, int) events = await self._handle_presence(service, users, new_token) self.scheduler.enqueue_for_appservice(service, ephemeral=events) @@ -368,6 +368,7 @@ class ApplicationServicesHandler: ) elif stream_key == StreamKeyType.TO_DEVICE: + assert isinstance(new_token, int) # Retrieve a list of to-device message events, as well as the # maximum stream token of the messages we were able to retrieve. to_device_messages = await self._get_to_device_messages( @@ -383,6 +384,7 @@ class ApplicationServicesHandler: ) elif stream_key == StreamKeyType.DEVICE_LIST: + assert isinstance(new_token, int) device_list_summary = await self._get_device_list_summary( service, new_token ) @@ -432,7 +434,7 @@ class ApplicationServicesHandler: return typing async def _handle_receipts( - self, service: ApplicationService, new_token: int + self, service: ApplicationService, new_token: MultiWriterStreamToken ) -> List[JsonMapping]: """ Return the latest read receipts that the given application service should receive. @@ -455,15 +457,17 @@ class ApplicationServicesHandler: from_key = await self.store.get_type_stream_id_for_appservice( service, "read_receipt" ) - if new_token is not None and new_token <= from_key: + if new_token is not None and new_token.stream <= from_key: logger.debug( "Rejecting token lower than or equal to stored: %s" % (new_token,) ) return [] + from_token = MultiWriterStreamToken(stream=from_key) + receipts_source = self.event_sources.sources.receipt receipts, _ = await receipts_source.get_new_events_as( - service=service, from_key=from_key, to_key=new_token + service=service, from_key=from_token, to_key=new_token ) return receipts diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index c34bd7db95..b1d8be866f 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -145,7 +145,7 @@ class InitialSyncHandler: joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN] receipt = await self.store.get_linearized_receipts_for_rooms( joined_rooms, - to_key=int(now_token.receipt_key), + to_key=now_token.receipt_key, ) receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 69ac468f75..b5f7a8b47e 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -20,6 +20,7 @@ from synapse.streams import EventSource from synapse.types import ( JsonDict, JsonMapping, + MultiWriterStreamToken, ReadReceipt, StreamKeyType, UserID, @@ -200,7 +201,7 @@ class ReceiptsHandler: await self.federation_sender.send_read_receipt(receipt) -class ReceiptEventSource(EventSource[int, JsonMapping]): +class ReceiptEventSource(EventSource[MultiWriterStreamToken, JsonMapping]): def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.config = hs.config @@ -273,13 +274,12 @@ class ReceiptEventSource(EventSource[int, JsonMapping]): async def get_new_events( self, user: UserID, - from_key: int, + from_key: MultiWriterStreamToken, limit: int, room_ids: Iterable[str], is_guest: bool, explicit_room_id: Optional[str] = None, - ) -> Tuple[List[JsonMapping], int]: - from_key = int(from_key) + ) -> Tuple[List[JsonMapping], MultiWriterStreamToken]: to_key = self.get_current_key() if from_key == to_key: @@ -296,8 +296,11 @@ class ReceiptEventSource(EventSource[int, JsonMapping]): return events, to_key async def get_new_events_as( - self, from_key: int, to_key: int, service: ApplicationService - ) -> Tuple[List[JsonMapping], int]: + self, + from_key: MultiWriterStreamToken, + to_key: MultiWriterStreamToken, + service: ApplicationService, + ) -> Tuple[List[JsonMapping], MultiWriterStreamToken]: """Returns a set of new read receipt events that an appservice may be interested in. @@ -312,8 +315,6 @@ class ReceiptEventSource(EventSource[int, JsonMapping]): appservice may be interested in. * The current read receipt stream token. """ - from_key = int(from_key) - if from_key == to_key: return [], to_key @@ -333,5 +334,5 @@ class ReceiptEventSource(EventSource[int, JsonMapping]): return events, to_key - def get_current_key(self) -> int: + def get_current_key(self) -> MultiWriterStreamToken: return self.store.get_max_receipt_stream_id() diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index f131c0e8e0..f75c1548ca 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -57,6 +57,7 @@ from synapse.types import ( DeviceListUpdates, JsonDict, JsonMapping, + MultiWriterStreamToken, MutableStateMap, Requester, RoomStreamToken, @@ -477,7 +478,11 @@ class SyncHandler: event_copy = {k: v for (k, v) in event.items() if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) - receipt_key = since_token.receipt_key if since_token else 0 + receipt_key = ( + since_token.receipt_key + if since_token + else MultiWriterStreamToken(stream=0) + ) receipt_source = self.event_sources.sources.receipt receipts, receipt_key = await receipt_source.get_new_events( diff --git a/synapse/notifier.py b/synapse/notifier.py index 99e7715896..ee0bd84f1e 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -21,11 +21,13 @@ from typing import ( Dict, Iterable, List, + Literal, Optional, Set, Tuple, TypeVar, Union, + overload, ) import attr @@ -44,6 +46,7 @@ from synapse.metrics import LaterGauge from synapse.streams.config import PaginationConfig from synapse.types import ( JsonDict, + MultiWriterStreamToken, PersistedEventPosition, RoomStreamToken, StrCollection, @@ -127,7 +130,7 @@ class _NotifierUserStream: def notify( self, stream_key: StreamKeyType, - stream_id: Union[int, RoomStreamToken], + stream_id: Union[int, RoomStreamToken, MultiWriterStreamToken], time_now_ms: int, ) -> None: """Notify any listeners for this user of a new event from an @@ -452,10 +455,48 @@ class Notifier: except Exception: logger.exception("Error pusher pool of event") + @overload + def on_new_event( + self, + stream_key: Literal[StreamKeyType.ROOM], + new_token: RoomStreamToken, + users: Optional[Collection[Union[str, UserID]]] = None, + rooms: Optional[StrCollection] = None, + ) -> None: + ... + + @overload + def on_new_event( + self, + stream_key: Literal[StreamKeyType.RECEIPT], + new_token: MultiWriterStreamToken, + users: Optional[Collection[Union[str, UserID]]] = None, + rooms: Optional[StrCollection] = None, + ) -> None: + ... + + @overload + def on_new_event( + self, + stream_key: Literal[ + StreamKeyType.ACCOUNT_DATA, + StreamKeyType.DEVICE_LIST, + StreamKeyType.PRESENCE, + StreamKeyType.PUSH_RULES, + StreamKeyType.TO_DEVICE, + StreamKeyType.TYPING, + StreamKeyType.UN_PARTIAL_STATED_ROOMS, + ], + new_token: int, + users: Optional[Collection[Union[str, UserID]]] = None, + rooms: Optional[StrCollection] = None, + ) -> None: + ... + def on_new_event( self, stream_key: StreamKeyType, - new_token: Union[int, RoomStreamToken], + new_token: Union[int, RoomStreamToken, MultiWriterStreamToken], users: Optional[Collection[Union[str, UserID]]] = None, rooms: Optional[StrCollection] = None, ) -> None: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 384355698d..1312b6f21e 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -126,8 +126,9 @@ class ReplicationDataHandler: StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows] ) elif stream_name == ReceiptsStream.NAME: + new_token = self.store.get_max_receipt_stream_id() self.notifier.on_new_event( - StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows] + StreamKeyType.RECEIPT, new_token, rooms=[row.room_id for row in rows] ) await self._pusher_pool.on_new_receipts({row.user_id for row in rows}) elif stream_name == ToDeviceStream.NAME: diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index b2645ab43c..56e8eb16a8 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -28,6 +28,8 @@ from typing import ( cast, ) +from immutabledict import immutabledict + from synapse.api.constants import EduTypes from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -43,7 +45,12 @@ from synapse.storage.util.id_generators import ( MultiWriterIdGenerator, StreamIdGenerator, ) -from synapse.types import JsonDict, JsonMapping +from synapse.types import ( + JsonDict, + JsonMapping, + MultiWriterStreamToken, + PersistedPosition, +) from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -105,7 +112,7 @@ class ReceiptsWorkerStore(SQLBaseStore): "receipts_linearized", entity_column="room_id", stream_column="stream_id", - max_value=max_receipts_stream_id, + max_value=max_receipts_stream_id.stream, limit=10000, ) self._receipts_stream_cache = StreamChangeCache( @@ -114,9 +121,31 @@ class ReceiptsWorkerStore(SQLBaseStore): prefilled_cache=receipts_stream_prefill, ) - def get_max_receipt_stream_id(self) -> int: + def get_max_receipt_stream_id(self) -> MultiWriterStreamToken: """Get the current max stream ID for receipts stream""" - return self._receipts_id_gen.get_current_token() + + min_pos = self._receipts_id_gen.get_current_token() + + positions = {} + if isinstance(self._receipts_id_gen, MultiWriterIdGenerator): + # The `min_pos` is the minimum position that we know all instances + # have finished persisting to, so we only care about instances whose + # positions are ahead of that. (Instance positions can be behind the + # min position as there are times we can work out that the minimum + # position is ahead of the naive minimum across all current + # positions. See MultiWriterIdGenerator for details) + positions = { + i: p + for i, p in self._receipts_id_gen.get_positions().items() + if p > min_pos + } + + return MultiWriterStreamToken( + stream=min_pos, instance_map=immutabledict(positions) + ) + + def get_receipt_stream_id_for_instance(self, instance_name: str) -> int: + return self._receipts_id_gen.get_current_token_for_writer(instance_name) def get_last_unthreaded_receipt_for_user_txn( self, @@ -257,7 +286,10 @@ class ReceiptsWorkerStore(SQLBaseStore): } async def get_linearized_receipts_for_rooms( - self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None + self, + room_ids: Iterable[str], + to_key: MultiWriterStreamToken, + from_key: Optional[MultiWriterStreamToken] = None, ) -> List[JsonMapping]: """Get receipts for multiple rooms for sending to clients. @@ -276,7 +308,7 @@ class ReceiptsWorkerStore(SQLBaseStore): # Only ask the database about rooms where there have been new # receipts added since `from_key` room_ids = self._receipts_stream_cache.get_entities_changed( - room_ids, from_key + room_ids, from_key.stream ) results = await self._get_linearized_receipts_for_rooms( @@ -286,7 +318,10 @@ class ReceiptsWorkerStore(SQLBaseStore): return [ev for res in results.values() for ev in res] async def get_linearized_receipts_for_room( - self, room_id: str, to_key: int, from_key: Optional[int] = None + self, + room_id: str, + to_key: MultiWriterStreamToken, + from_key: Optional[MultiWriterStreamToken] = None, ) -> Sequence[JsonMapping]: """Get receipts for a single room for sending to clients. @@ -302,36 +337,49 @@ class ReceiptsWorkerStore(SQLBaseStore): if from_key is not None: # Check the cache first to see if any new receipts have been added # since`from_key`. If not we can no-op. - if not self._receipts_stream_cache.has_entity_changed(room_id, from_key): + if not self._receipts_stream_cache.has_entity_changed( + room_id, from_key.stream + ): return [] return await self._get_linearized_receipts_for_room(room_id, to_key, from_key) @cached(tree=True) async def _get_linearized_receipts_for_room( - self, room_id: str, to_key: int, from_key: Optional[int] = None + self, + room_id: str, + to_key: MultiWriterStreamToken, + from_key: Optional[MultiWriterStreamToken] = None, ) -> Sequence[JsonMapping]: """See get_linearized_receipts_for_room""" def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]: if from_key: - sql = ( - "SELECT receipt_type, user_id, event_id, data" - " FROM receipts_linearized WHERE" - " room_id = ? AND stream_id > ? AND stream_id <= ?" - ) + sql = """ + SELECT stream_id, instance_name, receipt_type, user_id, event_id, data + FROM receipts_linearized + WHERE room_id = ? AND stream_id > ? AND stream_id <= ? + """ - txn.execute(sql, (room_id, from_key, to_key)) - else: - sql = ( - "SELECT receipt_type, user_id, event_id, data" - " FROM receipts_linearized WHERE" - " room_id = ? AND stream_id <= ?" + txn.execute( + sql, (room_id, from_key.stream, to_key.get_max_stream_pos()) ) + else: + sql = """ + SELECT stream_id, instance_name, receipt_type, user_id, event_id, data + FROM receipts_linearized WHERE + room_id = ? AND stream_id <= ? + """ - txn.execute(sql, (room_id, to_key)) + txn.execute(sql, (room_id, to_key.get_max_stream_pos())) - return cast(List[Tuple[str, str, str, str]], txn.fetchall()) + return [ + (receipt_type, user_id, event_id, data) + for stream_id, instance_name, receipt_type, user_id, event_id, data in txn + if MultiWriterStreamToken.is_stream_position_in_range( + from_key, to_key, instance_name, stream_id + ) + ] rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f) @@ -352,7 +400,10 @@ class ReceiptsWorkerStore(SQLBaseStore): num_args=3, ) async def _get_linearized_receipts_for_rooms( - self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None + self, + room_ids: Collection[str], + to_key: MultiWriterStreamToken, + from_key: Optional[MultiWriterStreamToken] = None, ) -> Mapping[str, Sequence[JsonMapping]]: if not room_ids: return {} @@ -362,7 +413,8 @@ class ReceiptsWorkerStore(SQLBaseStore): ) -> List[Tuple[str, str, str, str, Optional[str], str]]: if from_key: sql = """ - SELECT room_id, receipt_type, user_id, event_id, thread_id, data + SELECT stream_id, instance_name, room_id, receipt_type, + user_id, event_id, thread_id, data FROM receipts_linearized WHERE stream_id > ? AND stream_id <= ? AND """ @@ -370,10 +422,14 @@ class ReceiptsWorkerStore(SQLBaseStore): self.database_engine, "room_id", room_ids ) - txn.execute(sql + clause, [from_key, to_key] + list(args)) + txn.execute( + sql + clause, + [from_key.stream, to_key.get_max_stream_pos()] + list(args), + ) else: sql = """ - SELECT room_id, receipt_type, user_id, event_id, thread_id, data + SELECT stream_id, instance_name, room_id, receipt_type, + user_id, event_id, thread_id, data FROM receipts_linearized WHERE stream_id <= ? AND """ @@ -382,11 +438,15 @@ class ReceiptsWorkerStore(SQLBaseStore): self.database_engine, "room_id", room_ids ) - txn.execute(sql + clause, [to_key] + list(args)) + txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args)) - return cast( - List[Tuple[str, str, str, str, Optional[str], str]], txn.fetchall() - ) + return [ + (room_id, receipt_type, user_id, event_id, thread_id, data) + for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn + if MultiWriterStreamToken.is_stream_position_in_range( + from_key, to_key, instance_name, stream_id + ) + ] txn_results = await self.db_pool.runInteraction( "_get_linearized_receipts_for_rooms", f @@ -420,7 +480,9 @@ class ReceiptsWorkerStore(SQLBaseStore): num_args=2, ) async def get_linearized_receipts_for_all_rooms( - self, to_key: int, from_key: Optional[int] = None + self, + to_key: MultiWriterStreamToken, + from_key: Optional[MultiWriterStreamToken] = None, ) -> Mapping[str, JsonMapping]: """Get receipts for all rooms between two stream_ids, up to a limit of the latest 100 read receipts. @@ -437,25 +499,31 @@ class ReceiptsWorkerStore(SQLBaseStore): def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]: if from_key: sql = """ - SELECT room_id, receipt_type, user_id, event_id, data + SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data FROM receipts_linearized WHERE stream_id > ? AND stream_id <= ? ORDER BY stream_id DESC LIMIT 100 """ - txn.execute(sql, [from_key, to_key]) + txn.execute(sql, [from_key.stream, to_key.get_max_stream_pos()]) else: sql = """ - SELECT room_id, receipt_type, user_id, event_id, data + SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data FROM receipts_linearized WHERE stream_id <= ? ORDER BY stream_id DESC LIMIT 100 """ - txn.execute(sql, [to_key]) + txn.execute(sql, [to_key.get_max_stream_pos()]) - return cast(List[Tuple[str, str, str, str, str]], txn.fetchall()) + return [ + (room_id, receipt_type, user_id, event_id, data) + for stream_id, instance_name, room_id, receipt_type, user_id, event_id, data in txn + if MultiWriterStreamToken.is_stream_position_in_range( + from_key, to_key, instance_name, stream_id + ) + ] txn_results = await self.db_pool.runInteraction( "get_linearized_receipts_for_all_rooms", f @@ -545,10 +613,11 @@ class ReceiptsWorkerStore(SQLBaseStore): SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? + AND instance_name = ? ORDER BY stream_id ASC LIMIT ? """ - txn.execute(sql, (last_id, current_id, limit)) + txn.execute(sql, (last_id, current_id, instance_name, limit)) updates = cast( List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], @@ -695,6 +764,7 @@ class ReceiptsWorkerStore(SQLBaseStore): keyvalues=keyvalues, values={ "stream_id": stream_id, + "instance_name": self._instance_name, "event_id": event_id, "event_stream_ordering": stream_ordering, "data": json_encoder.encode(data), @@ -750,7 +820,7 @@ class ReceiptsWorkerStore(SQLBaseStore): event_ids: List[str], thread_id: Optional[str], data: dict, - ) -> Optional[int]: + ) -> Optional[PersistedPosition]: """Insert a receipt, either from local client or remote server. Automatically does conversion between linearized and graph @@ -812,7 +882,7 @@ class ReceiptsWorkerStore(SQLBaseStore): data, ) - return stream_id + return PersistedPosition(self._instance_name, stream_id) async def _insert_graph_receipt( self, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 7f40e2c446..ce7bfd5146 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -47,7 +47,7 @@ from synapse.storage.databases.main.stream import ( generate_pagination_where_clause, ) from synapse.storage.engines import PostgresEngine -from synapse.types import JsonDict, StreamKeyType, StreamToken +from synapse.types import JsonDict, MultiWriterStreamToken, StreamKeyType, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: @@ -314,7 +314,7 @@ class RelationsWorkerStore(SQLBaseStore): room_key=next_key, presence_key=0, typing_key=0, - receipt_key=0, + receipt_key=MultiWriterStreamToken(stream=0), account_data_key=0, push_rules_key=0, to_device_key=0, diff --git a/synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite b/synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite new file mode 100644 index 0000000000..6c7ad0fd37 --- /dev/null +++ b/synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite @@ -0,0 +1,17 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This already exists on Postgres. +ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT; diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 609a0978a9..d0bb83b184 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -23,7 +23,7 @@ from synapse.handlers.room import RoomEventSource from synapse.handlers.typing import TypingNotificationEventSource from synapse.logging.opentracing import trace from synapse.streams import EventSource -from synapse.types import StreamKeyType, StreamToken +from synapse.types import MultiWriterStreamToken, StreamKeyType, StreamToken if TYPE_CHECKING: from synapse.server import HomeServer @@ -111,7 +111,7 @@ class EventSources: room_key=await self.sources.room.get_current_key_for_room(room_id), presence_key=0, typing_key=0, - receipt_key=0, + receipt_key=MultiWriterStreamToken(stream=0), account_data_key=0, push_rules_key=0, to_device_key=0, diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 09a88c86a7..4c5b26ad93 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -695,6 +695,90 @@ class RoomStreamToken(AbstractMultiWriterStreamToken): return "s%d" % (self.stream,) +@attr.s(frozen=True, slots=True, order=False) +class MultiWriterStreamToken(AbstractMultiWriterStreamToken): + """A basic stream token class for streams that supports multiple writers.""" + + @classmethod + async def parse(cls, store: "DataStore", string: str) -> "MultiWriterStreamToken": + try: + if string[0].isdigit(): + return cls(stream=int(string)) + if string[0] == "m": + parts = string[1:].split("~") + stream = int(parts[0]) + + instance_map = {} + for part in parts[1:]: + key, value = part.split(".") + instance_id = int(key) + pos = int(value) + + instance_name = await store.get_name_from_instance_id(instance_id) + instance_map[instance_name] = pos + + return cls( + stream=stream, + instance_map=immutabledict(instance_map), + ) + except CancelledError: + raise + except Exception: + pass + raise SynapseError(400, "Invalid stream token %r" % (string,)) + + async def to_string(self, store: "DataStore") -> str: + if self.instance_map: + entries = [] + for name, pos in self.instance_map.items(): + if pos <= self.stream: + # Ignore instances who are below the minimum stream position + # (we might know they've advanced without seeing a recent + # write from them). + continue + + instance_id = await store.get_id_for_instance(name) + entries.append(f"{instance_id}.{pos}") + + encoded_map = "~".join(entries) + return f"m{self.stream}~{encoded_map}" + else: + return str(self.stream) + + @staticmethod + def is_stream_position_in_range( + low: Optional["AbstractMultiWriterStreamToken"], + high: Optional["AbstractMultiWriterStreamToken"], + instance_name: Optional[str], + pos: int, + ) -> bool: + """Checks if a given persisted position is between the two given tokens. + + If `instance_name` is None then the row was persisted before multi + writer support. + """ + + if low: + if instance_name: + low_stream = low.instance_map.get(instance_name, low.stream) + else: + low_stream = low.stream + + if pos <= low_stream: + return False + + if high: + if instance_name: + high_stream = high.instance_map.get(instance_name, high.stream) + else: + high_stream = high.stream + + if high_stream < pos: + return False + + return True + + class StreamKeyType(Enum): """Known stream types. @@ -776,7 +860,9 @@ class StreamToken: ) presence_key: int typing_key: int - receipt_key: int + receipt_key: MultiWriterStreamToken = attr.ib( + validator=attr.validators.instance_of(MultiWriterStreamToken) + ) account_data_key: int push_rules_key: int to_device_key: int @@ -799,8 +885,31 @@ class StreamToken: while len(keys) < len(attr.fields(cls)): # i.e. old token from before receipt_key keys.append("0") + + ( + room_key, + presence_key, + typing_key, + receipt_key, + account_data_key, + push_rules_key, + to_device_key, + device_list_key, + groups_key, + un_partial_stated_rooms_key, + ) = keys + return cls( - await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:]) + room_key=await RoomStreamToken.parse(store, room_key), + presence_key=int(presence_key), + typing_key=int(typing_key), + receipt_key=await MultiWriterStreamToken.parse(store, receipt_key), + account_data_key=int(account_data_key), + push_rules_key=int(push_rules_key), + to_device_key=int(to_device_key), + device_list_key=int(device_list_key), + groups_key=int(groups_key), + un_partial_stated_rooms_key=int(un_partial_stated_rooms_key), ) except CancelledError: raise @@ -813,7 +922,7 @@ class StreamToken: await self.room_key.to_string(store), str(self.presence_key), str(self.typing_key), - str(self.receipt_key), + await self.receipt_key.to_string(store), str(self.account_data_key), str(self.push_rules_key), str(self.to_device_key), @@ -841,6 +950,11 @@ class StreamToken: StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value) ) return new_token + elif key == StreamKeyType.RECEIPT: + new_token = self.copy_and_replace( + StreamKeyType.RECEIPT, self.receipt_key.copy_and_advance(new_value) + ) + return new_token new_token = self.copy_and_replace(key, new_value) new_id = new_token.get_field(key) @@ -859,6 +973,10 @@ class StreamToken: ... @overload + def get_field(self, key: Literal[StreamKeyType.RECEIPT]) -> MultiWriterStreamToken: + ... + + @overload def get_field( self, key: Literal[ @@ -866,7 +984,6 @@ class StreamToken: StreamKeyType.DEVICE_LIST, StreamKeyType.PRESENCE, StreamKeyType.PUSH_RULES, - StreamKeyType.RECEIPT, StreamKeyType.TO_DEVICE, StreamKeyType.TYPING, StreamKeyType.UN_PARTIAL_STATED_ROOMS, @@ -875,15 +992,21 @@ class StreamToken: ... @overload - def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]: + def get_field( + self, key: StreamKeyType + ) -> Union[int, RoomStreamToken, MultiWriterStreamToken]: ... - def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]: + def get_field( + self, key: StreamKeyType + ) -> Union[int, RoomStreamToken, MultiWriterStreamToken]: """Returns the stream ID for the given key.""" return getattr(self, key.value) -StreamToken.START = StreamToken(RoomStreamToken(stream=0), 0, 0, 0, 0, 0, 0, 0, 0, 0) +StreamToken.START = StreamToken( + RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0 +) @attr.s(slots=True, frozen=True, auto_attribs=True) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index c888d1ff01..78646cb5dc 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -31,7 +31,12 @@ from synapse.appservice import ( from synapse.handlers.appservice import ApplicationServicesHandler from synapse.rest.client import login, receipts, register, room, sendtodevice from synapse.server import HomeServer -from synapse.types import JsonDict, RoomStreamToken, StreamKeyType +from synapse.types import ( + JsonDict, + MultiWriterStreamToken, + RoomStreamToken, + StreamKeyType, +) from synapse.util import Clock from synapse.util.stringutils import random_string @@ -305,7 +310,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): ) self.handler.notify_interested_services_ephemeral( - StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"] + StreamKeyType.RECEIPT, + MultiWriterStreamToken(stream=580), + ["@fakerecipient:example.com"], ) self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( interested_service, ephemeral=[event] @@ -333,7 +340,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): ) self.handler.notify_interested_services_ephemeral( - StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"] + StreamKeyType.RECEIPT, + MultiWriterStreamToken(stream=580), + ["@fakerecipient:example.com"], ) # This method will be called, but with an empty list of events self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( @@ -636,7 +645,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): self.hs.get_application_service_handler()._notify_interested_services_ephemeral( services=[interested_appservice], stream_key=StreamKeyType.RECEIPT, - new_token=stream_token, + new_token=MultiWriterStreamToken(stream=stream_token), users=[self.exclusive_as_user], ) ) diff --git a/tests/replication/test_sharded_receipts.py b/tests/replication/test_sharded_receipts.py new file mode 100644 index 0000000000..41876b36de --- /dev/null +++ b/tests/replication/test_sharded_receipts.py @@ -0,0 +1,243 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import ReceiptTypes +from synapse.rest import admin +from synapse.rest.client import login, receipts, room, sync +from synapse.server import HomeServer +from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.types import StreamToken +from synapse.util import Clock + +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.server import make_request + +logger = logging.getLogger(__name__) + + +class ReceiptsShardTestCase(BaseMultiWorkerStreamTestCase): + """Checks receipts sharding works""" + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + receipts.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + # Register a user who sends a message that we'll get notified about + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + self.room_creator = self.hs.get_room_creation_handler() + self.store = hs.get_datastores().main + + def default_config(self) -> dict: + conf = super().default_config() + conf["stream_writers"] = {"receipts": ["worker1", "worker2"]} + conf["instance_map"] = { + "main": {"host": "testserv", "port": 8765}, + "worker1": {"host": "testserv", "port": 1001}, + "worker2": {"host": "testserv", "port": 1002}, + } + return conf + + def test_basic(self) -> None: + """Simple test to ensure that receipts can be sent on multiple + workers. + """ + + worker1 = self.make_worker_hs( + "synapse.app.generic_worker", + {"worker_name": "worker1"}, + ) + worker1_site = self._hs_to_site[worker1] + + worker2 = self.make_worker_hs( + "synapse.app.generic_worker", + {"worker_name": "worker2"}, + ) + worker2_site = self._hs_to_site[worker2] + + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Create a room + room_id = self.helper.create_room_as(user_id, tok=access_token) + + # The other user joins + self.helper.join( + room=room_id, user=self.other_user_id, tok=self.other_access_token + ) + + # First user sends a message, the other users sends a receipt. + response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token) + event_id = response["event_id"] + + channel = make_request( + reactor=self.reactor, + site=worker1_site, + method="POST", + path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_id}", + access_token=access_token, + content={}, + ) + self.assertEqual(200, channel.code) + + # Now we do it again using the second worker + response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token) + event_id = response["event_id"] + + channel = make_request( + reactor=self.reactor, + site=worker2_site, + method="POST", + path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_id}", + access_token=access_token, + content={}, + ) + self.assertEqual(200, channel.code) + + def test_vector_clock_token(self) -> None: + """Tests that using a stream token with a vector clock component works + correctly with basic /sync usage. + """ + + worker_hs1 = self.make_worker_hs( + "synapse.app.generic_worker", + {"worker_name": "worker1"}, + ) + worker1_site = self._hs_to_site[worker_hs1] + + worker_hs2 = self.make_worker_hs( + "synapse.app.generic_worker", + {"worker_name": "worker2"}, + ) + worker2_site = self._hs_to_site[worker_hs2] + + sync_hs = self.make_worker_hs( + "synapse.app.generic_worker", + {"worker_name": "sync"}, + ) + sync_hs_site = self._hs_to_site[sync_hs] + + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + store = self.hs.get_datastores().main + + room_id = self.helper.create_room_as(user_id, tok=access_token) + + # The other user joins + self.helper.join( + room=room_id, user=self.other_user_id, tok=self.other_access_token + ) + + response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token) + first_event = response["event_id"] + + # Do an initial sync so that we're up to date. + channel = make_request( + self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token + ) + next_batch = channel.json_body["next_batch"] + + # We now gut wrench into the events stream MultiWriterIdGenerator on + # worker2 to mimic it getting stuck persisting a receipt. This ensures + # that when we send an event on worker1 we end up in a state where + # worker2 events stream position lags that on worker1, resulting in a + # receipts token with a non-empty instance map component. + # + # Worker2's receipts stream position will not advance until we call + # __aexit__ again. + worker_store2 = worker_hs2.get_datastores().main + assert isinstance(worker_store2._receipts_id_gen, MultiWriterIdGenerator) + + actx = worker_store2._receipts_id_gen.get_next() + self.get_success(actx.__aenter__()) + + channel = make_request( + reactor=self.reactor, + site=worker1_site, + method="POST", + path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{first_event}", + access_token=access_token, + content={}, + ) + self.assertEqual(200, channel.code) + + # Assert that the current stream token has an instance map component, as + # we are trying to test vector clock tokens. + receipts_token = store.get_max_receipt_stream_id() + self.assertGreater(len(receipts_token.instance_map), 0) + + # Check that syncing still gets the new receipt, despite the gap in the + # stream IDs. + channel = make_request( + self.reactor, + sync_hs_site, + "GET", + f"/sync?since={next_batch}", + access_token=access_token, + ) + + # We should only see the new event and nothing else + self.assertIn(room_id, channel.json_body["rooms"]["join"]) + + events = channel.json_body["rooms"]["join"][room_id]["ephemeral"]["events"] + self.assertEqual(len(events), 1) + self.assertIn(first_event, events[0]["content"]) + + # Get the next batch and makes sure its a vector clock style token. + vector_clock_token = channel.json_body["next_batch"] + parsed_token = self.get_success( + StreamToken.from_string(store, vector_clock_token) + ) + self.assertGreaterEqual(len(parsed_token.receipt_key.instance_map), 1) + + # Now that we've got a vector clock token we finish the fake persisting + # a receipt we started above. + self.get_success(actx.__aexit__(None, None, None)) + + # Now try and send another receipts to the other worker. + response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token) + second_event = response["event_id"] + + channel = make_request( + reactor=self.reactor, + site=worker2_site, + method="POST", + path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{second_event}", + access_token=access_token, + content={}, + ) + + channel = make_request( + self.reactor, + sync_hs_site, + "GET", + f"/sync?since={vector_clock_token}", + access_token=access_token, + ) + + self.assertIn(room_id, channel.json_body["rooms"]["join"]) + + events = channel.json_body["rooms"]["join"][room_id]["ephemeral"]["events"] + self.assertEqual(len(events), 1) + self.assertIn(second_event, events[0]["content"]) |