summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/17200.misc1
-rw-r--r--synapse/handlers/sync.py65
-rw-r--r--synapse/rest/client/sync.py2
-rw-r--r--tests/events/test_presence_router.py4
-rw-r--r--tests/handlers/test_sync.py81
5 files changed, 128 insertions, 25 deletions
diff --git a/changelog.d/17200.misc b/changelog.d/17200.misc
new file mode 100644
index 0000000000..a02b315041
--- /dev/null
+++ b/changelog.d/17200.misc
@@ -0,0 +1 @@
+Prepare sync handler to be able to return different sync responses (`SyncVersion`).
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 0bef58351c..53fe2a6a53 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -20,6 +20,7 @@
 #
 import itertools
 import logging
+from enum import Enum
 from typing import (
     TYPE_CHECKING,
     AbstractSet,
@@ -112,6 +113,23 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100
 SyncRequestKey = Tuple[Any, ...]
 
 
+class SyncVersion(Enum):
+    """
+    Enum for specifying the version of sync request. This is used to key which type of
+    sync response that we are generating.
+
+    This is different than the `sync_type` you might see used in other code below; which
+    specifies the sub-type sync request (e.g. initial_sync, full_state_sync,
+    incremental_sync) and is really only relevant for the `/sync` v2 endpoint.
+    """
+
+    # These string values are semantically significant because they are used in the the
+    # metrics
+
+    # Traditional `/sync` endpoint
+    SYNC_V2 = "sync_v2"
+
+
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class SyncConfig:
     user: UserID
@@ -309,6 +327,7 @@ class SyncHandler:
         self,
         requester: Requester,
         sync_config: SyncConfig,
+        sync_version: SyncVersion,
         since_token: Optional[StreamToken] = None,
         timeout: int = 0,
         full_state: bool = False,
@@ -316,6 +335,17 @@ class SyncHandler:
         """Get the sync for a client if we have new data for it now. Otherwise
         wait for new data to arrive on the server. If the timeout expires, then
         return an empty sync result.
+
+        Args:
+            requester: The user requesting the sync response.
+            sync_config: Config/info necessary to process the sync request.
+            sync_version: Determines what kind of sync response to generate.
+            since_token: The point in the stream to sync from.
+            timeout: How long to wait for new data to arrive before giving up.
+            full_state: Whether to return the full state for each room.
+
+        Returns:
+            When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
         """
         # If the user is not part of the mau group, then check that limits have
         # not been exceeded (if not part of the group by this point, almost certain
@@ -327,6 +357,7 @@ class SyncHandler:
             sync_config.request_key,
             self._wait_for_sync_for_user,
             sync_config,
+            sync_version,
             since_token,
             timeout,
             full_state,
@@ -338,6 +369,7 @@ class SyncHandler:
     async def _wait_for_sync_for_user(
         self,
         sync_config: SyncConfig,
+        sync_version: SyncVersion,
         since_token: Optional[StreamToken],
         timeout: int,
         full_state: bool,
@@ -363,9 +395,11 @@ class SyncHandler:
         else:
             sync_type = "incremental_sync"
 
+        sync_label = f"{sync_version}:{sync_type}"
+
         context = current_context()
         if context:
-            context.tag = sync_type
+            context.tag = sync_label
 
         # if we have a since token, delete any to-device messages before that token
         # (since we now know that the device has received them)
@@ -384,14 +418,16 @@ class SyncHandler:
             # we are going to return immediately, so don't bother calling
             # notifier.wait_for_events.
             result: SyncResult = await self.current_sync_for_user(
-                sync_config, since_token, full_state=full_state
+                sync_config, sync_version, since_token, full_state=full_state
             )
         else:
             # Otherwise, we wait for something to happen and report it to the user.
             async def current_sync_callback(
                 before_token: StreamToken, after_token: StreamToken
             ) -> SyncResult:
-                return await self.current_sync_for_user(sync_config, since_token)
+                return await self.current_sync_for_user(
+                    sync_config, sync_version, since_token
+                )
 
             result = await self.notifier.wait_for_events(
                 sync_config.user.to_string(),
@@ -416,13 +452,14 @@ class SyncHandler:
                 lazy_loaded = "true"
             else:
                 lazy_loaded = "false"
-            non_empty_sync_counter.labels(sync_type, lazy_loaded).inc()
+            non_empty_sync_counter.labels(sync_label, lazy_loaded).inc()
 
         return result
 
     async def current_sync_for_user(
         self,
         sync_config: SyncConfig,
+        sync_version: SyncVersion,
         since_token: Optional[StreamToken] = None,
         full_state: bool = False,
     ) -> SyncResult:
@@ -431,12 +468,26 @@ class SyncHandler:
         This is a wrapper around `generate_sync_result` which starts an open tracing
         span to track the sync. See `generate_sync_result` for the next part of your
         indoctrination.
+
+        Args:
+            sync_config: Config/info necessary to process the sync request.
+            sync_version: Determines what kind of sync response to generate.
+            since_token: The point in the stream to sync from.p.
+            full_state: Whether to return the full state for each room.
+        Returns:
+            When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
         """
         with start_active_span("sync.current_sync_for_user"):
             log_kv({"since_token": since_token})
-            sync_result = await self.generate_sync_result(
-                sync_config, since_token, full_state
-            )
+            # Go through the `/sync` v2 path
+            if sync_version == SyncVersion.SYNC_V2:
+                sync_result: SyncResult = await self.generate_sync_result(
+                    sync_config, since_token, full_state
+                )
+            else:
+                raise Exception(
+                    f"Unknown sync_version (this is a Synapse problem): {sync_version}"
+                )
 
             set_tag(SynapseTags.SYNC_RESULT, bool(sync_result))
             return sync_result
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index d19aaf0e22..d0713536e1 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -40,6 +40,7 @@ from synapse.handlers.sync import (
     KnockedSyncResult,
     SyncConfig,
     SyncResult,
+    SyncVersion,
 )
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
@@ -232,6 +233,7 @@ class SyncRestServlet(RestServlet):
             sync_result = await self.sync_handler.wait_for_sync_for_user(
                 requester,
                 sync_config,
+                SyncVersion.SYNC_V2,
                 since_token=since_token,
                 timeout=timeout,
                 full_state=full_state,
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index e51cdf01ab..aa67afa695 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -36,7 +36,7 @@ from synapse.server import HomeServer
 from synapse.types import JsonDict, StreamToken, create_requester
 from synapse.util import Clock
 
-from tests.handlers.test_sync import generate_sync_config
+from tests.handlers.test_sync import SyncVersion, generate_sync_config
 from tests.unittest import (
     FederatingHomeserverTestCase,
     HomeserverTestCase,
@@ -521,7 +521,7 @@ def sync_presence(
     sync_config = generate_sync_config(requester.user.to_string())
     sync_result = testcase.get_success(
         testcase.hs.get_sync_handler().wait_for_sync_for_user(
-            requester, sync_config, since_token
+            requester, sync_config, SyncVersion.SYNC_V2, since_token
         )
     )
 
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 2780d29cad..9c12a11e3a 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -31,7 +31,7 @@ from synapse.api.room_versions import RoomVersion, RoomVersions
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.federation.federation_base import event_from_pdu_json
-from synapse.handlers.sync import SyncConfig, SyncResult
+from synapse.handlers.sync import SyncConfig, SyncResult, SyncVersion
 from synapse.rest import admin
 from synapse.rest.client import knock, login, room
 from synapse.server import HomeServer
@@ -73,13 +73,21 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         # Check that the happy case does not throw errors
         self.get_success(self.store.upsert_monthly_active_user(user_id1))
         self.get_success(
-            self.sync_handler.wait_for_sync_for_user(requester, sync_config)
+            self.sync_handler.wait_for_sync_for_user(
+                requester,
+                sync_config,
+                sync_version=SyncVersion.SYNC_V2,
+            )
         )
 
         # Test that global lock works
         self.auth_blocking._hs_disabled = True
         e = self.get_failure(
-            self.sync_handler.wait_for_sync_for_user(requester, sync_config),
+            self.sync_handler.wait_for_sync_for_user(
+                requester,
+                sync_config,
+                sync_version=SyncVersion.SYNC_V2,
+            ),
             ResourceLimitError,
         )
         self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
@@ -90,7 +98,11 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         requester = create_requester(user_id2)
 
         e = self.get_failure(
-            self.sync_handler.wait_for_sync_for_user(requester, sync_config),
+            self.sync_handler.wait_for_sync_for_user(
+                requester,
+                sync_config,
+                sync_version=SyncVersion.SYNC_V2,
+            ),
             ResourceLimitError,
         )
         self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
@@ -109,7 +121,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         requester = create_requester(user)
         initial_result = self.get_success(
             self.sync_handler.wait_for_sync_for_user(
-                requester, sync_config=generate_sync_config(user, device_id="dev")
+                requester,
+                sync_config=generate_sync_config(user, device_id="dev"),
+                sync_version=SyncVersion.SYNC_V2,
             )
         )
 
@@ -140,7 +154,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         # The rooms should appear in the sync response.
         result = self.get_success(
             self.sync_handler.wait_for_sync_for_user(
-                requester, sync_config=generate_sync_config(user)
+                requester,
+                sync_config=generate_sync_config(user),
+                sync_version=SyncVersion.SYNC_V2,
             )
         )
         self.assertIn(joined_room, [r.room_id for r in result.joined])
@@ -152,6 +168,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
             self.sync_handler.wait_for_sync_for_user(
                 requester,
                 sync_config=generate_sync_config(user, device_id="dev"),
+                sync_version=SyncVersion.SYNC_V2,
                 since_token=initial_result.next_batch,
             )
         )
@@ -180,7 +197,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         # Get a new request key.
         result = self.get_success(
             self.sync_handler.wait_for_sync_for_user(
-                requester, sync_config=generate_sync_config(user)
+                requester,
+                sync_config=generate_sync_config(user),
+                sync_version=SyncVersion.SYNC_V2,
             )
         )
         self.assertNotIn(joined_room, [r.room_id for r in result.joined])
@@ -192,6 +211,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
             self.sync_handler.wait_for_sync_for_user(
                 requester,
                 sync_config=generate_sync_config(user, device_id="dev"),
+                sync_version=SyncVersion.SYNC_V2,
                 since_token=initial_result.next_batch,
             )
         )
@@ -231,7 +251,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         # Do a sync as Alice to get the latest event in the room.
         alice_sync_result: SyncResult = self.get_success(
             self.sync_handler.wait_for_sync_for_user(
-                create_requester(owner), generate_sync_config(owner)
+                create_requester(owner),
+                generate_sync_config(owner),
+                sync_version=SyncVersion.SYNC_V2,
             )
         )
         self.assertEqual(len(alice_sync_result.joined), 1)
@@ -251,7 +273,11 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         eve_requester = create_requester(eve)
         eve_sync_config = generate_sync_config(eve)
         eve_sync_after_ban: SyncResult = self.get_success(
-            self.sync_handler.wait_for_sync_for_user(eve_requester, eve_sync_config)
+            self.sync_handler.wait_for_sync_for_user(
+                eve_requester,
+                eve_sync_config,
+                sync_version=SyncVersion.SYNC_V2,
+            )
         )
 
         # Sanity check this sync result. We shouldn't be joined to the room.
@@ -268,6 +294,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
             self.sync_handler.wait_for_sync_for_user(
                 eve_requester,
                 eve_sync_config,
+                sync_version=SyncVersion.SYNC_V2,
                 since_token=eve_sync_after_ban.next_batch,
             )
         )
@@ -279,6 +306,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
             self.sync_handler.wait_for_sync_for_user(
                 eve_requester,
                 eve_sync_config,
+                sync_version=SyncVersion.SYNC_V2,
                 since_token=None,
             )
         )
@@ -310,7 +338,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         # Do an initial sync as Alice to get a known starting point.
         initial_sync_result = self.get_success(
             self.sync_handler.wait_for_sync_for_user(
-                alice_requester, generate_sync_config(alice)
+                alice_requester,
+                generate_sync_config(alice),
+                sync_version=SyncVersion.SYNC_V2,
             )
         )
         last_room_creation_event_id = (
@@ -338,6 +368,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
                         self.hs, {"room": {"timeline": {"limit": 2}}}
                     ),
                 ),
+                sync_version=SyncVersion.SYNC_V2,
                 since_token=initial_sync_result.next_batch,
             )
         )
@@ -380,7 +411,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         # Do an initial sync as Alice to get a known starting point.
         initial_sync_result = self.get_success(
             self.sync_handler.wait_for_sync_for_user(
-                alice_requester, generate_sync_config(alice)
+                alice_requester,
+                generate_sync_config(alice),
+                sync_version=SyncVersion.SYNC_V2,
             )
         )
         last_room_creation_event_id = (
@@ -418,6 +451,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
                         },
                     ),
                 ),
+                sync_version=SyncVersion.SYNC_V2,
                 since_token=initial_sync_result.next_batch,
             )
         )
@@ -461,7 +495,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         # Do an initial sync as Alice to get a known starting point.
         initial_sync_result = self.get_success(
             self.sync_handler.wait_for_sync_for_user(
-                alice_requester, generate_sync_config(alice)
+                alice_requester,
+                generate_sync_config(alice),
+                sync_version=SyncVersion.SYNC_V2,
             )
         )
         last_room_creation_event_id = (
@@ -486,6 +522,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
                         self.hs, {"room": {"timeline": {"limit": 1}}}
                     ),
                 ),
+                sync_version=SyncVersion.SYNC_V2,
                 since_token=initial_sync_result.next_batch,
             )
         )
@@ -515,6 +552,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
                         self.hs, {"room": {"timeline": {"limit": 1}}}
                     ),
                 ),
+                sync_version=SyncVersion.SYNC_V2,
                 since_token=incremental_sync.next_batch,
             )
         )
@@ -574,7 +612,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         # Do an initial sync to get a known starting point.
         initial_sync_result = self.get_success(
             self.sync_handler.wait_for_sync_for_user(
-                alice_requester, generate_sync_config(alice)
+                alice_requester,
+                generate_sync_config(alice),
+                sync_version=SyncVersion.SYNC_V2,
             )
         )
         last_room_creation_event_id = (
@@ -598,6 +638,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
                         self.hs, {"room": {"timeline": {"limit": 1}}}
                     ),
                 ),
+                sync_version=SyncVersion.SYNC_V2,
             )
         )
         room_sync = initial_sync_result.joined[0]
@@ -618,6 +659,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
             self.sync_handler.wait_for_sync_for_user(
                 alice_requester,
                 generate_sync_config(alice),
+                sync_version=SyncVersion.SYNC_V2,
                 since_token=initial_sync_result.next_batch,
             )
         )
@@ -668,7 +710,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
 
         initial_sync_result = self.get_success(
             self.sync_handler.wait_for_sync_for_user(
-                bob_requester, generate_sync_config(bob)
+                bob_requester,
+                generate_sync_config(bob),
+                sync_version=SyncVersion.SYNC_V2,
             )
         )
 
@@ -699,6 +743,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
                 generate_sync_config(
                     bob, filter_collection=FilterCollection(self.hs, filter_dict)
                 ),
+                sync_version=SyncVersion.SYNC_V2,
                 since_token=None if initial_sync else initial_sync_result.next_batch,
             )
         ).archived[0]
@@ -791,7 +836,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         # but that it does not come down /sync in public room
         sync_result: SyncResult = self.get_success(
             self.sync_handler.wait_for_sync_for_user(
-                create_requester(user), generate_sync_config(user)
+                create_requester(user),
+                generate_sync_config(user),
+                sync_version=SyncVersion.SYNC_V2,
             )
         )
         event_ids = []
@@ -837,7 +884,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
 
         private_sync_result: SyncResult = self.get_success(
             self.sync_handler.wait_for_sync_for_user(
-                create_requester(user2), generate_sync_config(user2)
+                create_requester(user2),
+                generate_sync_config(user2),
+                sync_version=SyncVersion.SYNC_V2,
             )
         )
         priv_event_ids = []