summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/17452.misc1
-rw-r--r--synapse/handlers/sliding_sync.py30
-rw-r--r--synapse/rest/client/sync.py6
-rw-r--r--synapse/types/__init__.py43
-rw-r--r--synapse/types/handlers/__init__.py13
-rw-r--r--tests/rest/client/test_sync.py416
6 files changed, 301 insertions, 208 deletions
diff --git a/changelog.d/17452.misc b/changelog.d/17452.misc
new file mode 100644
index 0000000000..4fd07f617b
--- /dev/null
+++ b/changelog.d/17452.misc
@@ -0,0 +1 @@
+Change sliding sync to use their own token format in preparation for storing per-connection state.
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index 554ab59bf3..36665db8e1 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -49,6 +49,7 @@ from synapse.types import (
     PersistedEventPosition,
     Requester,
     RoomStreamToken,
+    SlidingSyncStreamToken,
     StateMap,
     StreamKeyType,
     StreamToken,
@@ -362,7 +363,7 @@ class SlidingSyncHandler:
         self,
         requester: Requester,
         sync_config: SlidingSyncConfig,
-        from_token: Optional[StreamToken] = None,
+        from_token: Optional[SlidingSyncStreamToken] = None,
         timeout_ms: int = 0,
     ) -> SlidingSyncResult:
         """
@@ -393,7 +394,7 @@ class SlidingSyncHandler:
             # this returns false, it means we timed out waiting, and we should
             # just return an empty response.
             before_wait_ts = self.clock.time_msec()
-            if not await self.notifier.wait_for_stream_token(from_token):
+            if not await self.notifier.wait_for_stream_token(from_token.stream_token):
                 logger.warning(
                     "Timed out waiting for worker to catch up. Returning empty response"
                 )
@@ -431,7 +432,7 @@ class SlidingSyncHandler:
                 sync_config.user.to_string(),
                 timeout_ms,
                 current_sync_callback,
-                from_token=from_token,
+                from_token=from_token.stream_token,
             )
 
         return result
@@ -440,7 +441,7 @@ class SlidingSyncHandler:
         self,
         sync_config: SlidingSyncConfig,
         to_token: StreamToken,
-        from_token: Optional[StreamToken] = None,
+        from_token: Optional[SlidingSyncStreamToken] = None,
     ) -> SlidingSyncResult:
         """
         Generates the response body of a Sliding Sync result, represented as a
@@ -473,7 +474,7 @@ class SlidingSyncHandler:
                 await self.get_room_membership_for_user_at_to_token(
                     user=sync_config.user,
                     to_token=to_token,
-                    from_token=from_token,
+                    from_token=from_token.stream_token if from_token else None,
                 )
             )
 
@@ -631,8 +632,11 @@ class SlidingSyncHandler:
             to_token=to_token,
         )
 
+        # TODO: Update this when we implement per-connection state
+        connection_token = 0
+
         return SlidingSyncResult(
-            next_pos=to_token,
+            next_pos=SlidingSyncStreamToken(to_token, connection_token),
             lists=lists,
             rooms=rooms,
             extensions=extensions,
@@ -1367,7 +1371,7 @@ class SlidingSyncHandler:
         room_id: str,
         room_sync_config: RoomSyncConfig,
         room_membership_for_user_at_to_token: _RoomMembershipForUser,
-        from_token: Optional[StreamToken],
+        from_token: Optional[SlidingSyncStreamToken],
         to_token: StreamToken,
     ) -> SlidingSyncResult.RoomResult:
         """
@@ -1431,7 +1435,7 @@ class SlidingSyncHandler:
             #  - TODO: For an incremental sync where we haven't sent it down this
             #    connection before
             to_bound = (
-                from_token.room_key
+                from_token.stream_token.room_key
                 if from_token is not None
                 and not room_membership_for_user_at_to_token.newly_joined
                 else None
@@ -1498,7 +1502,9 @@ class SlidingSyncHandler:
                         instance_name=timeline_event.internal_metadata.instance_name,
                         stream=timeline_event.internal_metadata.stream_ordering,
                     )
-                    if persisted_position.persisted_after(from_token.room_key):
+                    if persisted_position.persisted_after(
+                        from_token.stream_token.room_key
+                    ):
                         num_live += 1
                     else:
                         # Since we're iterating over the timeline events in
@@ -1786,7 +1792,7 @@ class SlidingSyncHandler:
         self,
         sync_config: SlidingSyncConfig,
         to_token: StreamToken,
-        from_token: Optional[StreamToken],
+        from_token: Optional[SlidingSyncStreamToken],
     ) -> SlidingSyncResult.Extensions:
         """Handle extension requests.
 
@@ -1900,7 +1906,7 @@ class SlidingSyncHandler:
         sync_config: SlidingSyncConfig,
         e2ee_request: SlidingSyncConfig.Extensions.E2eeExtension,
         to_token: StreamToken,
-        from_token: Optional[StreamToken],
+        from_token: Optional[SlidingSyncStreamToken],
     ) -> Optional[SlidingSyncResult.Extensions.E2eeExtension]:
         """Handle E2EE device extension (MSC3884)
 
@@ -1922,7 +1928,7 @@ class SlidingSyncHandler:
             # TODO: This should take into account the `from_token` and `to_token`
             device_list_updates = await self.device_handler.get_user_ids_changed(
                 user_id=user_id,
-                from_token=from_token,
+                from_token=from_token.stream_token,
             )
 
         device_one_time_keys_count: Mapping[str, int] = {}
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 93fe1d439e..d72dfa2b10 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -54,7 +54,7 @@ from synapse.http.servlet import (
 from synapse.http.site import SynapseRequest
 from synapse.logging.opentracing import trace_with_opname
 from synapse.rest.admin.experimental_features import ExperimentalFeature
-from synapse.types import JsonDict, Requester, StreamToken
+from synapse.types import JsonDict, Requester, SlidingSyncStreamToken, StreamToken
 from synapse.types.rest.client import SlidingSyncBody
 from synapse.util import json_decoder
 from synapse.util.caches.lrucache import LruCache
@@ -889,7 +889,9 @@ class SlidingSyncRestServlet(RestServlet):
 
         from_token = None
         if from_token_string is not None:
-            from_token = await StreamToken.from_string(self.store, from_token_string)
+            from_token = await SlidingSyncStreamToken.from_string(
+                self.store, from_token_string
+            )
 
         # TODO: We currently don't know whether we're going to use sticky params or
         # maybe some filters like sync v2  where they are built up once and referenced
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index c0d30ac2a3..5259550f1c 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -1161,6 +1161,49 @@ StreamToken.START = StreamToken(
 
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
+class SlidingSyncStreamToken:
+    """The same as a `StreamToken`, but includes an extra field at the start for
+    the sliding sync connection token (separated by a '/'). This is used to
+    store per-connection state.
+
+    This then looks something like:
+        5/s2633508_17_338_6732159_1082514_541479_274711_265584_1_379
+
+    Attributes:
+        stream_token: Token representing the position of all the standard
+            streams.
+        connection_position: Token used by sliding sync to track updates to any
+            per-connection state stored by Synapse.
+    """
+
+    stream_token: StreamToken
+    connection_position: int
+
+    @staticmethod
+    @cancellable
+    async def from_string(store: "DataStore", string: str) -> "SlidingSyncStreamToken":
+        """Creates a SlidingSyncStreamToken from its textual representation."""
+        try:
+            connection_position_str, stream_token_str = string.split("/", 1)
+            connection_position = int(connection_position_str)
+            stream_token = await StreamToken.from_string(store, stream_token_str)
+
+            return SlidingSyncStreamToken(
+                stream_token=stream_token,
+                connection_position=connection_position,
+            )
+        except CancelledError:
+            raise
+        except Exception:
+            raise SynapseError(400, "Invalid stream token")
+
+    async def to_string(self, store: "DataStore") -> str:
+        """Serializes the token to a string"""
+        stream_token_str = await self.stream_token.to_string(store)
+        return f"{self.connection_position}/{stream_token_str}"
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class PersistedPosition:
     """Position of a newly persisted row with instance that persisted it."""
 
diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py
index 4c6c42db04..59eb0963ee 100644
--- a/synapse/types/handlers/__init__.py
+++ b/synapse/types/handlers/__init__.py
@@ -31,7 +31,14 @@ else:
     from pydantic import Extra
 
 from synapse.events import EventBase
-from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, StreamToken, UserID
+from synapse.types import (
+    DeviceListUpdates,
+    JsonDict,
+    JsonMapping,
+    SlidingSyncStreamToken,
+    StreamToken,
+    UserID,
+)
 from synapse.types.rest.client import SlidingSyncBody
 
 if TYPE_CHECKING:
@@ -329,7 +336,7 @@ class SlidingSyncResult:
         def __bool__(self) -> bool:
             return bool(self.to_device or self.e2ee)
 
-    next_pos: StreamToken
+    next_pos: SlidingSyncStreamToken
     lists: Dict[str, SlidingWindowList]
     rooms: Dict[str, RoomResult]
     extensions: Extensions
@@ -342,7 +349,7 @@ class SlidingSyncResult:
         return bool(self.lists or self.rooms or self.extensions)
 
     @staticmethod
-    def empty(next_pos: StreamToken) -> "SlidingSyncResult":
+    def empty(next_pos: SlidingSyncStreamToken) -> "SlidingSyncResult":
         "Return a new empty result"
         return SlidingSyncResult(
             next_pos=next_pos,
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 2628869de6..65c5f8ccae 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -21,7 +21,7 @@
 import json
 import logging
 from http import HTTPStatus
-from typing import Any, Dict, Iterable, List
+from typing import Any, Dict, Iterable, List, Optional, Tuple
 
 from parameterized import parameterized, parameterized_class
 
@@ -50,7 +50,14 @@ from synapse.rest.client import (
     sync,
 )
 from synapse.server import HomeServer
-from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID
+from synapse.types import (
+    JsonDict,
+    RoomStreamToken,
+    SlidingSyncStreamToken,
+    StreamKeyType,
+    StreamToken,
+    UserID,
+)
 from synapse.util import Clock
 
 from tests import unittest
@@ -1225,7 +1232,43 @@ class ExcludeRoomTestCase(unittest.HomeserverTestCase):
         self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"])
 
 
-class SlidingSyncTestCase(unittest.HomeserverTestCase):
+class SlidingSyncBase(unittest.HomeserverTestCase):
+    """Base class for sliding sync test cases"""
+
+    sync_endpoint = "/_matrix/client/unstable/org.matrix.simplified_msc3575/sync"
+
+    def do_sync(
+        self, sync_body: JsonDict, *, since: Optional[str] = None, tok: str
+    ) -> Tuple[JsonDict, str]:
+        """Do a sliding sync request with given body.
+
+        Asserts the request was successful.
+
+        Attributes:
+            sync_body: The full request body to use
+            since: Optional since token
+            tok: Access token to use
+
+        Returns:
+            A tuple of the response body and the `pos` field.
+        """
+
+        sync_path = self.sync_endpoint
+        if since:
+            sync_path += f"?pos={since}"
+
+        channel = self.make_request(
+            method="POST",
+            path=sync_path,
+            content=sync_body,
+            access_token=tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        return channel.json_body, channel.json_body["pos"]
+
+
+class SlidingSyncTestCase(SlidingSyncBase):
     """
     Tests regarding MSC3575 Sliding Sync `/sync` endpoint.
     """
@@ -1246,10 +1289,6 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
-        self.sync_endpoint = (
-            "/_matrix/client/unstable/org.matrix.simplified_msc3575/sync"
-        )
-        self.store = hs.get_datastores().main
         self.event_sources = hs.get_event_sources()
         self.storage_controllers = hs.get_storage_controllers()
         self.account_data_handler = hs.get_account_data_handler()
@@ -1496,7 +1535,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         )
 
         future_position_token_serialized = self.get_success(
-            future_position_token.to_string(self.store)
+            SlidingSyncStreamToken(future_position_token, 0).to_string(self.store)
         )
 
         # Make the Sliding Sync request
@@ -1544,23 +1583,22 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
         self.helper.join(room_id, user1_id, tok=user1_tok)
 
-        from_token = self.event_sources.get_current_token()
+        sync_body = {
+            "lists": {
+                "foo-list": {
+                    "ranges": [[0, 0]],
+                    "required_state": [],
+                    "timeline_limit": 1,
+                }
+            }
+        }
+        _, from_token = self.do_sync(sync_body, tok=user1_tok)
 
         # Make the Sliding Sync request
         channel = self.make_request(
             "POST",
-            self.sync_endpoint
-            + "?timeout=10000"
-            + f"&pos={self.get_success(from_token.to_string(self.store))}",
-            {
-                "lists": {
-                    "foo-list": {
-                        "ranges": [[0, 0]],
-                        "required_state": [],
-                        "timeline_limit": 1,
-                    }
-                }
-            },
+            self.sync_endpoint + f"?timeout=10000&pos={from_token}",
+            content=sync_body,
             access_token=user1_tok,
             await_result=False,
         )
@@ -2771,7 +2809,20 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
             room_id1, "activity before token2", tok=user2_tok
         )
 
-        from_token = self.event_sources.get_current_token()
+        # The `timeline_limit` is set to 4 so we can at least see one historical event
+        # before the `from_token`. We should see historical events because this is a
+        # `newly_joined` room.
+        timeline_limit = 4
+        sync_body = {
+            "lists": {
+                "foo-list": {
+                    "ranges": [[0, 1]],
+                    "required_state": [],
+                    "timeline_limit": timeline_limit,
+                }
+            }
+        }
+        _, from_token = self.do_sync(sync_body, tok=user1_tok)
 
         # Join the room after the `from_token` which will make us consider this room as
         # `newly_joined`.
@@ -2786,24 +2837,11 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
             room_id1, "activity after token4", tok=user2_tok
         )
 
-        # The `timeline_limit` is set to 4 so we can at least see one historical event
-        # before the `from_token`. We should see historical events because this is a
-        # `newly_joined` room.
-        timeline_limit = 4
         # Make an incremental Sliding Sync request (what we're trying to test)
         channel = self.make_request(
             "POST",
-            self.sync_endpoint
-            + f"?pos={self.get_success(from_token.to_string(self.store))}",
-            {
-                "lists": {
-                    "foo-list": {
-                        "ranges": [[0, 1]],
-                        "required_state": [],
-                        "timeline_limit": timeline_limit,
-                    }
-                }
-            },
+            self.sync_endpoint + f"?pos={from_token}",
+            content=sync_body,
             access_token=user1_tok,
         )
         self.assertEqual(channel.code, 200, channel.json_body)
@@ -2980,7 +3018,16 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         self.helper.send(room_id1, "activity after invite3", tok=user2_tok)
         self.helper.send(room_id1, "activity after invite4", tok=user2_tok)
 
-        from_token = self.event_sources.get_current_token()
+        sync_body = {
+            "lists": {
+                "foo-list": {
+                    "ranges": [[0, 1]],
+                    "required_state": [],
+                    "timeline_limit": 3,
+                }
+            }
+        }
+        _, from_token = self.do_sync(sync_body, tok=user1_tok)
 
         self.helper.send(room_id1, "activity after token5", tok=user2_tok)
         self.helper.send(room_id1, "activity after toekn6", tok=user2_tok)
@@ -2988,17 +3035,8 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         # Make the Sliding Sync request
         channel = self.make_request(
             "POST",
-            self.sync_endpoint
-            + f"?pos={self.get_success(from_token.to_string(self.store))}",
-            {
-                "lists": {
-                    "foo-list": {
-                        "ranges": [[0, 1]],
-                        "required_state": [],
-                        "timeline_limit": 3,
-                    }
-                }
-            },
+            self.sync_endpoint + f"?pos={from_token}",
+            content=sync_body,
             access_token=user1_tok,
         )
         self.assertEqual(channel.code, 200, channel.json_body)
@@ -3237,7 +3275,17 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         self.helper.send(room_id1, "activity after invite3", tok=user2_tok)
         self.helper.send(room_id1, "activity after invite4", tok=user2_tok)
 
-        from_token = self.event_sources.get_current_token()
+        sync_body = {
+            "lists": {
+                "foo-list": {
+                    "ranges": [[0, 1]],
+                    "required_state": [],
+                    # Large enough to see the latest events and before the invite
+                    "timeline_limit": 4,
+                }
+            }
+        }
+        _, from_token = self.do_sync(sync_body, tok=user1_tok)
 
         self.helper.send(room_id1, "activity after token5", tok=user2_tok)
         self.helper.send(room_id1, "activity after toekn6", tok=user2_tok)
@@ -3245,18 +3293,8 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         # Make the Sliding Sync request
         channel = self.make_request(
             "POST",
-            self.sync_endpoint
-            + f"?pos={self.get_success(from_token.to_string(self.store))}",
-            {
-                "lists": {
-                    "foo-list": {
-                        "ranges": [[0, 1]],
-                        "required_state": [],
-                        # Large enough to see the latest events and before the invite
-                        "timeline_limit": 4,
-                    }
-                }
-            },
+            self.sync_endpoint + f"?pos={from_token}",
+            content=sync_body,
             access_token=user1_tok,
         )
         self.assertEqual(channel.code, 200, channel.json_body)
@@ -3402,7 +3440,16 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         self.helper.send(room_id1, "activity before2", tok=user2_tok)
         self.helper.join(room_id1, user1_id, tok=user1_tok)
 
-        from_token = self.event_sources.get_current_token()
+        sync_body = {
+            "lists": {
+                "foo-list": {
+                    "ranges": [[0, 1]],
+                    "required_state": [],
+                    "timeline_limit": 4,
+                }
+            }
+        }
+        _, from_token = self.do_sync(sync_body, tok=user1_tok)
 
         event_response3 = self.helper.send(room_id1, "activity after3", tok=user2_tok)
         event_response4 = self.helper.send(room_id1, "activity after4", tok=user2_tok)
@@ -3418,17 +3465,8 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         # Make the Sliding Sync request
         channel = self.make_request(
             "POST",
-            self.sync_endpoint
-            + f"?pos={self.get_success(from_token.to_string(self.store))}",
-            {
-                "lists": {
-                    "foo-list": {
-                        "ranges": [[0, 1]],
-                        "required_state": [],
-                        "timeline_limit": 4,
-                    }
-                }
-            },
+            self.sync_endpoint + f"?pos={from_token}",
+            content=sync_body,
             access_token=user1_tok,
         )
         self.assertEqual(channel.code, 200, channel.json_body)
@@ -3479,24 +3517,24 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
 
         self.helper.send(room_id1, "activity after3", tok=user2_tok)
 
-        from_token = self.event_sources.get_current_token()
+        sync_body = {
+            "lists": {
+                "foo-list": {
+                    "ranges": [[0, 1]],
+                    "required_state": [],
+                    "timeline_limit": 4,
+                }
+            }
+        }
+        _, from_token = self.do_sync(sync_body, tok=user1_tok)
 
         self.helper.send(room_id1, "activity after4", tok=user2_tok)
 
         # Make the Sliding Sync request
         channel = self.make_request(
             "POST",
-            self.sync_endpoint
-            + f"?pos={self.get_success(from_token.to_string(self.store))}",
-            {
-                "lists": {
-                    "foo-list": {
-                        "ranges": [[0, 1]],
-                        "required_state": [],
-                        "timeline_limit": 4,
-                    }
-                }
-            },
+            self.sync_endpoint + f"?pos={from_token}",
+            content=sync_body,
             access_token=user1_tok,
         )
         self.assertEqual(channel.code, 200, channel.json_body)
@@ -3614,27 +3652,27 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
         self.helper.join(room_id1, user1_id, tok=user1_tok)
 
-        after_room_token = self.event_sources.get_current_token()
+        sync_body = {
+            "lists": {
+                "foo-list": {
+                    "ranges": [[0, 1]],
+                    "required_state": [
+                        [EventTypes.Create, ""],
+                        [EventTypes.RoomHistoryVisibility, ""],
+                        # This one doesn't exist in the room
+                        [EventTypes.Tombstone, ""],
+                    ],
+                    "timeline_limit": 0,
+                }
+            }
+        }
+        _, after_room_token = self.do_sync(sync_body, tok=user1_tok)
 
         # Make the Sliding Sync request
         channel = self.make_request(
             "POST",
-            self.sync_endpoint
-            + f"?pos={self.get_success(after_room_token.to_string(self.store))}",
-            {
-                "lists": {
-                    "foo-list": {
-                        "ranges": [[0, 1]],
-                        "required_state": [
-                            [EventTypes.Create, ""],
-                            [EventTypes.RoomHistoryVisibility, ""],
-                            # This one doesn't exist in the room
-                            [EventTypes.Tombstone, ""],
-                        ],
-                        "timeline_limit": 0,
-                    }
-                }
-            },
+            self.sync_endpoint + f"?pos={after_room_token}",
+            content=sync_body,
             access_token=user1_tok,
         )
         self.assertEqual(channel.code, 200, channel.json_body)
@@ -3966,7 +4004,20 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         user3_id = self.register_user("user3", "pass")
         user3_tok = self.login(user3_id, "pass")
 
-        from_token = self.event_sources.get_current_token()
+        sync_body = {
+            "lists": {
+                "foo-list": {
+                    "ranges": [[0, 1]],
+                    "required_state": [
+                        [EventTypes.Create, ""],
+                        [EventTypes.Member, "*"],
+                        ["org.matrix.foo_state", ""],
+                    ],
+                    "timeline_limit": 3,
+                }
+            }
+        }
+        _, from_token = self.do_sync(sync_body, tok=user1_tok)
 
         room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
         self.helper.join(room_id1, user1_id, tok=user1_tok)
@@ -4004,21 +4055,8 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         # Make the Sliding Sync request with lazy loading for the room members
         channel = self.make_request(
             "POST",
-            self.sync_endpoint
-            + f"?pos={self.get_success(from_token.to_string(self.store))}",
-            {
-                "lists": {
-                    "foo-list": {
-                        "ranges": [[0, 1]],
-                        "required_state": [
-                            [EventTypes.Create, ""],
-                            [EventTypes.Member, "*"],
-                            ["org.matrix.foo_state", ""],
-                        ],
-                        "timeline_limit": 3,
-                    }
-                }
-            },
+            self.sync_endpoint + f"?pos={from_token}",
+            content=sync_body,
             access_token=user1_tok,
         )
         self.assertEqual(channel.code, 200, channel.json_body)
@@ -4468,7 +4506,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         )
 
 
-class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase):
+class SlidingSyncToDeviceExtensionTestCase(SlidingSyncBase):
     """Tests for the to-device sliding sync extension"""
 
     servlets = [
@@ -4714,22 +4752,21 @@ class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase):
         user2_id = self.register_user("u2", "pass")
         user2_tok = self.login(user2_id, "pass", "d2")
 
-        from_token = self.event_sources.get_current_token()
+        sync_body = {
+            "lists": {},
+            "extensions": {
+                "to_device": {
+                    "enabled": True,
+                }
+            },
+        }
+        _, from_token = self.do_sync(sync_body, tok=user1_tok)
 
         # Make the Sliding Sync request
         channel = self.make_request(
             "POST",
-            self.sync_endpoint
-            + "?timeout=10000"
-            + f"&pos={self.get_success(from_token.to_string(self.store))}",
-            {
-                "lists": {},
-                "extensions": {
-                    "to_device": {
-                        "enabled": True,
-                    }
-                },
-            },
+            self.sync_endpoint + "?timeout=10000" + f"&pos={from_token}",
+            content=sync_body,
             access_token=user1_tok,
             await_result=False,
         )
@@ -4765,22 +4802,21 @@ class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase):
         user1_id = self.register_user("user1", "pass")
         user1_tok = self.login(user1_id, "pass")
 
-        from_token = self.event_sources.get_current_token()
+        sync_body = {
+            "lists": {},
+            "extensions": {
+                "to_device": {
+                    "enabled": True,
+                }
+            },
+        }
+        _, from_token = self.do_sync(sync_body, tok=user1_tok)
 
         # Make the Sliding Sync request
         channel = self.make_request(
             "POST",
-            self.sync_endpoint
-            + "?timeout=10000"
-            + f"&pos={self.get_success(from_token.to_string(self.store))}",
-            {
-                "lists": {},
-                "extensions": {
-                    "to_device": {
-                        "enabled": True,
-                    }
-                },
-            },
+            self.sync_endpoint + "?timeout=10000" + f"&pos={from_token}",
+            content=sync_body,
             access_token=user1_tok,
             await_result=False,
         )
@@ -4801,7 +4837,7 @@ class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase):
         self._assert_to_device_response(channel, [])
 
 
-class SlidingSyncE2eeExtensionTestCase(unittest.HomeserverTestCase):
+class SlidingSyncE2eeExtensionTestCase(SlidingSyncBase):
     """Tests for the e2ee sliding sync extension"""
 
     servlets = [
@@ -4924,21 +4960,21 @@ class SlidingSyncE2eeExtensionTestCase(unittest.HomeserverTestCase):
         user1_id = self.register_user("user1", "pass")
         user1_tok = self.login(user1_id, "pass")
 
-        from_token = self.event_sources.get_current_token()
+        sync_body = {
+            "lists": {},
+            "extensions": {
+                "e2ee": {
+                    "enabled": True,
+                }
+            },
+        }
+        _, from_token = self.do_sync(sync_body, tok=user1_tok)
 
         # Make an incremental Sliding Sync request with the e2ee extension enabled
         channel = self.make_request(
             "POST",
-            self.sync_endpoint
-            + f"?pos={self.get_success(from_token.to_string(self.store))}",
-            {
-                "lists": {},
-                "extensions": {
-                    "e2ee": {
-                        "enabled": True,
-                    }
-                },
-            },
+            self.sync_endpoint + f"?pos={from_token}",
+            content=sync_body,
             access_token=user1_tok,
         )
         self.assertEqual(channel.code, 200, channel.json_body)
@@ -4992,22 +5028,21 @@ class SlidingSyncE2eeExtensionTestCase(unittest.HomeserverTestCase):
         self.helper.join(room_id, user1_id, tok=user1_tok)
         self.helper.join(room_id, user3_id, tok=user3_tok)
 
-        from_token = self.event_sources.get_current_token()
+        sync_body = {
+            "lists": {},
+            "extensions": {
+                "e2ee": {
+                    "enabled": True,
+                }
+            },
+        }
+        _, from_token = self.do_sync(sync_body, tok=user1_tok)
 
         # Make the Sliding Sync request
         channel = self.make_request(
             "POST",
-            self.sync_endpoint
-            + "?timeout=10000"
-            + f"&pos={self.get_success(from_token.to_string(self.store))}",
-            {
-                "lists": {},
-                "extensions": {
-                    "e2ee": {
-                        "enabled": True,
-                    }
-                },
-            },
+            self.sync_endpoint + "?timeout=10000" + f"&pos={from_token}",
+            content=sync_body,
             access_token=user1_tok,
             await_result=False,
         )
@@ -5053,22 +5088,21 @@ class SlidingSyncE2eeExtensionTestCase(unittest.HomeserverTestCase):
         user1_id = self.register_user("user1", "pass")
         user1_tok = self.login(user1_id, "pass")
 
-        from_token = self.event_sources.get_current_token()
+        sync_body = {
+            "lists": {},
+            "extensions": {
+                "e2ee": {
+                    "enabled": True,
+                }
+            },
+        }
+        _, from_token = self.do_sync(sync_body, tok=user1_tok)
 
         # Make the Sliding Sync request
         channel = self.make_request(
             "POST",
-            self.sync_endpoint
-            + "?timeout=10000"
-            + f"&pos={self.get_success(from_token.to_string(self.store))}",
-            {
-                "lists": {},
-                "extensions": {
-                    "e2ee": {
-                        "enabled": True,
-                    }
-                },
-            },
+            self.sync_endpoint + f"?timeout=10000&pos={from_token}",
+            content=sync_body,
             access_token=user1_tok,
             await_result=False,
         )
@@ -5138,7 +5172,15 @@ class SlidingSyncE2eeExtensionTestCase(unittest.HomeserverTestCase):
         self.helper.join(room_id, user3_id, tok=user3_tok)
         self.helper.join(room_id, user4_id, tok=user4_tok)
 
-        from_token = self.event_sources.get_current_token()
+        sync_body = {
+            "lists": {},
+            "extensions": {
+                "e2ee": {
+                    "enabled": True,
+                }
+            },
+        }
+        _, from_token = self.do_sync(sync_body, tok=user1_tok)
 
         # Have user3 update their device list
         channel = self.make_request(
@@ -5157,16 +5199,8 @@ class SlidingSyncE2eeExtensionTestCase(unittest.HomeserverTestCase):
         # Make an incremental Sliding Sync request with the e2ee extension enabled
         channel = self.make_request(
             "POST",
-            self.sync_endpoint
-            + f"?pos={self.get_success(from_token.to_string(self.store))}",
-            {
-                "lists": {},
-                "extensions": {
-                    "e2ee": {
-                        "enabled": True,
-                    }
-                },
-            },
+            self.sync_endpoint + f"?pos={from_token}",
+            content=sync_body,
             access_token=user1_tok,
         )
         self.assertEqual(channel.code, 200, channel.json_body)