summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-09-30 20:29:19 +0100
committerGitHub <noreply@github.com>2020-09-30 20:29:19 +0100
commit7941372ec84786f85ae6d75fd2d7a4af5b72ac98 (patch)
tree7871841ee56daa554a283a8b2d409b8739047128
parentMerge pull request #8425 from matrix-org/rav/extremity_metrics (diff)
downloadsynapse-7941372ec84786f85ae6d75fd2d7a4af5b72ac98.tar.xz
Make token serializing/deserializing async (#8427)
The idea is that in future tokens will encode a mapping of instance to position. However, we don't want to include the full instance name in the string representation, so instead we'll have a mapping between instance name and an immutable integer ID in the DB that we can use instead. We'll then do the lookup when we serialize/deserialize the token (we could alternatively pass around an `Instance` type that includes both the name and ID, but that turns out to be a lot more invasive).
-rw-r--r--changelog.d/8427.misc1
-rw-r--r--synapse/handlers/events.py4
-rw-r--r--synapse/handlers/initial_sync.py14
-rw-r--r--synapse/handlers/pagination.py8
-rw-r--r--synapse/handlers/room.py8
-rw-r--r--synapse/handlers/search.py8
-rw-r--r--synapse/rest/admin/__init__.py2
-rw-r--r--synapse/rest/client/v1/events.py3
-rw-r--r--synapse/rest/client/v1/initial_sync.py3
-rw-r--r--synapse/rest/client/v1/room.py11
-rw-r--r--synapse/rest/client/v2_alpha/keys.py3
-rw-r--r--synapse/rest/client/v2_alpha/sync.py10
-rw-r--r--synapse/storage/databases/main/purge_events.py8
-rw-r--r--synapse/streams/config.py9
-rw-r--r--synapse/types.py43
-rw-r--r--tests/rest/client/v1/test_rooms.py30
-rw-r--r--tests/storage/test_purge.py9
17 files changed, 115 insertions, 59 deletions
diff --git a/changelog.d/8427.misc b/changelog.d/8427.misc
new file mode 100644
index 0000000000..c9656b9112
--- /dev/null
+++ b/changelog.d/8427.misc
@@ -0,0 +1 @@
+Make stream token serializing/deserializing async.
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 0875b74ea8..539b4fc32e 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -133,8 +133,8 @@ class EventStreamHandler(BaseHandler):
 
             chunk = {
                 "chunk": chunks,
-                "start": tokens[0].to_string(),
-                "end": tokens[1].to_string(),
+                "start": await tokens[0].to_string(self.store),
+                "end": await tokens[1].to_string(self.store),
             }
 
             return chunk
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 43f15435de..39a85801c1 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -203,8 +203,8 @@ class InitialSyncHandler(BaseHandler):
                             messages, time_now=time_now, as_client_event=as_client_event
                         )
                     ),
-                    "start": start_token.to_string(),
-                    "end": end_token.to_string(),
+                    "start": await start_token.to_string(self.store),
+                    "end": await end_token.to_string(self.store),
                 }
 
                 d["state"] = await self._event_serializer.serialize_events(
@@ -249,7 +249,7 @@ class InitialSyncHandler(BaseHandler):
             ],
             "account_data": account_data_events,
             "receipts": receipt,
-            "end": now_token.to_string(),
+            "end": await now_token.to_string(self.store),
         }
 
         return ret
@@ -348,8 +348,8 @@ class InitialSyncHandler(BaseHandler):
                 "chunk": (
                     await self._event_serializer.serialize_events(messages, time_now)
                 ),
-                "start": start_token.to_string(),
-                "end": end_token.to_string(),
+                "start": await start_token.to_string(self.store),
+                "end": await end_token.to_string(self.store),
             },
             "state": (
                 await self._event_serializer.serialize_events(
@@ -447,8 +447,8 @@ class InitialSyncHandler(BaseHandler):
                 "chunk": (
                     await self._event_serializer.serialize_events(messages, time_now)
                 ),
-                "start": start_token.to_string(),
-                "end": end_token.to_string(),
+                "start": await start_token.to_string(self.store),
+                "end": await end_token.to_string(self.store),
             },
             "state": state,
             "presence": presence,
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index d6779a4b44..2c2a633938 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -413,8 +413,8 @@ class PaginationHandler:
         if not events:
             return {
                 "chunk": [],
-                "start": from_token.to_string(),
-                "end": next_token.to_string(),
+                "start": await from_token.to_string(self.store),
+                "end": await next_token.to_string(self.store),
             }
 
         state = None
@@ -442,8 +442,8 @@ class PaginationHandler:
                     events, time_now, as_client_event=as_client_event
                 )
             ),
-            "start": from_token.to_string(),
-            "end": next_token.to_string(),
+            "start": await from_token.to_string(self.store),
+            "end": await next_token.to_string(self.store),
         }
 
         if state:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 836b3f381a..d5f7c78edf 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1077,11 +1077,13 @@ class RoomContextHandler:
         # the token, which we replace.
         token = StreamToken.START
 
-        results["start"] = token.copy_and_replace(
+        results["start"] = await token.copy_and_replace(
             "room_key", results["start"]
-        ).to_string()
+        ).to_string(self.store)
 
-        results["end"] = token.copy_and_replace("room_key", results["end"]).to_string()
+        results["end"] = await token.copy_and_replace(
+            "room_key", results["end"]
+        ).to_string(self.store)
 
         return results
 
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 6a76c20d79..e9402e6e2e 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -362,13 +362,13 @@ class SearchHandler(BaseHandler):
                     self.storage, user.to_string(), res["events_after"]
                 )
 
-                res["start"] = now_token.copy_and_replace(
+                res["start"] = await now_token.copy_and_replace(
                     "room_key", res["start"]
-                ).to_string()
+                ).to_string(self.store)
 
-                res["end"] = now_token.copy_and_replace(
+                res["end"] = await now_token.copy_and_replace(
                     "room_key", res["end"]
-                ).to_string()
+                ).to_string(self.store)
 
                 if include_profile:
                     senders = {
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index ba53f66f02..57cac22252 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -110,7 +110,7 @@ class PurgeHistoryRestServlet(RestServlet):
                 raise SynapseError(400, "Event is for wrong room.")
 
             room_token = await self.store.get_topological_token_for_event(event_id)
-            token = str(room_token)
+            token = await room_token.to_string(self.store)
 
             logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
         elif "purge_up_to_ts" in body:
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index 985d994f6b..1ecb77aa26 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -33,6 +33,7 @@ class EventStreamRestServlet(RestServlet):
         super().__init__()
         self.event_stream_handler = hs.get_event_stream_handler()
         self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
 
     async def on_GET(self, request):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
@@ -44,7 +45,7 @@ class EventStreamRestServlet(RestServlet):
         if b"room_id" in request.args:
             room_id = request.args[b"room_id"][0].decode("ascii")
 
-        pagin_config = PaginationConfig.from_request(request)
+        pagin_config = await PaginationConfig.from_request(self.store, request)
         timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
         if b"timeout" in request.args:
             try:
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index d7042786ce..91da0ee573 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -27,11 +27,12 @@ class InitialSyncRestServlet(RestServlet):
         super().__init__()
         self.initial_sync_handler = hs.get_initial_sync_handler()
         self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
 
     async def on_GET(self, request):
         requester = await self.auth.get_user_by_req(request)
         as_client_event = b"raw" not in request.args
-        pagination_config = PaginationConfig.from_request(request)
+        pagination_config = await PaginationConfig.from_request(self.store, request)
         include_archived = parse_boolean(request, "archived", default=False)
         content = await self.initial_sync_handler.snapshot_all_rooms(
             user_id=requester.user.to_string(),
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 7e64a2e0fe..b63389e5fe 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -451,6 +451,7 @@ class RoomMemberListRestServlet(RestServlet):
         super().__init__()
         self.message_handler = hs.get_message_handler()
         self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
 
     async def on_GET(self, request, room_id):
         # TODO support Pagination stream API (limit/tokens)
@@ -465,7 +466,7 @@ class RoomMemberListRestServlet(RestServlet):
         if at_token_string is None:
             at_token = None
         else:
-            at_token = StreamToken.from_string(at_token_string)
+            at_token = await StreamToken.from_string(self.store, at_token_string)
 
         # let you filter down on particular memberships.
         # XXX: this may not be the best shape for this API - we could pass in a filter
@@ -521,10 +522,13 @@ class RoomMessageListRestServlet(RestServlet):
         super().__init__()
         self.pagination_handler = hs.get_pagination_handler()
         self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
 
     async def on_GET(self, request, room_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
-        pagination_config = PaginationConfig.from_request(request, default_limit=10)
+        pagination_config = await PaginationConfig.from_request(
+            self.store, request, default_limit=10
+        )
         as_client_event = b"raw" not in request.args
         filter_str = parse_string(request, b"filter", encoding="utf-8")
         if filter_str:
@@ -580,10 +584,11 @@ class RoomInitialSyncRestServlet(RestServlet):
         super().__init__()
         self.initial_sync_handler = hs.get_initial_sync_handler()
         self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
 
     async def on_GET(self, request, room_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
-        pagination_config = PaginationConfig.from_request(request)
+        pagination_config = await PaginationConfig.from_request(self.store, request)
         content = await self.initial_sync_handler.room_initial_sync(
             room_id=room_id, requester=requester, pagin_config=pagination_config
         )
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 7abd6ff333..55c4606569 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -180,6 +180,7 @@ class KeyChangesServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.device_handler = hs.get_device_handler()
+        self.store = hs.get_datastore()
 
     async def on_GET(self, request):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
@@ -191,7 +192,7 @@ class KeyChangesServlet(RestServlet):
         # changes after the "to" as well as before.
         set_tag("to", parse_string(request, "to"))
 
-        from_token = StreamToken.from_string(from_token_string)
+        from_token = await StreamToken.from_string(self.store, from_token_string)
 
         user_id = requester.user.to_string()
 
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 51e395cc64..6779df952f 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -77,6 +77,7 @@ class SyncRestServlet(RestServlet):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
         self.sync_handler = hs.get_sync_handler()
         self.clock = hs.get_clock()
         self.filtering = hs.get_filtering()
@@ -151,10 +152,9 @@ class SyncRestServlet(RestServlet):
             device_id=device_id,
         )
 
+        since_token = None
         if since is not None:
-            since_token = StreamToken.from_string(since)
-        else:
-            since_token = None
+            since_token = await StreamToken.from_string(self.store, since)
 
         # send any outstanding server notices to the user.
         await self._server_notices_sender.on_user_syncing(user.to_string())
@@ -236,7 +236,7 @@ class SyncRestServlet(RestServlet):
                 "leave": sync_result.groups.leave,
             },
             "device_one_time_keys_count": sync_result.device_one_time_keys_count,
-            "next_batch": sync_result.next_batch.to_string(),
+            "next_batch": await sync_result.next_batch.to_string(self.store),
         }
 
     @staticmethod
@@ -413,7 +413,7 @@ class SyncRestServlet(RestServlet):
         result = {
             "timeline": {
                 "events": serialized_timeline,
-                "prev_batch": room.timeline.prev_batch.to_string(),
+                "prev_batch": await room.timeline.prev_batch.to_string(self.store),
                 "limited": room.timeline.limited,
             },
             "state": {"events": serialized_state},
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index d7a03cbf7d..ecfc6717b3 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -42,17 +42,17 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
             The set of state groups that are referenced by deleted events.
         """
 
+        parsed_token = await RoomStreamToken.parse(self, token)
+
         return await self.db_pool.runInteraction(
             "purge_history",
             self._purge_history_txn,
             room_id,
-            token,
+            parsed_token,
             delete_local_events,
         )
 
-    def _purge_history_txn(self, txn, room_id, token_str, delete_local_events):
-        token = RoomStreamToken.parse(token_str)
-
+    def _purge_history_txn(self, txn, room_id, token, delete_local_events):
         # Tables that should be pruned:
         #     event_auth
         #     event_backward_extremities
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 0bdf846edf..fdda21d165 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -12,7 +12,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from typing import Optional
 
@@ -21,6 +20,7 @@ import attr
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import parse_integer, parse_string
 from synapse.http.site import SynapseRequest
+from synapse.storage.databases.main import DataStore
 from synapse.types import StreamToken
 
 logger = logging.getLogger(__name__)
@@ -39,8 +39,9 @@ class PaginationConfig:
     limit = attr.ib(type=Optional[int])
 
     @classmethod
-    def from_request(
+    async def from_request(
         cls,
+        store: "DataStore",
         request: SynapseRequest,
         raise_invalid_params: bool = True,
         default_limit: Optional[int] = None,
@@ -54,13 +55,13 @@ class PaginationConfig:
             if from_tok == "END":
                 from_tok = None  # For backwards compat.
             elif from_tok:
-                from_tok = StreamToken.from_string(from_tok)
+                from_tok = await StreamToken.from_string(store, from_tok)
         except Exception:
             raise SynapseError(400, "'from' parameter is invalid")
 
         try:
             if to_tok:
-                to_tok = StreamToken.from_string(to_tok)
+                to_tok = await StreamToken.from_string(store, to_tok)
         except Exception:
             raise SynapseError(400, "'to' parameter is invalid")
 
diff --git a/synapse/types.py b/synapse/types.py
index 02bcc197ec..bd271f9f16 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -18,7 +18,17 @@ import re
 import string
 import sys
 from collections import namedtuple
-from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Mapping,
+    MutableMapping,
+    Optional,
+    Tuple,
+    Type,
+    TypeVar,
+)
 
 import attr
 from signedjson.key import decode_verify_key_bytes
@@ -26,6 +36,9 @@ from unpaddedbase64 import decode_base64
 
 from synapse.api.errors import Codes, SynapseError
 
+if TYPE_CHECKING:
+    from synapse.storage.databases.main import DataStore
+
 # define a version of typing.Collection that works on python 3.5
 if sys.version_info[:3] >= (3, 6, 0):
     from typing import Collection
@@ -393,7 +406,7 @@ class RoomStreamToken:
     stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
 
     @classmethod
-    def parse(cls, string: str) -> "RoomStreamToken":
+    async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
         try:
             if string[0] == "s":
                 return cls(topological=None, stream=int(string[1:]))
@@ -428,7 +441,7 @@ class RoomStreamToken:
     def as_tuple(self) -> Tuple[Optional[int], int]:
         return (self.topological, self.stream)
 
-    def __str__(self) -> str:
+    async def to_string(self, store: "DataStore") -> str:
         if self.topological is not None:
             return "t%d-%d" % (self.topological, self.stream)
         else:
@@ -453,18 +466,32 @@ class StreamToken:
     START = None  # type: StreamToken
 
     @classmethod
-    def from_string(cls, string):
+    async def from_string(cls, store: "DataStore", string: str) -> "StreamToken":
         try:
             keys = string.split(cls._SEPARATOR)
             while len(keys) < len(attr.fields(cls)):
                 # i.e. old token from before receipt_key
                 keys.append("0")
-            return cls(RoomStreamToken.parse(keys[0]), *(int(k) for k in keys[1:]))
+            return cls(
+                await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:])
+            )
         except Exception:
             raise SynapseError(400, "Invalid Token")
 
-    def to_string(self):
-        return self._SEPARATOR.join([str(k) for k in attr.astuple(self, recurse=False)])
+    async def to_string(self, store: "DataStore") -> str:
+        return self._SEPARATOR.join(
+            [
+                await self.room_key.to_string(store),
+                str(self.presence_key),
+                str(self.typing_key),
+                str(self.receipt_key),
+                str(self.account_data_key),
+                str(self.push_rules_key),
+                str(self.to_device_key),
+                str(self.device_list_key),
+                str(self.groups_key),
+            ]
+        )
 
     @property
     def room_stream_id(self):
@@ -493,7 +520,7 @@ class StreamToken:
         return attr.evolve(self, **{key: new_value})
 
 
-StreamToken.START = StreamToken.from_string("s0_0")
+StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0)
 
 
 @attr.s(slots=True, frozen=True)
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index a3287011e9..0d809d25d5 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -902,16 +902,18 @@ class RoomMessageListTestCase(RoomBase):
 
         # Send a first message in the room, which will be removed by the purge.
         first_event_id = self.helper.send(self.room_id, "message 1")["event_id"]
-        first_token = str(
-            self.get_success(store.get_topological_token_for_event(first_event_id))
+        first_token = self.get_success(
+            store.get_topological_token_for_event(first_event_id)
         )
+        first_token_str = self.get_success(first_token.to_string(store))
 
         # Send a second message in the room, which won't be removed, and which we'll
         # use as the marker to purge events before.
         second_event_id = self.helper.send(self.room_id, "message 2")["event_id"]
-        second_token = str(
-            self.get_success(store.get_topological_token_for_event(second_event_id))
+        second_token = self.get_success(
+            store.get_topological_token_for_event(second_event_id)
         )
+        second_token_str = self.get_success(second_token.to_string(store))
 
         # Send a third event in the room to ensure we don't fall under any edge case
         # due to our marker being the latest forward extremity in the room.
@@ -921,7 +923,11 @@ class RoomMessageListTestCase(RoomBase):
         request, channel = self.make_request(
             "GET",
             "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
-            % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
+            % (
+                self.room_id,
+                second_token_str,
+                json.dumps({"types": [EventTypes.Message]}),
+            ),
         )
         self.render(request)
         self.assertEqual(channel.code, 200, channel.json_body)
@@ -936,7 +942,7 @@ class RoomMessageListTestCase(RoomBase):
             pagination_handler._purge_history(
                 purge_id=purge_id,
                 room_id=self.room_id,
-                token=second_token,
+                token=second_token_str,
                 delete_local_events=True,
             )
         )
@@ -946,7 +952,11 @@ class RoomMessageListTestCase(RoomBase):
         request, channel = self.make_request(
             "GET",
             "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
-            % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
+            % (
+                self.room_id,
+                second_token_str,
+                json.dumps({"types": [EventTypes.Message]}),
+            ),
         )
         self.render(request)
         self.assertEqual(channel.code, 200, channel.json_body)
@@ -960,7 +970,11 @@ class RoomMessageListTestCase(RoomBase):
         request, channel = self.make_request(
             "GET",
             "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
-            % (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})),
+            % (
+                self.room_id,
+                first_token_str,
+                json.dumps({"types": [EventTypes.Message]}),
+            ),
         )
         self.render(request)
         self.assertEqual(channel.code, 200, channel.json_body)
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 723cd28933..cc1f3c53c5 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -47,12 +47,15 @@ class PurgeTests(HomeserverTestCase):
         storage = self.hs.get_storage()
 
         # Get the topological token
-        event = str(
-            self.get_success(store.get_topological_token_for_event(last["event_id"]))
+        token = self.get_success(
+            store.get_topological_token_for_event(last["event_id"])
         )
+        token_str = self.get_success(token.to_string(self.hs.get_datastore()))
 
         # Purge everything before this topological token
-        self.get_success(storage.purge_events.purge_history(self.room_id, event, True))
+        self.get_success(
+            storage.purge_events.purge_history(self.room_id, token_str, True)
+        )
 
         # 1-3 should fail and last will succeed, meaning that 1-3 are deleted
         # and last is not.