summary refs log tree commit diff
path: root/synapse/types
diff options
context:
space:
mode:
authorErik Johnston <erikj@matrix.org>2023-10-25 16:16:19 +0100
committerGitHub <noreply@github.com>2023-10-25 16:16:19 +0100
commitba47fea5286e084ec70d568aa62eb4820b857c47 (patch)
tree6e2c608feb1ea0c23b2b9cc40d11211cc3a10aa5 /synapse/types
parentFix tests on Twisted trunk. (#16528) (diff)
downloadsynapse-ba47fea5286e084ec70d568aa62eb4820b857c47.tar.xz
Allow multiple workers to write to receipts stream. (#16432)
Fixes #16417
Diffstat (limited to 'synapse/types')
-rw-r--r--synapse/types/__init__.py137
1 files changed, 130 insertions, 7 deletions
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 09a88c86a7..4c5b26ad93 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -695,6 +695,90 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
             return "s%d" % (self.stream,)
 
 
+@attr.s(frozen=True, slots=True, order=False)
+class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
+    """A basic stream token class for streams that supports multiple writers."""
+
+    @classmethod
+    async def parse(cls, store: "DataStore", string: str) -> "MultiWriterStreamToken":
+        try:
+            if string[0].isdigit():
+                return cls(stream=int(string))
+            if string[0] == "m":
+                parts = string[1:].split("~")
+                stream = int(parts[0])
+
+                instance_map = {}
+                for part in parts[1:]:
+                    key, value = part.split(".")
+                    instance_id = int(key)
+                    pos = int(value)
+
+                    instance_name = await store.get_name_from_instance_id(instance_id)
+                    instance_map[instance_name] = pos
+
+                return cls(
+                    stream=stream,
+                    instance_map=immutabledict(instance_map),
+                )
+        except CancelledError:
+            raise
+        except Exception:
+            pass
+        raise SynapseError(400, "Invalid stream token %r" % (string,))
+
+    async def to_string(self, store: "DataStore") -> str:
+        if self.instance_map:
+            entries = []
+            for name, pos in self.instance_map.items():
+                if pos <= self.stream:
+                    # Ignore instances who are below the minimum stream position
+                    # (we might know they've advanced without seeing a recent
+                    # write from them).
+                    continue
+
+                instance_id = await store.get_id_for_instance(name)
+                entries.append(f"{instance_id}.{pos}")
+
+            encoded_map = "~".join(entries)
+            return f"m{self.stream}~{encoded_map}"
+        else:
+            return str(self.stream)
+
+    @staticmethod
+    def is_stream_position_in_range(
+        low: Optional["AbstractMultiWriterStreamToken"],
+        high: Optional["AbstractMultiWriterStreamToken"],
+        instance_name: Optional[str],
+        pos: int,
+    ) -> bool:
+        """Checks if a given persisted position is between the two given tokens.
+
+        If `instance_name` is None then the row was persisted before multi
+        writer support.
+        """
+
+        if low:
+            if instance_name:
+                low_stream = low.instance_map.get(instance_name, low.stream)
+            else:
+                low_stream = low.stream
+
+            if pos <= low_stream:
+                return False
+
+        if high:
+            if instance_name:
+                high_stream = high.instance_map.get(instance_name, high.stream)
+            else:
+                high_stream = high.stream
+
+            if high_stream < pos:
+                return False
+
+        return True
+
+
 class StreamKeyType(Enum):
     """Known stream types.
 
@@ -776,7 +860,9 @@ class StreamToken:
     )
     presence_key: int
     typing_key: int
-    receipt_key: int
+    receipt_key: MultiWriterStreamToken = attr.ib(
+        validator=attr.validators.instance_of(MultiWriterStreamToken)
+    )
     account_data_key: int
     push_rules_key: int
     to_device_key: int
@@ -799,8 +885,31 @@ class StreamToken:
             while len(keys) < len(attr.fields(cls)):
                 # i.e. old token from before receipt_key
                 keys.append("0")
+
+            (
+                room_key,
+                presence_key,
+                typing_key,
+                receipt_key,
+                account_data_key,
+                push_rules_key,
+                to_device_key,
+                device_list_key,
+                groups_key,
+                un_partial_stated_rooms_key,
+            ) = keys
+
             return cls(
-                await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:])
+                room_key=await RoomStreamToken.parse(store, room_key),
+                presence_key=int(presence_key),
+                typing_key=int(typing_key),
+                receipt_key=await MultiWriterStreamToken.parse(store, receipt_key),
+                account_data_key=int(account_data_key),
+                push_rules_key=int(push_rules_key),
+                to_device_key=int(to_device_key),
+                device_list_key=int(device_list_key),
+                groups_key=int(groups_key),
+                un_partial_stated_rooms_key=int(un_partial_stated_rooms_key),
             )
         except CancelledError:
             raise
@@ -813,7 +922,7 @@ class StreamToken:
                 await self.room_key.to_string(store),
                 str(self.presence_key),
                 str(self.typing_key),
-                str(self.receipt_key),
+                await self.receipt_key.to_string(store),
                 str(self.account_data_key),
                 str(self.push_rules_key),
                 str(self.to_device_key),
@@ -841,6 +950,11 @@ class StreamToken:
                 StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value)
             )
             return new_token
+        elif key == StreamKeyType.RECEIPT:
+            new_token = self.copy_and_replace(
+                StreamKeyType.RECEIPT, self.receipt_key.copy_and_advance(new_value)
+            )
+            return new_token
 
         new_token = self.copy_and_replace(key, new_value)
         new_id = new_token.get_field(key)
@@ -859,6 +973,10 @@ class StreamToken:
         ...
 
     @overload
+    def get_field(self, key: Literal[StreamKeyType.RECEIPT]) -> MultiWriterStreamToken:
+        ...
+
+    @overload
     def get_field(
         self,
         key: Literal[
@@ -866,7 +984,6 @@ class StreamToken:
             StreamKeyType.DEVICE_LIST,
             StreamKeyType.PRESENCE,
             StreamKeyType.PUSH_RULES,
-            StreamKeyType.RECEIPT,
             StreamKeyType.TO_DEVICE,
             StreamKeyType.TYPING,
             StreamKeyType.UN_PARTIAL_STATED_ROOMS,
@@ -875,15 +992,21 @@ class StreamToken:
         ...
 
     @overload
-    def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
+    def get_field(
+        self, key: StreamKeyType
+    ) -> Union[int, RoomStreamToken, MultiWriterStreamToken]:
         ...
 
-    def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
+    def get_field(
+        self, key: StreamKeyType
+    ) -> Union[int, RoomStreamToken, MultiWriterStreamToken]:
         """Returns the stream ID for the given key."""
         return getattr(self, key.value)
 
 
-StreamToken.START = StreamToken(RoomStreamToken(stream=0), 0, 0, 0, 0, 0, 0, 0, 0, 0)
+StreamToken.START = StreamToken(
+    RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0
+)
 
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)