summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11866.feature1
-rw-r--r--synapse/handlers/room_member.py13
-rw-r--r--synapse/rest/client/room.py34
-rw-r--r--tests/rest/client/test_rooms.py119
4 files changed, 152 insertions, 15 deletions
diff --git a/changelog.d/11866.feature b/changelog.d/11866.feature
new file mode 100644
index 0000000000..0b52caf805
--- /dev/null
+++ b/changelog.d/11866.feature
@@ -0,0 +1 @@
+Allow application services to set the `origin_server_ts` of a state event by providing the query parameter `ts` in `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}/{stateKey}`, per [MSC3316](https://github.com/matrix-org/matrix-doc/pull/3316). Contributed by @lukasdenk.
\ No newline at end of file
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index ee669eb30f..6ad2b38b8f 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -322,6 +322,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         require_consent: bool = True,
         outlier: bool = False,
         historical: bool = False,
+        origin_server_ts: Optional[int] = None,
     ) -> Tuple[str, int]:
         """
         Internal membership update function to get an existing event or create
@@ -361,6 +362,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             historical: Indicates whether the message is being inserted
                 back in time around some existing events. This is used to skip
                 a few checks and mark the event as backfilled.
+            origin_server_ts: The origin_server_ts to use if a new event is created. Uses
+                the current timestamp if set to None.
 
         Returns:
             Tuple of event ID and stream ordering position
@@ -399,6 +402,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 "state_key": user_id,
                 # For backwards compatibility:
                 "membership": membership,
+                "origin_server_ts": origin_server_ts,
             },
             txn_id=txn_id,
             allow_no_prev_events=allow_no_prev_events,
@@ -504,6 +508,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         prev_event_ids: Optional[List[str]] = None,
         state_event_ids: Optional[List[str]] = None,
         depth: Optional[int] = None,
+        origin_server_ts: Optional[int] = None,
     ) -> Tuple[str, int]:
         """Update a user's membership in a room.
 
@@ -542,6 +547,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             depth: Override the depth used to order the event in the DAG.
                 Should normally be set to None, which will cause the depth to be calculated
                 based on the prev_events.
+            origin_server_ts: The origin_server_ts to use if a new event is created. Uses
+                the current timestamp if set to None.
 
         Returns:
             A tuple of the new event ID and stream ID.
@@ -583,6 +590,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                         prev_event_ids=prev_event_ids,
                         state_event_ids=state_event_ids,
                         depth=depth,
+                        origin_server_ts=origin_server_ts,
                     )
 
         return result
@@ -606,6 +614,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         prev_event_ids: Optional[List[str]] = None,
         state_event_ids: Optional[List[str]] = None,
         depth: Optional[int] = None,
+        origin_server_ts: Optional[int] = None,
     ) -> Tuple[str, int]:
         """Helper for update_membership.
 
@@ -646,6 +655,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             depth: Override the depth used to order the event in the DAG.
                 Should normally be set to None, which will cause the depth to be calculated
                 based on the prev_events.
+            origin_server_ts: The origin_server_ts to use if a new event is created. Uses
+                the current timestamp if set to None.
 
         Returns:
             A tuple of the new event ID and stream ID.
@@ -785,6 +796,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 require_consent=require_consent,
                 outlier=outlier,
                 historical=historical,
+                origin_server_ts=origin_server_ts,
             )
 
         latest_event_ids = await self.store.get_prev_events_for_room(room_id)
@@ -1030,6 +1042,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             content=content,
             require_consent=require_consent,
             outlier=outlier,
+            origin_server_ts=origin_server_ts,
         )
 
     async def _should_perform_remote_join(
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 0bca012535..b6dedbed04 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -268,15 +268,9 @@ class RoomStateEventRestServlet(TransactionRestServlet):
 
         content = parse_json_object_from_request(request)
 
-        event_dict = {
-            "type": event_type,
-            "content": content,
-            "room_id": room_id,
-            "sender": requester.user.to_string(),
-        }
-
-        if state_key is not None:
-            event_dict["state_key"] = state_key
+        origin_server_ts = None
+        if requester.app_service:
+            origin_server_ts = parse_integer(request, "ts")
 
         try:
             if event_type == EventTypes.Member:
@@ -287,8 +281,22 @@ class RoomStateEventRestServlet(TransactionRestServlet):
                     room_id=room_id,
                     action=membership,
                     content=content,
+                    origin_server_ts=origin_server_ts,
                 )
             else:
+                event_dict: JsonDict = {
+                    "type": event_type,
+                    "content": content,
+                    "room_id": room_id,
+                    "sender": requester.user.to_string(),
+                }
+
+                if state_key is not None:
+                    event_dict["state_key"] = state_key
+
+                if origin_server_ts is not None:
+                    event_dict["origin_server_ts"] = origin_server_ts
+
                 (
                     event,
                     _,
@@ -333,10 +341,10 @@ class RoomSendEventRestServlet(TransactionRestServlet):
             "sender": requester.user.to_string(),
         }
 
-        # Twisted will have processed the args by now.
-        assert request.args is not None
-        if b"ts" in request.args and requester.app_service:
-            event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
+        if requester.app_service:
+            origin_server_ts = parse_integer(request, "ts")
+            if origin_server_ts is not None:
+                event_dict["origin_server_ts"] = origin_server_ts
 
         try:
             (
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 7f8cf4fab0..5e66b5b26c 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -20,7 +20,7 @@
 import json
 from http import HTTPStatus
 from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
-from unittest.mock import Mock, call
+from unittest.mock import Mock, call, patch
 from urllib import parse as urlparse
 
 from parameterized import param, parameterized
@@ -39,9 +39,10 @@ from synapse.api.constants import (
     RoomTypes,
 )
 from synapse.api.errors import Codes, HttpResponseException
+from synapse.appservice import ApplicationService
 from synapse.handlers.pagination import PurgeStatus
 from synapse.rest import admin
-from synapse.rest.client import account, directory, login, profile, room, sync
+from synapse.rest.client import account, directory, login, profile, register, room, sync
 from synapse.server import HomeServer
 from synapse.types import JsonDict, RoomAlias, UserID, create_requester
 from synapse.util import Clock
@@ -1252,6 +1253,120 @@ class RoomJoinTestCase(RoomBase):
         )
 
 
+class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        room.register_servlets,
+        synapse.rest.admin.register_servlets,
+        register.register_servlets,
+    ]
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.appservice_user, _ = self.register_appservice_user(
+            "as_user_potato", self.appservice.token
+        )
+
+        # Create a room as the appservice user.
+        args = {
+            "access_token": self.appservice.token,
+            "user_id": self.appservice_user,
+        }
+        channel = self.make_request(
+            "POST",
+            f"/_matrix/client/r0/createRoom?{urlparse.urlencode(args)}",
+            content={"visibility": "public"},
+        )
+
+        assert channel.code == 200
+        self.room = channel.json_body["room_id"]
+
+        self.main_store = self.hs.get_datastores().main
+
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        config = self.default_config()
+
+        self.appservice = ApplicationService(
+            token="i_am_an_app_service",
+            id="1234",
+            namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+            # Note: this user does not have to match the regex above
+            sender="@as_main:test",
+        )
+
+        mock_load_appservices = Mock(return_value=[self.appservice])
+        with patch(
+            "synapse.storage.databases.main.appservice.load_appservices",
+            mock_load_appservices,
+        ):
+            hs = self.setup_test_homeserver(config=config)
+        return hs
+
+    def test_send_event_ts(self) -> None:
+        """Test sending a non-state event with a custom timestamp."""
+        ts = 1
+
+        url_params = {
+            "user_id": self.appservice_user,
+            "ts": ts,
+        }
+        channel = self.make_request(
+            "PUT",
+            path=f"/_matrix/client/r0/rooms/{self.room}/send/m.room.message/1234?"
+            + urlparse.urlencode(url_params),
+            content={"body": "test", "msgtype": "m.text"},
+            access_token=self.appservice.token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        event_id = channel.json_body["event_id"]
+
+        # Ensure the event was persisted with the correct timestamp.
+        res = self.get_success(self.main_store.get_event(event_id))
+        self.assertEquals(ts, res.origin_server_ts)
+
+    def test_send_state_event_ts(self) -> None:
+        """Test sending a state event with a custom timestamp."""
+        ts = 1
+
+        url_params = {
+            "user_id": self.appservice_user,
+            "ts": ts,
+        }
+        channel = self.make_request(
+            "PUT",
+            path=f"/_matrix/client/r0/rooms/{self.room}/state/m.room.name?"
+            + urlparse.urlencode(url_params),
+            content={"name": "test"},
+            access_token=self.appservice.token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        event_id = channel.json_body["event_id"]
+
+        # Ensure the event was persisted with the correct timestamp.
+        res = self.get_success(self.main_store.get_event(event_id))
+        self.assertEquals(ts, res.origin_server_ts)
+
+    def test_send_membership_event_ts(self) -> None:
+        """Test sending a membership event with a custom timestamp."""
+        ts = 1
+
+        url_params = {
+            "user_id": self.appservice_user,
+            "ts": ts,
+        }
+        channel = self.make_request(
+            "PUT",
+            path=f"/_matrix/client/r0/rooms/{self.room}/state/m.room.member/{self.appservice_user}?"
+            + urlparse.urlencode(url_params),
+            content={"membership": "join", "display_name": "test"},
+            access_token=self.appservice.token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        event_id = channel.json_body["event_id"]
+
+        # Ensure the event was persisted with the correct timestamp.
+        res = self.get_success(self.main_store.get_event(event_id))
+        self.assertEquals(ts, res.origin_server_ts)
+
+
 class RoomJoinRatelimitTestCase(RoomBase):
     user_id = "@sid1:red"