diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index f1766088fc..6d67a8cd5c 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -358,9 +358,9 @@ class WorkerConfig(Config):
"Must only specify one instance to handle `account_data` messages."
)
- if len(self.writers.receipts) != 1:
+ if len(self.writers.receipts) == 0:
raise ConfigError(
- "Must only specify one instance to handle `receipts` messages."
+ "Must specify at least one instance to handle `receipts` messages."
)
if len(self.writers.events) == 0:
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index c200a45f3a..873dadc3bd 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -47,6 +47,7 @@ from synapse.types import (
DeviceListUpdates,
JsonDict,
JsonMapping,
+ MultiWriterStreamToken,
RoomAlias,
RoomStreamToken,
StreamKeyType,
@@ -217,7 +218,7 @@ class ApplicationServicesHandler:
def notify_interested_services_ephemeral(
self,
stream_key: StreamKeyType,
- new_token: Union[int, RoomStreamToken],
+ new_token: Union[int, RoomStreamToken, MultiWriterStreamToken],
users: Collection[Union[str, UserID]],
) -> None:
"""
@@ -259,19 +260,6 @@ class ApplicationServicesHandler:
):
return
- # Assert that new_token is an integer (and not a RoomStreamToken).
- # All of the supported streams that this function handles use an
- # integer to track progress (rather than a RoomStreamToken - a
- # vector clock implementation) as they don't support multiple
- # stream writers.
- #
- # As a result, we simply assert that new_token is an integer.
- # If we do end up needing to pass a RoomStreamToken down here
- # in the future, using RoomStreamToken.stream (the minimum stream
- # position) to convert to an ascending integer value should work.
- # Additional context: https://github.com/matrix-org/synapse/pull/11137
- assert isinstance(new_token, int)
-
# Ignore to-device messages if the feature flag is not enabled
if (
stream_key == StreamKeyType.TO_DEVICE
@@ -286,6 +274,9 @@ class ApplicationServicesHandler:
):
return
+ # We know we're not a `RoomStreamToken` at this point.
+ assert not isinstance(new_token, RoomStreamToken)
+
# Check whether there are any appservices which have registered to receive
# ephemeral events.
#
@@ -327,7 +318,7 @@ class ApplicationServicesHandler:
self,
services: List[ApplicationService],
stream_key: StreamKeyType,
- new_token: int,
+ new_token: Union[int, MultiWriterStreamToken],
users: Collection[Union[str, UserID]],
) -> None:
logger.debug("Checking interested services for %s", stream_key)
@@ -340,6 +331,7 @@ class ApplicationServicesHandler:
#
# Instead we simply grab the latest typing updates in _handle_typing
# and, if they apply to this application service, send it off.
+ assert isinstance(new_token, int)
events = await self._handle_typing(service, new_token)
if events:
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
@@ -350,15 +342,23 @@ class ApplicationServicesHandler:
(service.id, stream_key)
):
if stream_key == StreamKeyType.RECEIPT:
+ assert isinstance(new_token, MultiWriterStreamToken)
+
+ # We store appservice tokens as integers, so we ignore
+ # the `instance_map` components and instead simply
+ # follow the base stream position.
+ new_token = MultiWriterStreamToken(stream=new_token.stream)
+
events = await self._handle_receipts(service, new_token)
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
# Persist the latest handled stream token for this appservice
await self.store.set_appservice_stream_type_pos(
- service, "read_receipt", new_token
+ service, "read_receipt", new_token.stream
)
elif stream_key == StreamKeyType.PRESENCE:
+ assert isinstance(new_token, int)
events = await self._handle_presence(service, users, new_token)
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
@@ -368,6 +368,7 @@ class ApplicationServicesHandler:
)
elif stream_key == StreamKeyType.TO_DEVICE:
+ assert isinstance(new_token, int)
# Retrieve a list of to-device message events, as well as the
# maximum stream token of the messages we were able to retrieve.
to_device_messages = await self._get_to_device_messages(
@@ -383,6 +384,7 @@ class ApplicationServicesHandler:
)
elif stream_key == StreamKeyType.DEVICE_LIST:
+ assert isinstance(new_token, int)
device_list_summary = await self._get_device_list_summary(
service, new_token
)
@@ -432,7 +434,7 @@ class ApplicationServicesHandler:
return typing
async def _handle_receipts(
- self, service: ApplicationService, new_token: int
+ self, service: ApplicationService, new_token: MultiWriterStreamToken
) -> List[JsonMapping]:
"""
Return the latest read receipts that the given application service should receive.
@@ -455,15 +457,17 @@ class ApplicationServicesHandler:
from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt"
)
- if new_token is not None and new_token <= from_key:
+ if new_token is not None and new_token.stream <= from_key:
logger.debug(
"Rejecting token lower than or equal to stored: %s" % (new_token,)
)
return []
+ from_token = MultiWriterStreamToken(stream=from_key)
+
receipts_source = self.event_sources.sources.receipt
receipts, _ = await receipts_source.get_new_events_as(
- service=service, from_key=from_key, to_key=new_token
+ service=service, from_key=from_token, to_key=new_token
)
return receipts
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index c34bd7db95..b1d8be866f 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -145,7 +145,7 @@ class InitialSyncHandler:
joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN]
receipt = await self.store.get_linearized_receipts_for_rooms(
joined_rooms,
- to_key=int(now_token.receipt_key),
+ to_key=now_token.receipt_key,
)
receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 69ac468f75..b5f7a8b47e 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -20,6 +20,7 @@ from synapse.streams import EventSource
from synapse.types import (
JsonDict,
JsonMapping,
+ MultiWriterStreamToken,
ReadReceipt,
StreamKeyType,
UserID,
@@ -200,7 +201,7 @@ class ReceiptsHandler:
await self.federation_sender.send_read_receipt(receipt)
-class ReceiptEventSource(EventSource[int, JsonMapping]):
+class ReceiptEventSource(EventSource[MultiWriterStreamToken, JsonMapping]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.config = hs.config
@@ -273,13 +274,12 @@ class ReceiptEventSource(EventSource[int, JsonMapping]):
async def get_new_events(
self,
user: UserID,
- from_key: int,
+ from_key: MultiWriterStreamToken,
limit: int,
room_ids: Iterable[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
- ) -> Tuple[List[JsonMapping], int]:
- from_key = int(from_key)
+ ) -> Tuple[List[JsonMapping], MultiWriterStreamToken]:
to_key = self.get_current_key()
if from_key == to_key:
@@ -296,8 +296,11 @@ class ReceiptEventSource(EventSource[int, JsonMapping]):
return events, to_key
async def get_new_events_as(
- self, from_key: int, to_key: int, service: ApplicationService
- ) -> Tuple[List[JsonMapping], int]:
+ self,
+ from_key: MultiWriterStreamToken,
+ to_key: MultiWriterStreamToken,
+ service: ApplicationService,
+ ) -> Tuple[List[JsonMapping], MultiWriterStreamToken]:
"""Returns a set of new read receipt events that an appservice
may be interested in.
@@ -312,8 +315,6 @@ class ReceiptEventSource(EventSource[int, JsonMapping]):
appservice may be interested in.
* The current read receipt stream token.
"""
- from_key = int(from_key)
-
if from_key == to_key:
return [], to_key
@@ -333,5 +334,5 @@ class ReceiptEventSource(EventSource[int, JsonMapping]):
return events, to_key
- def get_current_key(self) -> int:
+ def get_current_key(self) -> MultiWriterStreamToken:
return self.store.get_max_receipt_stream_id()
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index f131c0e8e0..f75c1548ca 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -57,6 +57,7 @@ from synapse.types import (
DeviceListUpdates,
JsonDict,
JsonMapping,
+ MultiWriterStreamToken,
MutableStateMap,
Requester,
RoomStreamToken,
@@ -477,7 +478,11 @@ class SyncHandler:
event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
- receipt_key = since_token.receipt_key if since_token else 0
+ receipt_key = (
+ since_token.receipt_key
+ if since_token
+ else MultiWriterStreamToken(stream=0)
+ )
receipt_source = self.event_sources.sources.receipt
receipts, receipt_key = await receipt_source.get_new_events(
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 99e7715896..ee0bd84f1e 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -21,11 +21,13 @@ from typing import (
Dict,
Iterable,
List,
+ Literal,
Optional,
Set,
Tuple,
TypeVar,
Union,
+ overload,
)
import attr
@@ -44,6 +46,7 @@ from synapse.metrics import LaterGauge
from synapse.streams.config import PaginationConfig
from synapse.types import (
JsonDict,
+ MultiWriterStreamToken,
PersistedEventPosition,
RoomStreamToken,
StrCollection,
@@ -127,7 +130,7 @@ class _NotifierUserStream:
def notify(
self,
stream_key: StreamKeyType,
- stream_id: Union[int, RoomStreamToken],
+ stream_id: Union[int, RoomStreamToken, MultiWriterStreamToken],
time_now_ms: int,
) -> None:
"""Notify any listeners for this user of a new event from an
@@ -452,10 +455,48 @@ class Notifier:
except Exception:
logger.exception("Error pusher pool of event")
+ @overload
+ def on_new_event(
+ self,
+ stream_key: Literal[StreamKeyType.ROOM],
+ new_token: RoomStreamToken,
+ users: Optional[Collection[Union[str, UserID]]] = None,
+ rooms: Optional[StrCollection] = None,
+ ) -> None:
+ ...
+
+ @overload
+ def on_new_event(
+ self,
+ stream_key: Literal[StreamKeyType.RECEIPT],
+ new_token: MultiWriterStreamToken,
+ users: Optional[Collection[Union[str, UserID]]] = None,
+ rooms: Optional[StrCollection] = None,
+ ) -> None:
+ ...
+
+ @overload
+ def on_new_event(
+ self,
+ stream_key: Literal[
+ StreamKeyType.ACCOUNT_DATA,
+ StreamKeyType.DEVICE_LIST,
+ StreamKeyType.PRESENCE,
+ StreamKeyType.PUSH_RULES,
+ StreamKeyType.TO_DEVICE,
+ StreamKeyType.TYPING,
+ StreamKeyType.UN_PARTIAL_STATED_ROOMS,
+ ],
+ new_token: int,
+ users: Optional[Collection[Union[str, UserID]]] = None,
+ rooms: Optional[StrCollection] = None,
+ ) -> None:
+ ...
+
def on_new_event(
self,
stream_key: StreamKeyType,
- new_token: Union[int, RoomStreamToken],
+ new_token: Union[int, RoomStreamToken, MultiWriterStreamToken],
users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[StrCollection] = None,
) -> None:
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 384355698d..1312b6f21e 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -126,8 +126,9 @@ class ReplicationDataHandler:
StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows]
)
elif stream_name == ReceiptsStream.NAME:
+ new_token = self.store.get_max_receipt_stream_id()
self.notifier.on_new_event(
- StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows]
+ StreamKeyType.RECEIPT, new_token, rooms=[row.room_id for row in rows]
)
await self._pusher_pool.on_new_receipts({row.user_id for row in rows})
elif stream_name == ToDeviceStream.NAME:
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index b2645ab43c..56e8eb16a8 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -28,6 +28,8 @@ from typing import (
cast,
)
+from immutabledict import immutabledict
+
from synapse.api.constants import EduTypes
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -43,7 +45,12 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator,
StreamIdGenerator,
)
-from synapse.types import JsonDict, JsonMapping
+from synapse.types import (
+ JsonDict,
+ JsonMapping,
+ MultiWriterStreamToken,
+ PersistedPosition,
+)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -105,7 +112,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"receipts_linearized",
entity_column="room_id",
stream_column="stream_id",
- max_value=max_receipts_stream_id,
+ max_value=max_receipts_stream_id.stream,
limit=10000,
)
self._receipts_stream_cache = StreamChangeCache(
@@ -114,9 +121,31 @@ class ReceiptsWorkerStore(SQLBaseStore):
prefilled_cache=receipts_stream_prefill,
)
- def get_max_receipt_stream_id(self) -> int:
+ def get_max_receipt_stream_id(self) -> MultiWriterStreamToken:
"""Get the current max stream ID for receipts stream"""
- return self._receipts_id_gen.get_current_token()
+
+ min_pos = self._receipts_id_gen.get_current_token()
+
+ positions = {}
+ if isinstance(self._receipts_id_gen, MultiWriterIdGenerator):
+ # The `min_pos` is the minimum position that we know all instances
+ # have finished persisting to, so we only care about instances whose
+ # positions are ahead of that. (Instance positions can be behind the
+ # min position as there are times we can work out that the minimum
+ # position is ahead of the naive minimum across all current
+ # positions. See MultiWriterIdGenerator for details)
+ positions = {
+ i: p
+ for i, p in self._receipts_id_gen.get_positions().items()
+ if p > min_pos
+ }
+
+ return MultiWriterStreamToken(
+ stream=min_pos, instance_map=immutabledict(positions)
+ )
+
+ def get_receipt_stream_id_for_instance(self, instance_name: str) -> int:
+ return self._receipts_id_gen.get_current_token_for_writer(instance_name)
def get_last_unthreaded_receipt_for_user_txn(
self,
@@ -257,7 +286,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
}
async def get_linearized_receipts_for_rooms(
- self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
+ self,
+ room_ids: Iterable[str],
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> List[JsonMapping]:
"""Get receipts for multiple rooms for sending to clients.
@@ -276,7 +308,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# Only ask the database about rooms where there have been new
# receipts added since `from_key`
room_ids = self._receipts_stream_cache.get_entities_changed(
- room_ids, from_key
+ room_ids, from_key.stream
)
results = await self._get_linearized_receipts_for_rooms(
@@ -286,7 +318,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
return [ev for res in results.values() for ev in res]
async def get_linearized_receipts_for_room(
- self, room_id: str, to_key: int, from_key: Optional[int] = None
+ self,
+ room_id: str,
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Sequence[JsonMapping]:
"""Get receipts for a single room for sending to clients.
@@ -302,36 +337,49 @@ class ReceiptsWorkerStore(SQLBaseStore):
if from_key is not None:
# Check the cache first to see if any new receipts have been added
# since`from_key`. If not we can no-op.
- if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
+ if not self._receipts_stream_cache.has_entity_changed(
+ room_id, from_key.stream
+ ):
return []
return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
@cached(tree=True)
async def _get_linearized_receipts_for_room(
- self, room_id: str, to_key: int, from_key: Optional[int] = None
+ self,
+ room_id: str,
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Sequence[JsonMapping]:
"""See get_linearized_receipts_for_room"""
def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]:
if from_key:
- sql = (
- "SELECT receipt_type, user_id, event_id, data"
- " FROM receipts_linearized WHERE"
- " room_id = ? AND stream_id > ? AND stream_id <= ?"
- )
+ sql = """
+ SELECT stream_id, instance_name, receipt_type, user_id, event_id, data
+ FROM receipts_linearized
+ WHERE room_id = ? AND stream_id > ? AND stream_id <= ?
+ """
- txn.execute(sql, (room_id, from_key, to_key))
- else:
- sql = (
- "SELECT receipt_type, user_id, event_id, data"
- " FROM receipts_linearized WHERE"
- " room_id = ? AND stream_id <= ?"
+ txn.execute(
+ sql, (room_id, from_key.stream, to_key.get_max_stream_pos())
)
+ else:
+ sql = """
+ SELECT stream_id, instance_name, receipt_type, user_id, event_id, data
+ FROM receipts_linearized WHERE
+ room_id = ? AND stream_id <= ?
+ """
- txn.execute(sql, (room_id, to_key))
+ txn.execute(sql, (room_id, to_key.get_max_stream_pos()))
- return cast(List[Tuple[str, str, str, str]], txn.fetchall())
+ return [
+ (receipt_type, user_id, event_id, data)
+ for stream_id, instance_name, receipt_type, user_id, event_id, data in txn
+ if MultiWriterStreamToken.is_stream_position_in_range(
+ from_key, to_key, instance_name, stream_id
+ )
+ ]
rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
@@ -352,7 +400,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
num_args=3,
)
async def _get_linearized_receipts_for_rooms(
- self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
+ self,
+ room_ids: Collection[str],
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Mapping[str, Sequence[JsonMapping]]:
if not room_ids:
return {}
@@ -362,7 +413,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
) -> List[Tuple[str, str, str, str, Optional[str], str]]:
if from_key:
sql = """
- SELECT room_id, receipt_type, user_id, event_id, thread_id, data
+ SELECT stream_id, instance_name, room_id, receipt_type,
+ user_id, event_id, thread_id, data
FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ? AND
"""
@@ -370,10 +422,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
self.database_engine, "room_id", room_ids
)
- txn.execute(sql + clause, [from_key, to_key] + list(args))
+ txn.execute(
+ sql + clause,
+ [from_key.stream, to_key.get_max_stream_pos()] + list(args),
+ )
else:
sql = """
- SELECT room_id, receipt_type, user_id, event_id, thread_id, data
+ SELECT stream_id, instance_name, room_id, receipt_type,
+ user_id, event_id, thread_id, data
FROM receipts_linearized WHERE
stream_id <= ? AND
"""
@@ -382,11 +438,15 @@ class ReceiptsWorkerStore(SQLBaseStore):
self.database_engine, "room_id", room_ids
)
- txn.execute(sql + clause, [to_key] + list(args))
+ txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args))
- return cast(
- List[Tuple[str, str, str, str, Optional[str], str]], txn.fetchall()
- )
+ return [
+ (room_id, receipt_type, user_id, event_id, thread_id, data)
+ for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn
+ if MultiWriterStreamToken.is_stream_position_in_range(
+ from_key, to_key, instance_name, stream_id
+ )
+ ]
txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
@@ -420,7 +480,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
num_args=2,
)
async def get_linearized_receipts_for_all_rooms(
- self, to_key: int, from_key: Optional[int] = None
+ self,
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Mapping[str, JsonMapping]:
"""Get receipts for all rooms between two stream_ids, up
to a limit of the latest 100 read receipts.
@@ -437,25 +499,31 @@ class ReceiptsWorkerStore(SQLBaseStore):
def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:
if from_key:
sql = """
- SELECT room_id, receipt_type, user_id, event_id, data
+ SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
"""
- txn.execute(sql, [from_key, to_key])
+ txn.execute(sql, [from_key.stream, to_key.get_max_stream_pos()])
else:
sql = """
- SELECT room_id, receipt_type, user_id, event_id, data
+ SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized WHERE
stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
"""
- txn.execute(sql, [to_key])
+ txn.execute(sql, [to_key.get_max_stream_pos()])
- return cast(List[Tuple[str, str, str, str, str]], txn.fetchall())
+ return [
+ (room_id, receipt_type, user_id, event_id, data)
+ for stream_id, instance_name, room_id, receipt_type, user_id, event_id, data in txn
+ if MultiWriterStreamToken.is_stream_position_in_range(
+ from_key, to_key, instance_name, stream_id
+ )
+ ]
txn_results = await self.db_pool.runInteraction(
"get_linearized_receipts_for_all_rooms", f
@@ -545,10 +613,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data
FROM receipts_linearized
WHERE ? < stream_id AND stream_id <= ?
+ AND instance_name = ?
ORDER BY stream_id ASC
LIMIT ?
"""
- txn.execute(sql, (last_id, current_id, limit))
+ txn.execute(sql, (last_id, current_id, instance_name, limit))
updates = cast(
List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]],
@@ -695,6 +764,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
keyvalues=keyvalues,
values={
"stream_id": stream_id,
+ "instance_name": self._instance_name,
"event_id": event_id,
"event_stream_ordering": stream_ordering,
"data": json_encoder.encode(data),
@@ -750,7 +820,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
event_ids: List[str],
thread_id: Optional[str],
data: dict,
- ) -> Optional[int]:
+ ) -> Optional[PersistedPosition]:
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
@@ -812,7 +882,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
data,
)
- return stream_id
+ return PersistedPosition(self._instance_name, stream_id)
async def _insert_graph_receipt(
self,
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 7f40e2c446..ce7bfd5146 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -47,7 +47,7 @@ from synapse.storage.databases.main.stream import (
generate_pagination_where_clause,
)
from synapse.storage.engines import PostgresEngine
-from synapse.types import JsonDict, StreamKeyType, StreamToken
+from synapse.types import JsonDict, MultiWriterStreamToken, StreamKeyType, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@@ -314,7 +314,7 @@ class RelationsWorkerStore(SQLBaseStore):
room_key=next_key,
presence_key=0,
typing_key=0,
- receipt_key=0,
+ receipt_key=MultiWriterStreamToken(stream=0),
account_data_key=0,
push_rules_key=0,
to_device_key=0,
diff --git a/synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite b/synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite
new file mode 100644
index 0000000000..6c7ad0fd37
--- /dev/null
+++ b/synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite
@@ -0,0 +1,17 @@
+/* Copyright 2023 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This already exists on Postgres.
+ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT;
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 609a0978a9..d0bb83b184 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -23,7 +23,7 @@ from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource
from synapse.logging.opentracing import trace
from synapse.streams import EventSource
-from synapse.types import StreamKeyType, StreamToken
+from synapse.types import MultiWriterStreamToken, StreamKeyType, StreamToken
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -111,7 +111,7 @@ class EventSources:
room_key=await self.sources.room.get_current_key_for_room(room_id),
presence_key=0,
typing_key=0,
- receipt_key=0,
+ receipt_key=MultiWriterStreamToken(stream=0),
account_data_key=0,
push_rules_key=0,
to_device_key=0,
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 09a88c86a7..4c5b26ad93 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -695,6 +695,90 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
return "s%d" % (self.stream,)
+@attr.s(frozen=True, slots=True, order=False)
+class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
+ """A basic stream token class for streams that supports multiple writers."""
+
+ @classmethod
+ async def parse(cls, store: "DataStore", string: str) -> "MultiWriterStreamToken":
+ try:
+ if string[0].isdigit():
+ return cls(stream=int(string))
+ if string[0] == "m":
+ parts = string[1:].split("~")
+ stream = int(parts[0])
+
+ instance_map = {}
+ for part in parts[1:]:
+ key, value = part.split(".")
+ instance_id = int(key)
+ pos = int(value)
+
+ instance_name = await store.get_name_from_instance_id(instance_id)
+ instance_map[instance_name] = pos
+
+ return cls(
+ stream=stream,
+ instance_map=immutabledict(instance_map),
+ )
+ except CancelledError:
+ raise
+ except Exception:
+ pass
+ raise SynapseError(400, "Invalid stream token %r" % (string,))
+
+ async def to_string(self, store: "DataStore") -> str:
+ if self.instance_map:
+ entries = []
+ for name, pos in self.instance_map.items():
+ if pos <= self.stream:
+ # Ignore instances who are below the minimum stream position
+ # (we might know they've advanced without seeing a recent
+ # write from them).
+ continue
+
+ instance_id = await store.get_id_for_instance(name)
+ entries.append(f"{instance_id}.{pos}")
+
+ encoded_map = "~".join(entries)
+ return f"m{self.stream}~{encoded_map}"
+ else:
+ return str(self.stream)
+
+ @staticmethod
+ def is_stream_position_in_range(
+ low: Optional["AbstractMultiWriterStreamToken"],
+ high: Optional["AbstractMultiWriterStreamToken"],
+ instance_name: Optional[str],
+ pos: int,
+ ) -> bool:
+ """Checks if a given persisted position is between the two given tokens.
+
+ If `instance_name` is None then the row was persisted before multi
+ writer support.
+ """
+
+ if low:
+ if instance_name:
+ low_stream = low.instance_map.get(instance_name, low.stream)
+ else:
+ low_stream = low.stream
+
+ if pos <= low_stream:
+ return False
+
+ if high:
+ if instance_name:
+ high_stream = high.instance_map.get(instance_name, high.stream)
+ else:
+ high_stream = high.stream
+
+ if high_stream < pos:
+ return False
+
+ return True
+
+
class StreamKeyType(Enum):
"""Known stream types.
@@ -776,7 +860,9 @@ class StreamToken:
)
presence_key: int
typing_key: int
- receipt_key: int
+ receipt_key: MultiWriterStreamToken = attr.ib(
+ validator=attr.validators.instance_of(MultiWriterStreamToken)
+ )
account_data_key: int
push_rules_key: int
to_device_key: int
@@ -799,8 +885,31 @@ class StreamToken:
while len(keys) < len(attr.fields(cls)):
# i.e. old token from before receipt_key
keys.append("0")
+
+ (
+ room_key,
+ presence_key,
+ typing_key,
+ receipt_key,
+ account_data_key,
+ push_rules_key,
+ to_device_key,
+ device_list_key,
+ groups_key,
+ un_partial_stated_rooms_key,
+ ) = keys
+
return cls(
- await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:])
+ room_key=await RoomStreamToken.parse(store, room_key),
+ presence_key=int(presence_key),
+ typing_key=int(typing_key),
+ receipt_key=await MultiWriterStreamToken.parse(store, receipt_key),
+ account_data_key=int(account_data_key),
+ push_rules_key=int(push_rules_key),
+ to_device_key=int(to_device_key),
+ device_list_key=int(device_list_key),
+ groups_key=int(groups_key),
+ un_partial_stated_rooms_key=int(un_partial_stated_rooms_key),
)
except CancelledError:
raise
@@ -813,7 +922,7 @@ class StreamToken:
await self.room_key.to_string(store),
str(self.presence_key),
str(self.typing_key),
- str(self.receipt_key),
+ await self.receipt_key.to_string(store),
str(self.account_data_key),
str(self.push_rules_key),
str(self.to_device_key),
@@ -841,6 +950,11 @@ class StreamToken:
StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value)
)
return new_token
+ elif key == StreamKeyType.RECEIPT:
+ new_token = self.copy_and_replace(
+ StreamKeyType.RECEIPT, self.receipt_key.copy_and_advance(new_value)
+ )
+ return new_token
new_token = self.copy_and_replace(key, new_value)
new_id = new_token.get_field(key)
@@ -859,6 +973,10 @@ class StreamToken:
...
@overload
+ def get_field(self, key: Literal[StreamKeyType.RECEIPT]) -> MultiWriterStreamToken:
+ ...
+
+ @overload
def get_field(
self,
key: Literal[
@@ -866,7 +984,6 @@ class StreamToken:
StreamKeyType.DEVICE_LIST,
StreamKeyType.PRESENCE,
StreamKeyType.PUSH_RULES,
- StreamKeyType.RECEIPT,
StreamKeyType.TO_DEVICE,
StreamKeyType.TYPING,
StreamKeyType.UN_PARTIAL_STATED_ROOMS,
@@ -875,15 +992,21 @@ class StreamToken:
...
@overload
- def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
+ def get_field(
+ self, key: StreamKeyType
+ ) -> Union[int, RoomStreamToken, MultiWriterStreamToken]:
...
- def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
+ def get_field(
+ self, key: StreamKeyType
+ ) -> Union[int, RoomStreamToken, MultiWriterStreamToken]:
"""Returns the stream ID for the given key."""
return getattr(self, key.value)
-StreamToken.START = StreamToken(RoomStreamToken(stream=0), 0, 0, 0, 0, 0, 0, 0, 0, 0)
+StreamToken.START = StreamToken(
+ RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0
+)
@attr.s(slots=True, frozen=True, auto_attribs=True)
|