summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11171.misc1
-rw-r--r--synapse/app/admin_cmd.py14
-rw-r--r--synapse/handlers/admin.py22
-rw-r--r--tests/handlers/test_admin.py35
-rw-r--r--tests/rest/client/utils.py29
5 files changed, 99 insertions, 2 deletions
diff --git a/changelog.d/11171.misc b/changelog.d/11171.misc
new file mode 100644
index 0000000000..b6a41a96da
--- /dev/null
+++ b/changelog.d/11171.misc
@@ -0,0 +1 @@
+Add knock information in admin export. Contributed by Rafael Gonçalves.
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 2fc848596d..ad20b1d6aa 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -145,6 +145,20 @@ class FileExfiltrationWriter(ExfiltrationWriter):
             for event in state.values():
                 print(json.dumps(event), file=f)
 
+    def write_knock(self, room_id, event, state):
+        self.write_events(room_id, [event])
+
+        # We write the knock state somewhere else as they aren't full events
+        # and are only a subset of the state at the event.
+        room_directory = os.path.join(self.base_directory, "rooms", room_id)
+        os.makedirs(room_directory, exist_ok=True)
+
+        knock_state = os.path.join(room_directory, "knock_state")
+
+        with open(knock_state, "a") as f:
+            for event in state.values():
+                print(json.dumps(event), file=f)
+
     def finished(self):
         return self.base_directory
 
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index a53cd62d3c..be3203ac80 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -90,6 +90,7 @@ class AdminHandler:
                 Membership.LEAVE,
                 Membership.BAN,
                 Membership.INVITE,
+                Membership.KNOCK,
             ),
         )
 
@@ -122,6 +123,13 @@ class AdminHandler:
                         invited_state = invite.unsigned["invite_room_state"]
                         writer.write_invite(room_id, invite, invited_state)
 
+                if room.membership == Membership.KNOCK:
+                    event_id = room.event_id
+                    knock = await self.store.get_event(event_id, allow_none=True)
+                    if knock:
+                        knock_state = knock.unsigned["knock_room_state"]
+                        writer.write_knock(room_id, knock, knock_state)
+
                 continue
 
             # We only want to bother fetching events up to the last time they
@@ -239,6 +247,20 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
     @abc.abstractmethod
+    def write_knock(
+        self, room_id: str, event: EventBase, state: StateMap[dict]
+    ) -> None:
+        """Write a knock for the room, with associated knock state.
+
+        Args:
+            room_id: The room ID the knock is for.
+            event: The knock event.
+            state: A subset of the state at the knock, with a subset of the
+                event keys (type, state_key content and sender).
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
     def finished(self) -> Any:
         """Called when all data has successfully been exported and written.
 
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index 59de1142b1..abf2a0fe0d 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -17,8 +17,9 @@ from unittest.mock import Mock
 
 import synapse.rest.admin
 import synapse.storage
-from synapse.api.constants import EventTypes
-from synapse.rest.client import login, room
+from synapse.api.constants import EventTypes, JoinRules
+from synapse.api.room_versions import RoomVersions
+from synapse.rest.client import knock, login, room
 
 from tests import unittest
 
@@ -28,6 +29,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         login.register_servlets,
         room.register_servlets,
+        knock.register_servlets,
     ]
 
     def prepare(self, reactor, clock, hs):
@@ -201,3 +203,32 @@ class ExfiltrateData(unittest.HomeserverTestCase):
         self.assertEqual(args[0], room_id)
         self.assertEqual(args[1].content["membership"], "invite")
         self.assertTrue(args[2])  # Assert there is at least one bit of state
+
+    def test_knock(self):
+        """Tests that knock get handled correctly."""
+        # create a knockable v7 room
+        room_id = self.helper.create_room_as(
+            self.user1, room_version=RoomVersions.V7.identifier, tok=self.token1
+        )
+        self.helper.send_state(
+            room_id,
+            EventTypes.JoinRules,
+            {"join_rule": JoinRules.KNOCK},
+            tok=self.token1,
+        )
+
+        self.helper.send(room_id, body="Hello!", tok=self.token1)
+        self.helper.knock(room_id, self.user2, tok=self.token2)
+
+        writer = Mock()
+
+        self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+        writer.write_events.assert_not_called()
+        writer.write_state.assert_not_called()
+        writer.write_knock.assert_called_once()
+
+        args = writer.write_knock.call_args[0]
+        self.assertEqual(args[0], room_id)
+        self.assertEqual(args[1].content["membership"], "knock")
+        self.assertTrue(args[2])  # Assert there is at least one bit of state
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 71fa87ce92..ec0979850b 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -120,6 +120,35 @@ class RestHelper:
             expect_code=expect_code,
         )
 
+    def knock(self, room=None, user=None, reason=None, expect_code=200, tok=None):
+        temp_id = self.auth_user_id
+        self.auth_user_id = user
+        path = "/knock/%s" % room
+        if tok:
+            path = path + "?access_token=%s" % tok
+
+        data = {}
+        if reason:
+            data["reason"] = reason
+
+        channel = make_request(
+            self.hs.get_reactor(),
+            self.site,
+            "POST",
+            path,
+            json.dumps(data).encode("utf8"),
+        )
+
+        assert (
+            int(channel.result["code"]) == expect_code
+        ), "Expected: %d, got: %d, resp: %r" % (
+            expect_code,
+            int(channel.result["code"]),
+            channel.result["body"],
+        )
+
+        self.auth_user_id = temp_id
+
     def leave(self, room=None, user=None, expect_code=200, tok=None):
         self.change_membership(
             room=room,