From 010457011c83d4db96d84c43687ee5e291f7685f Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 2 Mar 2022 11:17:36 +0000 Subject: Apply suggestions to changelog --- CHANGES.md | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 5485e8d47e..a81d0a4b14 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -9,14 +9,12 @@ Features -------- - Add support for [MSC3202](https://github.com/matrix-org/matrix-doc/pull/3202): sending one-time key counts and fallback key usage states to Application Services. ([\#11617](https://github.com/matrix-org/synapse/issues/11617)) -- Make a `POST` to `/rooms//receipt/m.read/` only trigger a push notification if the count of unread messages is different to the one in the last successfully sent push. ([\#11835](https://github.com/matrix-org/synapse/issues/11835)) -- Fetch images when previewing Twitter URLs. Contributed by @AndrewRyanChama. ([\#11985](https://github.com/matrix-org/synapse/issues/11985)) +- Improve the preview that is produced when generating URL previews for some web pages. Contributed by @AndrewRyanChama. ([\#11985](https://github.com/matrix-org/synapse/issues/11985)) - Track cache invalidations in Prometheus metrics, as already happens for cache eviction based on size or time. ([\#12000](https://github.com/matrix-org/synapse/issues/12000)) - Implement experimental support for [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) (account status endpoints). ([\#12001](https://github.com/matrix-org/synapse/issues/12001), [\#12067](https://github.com/matrix-org/synapse/issues/12067)) - Enable modules to set a custom display name when registering a user. ([\#12009](https://github.com/matrix-org/synapse/issues/12009)) -- Advertise Matrix 1.1 support on `/_matrix/client/versions`. ([\#12020](https://github.com/matrix-org/synapse/issues/12020)) +- Advertise Matrix 1.1 and 1.2 support on `/_matrix/client/versions`. ([\#12020](https://github.com/matrix-org/synapse/issues/12020), ([\#12022](https://github.com/matrix-org/synapse/issues/12022)) - Support only the stable identifier for [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069)'s `is_guest` on `/_matrix/client/v3/account/whoami`. ([\#12021](https://github.com/matrix-org/synapse/issues/12021)) -- Advertise Matrix 1.2 support on `/_matrix/client/versions`. ([\#12022](https://github.com/matrix-org/synapse/issues/12022)) - Use room version 9 as the default room version (per [MSC3589](https://github.com/matrix-org/matrix-doc/pull/3589)). ([\#12058](https://github.com/matrix-org/synapse/issues/12058)) - Add module callbacks to react to user deactivation status changes (i.e. deactivations and reactivations) and profile updates. ([\#12062](https://github.com/matrix-org/synapse/issues/12062)) @@ -24,16 +22,17 @@ Features Bugfixes -------- -- Fix a bug introduced in Synapse v1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary. ([\#11992](https://github.com/matrix-org/synapse/issues/11992)) -- Fix long standing bug where `get_rooms_for_user` was not correctly invalidated for remote users when the server left a room. ([\#11999](https://github.com/matrix-org/synapse/issues/11999)) +- Fix a bug introduced in Synapse 1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary. ([\#11992](https://github.com/matrix-org/synapse/issues/11992)) +- Fix long-standing bug where `get_rooms_for_user` was not correctly invalidated for remote users when the server left a room. ([\#11999](https://github.com/matrix-org/synapse/issues/11999)) - Fix a 500 error with Postgres when looking backwards with the [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) `/timestamp_to_event?dir=b` endpoint. ([\#12024](https://github.com/matrix-org/synapse/issues/12024)) -- Properly fix a long-standing bug where wrong data could be inserted into the `event_search` table when using SQLite. This could block running `synapse_port_db` with an "argument of type 'int' is not iterable" error. This bug was partially fixed by a change in Synapse 1.44.0. ([\#12037](https://github.com/matrix-org/synapse/issues/12037)) -- Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens. ([\#12056](https://github.com/matrix-org/synapse/issues/12056)) +- Properly fix a long-standing bug where wrong data could be inserted into the `event_search` table when using SQLite. This could block running `synapse_port_db` with an `argument of type 'int' is not iterable` error. This bug was partially fixed by a change in Synapse 1.44.0. ([\#12037](https://github.com/matrix-org/synapse/issues/12037)) +- Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens in version 1.38.0. ([\#12056](https://github.com/matrix-org/synapse/issues/12056)) - Fix a long-standing bug where Synapse would make additional failing requests over federation for missing data. ([\#12077](https://github.com/matrix-org/synapse/issues/12077)) -- Fix occasional 'Unhandled error in Deferred' error message. ([\#12089](https://github.com/matrix-org/synapse/issues/12089)) -- Fix a bug introduced in Synapse 1.51.0rc1 where incoming federation transactions containing at least one EDU would be dropped if debug logging was enabled for `synapse.8631_debug`. ([\#12098](https://github.com/matrix-org/synapse/issues/12098)) +- Fix occasional `Unhandled error in Deferred` error message. ([\#12089](https://github.com/matrix-org/synapse/issues/12089)) +- Fix a bug introduced in Synapse 1.51.0 where incoming federation transactions containing at least one EDU would be dropped if debug logging was enabled for `synapse.8631_debug`. ([\#12098](https://github.com/matrix-org/synapse/issues/12098)) - Fix a long-standing bug which could cause push notifications to malfunction if `use_frozen_dicts` was set in the configuration. ([\#12100](https://github.com/matrix-org/synapse/issues/12100)) - Fix an extremely rare, long-standing bug in `ReadWriteLock` that would cause an error when a newly unblocked writer completes instantly. ([\#12105](https://github.com/matrix-org/synapse/issues/12105)) +- Make a `POST` to `/rooms//receipt/m.read/` only trigger a push notification if the count of unread messages is different to the one in the last successfully sent push. This reduces server load and load on the receiving device. ([\#11835](https://github.com/matrix-org/synapse/issues/11835)) Updates to the Docker image -- cgit 1.4.1 From 6adb89ff007500ea9c41fb5bd1a9e644cc6397cd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 2 Mar 2022 06:56:16 -0500 Subject: Improve and refactor the tests for relations. (#12113) * Modernizes code (f-strings, etc.) * Fixes incorrect comments. * Splits the test case into two. * Factors out some duplicated code. --- changelog.d/12113.misc | 1 + tests/rest/client/test_relations.py | 386 +++++++++++++++++------------------- 2 files changed, 179 insertions(+), 208 deletions(-) create mode 100644 changelog.d/12113.misc diff --git a/changelog.d/12113.misc b/changelog.d/12113.misc new file mode 100644 index 0000000000..102e064053 --- /dev/null +++ b/changelog.d/12113.misc @@ -0,0 +1 @@ +Refactor the tests for event relations. diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index c8db45719e..a087cd7b21 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -34,7 +34,7 @@ from tests.test_utils import make_awaitable from tests.test_utils.event_injection import inject_event -class RelationsTestCase(unittest.HomeserverTestCase): +class BaseRelationsTestCase(unittest.HomeserverTestCase): servlets = [ relations.register_servlets, room.register_servlets, @@ -48,7 +48,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): def default_config(self) -> dict: # We need to enable msc1849 support for aggregations config = super().default_config() - config["experimental_msc1849_support_enabled"] = True # We enable frozen dicts as relations/edits change event contents, so we # want to test that we don't modify the events in the caches. @@ -67,10 +66,62 @@ class RelationsTestCase(unittest.HomeserverTestCase): res = self.helper.send(self.room, body="Hi!", tok=self.user_token) self.parent_id = res["event_id"] - def test_send_relation(self) -> None: - """Tests that sending a relation using the new /send_relation works - creates the right shape of event. + def _create_user(self, localpart: str) -> Tuple[str, str]: + user_id = self.register_user(localpart, "abc123") + access_token = self.login(localpart, "abc123") + + return user_id, access_token + + def _send_relation( + self, + relation_type: str, + event_type: str, + key: Optional[str] = None, + content: Optional[dict] = None, + access_token: Optional[str] = None, + parent_id: Optional[str] = None, + ) -> FakeChannel: + """Helper function to send a relation pointing at `self.parent_id` + + Args: + relation_type: One of `RelationTypes` + event_type: The type of the event to create + key: The aggregation key used for m.annotation relation type. + content: The content of the created event. Will be modified to configure + the m.relates_to key based on the other provided parameters. + access_token: The access token used to send the relation, defaults + to `self.user_token` + parent_id: The event_id this relation relates to. If None, then self.parent_id + + Returns: + FakeChannel """ + if not access_token: + access_token = self.user_token + + original_id = parent_id if parent_id else self.parent_id + + if content is None: + content = {} + content["m.relates_to"] = { + "event_id": original_id, + "rel_type": relation_type, + } + if key is not None: + content["m.relates_to"]["key"] = key + + channel = self.make_request( + "POST", + f"/_matrix/client/v3/rooms/{self.room}/send/{event_type}", + content, + access_token=access_token, + ) + return channel + + +class RelationsTestCase(BaseRelationsTestCase): + def test_send_relation(self) -> None: + """Tests that sending a relation works.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="πŸ‘") self.assertEqual(200, channel.code, channel.json_body) @@ -79,7 +130,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "/rooms/%s/event/%s" % (self.room, event_id), + f"/rooms/{self.room}/event/{event_id}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -317,9 +368,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): # Request /sync, limiting it such that only the latest event is returned # (and not the relation). - filter = urllib.parse.quote_plus( - '{"room": {"timeline": {"limit": 1}}}'.encode() - ) + filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}') channel = self.make_request( "GET", f"/sync?filter={filter}", access_token=self.user_token ) @@ -404,8 +453,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s" - % (self.room, self.parent_id, from_token), + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -544,8 +592,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s" - % (self.room, self.parent_id), + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -560,47 +607,13 @@ class RelationsTestCase(unittest.HomeserverTestCase): }, ) - def test_aggregation_redactions(self) -> None: - """Test that annotations get correctly aggregated after a redaction.""" - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) - to_redact_event_id = channel.json_body["event_id"] - - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self.assertEqual(200, channel.code, channel.json_body) - - # Now lets redact one of the 'a' reactions - channel = self.make_request( - "POST", - "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id), - access_token=self.user_token, - content={}, - ) - self.assertEqual(200, channel.code, channel.json_body) - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s" - % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual( - channel.json_body, - {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, - ) - def test_aggregation_must_be_annotation(self) -> None: """Test that aggregations must be annotations.""" channel = self.make_request( "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1" - % (self.room, self.parent_id, RelationTypes.REPLACE), + f"/_matrix/client/unstable/rooms/{self.room}/aggregations" + f"/{self.parent_id}/{RelationTypes.REPLACE}?limit=1", access_token=self.user_token, ) self.assertEqual(400, channel.code, channel.json_body) @@ -986,9 +999,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): # Request sync, but limit the timeline so it becomes limited (and includes # bundled aggregations). - filter = urllib.parse.quote_plus( - '{"room": {"timeline": {"limit": 2}}}'.encode() - ) + filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 2}}}') channel = self.make_request( "GET", f"/sync?filter={filter}", access_token=self.user_token ) @@ -1053,7 +1064,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), + f"/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -1096,7 +1107,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "/rooms/%s/event/%s" % (self.room, reply), + f"/rooms/{self.room}/event/{reply}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -1198,7 +1209,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): # Request the original event. channel = self.make_request( "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), + f"/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -1217,102 +1228,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - def test_relations_redaction_redacts_edits(self) -> None: - """Test that edits of an event are redacted when the original event - is redacted. - """ - # Send a new event - res = self.helper.send(self.room, body="Heyo!", tok=self.user_token) - original_event_id = res["event_id"] - - # Add a relation - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - parent_id=original_event_id, - content={ - "msgtype": "m.text", - "body": "Wibble", - "m.new_content": {"msgtype": "m.text", "body": "First edit"}, - }, - ) - self.assertEqual(200, channel.code, channel.json_body) - - # Check the relation is returned - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" - % (self.room, original_event_id), - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertIn("chunk", channel.json_body) - self.assertEqual(len(channel.json_body["chunk"]), 1) - - # Redact the original event - channel = self.make_request( - "PUT", - "/rooms/%s/redact/%s/%s" - % (self.room, original_event_id, "test_relations_redaction_redacts_edits"), - access_token=self.user_token, - content="{}", - ) - self.assertEqual(200, channel.code, channel.json_body) - - # Try to check for remaining m.replace relations - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" - % (self.room, original_event_id), - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - # Check that no relations are returned - self.assertIn("chunk", channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) - - def test_aggregations_redaction_prevents_access_to_aggregations(self) -> None: - """Test that annotations of an event are redacted when the original event - is redacted. - """ - # Send a new event - res = self.helper.send(self.room, body="Hello!", tok=self.user_token) - original_event_id = res["event_id"] - - # Add a relation - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", key="πŸ‘", parent_id=original_event_id - ) - self.assertEqual(200, channel.code, channel.json_body) - - # Redact the original - channel = self.make_request( - "PUT", - "/rooms/%s/redact/%s/%s" - % ( - self.room, - original_event_id, - "test_aggregations_redaction_prevents_access_to_aggregations", - ), - access_token=self.user_token, - content="{}", - ) - self.assertEqual(200, channel.code, channel.json_body) - - # Check that aggregations returns zero - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction" - % (self.room, original_event_id), - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertIn("chunk", channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) - def test_unknown_relations(self) -> None: """Unknown relations should be accepted.""" channel = self._send_relation("m.relation.test", "m.room.test") @@ -1321,8 +1236,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" - % (self.room, self.parent_id), + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -1343,7 +1257,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): # When bundling the unknown relation is not included. channel = self.make_request( "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), + f"/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -1352,8 +1266,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): # But unknown relations can be directly queried. channel = self.make_request( "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1" - % (self.room, self.parent_id), + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -1369,58 +1282,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): raise AssertionError(f"Event {self.parent_id} not found in chunk") - def _send_relation( - self, - relation_type: str, - event_type: str, - key: Optional[str] = None, - content: Optional[dict] = None, - access_token: Optional[str] = None, - parent_id: Optional[str] = None, - ) -> FakeChannel: - """Helper function to send a relation pointing at `self.parent_id` - - Args: - relation_type: One of `RelationTypes` - event_type: The type of the event to create - key: The aggregation key used for m.annotation relation type. - content: The content of the created event. Will be modified to configure - the m.relates_to key based on the other provided parameters. - access_token: The access token used to send the relation, defaults - to `self.user_token` - parent_id: The event_id this relation relates to. If None, then self.parent_id - - Returns: - FakeChannel - """ - if not access_token: - access_token = self.user_token - - original_id = parent_id if parent_id else self.parent_id - - if content is None: - content = {} - content["m.relates_to"] = { - "event_id": original_id, - "rel_type": relation_type, - } - if key is not None: - content["m.relates_to"]["key"] = key - - channel = self.make_request( - "POST", - f"/_matrix/client/v3/rooms/{self.room}/send/{event_type}", - content, - access_token=access_token, - ) - return channel - - def _create_user(self, localpart: str) -> Tuple[str, str]: - user_id = self.register_user(localpart, "abc123") - access_token = self.login(localpart, "abc123") - - return user_id, access_token - def test_background_update(self) -> None: """Test the event_arbitrary_relations background update.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="πŸ‘") @@ -1482,3 +1343,112 @@ class RelationsTestCase(unittest.HomeserverTestCase): [ev["event_id"] for ev in channel.json_body["chunk"]], [annotation_event_id_good, thread_event_id], ) + + +class RelationRedactionTestCase(BaseRelationsTestCase): + """Test the behaviour of relations when the parent or child event is redacted.""" + + def _redact(self, event_id: str) -> None: + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room}/redact/{event_id}", + access_token=self.user_token, + content={}, + ) + self.assertEqual(200, channel.code, channel.json_body) + + def test_redact_relation_annotation(self) -> None: + """Test that annotations of an event are properly handled after the + annotation is redacted. + """ + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEqual(200, channel.code, channel.json_body) + to_redact_event_id = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self.assertEqual(200, channel.code, channel.json_body) + + # Redact one of the reactions. + self._redact(to_redact_event_id) + + # Ensure that the aggregations are correct. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + self.assertEqual( + channel.json_body, + {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, + ) + + def test_redact_relation_edit(self) -> None: + """Test that edits of an event are redacted when the original event + is redacted. + """ + # Add a relation + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + parent_id=self.parent_id, + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "First edit"}, + }, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # Check the relation is returned + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations" + f"/{self.parent_id}/m.replace/m.room.message", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + self.assertIn("chunk", channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1) + + # Redact the original event + self._redact(self.parent_id) + + # Try to check for remaining m.replace relations + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations" + f"/{self.parent_id}/m.replace/m.room.message", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # Check that no relations are returned + self.assertIn("chunk", channel.json_body) + self.assertEqual(channel.json_body["chunk"], []) + + def test_redact_parent(self) -> None: + """Test that annotations of an event are redacted when the original event + is redacted. + """ + # Add a relation + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="πŸ‘") + self.assertEqual(200, channel.code, channel.json_body) + + # Redact the original event. + self._redact(self.parent_id) + + # Check that aggregations returns zero + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}/m.annotation/m.reaction", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + self.assertIn("chunk", channel.json_body) + self.assertEqual(channel.json_body["chunk"], []) -- cgit 1.4.1 From 7317b0be82e31d7b41be64e2fea92aad428283d8 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 2 Mar 2022 11:59:53 +0000 Subject: Tweak changelog --- CHANGES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index a81d0a4b14..542390592f 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -23,7 +23,7 @@ Bugfixes -------- - Fix a bug introduced in Synapse 1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary. ([\#11992](https://github.com/matrix-org/synapse/issues/11992)) -- Fix long-standing bug where `get_rooms_for_user` was not correctly invalidated for remote users when the server left a room. ([\#11999](https://github.com/matrix-org/synapse/issues/11999)) +- Fix long-standing bug where the `get_rooms_for_user` cache was not correctly invalidated for remote users when the server left a room. ([\#11999](https://github.com/matrix-org/synapse/issues/11999)) - Fix a 500 error with Postgres when looking backwards with the [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) `/timestamp_to_event?dir=b` endpoint. ([\#12024](https://github.com/matrix-org/synapse/issues/12024)) - Properly fix a long-standing bug where wrong data could be inserted into the `event_search` table when using SQLite. This could block running `synapse_port_db` with an `argument of type 'int' is not iterable` error. This bug was partially fixed by a change in Synapse 1.44.0. ([\#12037](https://github.com/matrix-org/synapse/issues/12037)) - Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens in version 1.38.0. ([\#12056](https://github.com/matrix-org/synapse/issues/12056)) -- cgit 1.4.1 From 3b9142f7f462f23eeb754eca6003f127bcc62271 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 2 Mar 2022 12:09:48 +0000 Subject: Reword changelog line about URL previews --- CHANGES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 542390592f..0a87f5cd42 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -9,7 +9,7 @@ Features -------- - Add support for [MSC3202](https://github.com/matrix-org/matrix-doc/pull/3202): sending one-time key counts and fallback key usage states to Application Services. ([\#11617](https://github.com/matrix-org/synapse/issues/11617)) -- Improve the preview that is produced when generating URL previews for some web pages. Contributed by @AndrewRyanChama. ([\#11985](https://github.com/matrix-org/synapse/issues/11985)) +- Improve the generated URL previews for some web pages. Contributed by @AndrewRyanChama. ([\#11985](https://github.com/matrix-org/synapse/issues/11985)) - Track cache invalidations in Prometheus metrics, as already happens for cache eviction based on size or time. ([\#12000](https://github.com/matrix-org/synapse/issues/12000)) - Implement experimental support for [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) (account status endpoints). ([\#12001](https://github.com/matrix-org/synapse/issues/12001), [\#12067](https://github.com/matrix-org/synapse/issues/12067)) - Enable modules to set a custom display name when registering a user. ([\#12009](https://github.com/matrix-org/synapse/issues/12009)) -- cgit 1.4.1 From f3f0ab10fe766c766dedf9d80e4ef198e3e45c09 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 2 Mar 2022 13:00:16 +0000 Subject: Move scripts directory inside synapse, exposing as setuptools entry_points (#12118) * Two scripts are basically entry_points already * Move and rename scripts/* to synapse/_scripts/*.py * Delete sync_room_to_group.pl * Expose entry points in setup.py * Update linter script and config * Fixup scripts & docs mentioning scripts that moved Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- .ci/scripts/test_export_data_command.sh | 4 +- .ci/scripts/test_synapse_port_db.sh | 12 +- .dockerignore | 1 - MANIFEST.in | 1 - changelog.d/12118.misc | 1 + docker/Dockerfile | 1 - docs/development/database_schema.md | 6 +- docs/usage/administration/admin_api/README.md | 2 +- mypy.ini | 4 + scripts-dev/generate_sample_config | 10 +- scripts-dev/lint.sh | 7 - scripts-dev/make_full_schema.sh | 6 +- scripts/export_signing_key | 100 -- scripts/generate_config | 78 -- scripts/generate_log_config | 44 - scripts/generate_signing_key.py | 36 - scripts/hash_password | 79 -- scripts/move_remote_media_to_new_store.py | 118 -- scripts/register_new_matrix_user | 19 - scripts/synapse_port_db | 1253 ------------------- scripts/synapse_review_recent_signups | 19 - scripts/sync_room_to_group.pl | 45 - scripts/update_synapse_database | 117 -- setup.py | 14 +- snap/snapcraft.yaml | 2 +- synapse/_scripts/export_signing_key.py | 103 ++ synapse/_scripts/generate_config.py | 83 ++ synapse/_scripts/generate_log_config.py | 49 + synapse/_scripts/generate_signing_key.py | 41 + synapse/_scripts/hash_password.py | 83 ++ synapse/_scripts/move_remote_media_to_new_store.py | 118 ++ synapse/_scripts/synapse_port_db.py | 1257 ++++++++++++++++++++ synapse/_scripts/update_synapse_database.py | 117 ++ synapse/config/_base.py | 2 +- tox.ini | 8 - 35 files changed, 1891 insertions(+), 1949 deletions(-) create mode 100644 changelog.d/12118.misc delete mode 100755 scripts/export_signing_key delete mode 100755 scripts/generate_config delete mode 100755 scripts/generate_log_config delete mode 100755 scripts/generate_signing_key.py delete mode 100755 scripts/hash_password delete mode 100755 scripts/move_remote_media_to_new_store.py delete mode 100755 scripts/register_new_matrix_user delete mode 100755 scripts/synapse_port_db delete mode 100755 scripts/synapse_review_recent_signups delete mode 100755 scripts/sync_room_to_group.pl delete mode 100755 scripts/update_synapse_database create mode 100755 synapse/_scripts/export_signing_key.py create mode 100755 synapse/_scripts/generate_config.py create mode 100755 synapse/_scripts/generate_log_config.py create mode 100755 synapse/_scripts/generate_signing_key.py create mode 100755 synapse/_scripts/hash_password.py create mode 100755 synapse/_scripts/move_remote_media_to_new_store.py create mode 100755 synapse/_scripts/synapse_port_db.py create mode 100755 synapse/_scripts/update_synapse_database.py diff --git a/.ci/scripts/test_export_data_command.sh b/.ci/scripts/test_export_data_command.sh index ab96387a0a..224cae9216 100755 --- a/.ci/scripts/test_export_data_command.sh +++ b/.ci/scripts/test_export_data_command.sh @@ -21,7 +21,7 @@ python -m synapse.app.homeserver --generate-keys -c .ci/sqlite-config.yaml echo "--- Prepare test database" # Make sure the SQLite3 database is using the latest schema and has no pending background update. -scripts/update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates +update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates # Run the export-data command on the sqlite test database python -m synapse.app.admin_cmd -c .ci/sqlite-config.yaml export-data @anon-20191002_181700-832:localhost:8800 \ @@ -41,7 +41,7 @@ fi # Port the SQLite databse to postgres so we can check command works against postgres echo "+++ Port SQLite3 databse to postgres" -scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml +synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml # Run the export-data command on postgres database python -m synapse.app.admin_cmd -c .ci/postgres-config.yaml export-data @anon-20191002_181700-832:localhost:8800 \ diff --git a/.ci/scripts/test_synapse_port_db.sh b/.ci/scripts/test_synapse_port_db.sh index 797904e64c..91bd966f32 100755 --- a/.ci/scripts/test_synapse_port_db.sh +++ b/.ci/scripts/test_synapse_port_db.sh @@ -25,17 +25,19 @@ python -m synapse.app.homeserver --generate-keys -c .ci/sqlite-config.yaml echo "--- Prepare test database" # Make sure the SQLite3 database is using the latest schema and has no pending background update. -scripts/update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates +update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates # Create the PostgreSQL database. .ci/scripts/postgres_exec.py "CREATE DATABASE synapse" echo "+++ Run synapse_port_db against test database" -coverage run scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml +# TODO: this invocation of synapse_port_db (and others below) used to be prepended with `coverage run`, +# but coverage seems unable to find the entrypoints installed by `pip install -e .`. +synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml # We should be able to run twice against the same database. echo "+++ Run synapse_port_db a second time" -coverage run scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml +synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml ##### @@ -46,7 +48,7 @@ echo "--- Prepare empty SQLite database" # we do this by deleting the sqlite db, and then doing the same again. rm .ci/test_db.db -scripts/update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates +update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates # re-create the PostgreSQL database. .ci/scripts/postgres_exec.py \ @@ -54,4 +56,4 @@ scripts/update_synapse_database --database-config .ci/sqlite-config.yaml --run-b "CREATE DATABASE synapse" echo "+++ Run synapse_port_db against empty database" -coverage run scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml +synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml diff --git a/.dockerignore b/.dockerignore index f6c638b0a2..617f701597 100644 --- a/.dockerignore +++ b/.dockerignore @@ -3,7 +3,6 @@ # things to include !docker -!scripts !synapse !MANIFEST.in !README.rst diff --git a/MANIFEST.in b/MANIFEST.in index 76d14eb642..7e903518e1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -17,7 +17,6 @@ recursive-include synapse/storage *.txt recursive-include synapse/storage *.md recursive-include docs * -recursive-include scripts * recursive-include scripts-dev * recursive-include synapse *.pyi recursive-include tests *.py diff --git a/changelog.d/12118.misc b/changelog.d/12118.misc new file mode 100644 index 0000000000..a2c397d907 --- /dev/null +++ b/changelog.d/12118.misc @@ -0,0 +1 @@ +Move scripts to Synapse package and expose as setuptools entry points. diff --git a/docker/Dockerfile b/docker/Dockerfile index a8bb9b0e7f..327275a9ca 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -46,7 +46,6 @@ RUN \ && rm -rf /var/lib/apt/lists/* # Copy just what we need to pip install -COPY scripts /synapse/scripts/ COPY MANIFEST.in README.rst setup.py synctl /synapse/ COPY synapse/__init__.py /synapse/synapse/__init__.py COPY synapse/python_dependencies.py /synapse/synapse/python_dependencies.py diff --git a/docs/development/database_schema.md b/docs/development/database_schema.md index a767d3af9f..d996a7caa2 100644 --- a/docs/development/database_schema.md +++ b/docs/development/database_schema.md @@ -158,9 +158,9 @@ same as integers. There are three separate aspects to this: * Any new boolean column must be added to the `BOOLEAN_COLUMNS` list in - `scripts/synapse_port_db`. This tells the port script to cast the integer - value from SQLite to a boolean before writing the value to the postgres - database. + `synapse/_scripts/synapse_port_db.py`. This tells the port script to cast + the integer value from SQLite to a boolean before writing the value to the + postgres database. * Before SQLite 3.23, `TRUE` and `FALSE` were not recognised as constants by SQLite, and the `IS [NOT] TRUE`/`IS [NOT] FALSE` operators were not diff --git a/docs/usage/administration/admin_api/README.md b/docs/usage/administration/admin_api/README.md index 2fca96f8be..3cbedc5dfa 100644 --- a/docs/usage/administration/admin_api/README.md +++ b/docs/usage/administration/admin_api/README.md @@ -12,7 +12,7 @@ UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'; ``` A new server admin user can also be created using the `register_new_matrix_user` -command. This is a script that is located in the `scripts/` directory, or possibly +command. This is a script that is distributed as part of synapse. It is possibly already on your `$PATH` depending on how Synapse was installed. Finding your user's `access_token` is client-dependent, but will usually be shown in the client's settings. diff --git a/mypy.ini b/mypy.ini index 38ff787609..6b1e995e64 100644 --- a/mypy.ini +++ b/mypy.ini @@ -23,6 +23,10 @@ files = # https://docs.python.org/3/library/re.html#re.X exclude = (?x) ^( + |synapse/_scripts/export_signing_key.py + |synapse/_scripts/move_remote_media_to_new_store.py + |synapse/_scripts/synapse_port_db.py + |synapse/_scripts/update_synapse_database.py |synapse/storage/databases/__init__.py |synapse/storage/databases/main/__init__.py |synapse/storage/databases/main/cache.py diff --git a/scripts-dev/generate_sample_config b/scripts-dev/generate_sample_config index 4cd1d1d5b8..185e277933 100755 --- a/scripts-dev/generate_sample_config +++ b/scripts-dev/generate_sample_config @@ -10,19 +10,19 @@ SAMPLE_CONFIG="docs/sample_config.yaml" SAMPLE_LOG_CONFIG="docs/sample_log_config.yaml" check() { - diff -u "$SAMPLE_LOG_CONFIG" <(./scripts/generate_log_config) >/dev/null || return 1 + diff -u "$SAMPLE_LOG_CONFIG" <(synapse/_scripts/generate_log_config.py) >/dev/null || return 1 } if [ "$1" == "--check" ]; then - diff -u "$SAMPLE_CONFIG" <(./scripts/generate_config --header-file docs/.sample_config_header.yaml) >/dev/null || { + diff -u "$SAMPLE_CONFIG" <(synapse/_scripts/generate_config.py --header-file docs/.sample_config_header.yaml) >/dev/null || { echo -e "\e[1m\e[31m$SAMPLE_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config\`.\e[0m" >&2 exit 1 } - diff -u "$SAMPLE_LOG_CONFIG" <(./scripts/generate_log_config) >/dev/null || { + diff -u "$SAMPLE_LOG_CONFIG" <(synapse/_scripts/generate_log_config.py) >/dev/null || { echo -e "\e[1m\e[31m$SAMPLE_LOG_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config\`.\e[0m" >&2 exit 1 } else - ./scripts/generate_config --header-file docs/.sample_config_header.yaml -o "$SAMPLE_CONFIG" - ./scripts/generate_log_config -o "$SAMPLE_LOG_CONFIG" + synapse/_scripts/generate_config.py --header-file docs/.sample_config_header.yaml -o "$SAMPLE_CONFIG" + synapse/_scripts/generate_log_config.py -o "$SAMPLE_LOG_CONFIG" fi diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index b6554a73c1..df4d4934d0 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -84,13 +84,6 @@ else files=( "synapse" "docker" "tests" # annoyingly, black doesn't find these so we have to list them - "scripts/export_signing_key" - "scripts/generate_config" - "scripts/generate_log_config" - "scripts/hash_password" - "scripts/register_new_matrix_user" - "scripts/synapse_port_db" - "scripts/update_synapse_database" "scripts-dev" "scripts-dev/build_debian_packages" "scripts-dev/sign_json" diff --git a/scripts-dev/make_full_schema.sh b/scripts-dev/make_full_schema.sh index c3c90f4ec6..f0e22d4ca2 100755 --- a/scripts-dev/make_full_schema.sh +++ b/scripts-dev/make_full_schema.sh @@ -147,7 +147,7 @@ python -m synapse.app.homeserver --generate-keys -c "$SQLITE_CONFIG" # Make sure the SQLite3 database is using the latest schema and has no pending background update. echo "Running db background jobs..." -scripts/update_synapse_database --database-config --run-background-updates "$SQLITE_CONFIG" +synapse/_scripts/update_synapse_database.py --database-config --run-background-updates "$SQLITE_CONFIG" # Create the PostgreSQL database. echo "Creating postgres database..." @@ -156,10 +156,10 @@ createdb --lc-collate=C --lc-ctype=C --template=template0 "$POSTGRES_DB_NAME" echo "Copying data from SQLite3 to Postgres with synapse_port_db..." if [ -z "$COVERAGE" ]; then # No coverage needed - scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG" + synapse/_scripts/synapse_port_db.py --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG" else # Coverage desired - coverage run scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG" + coverage run synapse/_scripts/synapse_port_db.py --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG" fi # Delete schema_version, applied_schema_deltas and applied_module_schemas tables diff --git a/scripts/export_signing_key b/scripts/export_signing_key deleted file mode 100755 index bf0139bd64..0000000000 --- a/scripts/export_signing_key +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/env python -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import sys -import time -from typing import Optional - -import nacl.signing -from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys - - -def exit(status: int = 0, message: Optional[str] = None): - if message: - print(message, file=sys.stderr) - sys.exit(status) - - -def format_plain(public_key: nacl.signing.VerifyKey): - print( - "%s:%s %s" - % ( - public_key.alg, - public_key.version, - encode_verify_key_base64(public_key), - ) - ) - - -def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int): - print( - ' "%s:%s": { key: "%s", expired_ts: %i }' - % ( - public_key.alg, - public_key.version, - encode_verify_key_base64(public_key), - expiry_ts, - ) - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - parser.add_argument( - "key_file", - nargs="+", - type=argparse.FileType("r"), - help="The key file to read", - ) - - parser.add_argument( - "-x", - action="store_true", - dest="for_config", - help="format the output for inclusion in the old_signing_keys config setting", - ) - - parser.add_argument( - "--expiry-ts", - type=int, - default=int(time.time() * 1000) + 6 * 3600000, - help=( - "The expiry time to use for -x, in milliseconds since 1970. The default " - "is (now+6h)." - ), - ) - - args = parser.parse_args() - - formatter = ( - (lambda k: format_for_config(k, args.expiry_ts)) - if args.for_config - else format_plain - ) - - keys = [] - for file in args.key_file: - try: - res = read_signing_keys(file) - except Exception as e: - exit( - status=1, - message="Error reading key from file %s: %s %s" - % (file.name, type(e), e), - ) - res = [] - for key in res: - formatter(get_verify_key(key)) diff --git a/scripts/generate_config b/scripts/generate_config deleted file mode 100755 index 931b40c045..0000000000 --- a/scripts/generate_config +++ /dev/null @@ -1,78 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import shutil -import sys - -from synapse.config.homeserver import HomeServerConfig - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--config-dir", - default="CONFDIR", - help="The path where the config files are kept. Used to create filenames for " - "things like the log config and the signing key. Default: %(default)s", - ) - - parser.add_argument( - "--data-dir", - default="DATADIR", - help="The path where the data files are kept. Used to create filenames for " - "things like the database and media store. Default: %(default)s", - ) - - parser.add_argument( - "--server-name", - default="SERVERNAME", - help="The server name. Used to initialise the server_name config param, but also " - "used in the names of some of the config files. Default: %(default)s", - ) - - parser.add_argument( - "--report-stats", - action="store", - help="Whether the generated config reports anonymized usage statistics", - choices=["yes", "no"], - ) - - parser.add_argument( - "--generate-secrets", - action="store_true", - help="Enable generation of new secrets for things like the macaroon_secret_key." - "By default, these parameters will be left unset.", - ) - - parser.add_argument( - "-o", - "--output-file", - type=argparse.FileType("w"), - default=sys.stdout, - help="File to write the configuration to. Default: stdout", - ) - - parser.add_argument( - "--header-file", - type=argparse.FileType("r"), - help="File from which to read a header, which will be printed before the " - "generated config.", - ) - - args = parser.parse_args() - - report_stats = args.report_stats - if report_stats is not None: - report_stats = report_stats == "yes" - - conf = HomeServerConfig().generate_config( - config_dir_path=args.config_dir, - data_dir_path=args.data_dir, - server_name=args.server_name, - generate_secrets=args.generate_secrets, - report_stats=report_stats, - ) - - if args.header_file: - shutil.copyfileobj(args.header_file, args.output_file) - - args.output_file.write(conf) diff --git a/scripts/generate_log_config b/scripts/generate_log_config deleted file mode 100755 index e72a0dafb7..0000000000 --- a/scripts/generate_log_config +++ /dev/null @@ -1,44 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2020 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import sys - -from synapse.config.logger import DEFAULT_LOG_CONFIG - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - parser.add_argument( - "-o", - "--output-file", - type=argparse.FileType("w"), - default=sys.stdout, - help="File to write the configuration to. Default: stdout", - ) - - parser.add_argument( - "-f", - "--log-file", - type=str, - default="/var/log/matrix-synapse/homeserver.log", - help="name of the log file", - ) - - args = parser.parse_args() - out = args.output_file - out.write(DEFAULT_LOG_CONFIG.substitute(log_file=args.log_file)) - out.flush() diff --git a/scripts/generate_signing_key.py b/scripts/generate_signing_key.py deleted file mode 100755 index 07df25a809..0000000000 --- a/scripts/generate_signing_key.py +++ /dev/null @@ -1,36 +0,0 @@ -#!/usr/bin/env python -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import sys - -from signedjson.key import generate_signing_key, write_signing_keys - -from synapse.util.stringutils import random_string - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - parser.add_argument( - "-o", - "--output_file", - type=argparse.FileType("w"), - default=sys.stdout, - help="Where to write the output to", - ) - args = parser.parse_args() - - key_id = "a_" + random_string(4) - key = (generate_signing_key(key_id),) - write_signing_keys(args.output_file, key) diff --git a/scripts/hash_password b/scripts/hash_password deleted file mode 100755 index 1d6fb0d700..0000000000 --- a/scripts/hash_password +++ /dev/null @@ -1,79 +0,0 @@ -#!/usr/bin/env python - -import argparse -import getpass -import sys -import unicodedata - -import bcrypt -import yaml - -bcrypt_rounds = 12 -password_pepper = "" - - -def prompt_for_pass(): - password = getpass.getpass("Password: ") - - if not password: - raise Exception("Password cannot be blank.") - - confirm_password = getpass.getpass("Confirm password: ") - - if password != confirm_password: - raise Exception("Passwords do not match.") - - return password - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description=( - "Calculate the hash of a new password, so that passwords can be reset" - ) - ) - parser.add_argument( - "-p", - "--password", - default=None, - help="New password for user. Will prompt if omitted.", - ) - parser.add_argument( - "-c", - "--config", - type=argparse.FileType("r"), - help=( - "Path to server config file. " - "Used to read in bcrypt_rounds and password_pepper." - ), - ) - - args = parser.parse_args() - if "config" in args and args.config: - config = yaml.safe_load(args.config) - bcrypt_rounds = config.get("bcrypt_rounds", bcrypt_rounds) - password_config = config.get("password_config", None) or {} - password_pepper = password_config.get("pepper", password_pepper) - password = args.password - - if not password: - password = prompt_for_pass() - - # On Python 2, make sure we decode it to Unicode before we normalise it - if isinstance(password, bytes): - try: - password = password.decode(sys.stdin.encoding) - except UnicodeDecodeError: - print( - "ERROR! Your password is not decodable using your terminal encoding (%s)." - % (sys.stdin.encoding,) - ) - - pw = unicodedata.normalize("NFKC", password) - - hashed = bcrypt.hashpw( - pw.encode("utf8") + password_pepper.encode("utf8"), - bcrypt.gensalt(bcrypt_rounds), - ).decode("ascii") - - print(hashed) diff --git a/scripts/move_remote_media_to_new_store.py b/scripts/move_remote_media_to_new_store.py deleted file mode 100755 index 875aa4781f..0000000000 --- a/scripts/move_remote_media_to_new_store.py +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env python -# Copyright 2017 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Moves a list of remote media from one media store to another. - -The input should be a list of media files to be moved, one per line. Each line -should be formatted:: - - | - -This can be extracted from postgres with:: - - psql --tuples-only -A -c "select media_origin, filesystem_id from - matrix.remote_media_cache where ..." - -To use, pipe the above into:: - - PYTHON_PATH=. ./scripts/move_remote_media_to_new_store.py -""" - -import argparse -import logging -import os -import shutil -import sys - -from synapse.rest.media.v1.filepath import MediaFilePaths - -logger = logging.getLogger() - - -def main(src_repo, dest_repo): - src_paths = MediaFilePaths(src_repo) - dest_paths = MediaFilePaths(dest_repo) - for line in sys.stdin: - line = line.strip() - parts = line.split("|") - if len(parts) != 2: - print("Unable to parse input line %s" % line, file=sys.stderr) - sys.exit(1) - - move_media(parts[0], parts[1], src_paths, dest_paths) - - -def move_media(origin_server, file_id, src_paths, dest_paths): - """Move the given file, and any thumbnails, to the dest repo - - Args: - origin_server (str): - file_id (str): - src_paths (MediaFilePaths): - dest_paths (MediaFilePaths): - """ - logger.info("%s/%s", origin_server, file_id) - - # check that the original exists - original_file = src_paths.remote_media_filepath(origin_server, file_id) - if not os.path.exists(original_file): - logger.warning( - "Original for %s/%s (%s) does not exist", - origin_server, - file_id, - original_file, - ) - else: - mkdir_and_move( - original_file, dest_paths.remote_media_filepath(origin_server, file_id) - ) - - # now look for thumbnails - original_thumb_dir = src_paths.remote_media_thumbnail_dir(origin_server, file_id) - if not os.path.exists(original_thumb_dir): - return - - mkdir_and_move( - original_thumb_dir, - dest_paths.remote_media_thumbnail_dir(origin_server, file_id), - ) - - -def mkdir_and_move(original_file, dest_file): - dirname = os.path.dirname(dest_file) - if not os.path.exists(dirname): - logger.debug("mkdir %s", dirname) - os.makedirs(dirname) - logger.debug("mv %s %s", original_file, dest_file) - shutil.move(original_file, dest_file) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter - ) - parser.add_argument("-v", action="store_true", help="enable debug logging") - parser.add_argument("src_repo", help="Path to source content repo") - parser.add_argument("dest_repo", help="Path to source content repo") - args = parser.parse_args() - - logging_config = { - "level": logging.DEBUG if args.v else logging.INFO, - "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", - } - logging.basicConfig(**logging_config) - - main(args.src_repo, args.dest_repo) diff --git a/scripts/register_new_matrix_user b/scripts/register_new_matrix_user deleted file mode 100755 index 00104b9d62..0000000000 --- a/scripts/register_new_matrix_user +++ /dev/null @@ -1,19 +0,0 @@ -#!/usr/bin/env python -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse._scripts.register_new_matrix_user import main - -if __name__ == "__main__": - main() diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db deleted file mode 100755 index db354b3c8c..0000000000 --- a/scripts/synapse_port_db +++ /dev/null @@ -1,1253 +0,0 @@ -#!/usr/bin/env python -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import curses -import logging -import sys -import time -import traceback -from typing import Dict, Iterable, Optional, Set - -import yaml -from matrix_common.versionstring import get_distribution_version_string - -from twisted.internet import defer, reactor - -from synapse.config.database import DatabaseConnectionConfig -from synapse.config.homeserver import HomeServerConfig -from synapse.logging.context import ( - LoggingContext, - make_deferred_yieldable, - run_in_background, -) -from synapse.storage.database import DatabasePool, make_conn -from synapse.storage.databases.main import PushRuleStore -from synapse.storage.databases.main.account_data import AccountDataWorkerStore -from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore -from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore -from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore -from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyBackgroundStore -from synapse.storage.databases.main.events_bg_updates import ( - EventsBackgroundUpdatesStore, -) -from synapse.storage.databases.main.group_server import GroupServerWorkerStore -from synapse.storage.databases.main.media_repository import ( - MediaRepositoryBackgroundUpdateStore, -) -from synapse.storage.databases.main.presence import PresenceBackgroundUpdateStore -from synapse.storage.databases.main.pusher import PusherWorkerStore -from synapse.storage.databases.main.registration import ( - RegistrationBackgroundUpdateStore, - find_max_generated_user_id_localpart, -) -from synapse.storage.databases.main.room import RoomBackgroundUpdateStore -from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore -from synapse.storage.databases.main.search import SearchBackgroundUpdateStore -from synapse.storage.databases.main.state import MainStateBackgroundUpdateStore -from synapse.storage.databases.main.stats import StatsStore -from synapse.storage.databases.main.user_directory import ( - UserDirectoryBackgroundUpdateStore, -) -from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore -from synapse.storage.engines import create_engine -from synapse.storage.prepare_database import prepare_database -from synapse.util import Clock - -logger = logging.getLogger("synapse_port_db") - - -BOOLEAN_COLUMNS = { - "events": ["processed", "outlier", "contains_url"], - "rooms": ["is_public", "has_auth_chain_index"], - "event_edges": ["is_state"], - "presence_list": ["accepted"], - "presence_stream": ["currently_active"], - "public_room_list_stream": ["visibility"], - "devices": ["hidden"], - "device_lists_outbound_pokes": ["sent"], - "users_who_share_rooms": ["share_private"], - "groups": ["is_public"], - "group_rooms": ["is_public"], - "group_users": ["is_public", "is_admin"], - "group_summary_rooms": ["is_public"], - "group_room_categories": ["is_public"], - "group_summary_users": ["is_public"], - "group_roles": ["is_public"], - "local_group_membership": ["is_publicised", "is_admin"], - "e2e_room_keys": ["is_verified"], - "account_validity": ["email_sent"], - "redactions": ["have_censored"], - "room_stats_state": ["is_federatable"], - "local_media_repository": ["safe_from_quarantine"], - "users": ["shadow_banned"], - "e2e_fallback_keys_json": ["used"], - "access_tokens": ["used"], -} - - -APPEND_ONLY_TABLES = [ - "event_reference_hashes", - "events", - "event_json", - "state_events", - "room_memberships", - "topics", - "room_names", - "rooms", - "local_media_repository", - "local_media_repository_thumbnails", - "remote_media_cache", - "remote_media_cache_thumbnails", - "redactions", - "event_edges", - "event_auth", - "received_transactions", - "sent_transactions", - "transaction_id_to_pdu", - "users", - "state_groups", - "state_groups_state", - "event_to_state_groups", - "rejections", - "event_search", - "presence_stream", - "push_rules_stream", - "ex_outlier_stream", - "cache_invalidation_stream_by_instance", - "public_room_list_stream", - "state_group_edges", - "stream_ordering_to_exterm", -] - - -IGNORED_TABLES = { - # We don't port these tables, as they're a faff and we can regenerate - # them anyway. - "user_directory", - "user_directory_search", - "user_directory_search_content", - "user_directory_search_docsize", - "user_directory_search_segdir", - "user_directory_search_segments", - "user_directory_search_stat", - "user_directory_search_pos", - "users_who_share_private_rooms", - "users_in_public_room", - # UI auth sessions have foreign keys so additional care needs to be taken, - # the sessions are transient anyway, so ignore them. - "ui_auth_sessions", - "ui_auth_sessions_credentials", - "ui_auth_sessions_ips", -} - - -# Error returned by the run function. Used at the top-level part of the script to -# handle errors and return codes. -end_error = None # type: Optional[str] -# The exec_info for the error, if any. If error is defined but not exec_info the script -# will show only the error message without the stacktrace, if exec_info is defined but -# not the error then the script will show nothing outside of what's printed in the run -# function. If both are defined, the script will print both the error and the stacktrace. -end_error_exec_info = None - - -class Store( - ClientIpBackgroundUpdateStore, - DeviceInboxBackgroundUpdateStore, - DeviceBackgroundUpdateStore, - EventsBackgroundUpdatesStore, - MediaRepositoryBackgroundUpdateStore, - RegistrationBackgroundUpdateStore, - RoomBackgroundUpdateStore, - RoomMemberBackgroundUpdateStore, - SearchBackgroundUpdateStore, - StateBackgroundUpdateStore, - MainStateBackgroundUpdateStore, - UserDirectoryBackgroundUpdateStore, - EndToEndKeyBackgroundStore, - StatsStore, - AccountDataWorkerStore, - PushRuleStore, - PusherWorkerStore, - PresenceBackgroundUpdateStore, - GroupServerWorkerStore, -): - def execute(self, f, *args, **kwargs): - return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) - - def execute_sql(self, sql, *args): - def r(txn): - txn.execute(sql, args) - return txn.fetchall() - - return self.db_pool.runInteraction("execute_sql", r) - - def insert_many_txn(self, txn, table, headers, rows): - sql = "INSERT INTO %s (%s) VALUES (%s)" % ( - table, - ", ".join(k for k in headers), - ", ".join("%s" for _ in headers), - ) - - try: - txn.executemany(sql, rows) - except Exception: - logger.exception("Failed to insert: %s", table) - raise - - def set_room_is_public(self, room_id, is_public): - raise Exception( - "Attempt to set room_is_public during port_db: database not empty?" - ) - - -class MockHomeserver: - def __init__(self, config): - self.clock = Clock(reactor) - self.config = config - self.hostname = config.server.server_name - self.version_string = "Synapse/" + get_distribution_version_string( - "matrix-synapse" - ) - - def get_clock(self): - return self.clock - - def get_reactor(self): - return reactor - - def get_instance_name(self): - return "master" - - -class Porter(object): - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - - async def setup_table(self, table): - if table in APPEND_ONLY_TABLES: - # It's safe to just carry on inserting. - row = await self.postgres_store.db_pool.simple_select_one( - table="port_from_sqlite3", - keyvalues={"table_name": table}, - retcols=("forward_rowid", "backward_rowid"), - allow_none=True, - ) - - total_to_port = None - if row is None: - if table == "sent_transactions": - ( - forward_chunk, - already_ported, - total_to_port, - ) = await self._setup_sent_transactions() - backward_chunk = 0 - else: - await self.postgres_store.db_pool.simple_insert( - table="port_from_sqlite3", - values={ - "table_name": table, - "forward_rowid": 1, - "backward_rowid": 0, - }, - ) - - forward_chunk = 1 - backward_chunk = 0 - already_ported = 0 - else: - forward_chunk = row["forward_rowid"] - backward_chunk = row["backward_rowid"] - - if total_to_port is None: - already_ported, total_to_port = await self._get_total_count_to_port( - table, forward_chunk, backward_chunk - ) - else: - - def delete_all(txn): - txn.execute( - "DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,) - ) - txn.execute("TRUNCATE %s CASCADE" % (table,)) - - await self.postgres_store.execute(delete_all) - - await self.postgres_store.db_pool.simple_insert( - table="port_from_sqlite3", - values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, - ) - - forward_chunk = 1 - backward_chunk = 0 - - already_ported, total_to_port = await self._get_total_count_to_port( - table, forward_chunk, backward_chunk - ) - - return table, already_ported, total_to_port, forward_chunk, backward_chunk - - async def get_table_constraints(self) -> Dict[str, Set[str]]: - """Returns a map of tables that have foreign key constraints to tables they depend on.""" - - def _get_constraints(txn): - # We can pull the information about foreign key constraints out from - # the postgres schema tables. - sql = """ - SELECT DISTINCT - tc.table_name, - ccu.table_name AS foreign_table_name - FROM - information_schema.table_constraints AS tc - INNER JOIN information_schema.constraint_column_usage AS ccu - USING (table_schema, constraint_name) - WHERE tc.constraint_type = 'FOREIGN KEY' - AND tc.table_name != ccu.table_name; - """ - txn.execute(sql) - - results = {} - for table, foreign_table in txn: - results.setdefault(table, set()).add(foreign_table) - return results - - return await self.postgres_store.db_pool.runInteraction( - "get_table_constraints", _get_constraints - ) - - async def handle_table( - self, table, postgres_size, table_size, forward_chunk, backward_chunk - ): - logger.info( - "Table %s: %i/%i (rows %i-%i) already ported", - table, - postgres_size, - table_size, - backward_chunk + 1, - forward_chunk - 1, - ) - - if not table_size: - return - - self.progress.add_table(table, postgres_size, table_size) - - if table == "event_search": - await self.handle_search_table( - postgres_size, table_size, forward_chunk, backward_chunk - ) - return - - if table in IGNORED_TABLES: - self.progress.update(table, table_size) # Mark table as done - return - - if table == "user_directory_stream_pos": - # We need to make sure there is a single row, `(X, null), as that is - # what synapse expects to be there. - await self.postgres_store.db_pool.simple_insert( - table=table, values={"stream_id": None} - ) - self.progress.update(table, table_size) # Mark table as done - return - - forward_select = ( - "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" % (table,) - ) - - backward_select = ( - "SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?" % (table,) - ) - - do_forward = [True] - do_backward = [True] - - while True: - - def r(txn): - forward_rows = [] - backward_rows = [] - if do_forward[0]: - txn.execute(forward_select, (forward_chunk, self.batch_size)) - forward_rows = txn.fetchall() - if not forward_rows: - do_forward[0] = False - - if do_backward[0]: - txn.execute(backward_select, (backward_chunk, self.batch_size)) - backward_rows = txn.fetchall() - if not backward_rows: - do_backward[0] = False - - if forward_rows or backward_rows: - headers = [column[0] for column in txn.description] - else: - headers = None - - return headers, forward_rows, backward_rows - - headers, frows, brows = await self.sqlite_store.db_pool.runInteraction( - "select", r - ) - - if frows or brows: - if frows: - forward_chunk = max(row[0] for row in frows) + 1 - if brows: - backward_chunk = min(row[0] for row in brows) - 1 - - rows = frows + brows - rows = self._convert_rows(table, headers, rows) - - def insert(txn): - self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) - - self.postgres_store.db_pool.simple_update_one_txn( - txn, - table="port_from_sqlite3", - keyvalues={"table_name": table}, - updatevalues={ - "forward_rowid": forward_chunk, - "backward_rowid": backward_chunk, - }, - ) - - await self.postgres_store.execute(insert) - - postgres_size += len(rows) - - self.progress.update(table, postgres_size) - else: - return - - async def handle_search_table( - self, postgres_size, table_size, forward_chunk, backward_chunk - ): - select = ( - "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering" - " FROM event_search as es" - " INNER JOIN events AS e USING (event_id, room_id)" - " WHERE es.rowid >= ?" - " ORDER BY es.rowid LIMIT ?" - ) - - while True: - - def r(txn): - txn.execute(select, (forward_chunk, self.batch_size)) - rows = txn.fetchall() - headers = [column[0] for column in txn.description] - - return headers, rows - - headers, rows = await self.sqlite_store.db_pool.runInteraction("select", r) - - if rows: - forward_chunk = rows[-1][0] + 1 - - # We have to treat event_search differently since it has a - # different structure in the two different databases. - def insert(txn): - sql = ( - "INSERT INTO event_search (event_id, room_id, key," - " sender, vector, origin_server_ts, stream_ordering)" - " VALUES (?,?,?,?,to_tsvector('english', ?),?,?)" - ) - - rows_dict = [] - for row in rows: - d = dict(zip(headers, row)) - if "\0" in d["value"]: - logger.warning("dropping search row %s", d) - else: - rows_dict.append(d) - - txn.executemany( - sql, - [ - ( - row["event_id"], - row["room_id"], - row["key"], - row["sender"], - row["value"], - row["origin_server_ts"], - row["stream_ordering"], - ) - for row in rows_dict - ], - ) - - self.postgres_store.db_pool.simple_update_one_txn( - txn, - table="port_from_sqlite3", - keyvalues={"table_name": "event_search"}, - updatevalues={ - "forward_rowid": forward_chunk, - "backward_rowid": backward_chunk, - }, - ) - - await self.postgres_store.execute(insert) - - postgres_size += len(rows) - - self.progress.update("event_search", postgres_size) - - else: - return - - def build_db_store( - self, - db_config: DatabaseConnectionConfig, - allow_outdated_version: bool = False, - ): - """Builds and returns a database store using the provided configuration. - - Args: - db_config: The database configuration - allow_outdated_version: True to suppress errors about the database server - version being too old to run a complete synapse - - Returns: - The built Store object. - """ - self.progress.set_state("Preparing %s" % db_config.config["name"]) - - engine = create_engine(db_config.config) - - hs = MockHomeserver(self.hs_config) - - with make_conn(db_config, engine, "portdb") as db_conn: - engine.check_database( - db_conn, allow_outdated_version=allow_outdated_version - ) - prepare_database(db_conn, engine, config=self.hs_config) - store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) - db_conn.commit() - - return store - - async def run_background_updates_on_postgres(self): - # Manually apply all background updates on the PostgreSQL database. - postgres_ready = ( - await self.postgres_store.db_pool.updates.has_completed_background_updates() - ) - - if not postgres_ready: - # Only say that we're running background updates when there are background - # updates to run. - self.progress.set_state("Running background updates on PostgreSQL") - - while not postgres_ready: - await self.postgres_store.db_pool.updates.do_next_background_update(100) - postgres_ready = await ( - self.postgres_store.db_pool.updates.has_completed_background_updates() - ) - - async def run(self): - """Ports the SQLite database to a PostgreSQL database. - - When a fatal error is met, its message is assigned to the global "end_error" - variable. When this error comes with a stacktrace, its exec_info is assigned to - the global "end_error_exec_info" variable. - """ - global end_error - - try: - # we allow people to port away from outdated versions of sqlite. - self.sqlite_store = self.build_db_store( - DatabaseConnectionConfig("master-sqlite", self.sqlite_config), - allow_outdated_version=True, - ) - - # Check if all background updates are done, abort if not. - updates_complete = ( - await self.sqlite_store.db_pool.updates.has_completed_background_updates() - ) - if not updates_complete: - end_error = ( - "Pending background updates exist in the SQLite3 database." - " Please start Synapse again and wait until every update has finished" - " before running this script.\n" - ) - return - - self.postgres_store = self.build_db_store( - self.hs_config.database.get_single_database() - ) - - await self.run_background_updates_on_postgres() - - self.progress.set_state("Creating port tables") - - def create_port_table(txn): - txn.execute( - "CREATE TABLE IF NOT EXISTS port_from_sqlite3 (" - " table_name varchar(100) NOT NULL UNIQUE," - " forward_rowid bigint NOT NULL," - " backward_rowid bigint NOT NULL" - ")" - ) - - # The old port script created a table with just a "rowid" column. - # We want people to be able to rerun this script from an old port - # so that they can pick up any missing events that were not - # ported across. - def alter_table(txn): - txn.execute( - "ALTER TABLE IF EXISTS port_from_sqlite3" - " RENAME rowid TO forward_rowid" - ) - txn.execute( - "ALTER TABLE IF EXISTS port_from_sqlite3" - " ADD backward_rowid bigint NOT NULL DEFAULT 0" - ) - - try: - await self.postgres_store.db_pool.runInteraction( - "alter_table", alter_table - ) - except Exception: - # On Error Resume Next - pass - - await self.postgres_store.db_pool.runInteraction( - "create_port_table", create_port_table - ) - - # Step 2. Set up sequences - # - # We do this before porting the tables so that event if we fail half - # way through the postgres DB always have sequences that are greater - # than their respective tables. If we don't then creating the - # `DataStore` object will fail due to the inconsistency. - self.progress.set_state("Setting up sequence generators") - await self._setup_state_group_id_seq() - await self._setup_user_id_seq() - await self._setup_events_stream_seqs() - await self._setup_sequence( - "device_inbox_sequence", ("device_inbox", "device_federation_outbox") - ) - await self._setup_sequence( - "account_data_sequence", - ("room_account_data", "room_tags_revisions", "account_data"), - ) - await self._setup_sequence("receipts_sequence", ("receipts_linearized",)) - await self._setup_sequence("presence_stream_sequence", ("presence_stream",)) - await self._setup_auth_chain_sequence() - - # Step 3. Get tables. - self.progress.set_state("Fetching tables") - sqlite_tables = await self.sqlite_store.db_pool.simple_select_onecol( - table="sqlite_master", keyvalues={"type": "table"}, retcol="name" - ) - - postgres_tables = await self.postgres_store.db_pool.simple_select_onecol( - table="information_schema.tables", - keyvalues={}, - retcol="distinct table_name", - ) - - tables = set(sqlite_tables) & set(postgres_tables) - logger.info("Found %d tables", len(tables)) - - # Step 4. Figure out what still needs copying - self.progress.set_state("Checking on port progress") - setup_res = await make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background(self.setup_table, table) - for table in tables - if table not in ["schema_version", "applied_schema_deltas"] - and not table.startswith("sqlite_") - ], - consumeErrors=True, - ) - ) - # Map from table name to args passed to `handle_table`, i.e. a tuple - # of: `postgres_size`, `table_size`, `forward_chunk`, `backward_chunk`. - tables_to_port_info_map = {r[0]: r[1:] for r in setup_res} - - # Step 5. Do the copying. - # - # This is slightly convoluted as we need to ensure tables are ported - # in the correct order due to foreign key constraints. - self.progress.set_state("Copying to postgres") - - constraints = await self.get_table_constraints() - tables_ported = set() # type: Set[str] - - while tables_to_port_info_map: - # Pulls out all tables that are still to be ported and which - # only depend on tables that are already ported (if any). - tables_to_port = [ - table - for table in tables_to_port_info_map - if not constraints.get(table, set()) - tables_ported - ] - - await make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self.handle_table, - table, - *tables_to_port_info_map.pop(table), - ) - for table in tables_to_port - ], - consumeErrors=True, - ) - ) - - tables_ported.update(tables_to_port) - - self.progress.done() - except Exception as e: - global end_error_exec_info - end_error = str(e) - end_error_exec_info = sys.exc_info() - logger.exception("") - finally: - reactor.stop() - - def _convert_rows(self, table, headers, rows): - bool_col_names = BOOLEAN_COLUMNS.get(table, []) - - bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names] - - class BadValueException(Exception): - pass - - def conv(j, col): - if j in bool_cols: - return bool(col) - if isinstance(col, bytes): - return bytearray(col) - elif isinstance(col, str) and "\0" in col: - logger.warning( - "DROPPING ROW: NUL value in table %s col %s: %r", - table, - headers[j], - col, - ) - raise BadValueException() - return col - - outrows = [] - for row in rows: - try: - outrows.append( - tuple(conv(j, col) for j, col in enumerate(row) if j > 0) - ) - except BadValueException: - pass - - return outrows - - async def _setup_sent_transactions(self): - # Only save things from the last day - yesterday = int(time.time() * 1000) - 86400000 - - # And save the max transaction id from each destination - select = ( - "SELECT rowid, * FROM sent_transactions WHERE rowid IN (" - "SELECT max(rowid) FROM sent_transactions" - " GROUP BY destination" - ")" - ) - - def r(txn): - txn.execute(select) - rows = txn.fetchall() - headers = [column[0] for column in txn.description] - - ts_ind = headers.index("ts") - - return headers, [r for r in rows if r[ts_ind] < yesterday] - - headers, rows = await self.sqlite_store.db_pool.runInteraction("select", r) - - rows = self._convert_rows("sent_transactions", headers, rows) - - inserted_rows = len(rows) - if inserted_rows: - max_inserted_rowid = max(r[0] for r in rows) - - def insert(txn): - self.postgres_store.insert_many_txn( - txn, "sent_transactions", headers[1:], rows - ) - - await self.postgres_store.execute(insert) - else: - max_inserted_rowid = 0 - - def get_start_id(txn): - txn.execute( - "SELECT rowid FROM sent_transactions WHERE ts >= ?" - " ORDER BY rowid ASC LIMIT 1", - (yesterday,), - ) - - rows = txn.fetchall() - if rows: - return rows[0][0] - else: - return 1 - - next_chunk = await self.sqlite_store.execute(get_start_id) - next_chunk = max(max_inserted_rowid + 1, next_chunk) - - await self.postgres_store.db_pool.simple_insert( - table="port_from_sqlite3", - values={ - "table_name": "sent_transactions", - "forward_rowid": next_chunk, - "backward_rowid": 0, - }, - ) - - def get_sent_table_size(txn): - txn.execute( - "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,) - ) - (size,) = txn.fetchone() - return int(size) - - remaining_count = await self.sqlite_store.execute(get_sent_table_size) - - total_count = remaining_count + inserted_rows - - return next_chunk, inserted_rows, total_count - - async def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk): - frows = await self.sqlite_store.execute_sql( - "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk - ) - - brows = await self.sqlite_store.execute_sql( - "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk - ) - - return frows[0][0] + brows[0][0] - - async def _get_already_ported_count(self, table): - rows = await self.postgres_store.execute_sql( - "SELECT count(*) FROM %s" % (table,) - ) - - return rows[0][0] - - async def _get_total_count_to_port(self, table, forward_chunk, backward_chunk): - remaining, done = await make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self._get_remaining_count_to_port, - table, - forward_chunk, - backward_chunk, - ), - run_in_background(self._get_already_ported_count, table), - ], - ) - ) - - remaining = int(remaining) if remaining else 0 - done = int(done) if done else 0 - - return done, remaining + done - - async def _setup_state_group_id_seq(self) -> None: - curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol( - table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True - ) - - if not curr_id: - return - - def r(txn): - next_id = curr_id + 1 - txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) - - await self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r) - - async def _setup_user_id_seq(self) -> None: - curr_id = await self.sqlite_store.db_pool.runInteraction( - "setup_user_id_seq", find_max_generated_user_id_localpart - ) - - def r(txn): - next_id = curr_id + 1 - txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,)) - - await self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r) - - async def _setup_events_stream_seqs(self) -> None: - """Set the event stream sequences to the correct values.""" - - # We get called before we've ported the events table, so we need to - # fetch the current positions from the SQLite store. - curr_forward_id = await self.sqlite_store.db_pool.simple_select_one_onecol( - table="events", keyvalues={}, retcol="MAX(stream_ordering)", allow_none=True - ) - - curr_backward_id = await self.sqlite_store.db_pool.simple_select_one_onecol( - table="events", - keyvalues={}, - retcol="MAX(-MIN(stream_ordering), 1)", - allow_none=True, - ) - - def _setup_events_stream_seqs_set_pos(txn): - if curr_forward_id: - txn.execute( - "ALTER SEQUENCE events_stream_seq RESTART WITH %s", - (curr_forward_id + 1,), - ) - - if curr_backward_id: - txn.execute( - "ALTER SEQUENCE events_backfill_stream_seq RESTART WITH %s", - (curr_backward_id + 1,), - ) - - await self.postgres_store.db_pool.runInteraction( - "_setup_events_stream_seqs", - _setup_events_stream_seqs_set_pos, - ) - - async def _setup_sequence( - self, sequence_name: str, stream_id_tables: Iterable[str] - ) -> None: - """Set a sequence to the correct value.""" - current_stream_ids = [] - for stream_id_table in stream_id_tables: - max_stream_id = await self.sqlite_store.db_pool.simple_select_one_onecol( - table=stream_id_table, - keyvalues={}, - retcol="COALESCE(MAX(stream_id), 1)", - allow_none=True, - ) - current_stream_ids.append(max_stream_id) - - next_id = max(current_stream_ids) + 1 - - def r(txn): - sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name,) - txn.execute(sql + " %s", (next_id,)) - - await self.postgres_store.db_pool.runInteraction( - "_setup_%s" % (sequence_name,), r - ) - - async def _setup_auth_chain_sequence(self) -> None: - curr_chain_id = await self.sqlite_store.db_pool.simple_select_one_onecol( - table="event_auth_chains", - keyvalues={}, - retcol="MAX(chain_id)", - allow_none=True, - ) - - def r(txn): - txn.execute( - "ALTER SEQUENCE event_auth_chain_id RESTART WITH %s", - (curr_chain_id + 1,), - ) - - if curr_chain_id is not None: - await self.postgres_store.db_pool.runInteraction( - "_setup_event_auth_chain_id", - r, - ) - - -############################################## -# The following is simply UI stuff -############################################## - - -class Progress(object): - """Used to report progress of the port""" - - def __init__(self): - self.tables = {} - - self.start_time = int(time.time()) - - def add_table(self, table, cur, size): - self.tables[table] = { - "start": cur, - "num_done": cur, - "total": size, - "perc": int(cur * 100 / size), - } - - def update(self, table, num_done): - data = self.tables[table] - data["num_done"] = num_done - data["perc"] = int(num_done * 100 / data["total"]) - - def done(self): - pass - - -class CursesProgress(Progress): - """Reports progress to a curses window""" - - def __init__(self, stdscr): - self.stdscr = stdscr - - curses.use_default_colors() - curses.curs_set(0) - - curses.init_pair(1, curses.COLOR_RED, -1) - curses.init_pair(2, curses.COLOR_GREEN, -1) - - self.last_update = 0 - - self.finished = False - - self.total_processed = 0 - self.total_remaining = 0 - - super(CursesProgress, self).__init__() - - def update(self, table, num_done): - super(CursesProgress, self).update(table, num_done) - - self.total_processed = 0 - self.total_remaining = 0 - for data in self.tables.values(): - self.total_processed += data["num_done"] - data["start"] - self.total_remaining += data["total"] - data["num_done"] - - self.render() - - def render(self, force=False): - now = time.time() - - if not force and now - self.last_update < 0.2: - # reactor.callLater(1, self.render) - return - - self.stdscr.clear() - - rows, cols = self.stdscr.getmaxyx() - - duration = int(now) - int(self.start_time) - - minutes, seconds = divmod(duration, 60) - duration_str = "%02dm %02ds" % (minutes, seconds) - - if self.finished: - status = "Time spent: %s (Done!)" % (duration_str,) - else: - - if self.total_processed > 0: - left = float(self.total_remaining) / self.total_processed - - est_remaining = (int(now) - self.start_time) * left - est_remaining_str = "%02dm %02ds remaining" % divmod(est_remaining, 60) - else: - est_remaining_str = "Unknown" - status = "Time spent: %s (est. remaining: %s)" % ( - duration_str, - est_remaining_str, - ) - - self.stdscr.addstr(0, 0, status, curses.A_BOLD) - - max_len = max(len(t) for t in self.tables.keys()) - - left_margin = 5 - middle_space = 1 - - items = self.tables.items() - items = sorted(items, key=lambda i: (i[1]["perc"], i[0])) - - for i, (table, data) in enumerate(items): - if i + 2 >= rows: - break - - perc = data["perc"] - - color = curses.color_pair(2) if perc == 100 else curses.color_pair(1) - - self.stdscr.addstr( - i + 2, left_margin + max_len - len(table), table, curses.A_BOLD | color - ) - - size = 20 - - progress = "[%s%s]" % ( - "#" * int(perc * size / 100), - " " * (size - int(perc * size / 100)), - ) - - self.stdscr.addstr( - i + 2, - left_margin + max_len + middle_space, - "%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]), - ) - - if self.finished: - self.stdscr.addstr(rows - 1, 0, "Press any key to exit...") - - self.stdscr.refresh() - self.last_update = time.time() - - def done(self): - self.finished = True - self.render(True) - self.stdscr.getch() - - def set_state(self, state): - self.stdscr.clear() - self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD) - self.stdscr.refresh() - - -class TerminalProgress(Progress): - """Just prints progress to the terminal""" - - def update(self, table, num_done): - super(TerminalProgress, self).update(table, num_done) - - data = self.tables[table] - - print( - "%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"]) - ) - - def set_state(self, state): - print(state + "...") - - -############################################## -############################################## - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="A script to port an existing synapse SQLite database to" - " a new PostgreSQL database." - ) - parser.add_argument("-v", action="store_true") - parser.add_argument( - "--sqlite-database", - required=True, - help="The snapshot of the SQLite database file. This must not be" - " currently used by a running synapse server", - ) - parser.add_argument( - "--postgres-config", - type=argparse.FileType("r"), - required=True, - help="The database config file for the PostgreSQL database", - ) - parser.add_argument( - "--curses", action="store_true", help="display a curses based progress UI" - ) - - parser.add_argument( - "--batch-size", - type=int, - default=1000, - help="The number of rows to select from the SQLite table each" - " iteration [default=1000]", - ) - - args = parser.parse_args() - - logging_config = { - "level": logging.DEBUG if args.v else logging.INFO, - "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", - } - - if args.curses: - logging_config["filename"] = "port-synapse.log" - - logging.basicConfig(**logging_config) - - sqlite_config = { - "name": "sqlite3", - "args": { - "database": args.sqlite_database, - "cp_min": 1, - "cp_max": 1, - "check_same_thread": False, - }, - } - - hs_config = yaml.safe_load(args.postgres_config) - - if "database" not in hs_config: - sys.stderr.write("The configuration file must have a 'database' section.\n") - sys.exit(4) - - postgres_config = hs_config["database"] - - if "name" not in postgres_config: - sys.stderr.write("Malformed database config: no 'name'\n") - sys.exit(2) - if postgres_config["name"] != "psycopg2": - sys.stderr.write("Database must use the 'psycopg2' connector.\n") - sys.exit(3) - - config = HomeServerConfig() - config.parse_config_dict(hs_config, "", "") - - def start(stdscr=None): - if stdscr: - progress = CursesProgress(stdscr) - else: - progress = TerminalProgress() - - porter = Porter( - sqlite_config=sqlite_config, - progress=progress, - batch_size=args.batch_size, - hs_config=config, - ) - - @defer.inlineCallbacks - def run(): - with LoggingContext("synapse_port_db_run"): - yield defer.ensureDeferred(porter.run()) - - reactor.callWhenRunning(run) - - reactor.run() - - if args.curses: - curses.wrapper(start) - else: - start() - - if end_error: - if end_error_exec_info: - exc_type, exc_value, exc_traceback = end_error_exec_info - traceback.print_exception(exc_type, exc_value, exc_traceback) - - sys.stderr.write(end_error) - - sys.exit(5) diff --git a/scripts/synapse_review_recent_signups b/scripts/synapse_review_recent_signups deleted file mode 100755 index a36d46e14c..0000000000 --- a/scripts/synapse_review_recent_signups +++ /dev/null @@ -1,19 +0,0 @@ -#!/usr/bin/env python -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse._scripts.review_recent_signups import main - -if __name__ == "__main__": - main() diff --git a/scripts/sync_room_to_group.pl b/scripts/sync_room_to_group.pl deleted file mode 100755 index f0c2dfadfa..0000000000 --- a/scripts/sync_room_to_group.pl +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env perl - -use strict; -use warnings; - -use JSON::XS; -use LWP::UserAgent; -use URI::Escape; - -if (@ARGV < 4) { - die "usage: $0 \n"; -} - -my ($hs, $access_token, $room_id, $group_id) = @ARGV; -my $ua = LWP::UserAgent->new(); -$ua->timeout(10); - -if ($room_id =~ /^#/) { - $room_id = uri_escape($room_id); - $room_id = decode_json($ua->get("${hs}/_matrix/client/r0/directory/room/${room_id}?access_token=${access_token}")->decoded_content)->{room_id}; -} - -my $room_users = [ keys %{decode_json($ua->get("${hs}/_matrix/client/r0/rooms/${room_id}/joined_members?access_token=${access_token}")->decoded_content)->{joined}} ]; -my $group_users = [ - (map { $_->{user_id} } @{decode_json($ua->get("${hs}/_matrix/client/unstable/groups/${group_id}/users?access_token=${access_token}" )->decoded_content)->{chunk}}), - (map { $_->{user_id} } @{decode_json($ua->get("${hs}/_matrix/client/unstable/groups/${group_id}/invited_users?access_token=${access_token}" )->decoded_content)->{chunk}}), -]; - -die "refusing to sync from empty room" unless (@$room_users); -die "refusing to sync to empty group" unless (@$group_users); - -my $diff = {}; -foreach my $user (@$room_users) { $diff->{$user}++ } -foreach my $user (@$group_users) { $diff->{$user}-- } - -foreach my $user (keys %$diff) { - if ($diff->{$user} == 1) { - warn "inviting $user"; - print STDERR $ua->put("${hs}/_matrix/client/unstable/groups/${group_id}/admin/users/invite/${user}?access_token=${access_token}", Content=>'{}')->status_line."\n"; - } - elsif ($diff->{$user} == -1) { - warn "removing $user"; - print STDERR $ua->put("${hs}/_matrix/client/unstable/groups/${group_id}/admin/users/remove/${user}?access_token=${access_token}", Content=>'{}')->status_line."\n"; - } -} diff --git a/scripts/update_synapse_database b/scripts/update_synapse_database deleted file mode 100755 index f43676afaa..0000000000 --- a/scripts/update_synapse_database +++ /dev/null @@ -1,117 +0,0 @@ -#!/usr/bin/env python -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -import sys - -import yaml -from matrix_common.versionstring import get_distribution_version_string - -from twisted.internet import defer, reactor - -from synapse.config.homeserver import HomeServerConfig -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.server import HomeServer -from synapse.storage import DataStore - -logger = logging.getLogger("update_database") - - -class MockHomeserver(HomeServer): - DATASTORE_CLASS = DataStore - - def __init__(self, config, **kwargs): - super(MockHomeserver, self).__init__( - config.server.server_name, reactor=reactor, config=config, **kwargs - ) - - self.version_string = "Synapse/" + get_distribution_version_string( - "matrix-synapse" - ) - - -def run_background_updates(hs): - store = hs.get_datastores().main - - async def run_background_updates(): - await store.db_pool.updates.run_background_updates(sleep=False) - # Stop the reactor to exit the script once every background update is run. - reactor.stop() - - def run(): - # Apply all background updates on the database. - defer.ensureDeferred( - run_as_background_process("background_updates", run_background_updates) - ) - - reactor.callWhenRunning(run) - - reactor.run() - - -def main(): - parser = argparse.ArgumentParser( - description=( - "Updates a synapse database to the latest schema and optionally runs background updates" - " on it." - ) - ) - parser.add_argument("-v", action="store_true") - parser.add_argument( - "--database-config", - type=argparse.FileType("r"), - required=True, - help="Synapse configuration file, giving the details of the database to be updated", - ) - parser.add_argument( - "--run-background-updates", - action="store_true", - required=False, - help="run background updates after upgrading the database schema", - ) - - args = parser.parse_args() - - logging_config = { - "level": logging.DEBUG if args.v else logging.INFO, - "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", - } - - logging.basicConfig(**logging_config) - - # Load, process and sanity-check the config. - hs_config = yaml.safe_load(args.database_config) - - if "database" not in hs_config: - sys.stderr.write("The configuration file must have a 'database' section.\n") - sys.exit(4) - - config = HomeServerConfig() - config.parse_config_dict(hs_config, "", "") - - # Instantiate and initialise the homeserver object. - hs = MockHomeserver(config) - - # Setup instantiates the store within the homeserver object and updates the - # DB. - hs.setup() - - if args.run_background_updates: - run_background_updates(hs) - - -if __name__ == "__main__": - main() diff --git a/setup.py b/setup.py index 26f4650348..318df16766 100755 --- a/setup.py +++ b/setup.py @@ -15,7 +15,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import glob import os from typing import Any, Dict @@ -153,8 +152,19 @@ setup( python_requires="~=3.7", entry_points={ "console_scripts": [ + # Application "synapse_homeserver = synapse.app.homeserver:main", "synapse_worker = synapse.app.generic_worker:main", + # Scripts + "export_signing_key = synapse._scripts.export_signing_key:main", + "generate_config = synapse._scripts.generate_config:main", + "generate_log_config = synapse._scripts.generate_log_config:main", + "generate_signing_key = synapse._scripts.generate_signing_key:main", + "hash_password = synapse._scripts.hash_password:main", + "register_new_matrix_user = synapse._scripts.register_new_matrix_user:main", + "synapse_port_db = synapse._scripts.synapse_port_db:main", + "synapse_review_recent_signups = synapse._scripts.review_recent_signups:main", + "update_synapse_database = synapse._scripts.update_synapse_database:main", ] }, classifiers=[ @@ -167,6 +177,6 @@ setup( "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", ], - scripts=["synctl"] + glob.glob("scripts/*"), + scripts=["synctl"], cmdclass={"test": TestCommand}, ) diff --git a/snap/snapcraft.yaml b/snap/snapcraft.yaml index 9a01152c15..dd4c8478d5 100644 --- a/snap/snapcraft.yaml +++ b/snap/snapcraft.yaml @@ -20,7 +20,7 @@ apps: generate-config: command: generate_config generate-signing-key: - command: generate_signing_key.py + command: generate_signing_key register-new-matrix-user: command: register_new_matrix_user plugs: [network] diff --git a/synapse/_scripts/export_signing_key.py b/synapse/_scripts/export_signing_key.py new file mode 100755 index 0000000000..3d254348f1 --- /dev/null +++ b/synapse/_scripts/export_signing_key.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import sys +import time +from typing import Optional + +import nacl.signing +from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys + + +def exit(status: int = 0, message: Optional[str] = None): + if message: + print(message, file=sys.stderr) + sys.exit(status) + + +def format_plain(public_key: nacl.signing.VerifyKey): + print( + "%s:%s %s" + % ( + public_key.alg, + public_key.version, + encode_verify_key_base64(public_key), + ) + ) + + +def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int): + print( + ' "%s:%s": { key: "%s", expired_ts: %i }' + % ( + public_key.alg, + public_key.version, + encode_verify_key_base64(public_key), + expiry_ts, + ) + ) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "key_file", + nargs="+", + type=argparse.FileType("r"), + help="The key file to read", + ) + + parser.add_argument( + "-x", + action="store_true", + dest="for_config", + help="format the output for inclusion in the old_signing_keys config setting", + ) + + parser.add_argument( + "--expiry-ts", + type=int, + default=int(time.time() * 1000) + 6 * 3600000, + help=( + "The expiry time to use for -x, in milliseconds since 1970. The default " + "is (now+6h)." + ), + ) + + args = parser.parse_args() + + formatter = ( + (lambda k: format_for_config(k, args.expiry_ts)) + if args.for_config + else format_plain + ) + + for file in args.key_file: + try: + res = read_signing_keys(file) + except Exception as e: + exit( + status=1, + message="Error reading key from file %s: %s %s" + % (file.name, type(e), e), + ) + res = [] + for key in res: + formatter(get_verify_key(key)) + + +if __name__ == "__main__": + main() diff --git a/synapse/_scripts/generate_config.py b/synapse/_scripts/generate_config.py new file mode 100755 index 0000000000..75fce20b12 --- /dev/null +++ b/synapse/_scripts/generate_config.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 + +import argparse +import shutil +import sys + +from synapse.config.homeserver import HomeServerConfig + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--config-dir", + default="CONFDIR", + help="The path where the config files are kept. Used to create filenames for " + "things like the log config and the signing key. Default: %(default)s", + ) + + parser.add_argument( + "--data-dir", + default="DATADIR", + help="The path where the data files are kept. Used to create filenames for " + "things like the database and media store. Default: %(default)s", + ) + + parser.add_argument( + "--server-name", + default="SERVERNAME", + help="The server name. Used to initialise the server_name config param, but also " + "used in the names of some of the config files. Default: %(default)s", + ) + + parser.add_argument( + "--report-stats", + action="store", + help="Whether the generated config reports anonymized usage statistics", + choices=["yes", "no"], + ) + + parser.add_argument( + "--generate-secrets", + action="store_true", + help="Enable generation of new secrets for things like the macaroon_secret_key." + "By default, these parameters will be left unset.", + ) + + parser.add_argument( + "-o", + "--output-file", + type=argparse.FileType("w"), + default=sys.stdout, + help="File to write the configuration to. Default: stdout", + ) + + parser.add_argument( + "--header-file", + type=argparse.FileType("r"), + help="File from which to read a header, which will be printed before the " + "generated config.", + ) + + args = parser.parse_args() + + report_stats = args.report_stats + if report_stats is not None: + report_stats = report_stats == "yes" + + conf = HomeServerConfig().generate_config( + config_dir_path=args.config_dir, + data_dir_path=args.data_dir, + server_name=args.server_name, + generate_secrets=args.generate_secrets, + report_stats=report_stats, + ) + + if args.header_file: + shutil.copyfileobj(args.header_file, args.output_file) + + args.output_file.write(conf) + + +if __name__ == "__main__": + main() diff --git a/synapse/_scripts/generate_log_config.py b/synapse/_scripts/generate_log_config.py new file mode 100755 index 0000000000..82fc763140 --- /dev/null +++ b/synapse/_scripts/generate_log_config.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 + +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import sys + +from synapse.config.logger import DEFAULT_LOG_CONFIG + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-o", + "--output-file", + type=argparse.FileType("w"), + default=sys.stdout, + help="File to write the configuration to. Default: stdout", + ) + + parser.add_argument( + "-f", + "--log-file", + type=str, + default="/var/log/matrix-synapse/homeserver.log", + help="name of the log file", + ) + + args = parser.parse_args() + out = args.output_file + out.write(DEFAULT_LOG_CONFIG.substitute(log_file=args.log_file)) + out.flush() + + +if __name__ == "__main__": + main() diff --git a/synapse/_scripts/generate_signing_key.py b/synapse/_scripts/generate_signing_key.py new file mode 100755 index 0000000000..bc26d25bfd --- /dev/null +++ b/synapse/_scripts/generate_signing_key.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import sys + +from signedjson.key import generate_signing_key, write_signing_keys + +from synapse.util.stringutils import random_string + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-o", + "--output_file", + type=argparse.FileType("w"), + default=sys.stdout, + help="Where to write the output to", + ) + args = parser.parse_args() + + key_id = "a_" + random_string(4) + key = (generate_signing_key(key_id),) + write_signing_keys(args.output_file, key) + + +if __name__ == "__main__": + main() diff --git a/synapse/_scripts/hash_password.py b/synapse/_scripts/hash_password.py new file mode 100755 index 0000000000..708640c7de --- /dev/null +++ b/synapse/_scripts/hash_password.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python + +import argparse +import getpass +import sys +import unicodedata + +import bcrypt +import yaml + + +def prompt_for_pass(): + password = getpass.getpass("Password: ") + + if not password: + raise Exception("Password cannot be blank.") + + confirm_password = getpass.getpass("Confirm password: ") + + if password != confirm_password: + raise Exception("Passwords do not match.") + + return password + + +def main(): + bcrypt_rounds = 12 + password_pepper = "" + + parser = argparse.ArgumentParser( + description=( + "Calculate the hash of a new password, so that passwords can be reset" + ) + ) + parser.add_argument( + "-p", + "--password", + default=None, + help="New password for user. Will prompt if omitted.", + ) + parser.add_argument( + "-c", + "--config", + type=argparse.FileType("r"), + help=( + "Path to server config file. " + "Used to read in bcrypt_rounds and password_pepper." + ), + ) + + args = parser.parse_args() + if "config" in args and args.config: + config = yaml.safe_load(args.config) + bcrypt_rounds = config.get("bcrypt_rounds", bcrypt_rounds) + password_config = config.get("password_config", None) or {} + password_pepper = password_config.get("pepper", password_pepper) + password = args.password + + if not password: + password = prompt_for_pass() + + # On Python 2, make sure we decode it to Unicode before we normalise it + if isinstance(password, bytes): + try: + password = password.decode(sys.stdin.encoding) + except UnicodeDecodeError: + print( + "ERROR! Your password is not decodable using your terminal encoding (%s)." + % (sys.stdin.encoding,) + ) + + pw = unicodedata.normalize("NFKC", password) + + hashed = bcrypt.hashpw( + pw.encode("utf8") + password_pepper.encode("utf8"), + bcrypt.gensalt(bcrypt_rounds), + ).decode("ascii") + + print(hashed) + + +if __name__ == "__main__": + main() diff --git a/synapse/_scripts/move_remote_media_to_new_store.py b/synapse/_scripts/move_remote_media_to_new_store.py new file mode 100755 index 0000000000..9667d95dfe --- /dev/null +++ b/synapse/_scripts/move_remote_media_to_new_store.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Moves a list of remote media from one media store to another. + +The input should be a list of media files to be moved, one per line. Each line +should be formatted:: + + | + +This can be extracted from postgres with:: + + psql --tuples-only -A -c "select media_origin, filesystem_id from + matrix.remote_media_cache where ..." + +To use, pipe the above into:: + + PYTHON_PATH=. synapse/_scripts/move_remote_media_to_new_store.py +""" + +import argparse +import logging +import os +import shutil +import sys + +from synapse.rest.media.v1.filepath import MediaFilePaths + +logger = logging.getLogger() + + +def main(src_repo, dest_repo): + src_paths = MediaFilePaths(src_repo) + dest_paths = MediaFilePaths(dest_repo) + for line in sys.stdin: + line = line.strip() + parts = line.split("|") + if len(parts) != 2: + print("Unable to parse input line %s" % line, file=sys.stderr) + sys.exit(1) + + move_media(parts[0], parts[1], src_paths, dest_paths) + + +def move_media(origin_server, file_id, src_paths, dest_paths): + """Move the given file, and any thumbnails, to the dest repo + + Args: + origin_server (str): + file_id (str): + src_paths (MediaFilePaths): + dest_paths (MediaFilePaths): + """ + logger.info("%s/%s", origin_server, file_id) + + # check that the original exists + original_file = src_paths.remote_media_filepath(origin_server, file_id) + if not os.path.exists(original_file): + logger.warning( + "Original for %s/%s (%s) does not exist", + origin_server, + file_id, + original_file, + ) + else: + mkdir_and_move( + original_file, dest_paths.remote_media_filepath(origin_server, file_id) + ) + + # now look for thumbnails + original_thumb_dir = src_paths.remote_media_thumbnail_dir(origin_server, file_id) + if not os.path.exists(original_thumb_dir): + return + + mkdir_and_move( + original_thumb_dir, + dest_paths.remote_media_thumbnail_dir(origin_server, file_id), + ) + + +def mkdir_and_move(original_file, dest_file): + dirname = os.path.dirname(dest_file) + if not os.path.exists(dirname): + logger.debug("mkdir %s", dirname) + os.makedirs(dirname) + logger.debug("mv %s %s", original_file, dest_file) + shutil.move(original_file, dest_file) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument("-v", action="store_true", help="enable debug logging") + parser.add_argument("src_repo", help="Path to source content repo") + parser.add_argument("dest_repo", help="Path to source content repo") + args = parser.parse_args() + + logging_config = { + "level": logging.DEBUG if args.v else logging.INFO, + "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", + } + logging.basicConfig(**logging_config) + + main(args.src_repo, args.dest_repo) diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py new file mode 100755 index 0000000000..c38666da18 --- /dev/null +++ b/synapse/_scripts/synapse_port_db.py @@ -0,0 +1,1257 @@ +#!/usr/bin/env python +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import curses +import logging +import sys +import time +import traceback +from typing import Dict, Iterable, Optional, Set + +import yaml +from matrix_common.versionstring import get_distribution_version_string + +from twisted.internet import defer, reactor + +from synapse.config.database import DatabaseConnectionConfig +from synapse.config.homeserver import HomeServerConfig +from synapse.logging.context import ( + LoggingContext, + make_deferred_yieldable, + run_in_background, +) +from synapse.storage.database import DatabasePool, make_conn +from synapse.storage.databases.main import PushRuleStore +from synapse.storage.databases.main.account_data import AccountDataWorkerStore +from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore +from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore +from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore +from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyBackgroundStore +from synapse.storage.databases.main.events_bg_updates import ( + EventsBackgroundUpdatesStore, +) +from synapse.storage.databases.main.group_server import GroupServerWorkerStore +from synapse.storage.databases.main.media_repository import ( + MediaRepositoryBackgroundUpdateStore, +) +from synapse.storage.databases.main.presence import PresenceBackgroundUpdateStore +from synapse.storage.databases.main.pusher import PusherWorkerStore +from synapse.storage.databases.main.registration import ( + RegistrationBackgroundUpdateStore, + find_max_generated_user_id_localpart, +) +from synapse.storage.databases.main.room import RoomBackgroundUpdateStore +from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore +from synapse.storage.databases.main.search import SearchBackgroundUpdateStore +from synapse.storage.databases.main.state import MainStateBackgroundUpdateStore +from synapse.storage.databases.main.stats import StatsStore +from synapse.storage.databases.main.user_directory import ( + UserDirectoryBackgroundUpdateStore, +) +from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore +from synapse.storage.engines import create_engine +from synapse.storage.prepare_database import prepare_database +from synapse.util import Clock + +logger = logging.getLogger("synapse_port_db") + + +BOOLEAN_COLUMNS = { + "events": ["processed", "outlier", "contains_url"], + "rooms": ["is_public", "has_auth_chain_index"], + "event_edges": ["is_state"], + "presence_list": ["accepted"], + "presence_stream": ["currently_active"], + "public_room_list_stream": ["visibility"], + "devices": ["hidden"], + "device_lists_outbound_pokes": ["sent"], + "users_who_share_rooms": ["share_private"], + "groups": ["is_public"], + "group_rooms": ["is_public"], + "group_users": ["is_public", "is_admin"], + "group_summary_rooms": ["is_public"], + "group_room_categories": ["is_public"], + "group_summary_users": ["is_public"], + "group_roles": ["is_public"], + "local_group_membership": ["is_publicised", "is_admin"], + "e2e_room_keys": ["is_verified"], + "account_validity": ["email_sent"], + "redactions": ["have_censored"], + "room_stats_state": ["is_federatable"], + "local_media_repository": ["safe_from_quarantine"], + "users": ["shadow_banned"], + "e2e_fallback_keys_json": ["used"], + "access_tokens": ["used"], +} + + +APPEND_ONLY_TABLES = [ + "event_reference_hashes", + "events", + "event_json", + "state_events", + "room_memberships", + "topics", + "room_names", + "rooms", + "local_media_repository", + "local_media_repository_thumbnails", + "remote_media_cache", + "remote_media_cache_thumbnails", + "redactions", + "event_edges", + "event_auth", + "received_transactions", + "sent_transactions", + "transaction_id_to_pdu", + "users", + "state_groups", + "state_groups_state", + "event_to_state_groups", + "rejections", + "event_search", + "presence_stream", + "push_rules_stream", + "ex_outlier_stream", + "cache_invalidation_stream_by_instance", + "public_room_list_stream", + "state_group_edges", + "stream_ordering_to_exterm", +] + + +IGNORED_TABLES = { + # We don't port these tables, as they're a faff and we can regenerate + # them anyway. + "user_directory", + "user_directory_search", + "user_directory_search_content", + "user_directory_search_docsize", + "user_directory_search_segdir", + "user_directory_search_segments", + "user_directory_search_stat", + "user_directory_search_pos", + "users_who_share_private_rooms", + "users_in_public_room", + # UI auth sessions have foreign keys so additional care needs to be taken, + # the sessions are transient anyway, so ignore them. + "ui_auth_sessions", + "ui_auth_sessions_credentials", + "ui_auth_sessions_ips", +} + + +# Error returned by the run function. Used at the top-level part of the script to +# handle errors and return codes. +end_error = None # type: Optional[str] +# The exec_info for the error, if any. If error is defined but not exec_info the script +# will show only the error message without the stacktrace, if exec_info is defined but +# not the error then the script will show nothing outside of what's printed in the run +# function. If both are defined, the script will print both the error and the stacktrace. +end_error_exec_info = None + + +class Store( + ClientIpBackgroundUpdateStore, + DeviceInboxBackgroundUpdateStore, + DeviceBackgroundUpdateStore, + EventsBackgroundUpdatesStore, + MediaRepositoryBackgroundUpdateStore, + RegistrationBackgroundUpdateStore, + RoomBackgroundUpdateStore, + RoomMemberBackgroundUpdateStore, + SearchBackgroundUpdateStore, + StateBackgroundUpdateStore, + MainStateBackgroundUpdateStore, + UserDirectoryBackgroundUpdateStore, + EndToEndKeyBackgroundStore, + StatsStore, + AccountDataWorkerStore, + PushRuleStore, + PusherWorkerStore, + PresenceBackgroundUpdateStore, + GroupServerWorkerStore, +): + def execute(self, f, *args, **kwargs): + return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) + + def execute_sql(self, sql, *args): + def r(txn): + txn.execute(sql, args) + return txn.fetchall() + + return self.db_pool.runInteraction("execute_sql", r) + + def insert_many_txn(self, txn, table, headers, rows): + sql = "INSERT INTO %s (%s) VALUES (%s)" % ( + table, + ", ".join(k for k in headers), + ", ".join("%s" for _ in headers), + ) + + try: + txn.executemany(sql, rows) + except Exception: + logger.exception("Failed to insert: %s", table) + raise + + def set_room_is_public(self, room_id, is_public): + raise Exception( + "Attempt to set room_is_public during port_db: database not empty?" + ) + + +class MockHomeserver: + def __init__(self, config): + self.clock = Clock(reactor) + self.config = config + self.hostname = config.server.server_name + self.version_string = "Synapse/" + get_distribution_version_string( + "matrix-synapse" + ) + + def get_clock(self): + return self.clock + + def get_reactor(self): + return reactor + + def get_instance_name(self): + return "master" + + +class Porter(object): + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + async def setup_table(self, table): + if table in APPEND_ONLY_TABLES: + # It's safe to just carry on inserting. + row = await self.postgres_store.db_pool.simple_select_one( + table="port_from_sqlite3", + keyvalues={"table_name": table}, + retcols=("forward_rowid", "backward_rowid"), + allow_none=True, + ) + + total_to_port = None + if row is None: + if table == "sent_transactions": + ( + forward_chunk, + already_ported, + total_to_port, + ) = await self._setup_sent_transactions() + backward_chunk = 0 + else: + await self.postgres_store.db_pool.simple_insert( + table="port_from_sqlite3", + values={ + "table_name": table, + "forward_rowid": 1, + "backward_rowid": 0, + }, + ) + + forward_chunk = 1 + backward_chunk = 0 + already_ported = 0 + else: + forward_chunk = row["forward_rowid"] + backward_chunk = row["backward_rowid"] + + if total_to_port is None: + already_ported, total_to_port = await self._get_total_count_to_port( + table, forward_chunk, backward_chunk + ) + else: + + def delete_all(txn): + txn.execute( + "DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,) + ) + txn.execute("TRUNCATE %s CASCADE" % (table,)) + + await self.postgres_store.execute(delete_all) + + await self.postgres_store.db_pool.simple_insert( + table="port_from_sqlite3", + values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, + ) + + forward_chunk = 1 + backward_chunk = 0 + + already_ported, total_to_port = await self._get_total_count_to_port( + table, forward_chunk, backward_chunk + ) + + return table, already_ported, total_to_port, forward_chunk, backward_chunk + + async def get_table_constraints(self) -> Dict[str, Set[str]]: + """Returns a map of tables that have foreign key constraints to tables they depend on.""" + + def _get_constraints(txn): + # We can pull the information about foreign key constraints out from + # the postgres schema tables. + sql = """ + SELECT DISTINCT + tc.table_name, + ccu.table_name AS foreign_table_name + FROM + information_schema.table_constraints AS tc + INNER JOIN information_schema.constraint_column_usage AS ccu + USING (table_schema, constraint_name) + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_name != ccu.table_name; + """ + txn.execute(sql) + + results = {} + for table, foreign_table in txn: + results.setdefault(table, set()).add(foreign_table) + return results + + return await self.postgres_store.db_pool.runInteraction( + "get_table_constraints", _get_constraints + ) + + async def handle_table( + self, table, postgres_size, table_size, forward_chunk, backward_chunk + ): + logger.info( + "Table %s: %i/%i (rows %i-%i) already ported", + table, + postgres_size, + table_size, + backward_chunk + 1, + forward_chunk - 1, + ) + + if not table_size: + return + + self.progress.add_table(table, postgres_size, table_size) + + if table == "event_search": + await self.handle_search_table( + postgres_size, table_size, forward_chunk, backward_chunk + ) + return + + if table in IGNORED_TABLES: + self.progress.update(table, table_size) # Mark table as done + return + + if table == "user_directory_stream_pos": + # We need to make sure there is a single row, `(X, null), as that is + # what synapse expects to be there. + await self.postgres_store.db_pool.simple_insert( + table=table, values={"stream_id": None} + ) + self.progress.update(table, table_size) # Mark table as done + return + + forward_select = ( + "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" % (table,) + ) + + backward_select = ( + "SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?" % (table,) + ) + + do_forward = [True] + do_backward = [True] + + while True: + + def r(txn): + forward_rows = [] + backward_rows = [] + if do_forward[0]: + txn.execute(forward_select, (forward_chunk, self.batch_size)) + forward_rows = txn.fetchall() + if not forward_rows: + do_forward[0] = False + + if do_backward[0]: + txn.execute(backward_select, (backward_chunk, self.batch_size)) + backward_rows = txn.fetchall() + if not backward_rows: + do_backward[0] = False + + if forward_rows or backward_rows: + headers = [column[0] for column in txn.description] + else: + headers = None + + return headers, forward_rows, backward_rows + + headers, frows, brows = await self.sqlite_store.db_pool.runInteraction( + "select", r + ) + + if frows or brows: + if frows: + forward_chunk = max(row[0] for row in frows) + 1 + if brows: + backward_chunk = min(row[0] for row in brows) - 1 + + rows = frows + brows + rows = self._convert_rows(table, headers, rows) + + def insert(txn): + self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) + + self.postgres_store.db_pool.simple_update_one_txn( + txn, + table="port_from_sqlite3", + keyvalues={"table_name": table}, + updatevalues={ + "forward_rowid": forward_chunk, + "backward_rowid": backward_chunk, + }, + ) + + await self.postgres_store.execute(insert) + + postgres_size += len(rows) + + self.progress.update(table, postgres_size) + else: + return + + async def handle_search_table( + self, postgres_size, table_size, forward_chunk, backward_chunk + ): + select = ( + "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering" + " FROM event_search as es" + " INNER JOIN events AS e USING (event_id, room_id)" + " WHERE es.rowid >= ?" + " ORDER BY es.rowid LIMIT ?" + ) + + while True: + + def r(txn): + txn.execute(select, (forward_chunk, self.batch_size)) + rows = txn.fetchall() + headers = [column[0] for column in txn.description] + + return headers, rows + + headers, rows = await self.sqlite_store.db_pool.runInteraction("select", r) + + if rows: + forward_chunk = rows[-1][0] + 1 + + # We have to treat event_search differently since it has a + # different structure in the two different databases. + def insert(txn): + sql = ( + "INSERT INTO event_search (event_id, room_id, key," + " sender, vector, origin_server_ts, stream_ordering)" + " VALUES (?,?,?,?,to_tsvector('english', ?),?,?)" + ) + + rows_dict = [] + for row in rows: + d = dict(zip(headers, row)) + if "\0" in d["value"]: + logger.warning("dropping search row %s", d) + else: + rows_dict.append(d) + + txn.executemany( + sql, + [ + ( + row["event_id"], + row["room_id"], + row["key"], + row["sender"], + row["value"], + row["origin_server_ts"], + row["stream_ordering"], + ) + for row in rows_dict + ], + ) + + self.postgres_store.db_pool.simple_update_one_txn( + txn, + table="port_from_sqlite3", + keyvalues={"table_name": "event_search"}, + updatevalues={ + "forward_rowid": forward_chunk, + "backward_rowid": backward_chunk, + }, + ) + + await self.postgres_store.execute(insert) + + postgres_size += len(rows) + + self.progress.update("event_search", postgres_size) + + else: + return + + def build_db_store( + self, + db_config: DatabaseConnectionConfig, + allow_outdated_version: bool = False, + ): + """Builds and returns a database store using the provided configuration. + + Args: + db_config: The database configuration + allow_outdated_version: True to suppress errors about the database server + version being too old to run a complete synapse + + Returns: + The built Store object. + """ + self.progress.set_state("Preparing %s" % db_config.config["name"]) + + engine = create_engine(db_config.config) + + hs = MockHomeserver(self.hs_config) + + with make_conn(db_config, engine, "portdb") as db_conn: + engine.check_database( + db_conn, allow_outdated_version=allow_outdated_version + ) + prepare_database(db_conn, engine, config=self.hs_config) + store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) + db_conn.commit() + + return store + + async def run_background_updates_on_postgres(self): + # Manually apply all background updates on the PostgreSQL database. + postgres_ready = ( + await self.postgres_store.db_pool.updates.has_completed_background_updates() + ) + + if not postgres_ready: + # Only say that we're running background updates when there are background + # updates to run. + self.progress.set_state("Running background updates on PostgreSQL") + + while not postgres_ready: + await self.postgres_store.db_pool.updates.do_next_background_update(100) + postgres_ready = await ( + self.postgres_store.db_pool.updates.has_completed_background_updates() + ) + + async def run(self): + """Ports the SQLite database to a PostgreSQL database. + + When a fatal error is met, its message is assigned to the global "end_error" + variable. When this error comes with a stacktrace, its exec_info is assigned to + the global "end_error_exec_info" variable. + """ + global end_error + + try: + # we allow people to port away from outdated versions of sqlite. + self.sqlite_store = self.build_db_store( + DatabaseConnectionConfig("master-sqlite", self.sqlite_config), + allow_outdated_version=True, + ) + + # Check if all background updates are done, abort if not. + updates_complete = ( + await self.sqlite_store.db_pool.updates.has_completed_background_updates() + ) + if not updates_complete: + end_error = ( + "Pending background updates exist in the SQLite3 database." + " Please start Synapse again and wait until every update has finished" + " before running this script.\n" + ) + return + + self.postgres_store = self.build_db_store( + self.hs_config.database.get_single_database() + ) + + await self.run_background_updates_on_postgres() + + self.progress.set_state("Creating port tables") + + def create_port_table(txn): + txn.execute( + "CREATE TABLE IF NOT EXISTS port_from_sqlite3 (" + " table_name varchar(100) NOT NULL UNIQUE," + " forward_rowid bigint NOT NULL," + " backward_rowid bigint NOT NULL" + ")" + ) + + # The old port script created a table with just a "rowid" column. + # We want people to be able to rerun this script from an old port + # so that they can pick up any missing events that were not + # ported across. + def alter_table(txn): + txn.execute( + "ALTER TABLE IF EXISTS port_from_sqlite3" + " RENAME rowid TO forward_rowid" + ) + txn.execute( + "ALTER TABLE IF EXISTS port_from_sqlite3" + " ADD backward_rowid bigint NOT NULL DEFAULT 0" + ) + + try: + await self.postgres_store.db_pool.runInteraction( + "alter_table", alter_table + ) + except Exception: + # On Error Resume Next + pass + + await self.postgres_store.db_pool.runInteraction( + "create_port_table", create_port_table + ) + + # Step 2. Set up sequences + # + # We do this before porting the tables so that event if we fail half + # way through the postgres DB always have sequences that are greater + # than their respective tables. If we don't then creating the + # `DataStore` object will fail due to the inconsistency. + self.progress.set_state("Setting up sequence generators") + await self._setup_state_group_id_seq() + await self._setup_user_id_seq() + await self._setup_events_stream_seqs() + await self._setup_sequence( + "device_inbox_sequence", ("device_inbox", "device_federation_outbox") + ) + await self._setup_sequence( + "account_data_sequence", + ("room_account_data", "room_tags_revisions", "account_data"), + ) + await self._setup_sequence("receipts_sequence", ("receipts_linearized",)) + await self._setup_sequence("presence_stream_sequence", ("presence_stream",)) + await self._setup_auth_chain_sequence() + + # Step 3. Get tables. + self.progress.set_state("Fetching tables") + sqlite_tables = await self.sqlite_store.db_pool.simple_select_onecol( + table="sqlite_master", keyvalues={"type": "table"}, retcol="name" + ) + + postgres_tables = await self.postgres_store.db_pool.simple_select_onecol( + table="information_schema.tables", + keyvalues={}, + retcol="distinct table_name", + ) + + tables = set(sqlite_tables) & set(postgres_tables) + logger.info("Found %d tables", len(tables)) + + # Step 4. Figure out what still needs copying + self.progress.set_state("Checking on port progress") + setup_res = await make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(self.setup_table, table) + for table in tables + if table not in ["schema_version", "applied_schema_deltas"] + and not table.startswith("sqlite_") + ], + consumeErrors=True, + ) + ) + # Map from table name to args passed to `handle_table`, i.e. a tuple + # of: `postgres_size`, `table_size`, `forward_chunk`, `backward_chunk`. + tables_to_port_info_map = {r[0]: r[1:] for r in setup_res} + + # Step 5. Do the copying. + # + # This is slightly convoluted as we need to ensure tables are ported + # in the correct order due to foreign key constraints. + self.progress.set_state("Copying to postgres") + + constraints = await self.get_table_constraints() + tables_ported = set() # type: Set[str] + + while tables_to_port_info_map: + # Pulls out all tables that are still to be ported and which + # only depend on tables that are already ported (if any). + tables_to_port = [ + table + for table in tables_to_port_info_map + if not constraints.get(table, set()) - tables_ported + ] + + await make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background( + self.handle_table, + table, + *tables_to_port_info_map.pop(table), + ) + for table in tables_to_port + ], + consumeErrors=True, + ) + ) + + tables_ported.update(tables_to_port) + + self.progress.done() + except Exception as e: + global end_error_exec_info + end_error = str(e) + end_error_exec_info = sys.exc_info() + logger.exception("") + finally: + reactor.stop() + + def _convert_rows(self, table, headers, rows): + bool_col_names = BOOLEAN_COLUMNS.get(table, []) + + bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names] + + class BadValueException(Exception): + pass + + def conv(j, col): + if j in bool_cols: + return bool(col) + if isinstance(col, bytes): + return bytearray(col) + elif isinstance(col, str) and "\0" in col: + logger.warning( + "DROPPING ROW: NUL value in table %s col %s: %r", + table, + headers[j], + col, + ) + raise BadValueException() + return col + + outrows = [] + for row in rows: + try: + outrows.append( + tuple(conv(j, col) for j, col in enumerate(row) if j > 0) + ) + except BadValueException: + pass + + return outrows + + async def _setup_sent_transactions(self): + # Only save things from the last day + yesterday = int(time.time() * 1000) - 86400000 + + # And save the max transaction id from each destination + select = ( + "SELECT rowid, * FROM sent_transactions WHERE rowid IN (" + "SELECT max(rowid) FROM sent_transactions" + " GROUP BY destination" + ")" + ) + + def r(txn): + txn.execute(select) + rows = txn.fetchall() + headers = [column[0] for column in txn.description] + + ts_ind = headers.index("ts") + + return headers, [r for r in rows if r[ts_ind] < yesterday] + + headers, rows = await self.sqlite_store.db_pool.runInteraction("select", r) + + rows = self._convert_rows("sent_transactions", headers, rows) + + inserted_rows = len(rows) + if inserted_rows: + max_inserted_rowid = max(r[0] for r in rows) + + def insert(txn): + self.postgres_store.insert_many_txn( + txn, "sent_transactions", headers[1:], rows + ) + + await self.postgres_store.execute(insert) + else: + max_inserted_rowid = 0 + + def get_start_id(txn): + txn.execute( + "SELECT rowid FROM sent_transactions WHERE ts >= ?" + " ORDER BY rowid ASC LIMIT 1", + (yesterday,), + ) + + rows = txn.fetchall() + if rows: + return rows[0][0] + else: + return 1 + + next_chunk = await self.sqlite_store.execute(get_start_id) + next_chunk = max(max_inserted_rowid + 1, next_chunk) + + await self.postgres_store.db_pool.simple_insert( + table="port_from_sqlite3", + values={ + "table_name": "sent_transactions", + "forward_rowid": next_chunk, + "backward_rowid": 0, + }, + ) + + def get_sent_table_size(txn): + txn.execute( + "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,) + ) + (size,) = txn.fetchone() + return int(size) + + remaining_count = await self.sqlite_store.execute(get_sent_table_size) + + total_count = remaining_count + inserted_rows + + return next_chunk, inserted_rows, total_count + + async def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk): + frows = await self.sqlite_store.execute_sql( + "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk + ) + + brows = await self.sqlite_store.execute_sql( + "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk + ) + + return frows[0][0] + brows[0][0] + + async def _get_already_ported_count(self, table): + rows = await self.postgres_store.execute_sql( + "SELECT count(*) FROM %s" % (table,) + ) + + return rows[0][0] + + async def _get_total_count_to_port(self, table, forward_chunk, backward_chunk): + remaining, done = await make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background( + self._get_remaining_count_to_port, + table, + forward_chunk, + backward_chunk, + ), + run_in_background(self._get_already_ported_count, table), + ], + ) + ) + + remaining = int(remaining) if remaining else 0 + done = int(done) if done else 0 + + return done, remaining + done + + async def _setup_state_group_id_seq(self) -> None: + curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol( + table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True + ) + + if not curr_id: + return + + def r(txn): + next_id = curr_id + 1 + txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) + + await self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r) + + async def _setup_user_id_seq(self) -> None: + curr_id = await self.sqlite_store.db_pool.runInteraction( + "setup_user_id_seq", find_max_generated_user_id_localpart + ) + + def r(txn): + next_id = curr_id + 1 + txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,)) + + await self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r) + + async def _setup_events_stream_seqs(self) -> None: + """Set the event stream sequences to the correct values.""" + + # We get called before we've ported the events table, so we need to + # fetch the current positions from the SQLite store. + curr_forward_id = await self.sqlite_store.db_pool.simple_select_one_onecol( + table="events", keyvalues={}, retcol="MAX(stream_ordering)", allow_none=True + ) + + curr_backward_id = await self.sqlite_store.db_pool.simple_select_one_onecol( + table="events", + keyvalues={}, + retcol="MAX(-MIN(stream_ordering), 1)", + allow_none=True, + ) + + def _setup_events_stream_seqs_set_pos(txn): + if curr_forward_id: + txn.execute( + "ALTER SEQUENCE events_stream_seq RESTART WITH %s", + (curr_forward_id + 1,), + ) + + if curr_backward_id: + txn.execute( + "ALTER SEQUENCE events_backfill_stream_seq RESTART WITH %s", + (curr_backward_id + 1,), + ) + + await self.postgres_store.db_pool.runInteraction( + "_setup_events_stream_seqs", + _setup_events_stream_seqs_set_pos, + ) + + async def _setup_sequence( + self, sequence_name: str, stream_id_tables: Iterable[str] + ) -> None: + """Set a sequence to the correct value.""" + current_stream_ids = [] + for stream_id_table in stream_id_tables: + max_stream_id = await self.sqlite_store.db_pool.simple_select_one_onecol( + table=stream_id_table, + keyvalues={}, + retcol="COALESCE(MAX(stream_id), 1)", + allow_none=True, + ) + current_stream_ids.append(max_stream_id) + + next_id = max(current_stream_ids) + 1 + + def r(txn): + sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name,) + txn.execute(sql + " %s", (next_id,)) + + await self.postgres_store.db_pool.runInteraction( + "_setup_%s" % (sequence_name,), r + ) + + async def _setup_auth_chain_sequence(self) -> None: + curr_chain_id = await self.sqlite_store.db_pool.simple_select_one_onecol( + table="event_auth_chains", + keyvalues={}, + retcol="MAX(chain_id)", + allow_none=True, + ) + + def r(txn): + txn.execute( + "ALTER SEQUENCE event_auth_chain_id RESTART WITH %s", + (curr_chain_id + 1,), + ) + + if curr_chain_id is not None: + await self.postgres_store.db_pool.runInteraction( + "_setup_event_auth_chain_id", + r, + ) + + +############################################## +# The following is simply UI stuff +############################################## + + +class Progress(object): + """Used to report progress of the port""" + + def __init__(self): + self.tables = {} + + self.start_time = int(time.time()) + + def add_table(self, table, cur, size): + self.tables[table] = { + "start": cur, + "num_done": cur, + "total": size, + "perc": int(cur * 100 / size), + } + + def update(self, table, num_done): + data = self.tables[table] + data["num_done"] = num_done + data["perc"] = int(num_done * 100 / data["total"]) + + def done(self): + pass + + +class CursesProgress(Progress): + """Reports progress to a curses window""" + + def __init__(self, stdscr): + self.stdscr = stdscr + + curses.use_default_colors() + curses.curs_set(0) + + curses.init_pair(1, curses.COLOR_RED, -1) + curses.init_pair(2, curses.COLOR_GREEN, -1) + + self.last_update = 0 + + self.finished = False + + self.total_processed = 0 + self.total_remaining = 0 + + super(CursesProgress, self).__init__() + + def update(self, table, num_done): + super(CursesProgress, self).update(table, num_done) + + self.total_processed = 0 + self.total_remaining = 0 + for data in self.tables.values(): + self.total_processed += data["num_done"] - data["start"] + self.total_remaining += data["total"] - data["num_done"] + + self.render() + + def render(self, force=False): + now = time.time() + + if not force and now - self.last_update < 0.2: + # reactor.callLater(1, self.render) + return + + self.stdscr.clear() + + rows, cols = self.stdscr.getmaxyx() + + duration = int(now) - int(self.start_time) + + minutes, seconds = divmod(duration, 60) + duration_str = "%02dm %02ds" % (minutes, seconds) + + if self.finished: + status = "Time spent: %s (Done!)" % (duration_str,) + else: + + if self.total_processed > 0: + left = float(self.total_remaining) / self.total_processed + + est_remaining = (int(now) - self.start_time) * left + est_remaining_str = "%02dm %02ds remaining" % divmod(est_remaining, 60) + else: + est_remaining_str = "Unknown" + status = "Time spent: %s (est. remaining: %s)" % ( + duration_str, + est_remaining_str, + ) + + self.stdscr.addstr(0, 0, status, curses.A_BOLD) + + max_len = max(len(t) for t in self.tables.keys()) + + left_margin = 5 + middle_space = 1 + + items = self.tables.items() + items = sorted(items, key=lambda i: (i[1]["perc"], i[0])) + + for i, (table, data) in enumerate(items): + if i + 2 >= rows: + break + + perc = data["perc"] + + color = curses.color_pair(2) if perc == 100 else curses.color_pair(1) + + self.stdscr.addstr( + i + 2, left_margin + max_len - len(table), table, curses.A_BOLD | color + ) + + size = 20 + + progress = "[%s%s]" % ( + "#" * int(perc * size / 100), + " " * (size - int(perc * size / 100)), + ) + + self.stdscr.addstr( + i + 2, + left_margin + max_len + middle_space, + "%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]), + ) + + if self.finished: + self.stdscr.addstr(rows - 1, 0, "Press any key to exit...") + + self.stdscr.refresh() + self.last_update = time.time() + + def done(self): + self.finished = True + self.render(True) + self.stdscr.getch() + + def set_state(self, state): + self.stdscr.clear() + self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD) + self.stdscr.refresh() + + +class TerminalProgress(Progress): + """Just prints progress to the terminal""" + + def update(self, table, num_done): + super(TerminalProgress, self).update(table, num_done) + + data = self.tables[table] + + print( + "%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"]) + ) + + def set_state(self, state): + print(state + "...") + + +############################################## +############################################## + + +def main(): + parser = argparse.ArgumentParser( + description="A script to port an existing synapse SQLite database to" + " a new PostgreSQL database." + ) + parser.add_argument("-v", action="store_true") + parser.add_argument( + "--sqlite-database", + required=True, + help="The snapshot of the SQLite database file. This must not be" + " currently used by a running synapse server", + ) + parser.add_argument( + "--postgres-config", + type=argparse.FileType("r"), + required=True, + help="The database config file for the PostgreSQL database", + ) + parser.add_argument( + "--curses", action="store_true", help="display a curses based progress UI" + ) + + parser.add_argument( + "--batch-size", + type=int, + default=1000, + help="The number of rows to select from the SQLite table each" + " iteration [default=1000]", + ) + + args = parser.parse_args() + + logging_config = { + "level": logging.DEBUG if args.v else logging.INFO, + "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", + } + + if args.curses: + logging_config["filename"] = "port-synapse.log" + + logging.basicConfig(**logging_config) + + sqlite_config = { + "name": "sqlite3", + "args": { + "database": args.sqlite_database, + "cp_min": 1, + "cp_max": 1, + "check_same_thread": False, + }, + } + + hs_config = yaml.safe_load(args.postgres_config) + + if "database" not in hs_config: + sys.stderr.write("The configuration file must have a 'database' section.\n") + sys.exit(4) + + postgres_config = hs_config["database"] + + if "name" not in postgres_config: + sys.stderr.write("Malformed database config: no 'name'\n") + sys.exit(2) + if postgres_config["name"] != "psycopg2": + sys.stderr.write("Database must use the 'psycopg2' connector.\n") + sys.exit(3) + + config = HomeServerConfig() + config.parse_config_dict(hs_config, "", "") + + def start(stdscr=None): + if stdscr: + progress = CursesProgress(stdscr) + else: + progress = TerminalProgress() + + porter = Porter( + sqlite_config=sqlite_config, + progress=progress, + batch_size=args.batch_size, + hs_config=config, + ) + + @defer.inlineCallbacks + def run(): + with LoggingContext("synapse_port_db_run"): + yield defer.ensureDeferred(porter.run()) + + reactor.callWhenRunning(run) + + reactor.run() + + if args.curses: + curses.wrapper(start) + else: + start() + + if end_error: + if end_error_exec_info: + exc_type, exc_value, exc_traceback = end_error_exec_info + traceback.print_exception(exc_type, exc_value, exc_traceback) + + sys.stderr.write(end_error) + + sys.exit(5) + + +if __name__ == "__main__": + main() diff --git a/synapse/_scripts/update_synapse_database.py b/synapse/_scripts/update_synapse_database.py new file mode 100755 index 0000000000..f43676afaa --- /dev/null +++ b/synapse/_scripts/update_synapse_database.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import sys + +import yaml +from matrix_common.versionstring import get_distribution_version_string + +from twisted.internet import defer, reactor + +from synapse.config.homeserver import HomeServerConfig +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.server import HomeServer +from synapse.storage import DataStore + +logger = logging.getLogger("update_database") + + +class MockHomeserver(HomeServer): + DATASTORE_CLASS = DataStore + + def __init__(self, config, **kwargs): + super(MockHomeserver, self).__init__( + config.server.server_name, reactor=reactor, config=config, **kwargs + ) + + self.version_string = "Synapse/" + get_distribution_version_string( + "matrix-synapse" + ) + + +def run_background_updates(hs): + store = hs.get_datastores().main + + async def run_background_updates(): + await store.db_pool.updates.run_background_updates(sleep=False) + # Stop the reactor to exit the script once every background update is run. + reactor.stop() + + def run(): + # Apply all background updates on the database. + defer.ensureDeferred( + run_as_background_process("background_updates", run_background_updates) + ) + + reactor.callWhenRunning(run) + + reactor.run() + + +def main(): + parser = argparse.ArgumentParser( + description=( + "Updates a synapse database to the latest schema and optionally runs background updates" + " on it." + ) + ) + parser.add_argument("-v", action="store_true") + parser.add_argument( + "--database-config", + type=argparse.FileType("r"), + required=True, + help="Synapse configuration file, giving the details of the database to be updated", + ) + parser.add_argument( + "--run-background-updates", + action="store_true", + required=False, + help="run background updates after upgrading the database schema", + ) + + args = parser.parse_args() + + logging_config = { + "level": logging.DEBUG if args.v else logging.INFO, + "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", + } + + logging.basicConfig(**logging_config) + + # Load, process and sanity-check the config. + hs_config = yaml.safe_load(args.database_config) + + if "database" not in hs_config: + sys.stderr.write("The configuration file must have a 'database' section.\n") + sys.exit(4) + + config = HomeServerConfig() + config.parse_config_dict(hs_config, "", "") + + # Instantiate and initialise the homeserver object. + hs = MockHomeserver(config) + + # Setup instantiates the store within the homeserver object and updates the + # DB. + hs.setup() + + if args.run_background_updates: + run_background_updates(hs) + + +if __name__ == "__main__": + main() diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 1265738dc1..8e19e2fc26 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -383,7 +383,7 @@ class RootConfig: Build a default configuration file This is used when the user explicitly asks us to generate a config file - (eg with --generate_config). + (eg with --generate-config). Args: config_dir_path: The path where the config files are kept. Used to diff --git a/tox.ini b/tox.ini index 04b972e2c5..8d6aa7580b 100644 --- a/tox.ini +++ b/tox.ini @@ -38,15 +38,7 @@ lint_targets = setup.py synapse tests - scripts # annoyingly, black doesn't find these so we have to list them - scripts/export_signing_key - scripts/generate_config - scripts/generate_log_config - scripts/hash_password - scripts/register_new_matrix_user - scripts/synapse_port_db - scripts/update_synapse_database scripts-dev scripts-dev/build_debian_packages scripts-dev/sign_json -- cgit 1.4.1 From 1103c5fe8a795eafc4aeedc547faa1b68d5a12f5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 2 Mar 2022 08:18:51 -0500 Subject: Check if instances are lists, not sequences. (#12128) As a str is a sequence, the checks were not granular enough and would allow lists or strings, when only lists were valid. --- changelog.d/12128.misc | 1 + synapse/federation/federation_client.py | 8 ++++---- synapse/handlers/room_summary.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) create mode 100644 changelog.d/12128.misc diff --git a/changelog.d/12128.misc b/changelog.d/12128.misc new file mode 100644 index 0000000000..0570a8e327 --- /dev/null +++ b/changelog.d/12128.misc @@ -0,0 +1 @@ +Fix data validation to compare to lists, not sequences. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 64e595e748..467275b98c 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1428,7 +1428,7 @@ class FederationClient(FederationBase): # Validate children_state of the room. children_state = room.pop("children_state", []) - if not isinstance(children_state, Sequence): + if not isinstance(children_state, list): raise InvalidResponseError("'room.children_state' must be a list") if any(not isinstance(e, dict) for e in children_state): raise InvalidResponseError("Invalid event in 'children_state' list") @@ -1440,14 +1440,14 @@ class FederationClient(FederationBase): # Validate the children rooms. children = res.get("children", []) - if not isinstance(children, Sequence): + if not isinstance(children, list): raise InvalidResponseError("'children' must be a list") if any(not isinstance(r, dict) for r in children): raise InvalidResponseError("Invalid room in 'children' list") # Validate the inaccessible children. inaccessible_children = res.get("inaccessible_children", []) - if not isinstance(inaccessible_children, Sequence): + if not isinstance(inaccessible_children, list): raise InvalidResponseError("'inaccessible_children' must be a list") if any(not isinstance(r, str) for r in inaccessible_children): raise InvalidResponseError( @@ -1630,7 +1630,7 @@ def _validate_hierarchy_event(d: JsonDict) -> None: raise ValueError("Invalid event: 'content' must be a dict") via = content.get("via") - if not isinstance(via, Sequence): + if not isinstance(via, list): raise ValueError("Invalid event: 'via' must be a list") if any(not isinstance(v, str) for v in via): raise ValueError("Invalid event: 'via' must be a list of strings") diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 55c2cbdba8..3979cbba71 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -857,7 +857,7 @@ class _RoomEntry: def _has_valid_via(e: EventBase) -> bool: via = e.content.get("via") - if not via or not isinstance(via, Sequence): + if not via or not isinstance(via, list): return False for v in via: if not isinstance(v, str): -- cgit 1.4.1 From 6d282a9c89ae9fb55fe7ccc8d0ab16bf18b206ec Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 2 Mar 2022 14:28:18 +0000 Subject: Make release script write correct no-op changelog (#12127) As we want to include the previous version in the "No new changes..." string. --- changelog.d/12127.misc | 1 + scripts-dev/release.py | 30 ++++++++++++++++++++++++++++-- 2 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12127.misc diff --git a/changelog.d/12127.misc b/changelog.d/12127.misc new file mode 100644 index 0000000000..e42eca63a8 --- /dev/null +++ b/changelog.d/12127.misc @@ -0,0 +1 @@ +Update release script to insert the previous version when writing "No significant changes" line in the changelog. diff --git a/scripts-dev/release.py b/scripts-dev/release.py index 4e1f99fee4..046453e65f 100755 --- a/scripts-dev/release.py +++ b/scripts-dev/release.py @@ -17,6 +17,8 @@ """An interactive script for doing a release. See `cli()` below. """ +import glob +import os import re import subprocess import sys @@ -209,8 +211,8 @@ def prepare(): with open("synapse/__init__.py", "w") as f: f.write(parsed_synapse_ast.dumps()) - # Generate changelogs - run_until_successful("python3 -m towncrier", shell=True) + # Generate changelogs. + generate_and_write_changelog(current_version) # Generate debian changelogs if parsed_new_version.pre is not None: @@ -523,5 +525,29 @@ def get_changes_for_version(wanted_version: version.Version) -> str: return "\n".join(version_changelog) +def generate_and_write_changelog(current_version: version.Version): + # We do this by getting a draft so that we can edit it before writing to the + # changelog. + result = run_until_successful( + "python3 -m towncrier --draft", shell=True, capture_output=True + ) + new_changes = result.stdout.decode("utf-8") + new_changes = new_changes.replace( + "No significant changes.", f"No significant changes since {current_version}." + ) + + # Prepend changes to changelog + with open("CHANGES.md", "r+") as f: + existing_content = f.read() + f.seek(0, 0) + f.write(new_changes) + f.write("\n") + f.write(existing_content) + + # Remove all the news fragments + for f in glob.iglob("changelog.d/*.*"): + os.remove(f) + + if __name__ == "__main__": cli() -- cgit 1.4.1 From b4461e7d8ab6cfe150f39f62aa68f7f13ef97a24 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 2 Mar 2022 16:11:16 +0000 Subject: Enable complexity checking in complexity checking docs example (#11998) --- changelog.d/11998.doc | 1 + .../running_synapse_on_single_board_computers.md | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 9 deletions(-) create mode 100644 changelog.d/11998.doc diff --git a/changelog.d/11998.doc b/changelog.d/11998.doc new file mode 100644 index 0000000000..33ab7b7880 --- /dev/null +++ b/changelog.d/11998.doc @@ -0,0 +1 @@ +Fix complexity checking config example in [Resource Constrained Devices](https://matrix-org.github.io/synapse/v1.54/other/running_synapse_on_single_board_computers.html) docs page. \ No newline at end of file diff --git a/docs/other/running_synapse_on_single_board_computers.md b/docs/other/running_synapse_on_single_board_computers.md index ea14afa8b2..dcf96f0056 100644 --- a/docs/other/running_synapse_on_single_board_computers.md +++ b/docs/other/running_synapse_on_single_board_computers.md @@ -31,28 +31,29 @@ Anything that requires modifying the device list [#7721](https://github.com/matr Put the below in a new file at /etc/matrix-synapse/conf.d/sbc.yaml to override the defaults in homeserver.yaml. ``` -# Set to false to disable presence tracking on this homeserver. +# Disable presence tracking, which is currently fairly resource intensive +# More info: https://github.com/matrix-org/synapse/issues/9478 use_presence: false -# When this is enabled, the room "complexity" will be checked before a user -# joins a new remote room. If it is above the complexity limit, the server will -# disallow joining, or will instantly leave. +# Set a small complexity limit, preventing users from joining large rooms +# which may be resource-intensive to remain a part of. +# +# Note that this will not prevent users from joining smaller rooms that +# eventually become complex. limit_remote_rooms: - # Uncomment to enable room complexity checking. - #enabled: true + enabled: true complexity: 3.0 # Database configuration database: + # Use postgres for the best performance name: psycopg2 args: user: matrix-synapse - # Generate a long, secure one with a password manager + # Generate a long, secure password using a password manager password: hunter2 database: matrix-synapse host: localhost - cp_min: 5 - cp_max: 10 ``` Currently the complexity is measured by [current_state_events / 500](https://github.com/matrix-org/synapse/blob/v1.20.1/synapse/storage/databases/main/events_worker.py#L986). You can find join times and your most complex rooms like this: -- cgit 1.4.1 From 2ffaf30803f93273a4d8a65c9e6c3110c8433488 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 2 Mar 2022 17:34:14 +0100 Subject: Add type hints to `tests/rest/client` (#12108) * Add type hints to `tests/rest/client` * newsfile * fix imports * add `test_account.py` * Remove one type hint in `test_report_event.py` * change `on_create_room` to `async` * update new functions in `test_third_party_rules.py` * Add `test_filter.py` * add `test_rooms.py` * change to `assertEquals` to `assertEqual` * lint --- changelog.d/12108.misc | 1 + mypy.ini | 6 - tests/rest/client/test_account.py | 290 +++++++++++++++------------- tests/rest/client/test_filter.py | 29 +-- tests/rest/client/test_relations.py | 4 +- tests/rest/client/test_report_event.py | 25 ++- tests/rest/client/test_rooms.py | 271 +++++++++++++------------- tests/rest/client/test_third_party_rules.py | 108 +++++++---- tests/rest/client/test_typing.py | 41 ++-- 9 files changed, 423 insertions(+), 352 deletions(-) create mode 100644 changelog.d/12108.misc diff --git a/changelog.d/12108.misc b/changelog.d/12108.misc new file mode 100644 index 0000000000..0360dbd61e --- /dev/null +++ b/changelog.d/12108.misc @@ -0,0 +1 @@ +Add type hints to `tests/rest/client`. diff --git a/mypy.ini b/mypy.ini index 6b1e995e64..23ca4eaa5a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -78,13 +78,7 @@ exclude = (?x) |tests/push/test_http.py |tests/push/test_presentable_names.py |tests/push/test_push_rule_evaluator.py - |tests/rest/client/test_account.py - |tests/rest/client/test_filter.py - |tests/rest/client/test_report_event.py - |tests/rest/client/test_rooms.py - |tests/rest/client/test_third_party_rules.py |tests/rest/client/test_transactions.py - |tests/rest/client/test_typing.py |tests/rest/key/v2/test_remote_key_resource.py |tests/rest/media/v1/test_base.py |tests/rest/media/v1/test_media_storage.py diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 6c4462e74a..def836054d 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -15,11 +15,12 @@ import json import os import re from email.parser import Parser -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from unittest.mock import Mock import pkg_resources +from twisted.internet.interfaces import IReactorTCP from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -30,6 +31,7 @@ from synapse.rest import admin from synapse.rest.client import account, login, register, room from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from synapse.server import HomeServer +from synapse.types import JsonDict from synapse.util import Clock from tests import unittest @@ -46,7 +48,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() # Email config. @@ -67,20 +69,27 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) async def sendmail( - reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs - ): - self.email_attempts.append(msg) - - self.email_attempts = [] + reactor: IReactorTCP, + smtphost: str, + smtpport: int, + from_addr: str, + to_addr: str, + msg_bytes: bytes, + *args: Any, + **kwargs: Any, + ) -> None: + self.email_attempts.append(msg_bytes) + + self.email_attempts: List[bytes] = [] hs.get_send_email_handler()._sendmail = sendmail return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.submit_token_resource = PasswordResetSubmitTokenResource(hs) - def test_basic_password_reset(self): + def test_basic_password_reset(self) -> None: """Test basic password reset flow""" old_password = "monkey" new_password = "kangeroo" @@ -118,7 +127,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.attempt_wrong_password_login("kermit", old_password) @override_config({"rc_3pid_validation": {"burst_count": 3}}) - def test_ratelimit_by_email(self): + def test_ratelimit_by_email(self) -> None: """Test that we ratelimit /requestToken for the same email.""" old_password = "monkey" new_password = "kangeroo" @@ -139,7 +148,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): ) ) - def reset(ip): + def reset(ip: str) -> None: client_secret = "foobar" session_id = self._request_token(email, client_secret, ip) @@ -166,7 +175,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.assertEqual(cm.exception.code, 429) - def test_basic_password_reset_canonicalise_email(self): + def test_basic_password_reset_canonicalise_email(self) -> None: """Test basic password reset flow Request password reset with different spelling """ @@ -206,7 +215,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Assert we can't log in with the old password self.attempt_wrong_password_login("kermit", old_password) - def test_cant_reset_password_without_clicking_link(self): + def test_cant_reset_password_without_clicking_link(self) -> None: """Test that we do actually need to click the link in the email""" old_password = "monkey" new_password = "kangeroo" @@ -241,7 +250,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Assert we can't log in with the new password self.attempt_wrong_password_login("kermit", new_password) - def test_no_valid_token(self): + def test_no_valid_token(self) -> None: """Test that we do actually need to request a token and can't just make a session up. """ @@ -277,7 +286,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.attempt_wrong_password_login("kermit", new_password) @unittest.override_config({"request_token_inhibit_3pid_errors": True}) - def test_password_reset_bad_email_inhibit_error(self): + def test_password_reset_bad_email_inhibit_error(self) -> None: """Test that triggering a password reset with an email address that isn't bound to an account doesn't leak the lack of binding for that address if configured that way. @@ -292,7 +301,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.assertIsNotNone(session_id) - def _request_token(self, email, client_secret, ip="127.0.0.1"): + def _request_token( + self, + email: str, + client_secret: str, + ip: str = "127.0.0.1", + ) -> str: channel = self.make_request( "POST", b"account/password/email/requestToken", @@ -309,7 +323,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): return channel.json_body["sid"] - def _validate_token(self, link): + def _validate_token(self, link: str) -> None: # Remove the host path = link.replace("https://example.com", "") @@ -339,7 +353,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): ) self.assertEqual(200, channel.code, channel.result) - def _get_link_from_email(self): + def _get_link_from_email(self) -> str: assert self.email_attempts, "No emails have been sent" raw_msg = self.email_attempts[-1].decode("UTF-8") @@ -354,14 +368,19 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): if not text: self.fail("Could not find text portion of email to parse") + assert text is not None match = re.search(r"https://example.com\S+", text) assert match, "Could not find link in email" return match.group(0) def _reset_password( - self, new_password, session_id, client_secret, expected_code=200 - ): + self, + new_password: str, + session_id: str, + client_secret: str, + expected_code: int = 200, + ) -> None: channel = self.make_request( "POST", b"account/password", @@ -388,11 +407,11 @@ class DeactivateTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver() return self.hs - def test_deactivate_account(self): + def test_deactivate_account(self) -> None: user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") @@ -407,7 +426,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", "account/whoami", access_token=tok) self.assertEqual(channel.code, 401) - def test_pending_invites(self): + def test_pending_invites(self) -> None: """Tests that deactivating a user rejects every pending invite for them.""" store = self.hs.get_datastores().main @@ -448,7 +467,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.assertEqual(len(memberships), 1, memberships) self.assertEqual(memberships[0].room_id, room_id, memberships) - def deactivate(self, user_id, tok): + def deactivate(self, user_id: str, tok: str) -> None: request_data = json.dumps( { "auth": { @@ -474,12 +493,12 @@ class WhoamiTestCase(unittest.HomeserverTestCase): register.register_servlets, ] - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["allow_guest_access"] = True return config - def test_GET_whoami(self): + def test_GET_whoami(self) -> None: device_id = "wouldgohere" user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test", device_id=device_id) @@ -496,7 +515,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase): }, ) - def test_GET_whoami_guests(self): + def test_GET_whoami_guests(self) -> None: channel = self.make_request( b"POST", b"/_matrix/client/r0/register?kind=guest", b"{}" ) @@ -516,7 +535,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase): }, ) - def test_GET_whoami_appservices(self): + def test_GET_whoami_appservices(self) -> None: user_id = "@as:test" as_token = "i_am_an_app_service" @@ -541,7 +560,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase): ) self.assertFalse(hasattr(whoami, "device_id")) - def _whoami(self, tok): + def _whoami(self, tok: str) -> JsonDict: channel = self.make_request("GET", "account/whoami", {}, access_token=tok) self.assertEqual(channel.code, 200) return channel.json_body @@ -555,7 +574,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): synapse.rest.admin.register_servlets_for_client_rest_resource, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() # Email config. @@ -576,16 +595,23 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.hs = self.setup_test_homeserver(config=config) async def sendmail( - reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs - ): - self.email_attempts.append(msg) - - self.email_attempts = [] + reactor: IReactorTCP, + smtphost: str, + smtpport: int, + from_addr: str, + to_addr: str, + msg_bytes: bytes, + *args: Any, + **kwargs: Any, + ) -> None: + self.email_attempts.append(msg_bytes) + + self.email_attempts: List[bytes] = [] self.hs.get_send_email_handler()._sendmail = sendmail return self.hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.user_id = self.register_user("kermit", "test") @@ -593,83 +619,73 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.email = "test@example.com" self.url_3pid = b"account/3pid" - def test_add_valid_email(self): - self.get_success(self._add_email(self.email, self.email)) + def test_add_valid_email(self) -> None: + self._add_email(self.email, self.email) - def test_add_valid_email_second_time(self): - self.get_success(self._add_email(self.email, self.email)) - self.get_success( - self._request_token_invalid_email( - self.email, - expected_errcode=Codes.THREEPID_IN_USE, - expected_error="Email is already in use", - ) + def test_add_valid_email_second_time(self) -> None: + self._add_email(self.email, self.email) + self._request_token_invalid_email( + self.email, + expected_errcode=Codes.THREEPID_IN_USE, + expected_error="Email is already in use", ) - def test_add_valid_email_second_time_canonicalise(self): - self.get_success(self._add_email(self.email, self.email)) - self.get_success( - self._request_token_invalid_email( - "TEST@EXAMPLE.COM", - expected_errcode=Codes.THREEPID_IN_USE, - expected_error="Email is already in use", - ) + def test_add_valid_email_second_time_canonicalise(self) -> None: + self._add_email(self.email, self.email) + self._request_token_invalid_email( + "TEST@EXAMPLE.COM", + expected_errcode=Codes.THREEPID_IN_USE, + expected_error="Email is already in use", ) - def test_add_email_no_at(self): - self.get_success( - self._request_token_invalid_email( - "address-without-at.bar", - expected_errcode=Codes.UNKNOWN, - expected_error="Unable to parse email address", - ) + def test_add_email_no_at(self) -> None: + self._request_token_invalid_email( + "address-without-at.bar", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", ) - def test_add_email_two_at(self): - self.get_success( - self._request_token_invalid_email( - "foo@foo@test.bar", - expected_errcode=Codes.UNKNOWN, - expected_error="Unable to parse email address", - ) + def test_add_email_two_at(self) -> None: + self._request_token_invalid_email( + "foo@foo@test.bar", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", ) - def test_add_email_bad_format(self): - self.get_success( - self._request_token_invalid_email( - "user@bad.example.net@good.example.com", - expected_errcode=Codes.UNKNOWN, - expected_error="Unable to parse email address", - ) + def test_add_email_bad_format(self) -> None: + self._request_token_invalid_email( + "user@bad.example.net@good.example.com", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", ) - def test_add_email_domain_to_lower(self): - self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar")) + def test_add_email_domain_to_lower(self) -> None: + self._add_email("foo@TEST.BAR", "foo@test.bar") - def test_add_email_domain_with_umlaut(self): - self.get_success(self._add_email("foo@Γ–umlaut.com", "foo@ΓΆumlaut.com")) + def test_add_email_domain_with_umlaut(self) -> None: + self._add_email("foo@Γ–umlaut.com", "foo@ΓΆumlaut.com") - def test_add_email_address_casefold(self): - self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com")) + def test_add_email_address_casefold(self) -> None: + self._add_email("Strauß@Example.com", "strauss@example.com") - def test_address_trim(self): - self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar")) + def test_address_trim(self) -> None: + self._add_email(" foo@test.bar ", "foo@test.bar") @override_config({"rc_3pid_validation": {"burst_count": 3}}) - def test_ratelimit_by_ip(self): + def test_ratelimit_by_ip(self) -> None: """Tests that adding emails is ratelimited by IP""" # We expect to be able to set three emails before getting ratelimited. - self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar")) - self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar")) - self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar")) + self._add_email("foo1@test.bar", "foo1@test.bar") + self._add_email("foo2@test.bar", "foo2@test.bar") + self._add_email("foo3@test.bar", "foo3@test.bar") with self.assertRaises(HttpResponseException) as cm: - self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar")) + self._add_email("foo4@test.bar", "foo4@test.bar") self.assertEqual(cm.exception.code, 429) - def test_add_email_if_disabled(self): + def test_add_email_if_disabled(self) -> None: """Test adding email to profile when doing so is disallowed""" self.hs.config.registration.enable_3pid_changes = False @@ -695,7 +711,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user @@ -705,10 +721,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) - def test_delete_email(self): + def test_delete_email(self) -> None: """Test deleting an email from profile""" # Add a threepid self.get_success( @@ -727,7 +743,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): {"medium": "email", "address": self.email}, access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # Get user channel = self.make_request( @@ -736,10 +752,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) - def test_delete_email_if_disabled(self): + def test_delete_email_if_disabled(self) -> None: """Test deleting an email from profile when disallowed""" self.hs.config.registration.enable_3pid_changes = False @@ -761,7 +777,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user @@ -771,11 +787,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) - def test_cant_add_email_without_clicking_link(self): + def test_cant_add_email_without_clicking_link(self) -> None: """Test that we do actually need to click the link in the email""" client_secret = "foobar" session_id = self._request_token(self.email, client_secret) @@ -797,7 +813,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user @@ -807,10 +823,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) - def test_no_valid_token(self): + def test_no_valid_token(self) -> None: """Test that we do actually need to request a token and can't just make a session up. """ @@ -832,7 +848,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user @@ -842,11 +858,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) @override_config({"next_link_domain_whitelist": None}) - def test_next_link(self): + def test_next_link(self) -> None: """Tests a valid next_link parameter value with no whitelist (good case)""" self._request_token( "something@example.com", @@ -856,7 +872,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) @override_config({"next_link_domain_whitelist": None}) - def test_next_link_exotic_protocol(self): + def test_next_link_exotic_protocol(self) -> None: """Tests using a esoteric protocol as a next_link parameter value. Someone may be hosting a client on IPFS etc. """ @@ -868,7 +884,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) @override_config({"next_link_domain_whitelist": None}) - def test_next_link_file_uri(self): + def test_next_link_file_uri(self) -> None: """Tests next_link parameters cannot be file URI""" # Attempt to use a next_link value that points to the local disk self._request_token( @@ -879,7 +895,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) - def test_next_link_domain_whitelist(self): + def test_next_link_domain_whitelist(self) -> None: """Tests next_link parameters must fit the whitelist if provided""" # Ensure not providing a next_link parameter still works @@ -912,7 +928,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) @override_config({"next_link_domain_whitelist": []}) - def test_empty_next_link_domain_whitelist(self): + def test_empty_next_link_domain_whitelist(self) -> None: """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially disallowed """ @@ -962,28 +978,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def _request_token_invalid_email( self, - email, - expected_errcode, - expected_error, - client_secret="foobar", - ): + email: str, + expected_errcode: str, + expected_error: str, + client_secret: str = "foobar", + ) -> None: channel = self.make_request( "POST", b"account/3pid/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(expected_errcode, channel.json_body["errcode"]) self.assertEqual(expected_error, channel.json_body["error"]) - def _validate_token(self, link): + def _validate_token(self, link: str) -> None: # Remove the host path = link.replace("https://example.com", "") channel = self.make_request("GET", path, shorthand=False) self.assertEqual(200, channel.code, channel.result) - def _get_link_from_email(self): + def _get_link_from_email(self) -> str: assert self.email_attempts, "No emails have been sent" raw_msg = self.email_attempts[-1].decode("UTF-8") @@ -998,12 +1014,13 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): if not text: self.fail("Could not find text portion of email to parse") + assert text is not None match = re.search(r"https://example.com\S+", text) assert match, "Could not find link in email" return match.group(0) - def _add_email(self, request_email, expected_email): + def _add_email(self, request_email: str, expected_email: str) -> None: """Test adding an email to profile""" previous_email_attempts = len(self.email_attempts) @@ -1030,7 +1047,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # Get user channel = self.make_request( @@ -1039,7 +1056,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} @@ -1055,18 +1072,18 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): url = "/_matrix/client/unstable/org.matrix.msc3720/account_status" - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["experimental_features"] = {"msc3720_enabled": True} return self.setup_test_homeserver(config=config) - def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.requester = self.register_user("requester", "password") self.requester_tok = self.login("requester", "password") - self.server_name = homeserver.config.server.server_name + self.server_name = hs.config.server.server_name - def test_missing_mxid(self): + def test_missing_mxid(self) -> None: """Tests that not providing any MXID raises an error.""" self._test_status( users=None, @@ -1074,7 +1091,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): expected_errcode=Codes.MISSING_PARAM, ) - def test_invalid_mxid(self): + def test_invalid_mxid(self) -> None: """Tests that providing an invalid MXID raises an error.""" self._test_status( users=["bad:test"], @@ -1082,7 +1099,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): expected_errcode=Codes.INVALID_PARAM, ) - def test_local_user_not_exists(self): + def test_local_user_not_exists(self) -> None: """Tests that the account status endpoints correctly reports that a user doesn't exist. """ @@ -1098,7 +1115,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): expected_failures=[], ) - def test_local_user_exists(self): + def test_local_user_exists(self) -> None: """Tests that the account status endpoint correctly reports that a user doesn't exist. """ @@ -1115,7 +1132,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): expected_failures=[], ) - def test_local_user_deactivated(self): + def test_local_user_deactivated(self) -> None: """Tests that the account status endpoint correctly reports a deactivated user.""" user = self.register_user("someuser", "password") self.get_success( @@ -1135,7 +1152,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): expected_failures=[], ) - def test_mixed_local_and_remote_users(self): + def test_mixed_local_and_remote_users(self) -> None: """Tests that if some users are remote the account status endpoint correctly merges the remote responses with the local result. """ @@ -1150,7 +1167,13 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): "@bad:badremote", ] - async def post_json(destination, path, data, *a, **kwa): + async def post_json( + destination: str, + path: str, + data: Optional[JsonDict] = None, + *a: Any, + **kwa: Any, + ) -> Union[JsonDict, list]: if destination == "remote": return { "account_statuses": { @@ -1160,9 +1183,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): }, } } - if destination == "otherremote": - return {} - if destination == "badremote": + elif destination == "badremote": # badremote tries to overwrite the status of a user that doesn't belong # to it (i.e. users[1]) with false data, which Synapse is expected to # ignore. @@ -1176,6 +1197,9 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): }, } } + # if destination == "otherremote" + else: + return {} # Register a mock that will return the expected result depending on the remote. self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) @@ -1205,7 +1229,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None, expected_failures: Optional[List[str]] = None, expected_errcode: Optional[str] = None, - ): + ) -> None: """Send a request to the account status endpoint and check that the response matches with what's expected. diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index 5c31a54421..823e8ab8c4 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import Codes from synapse.rest.client import filter +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -30,11 +32,11 @@ class FilterTestCase(unittest.HomeserverTestCase): EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}' servlets = [filter.register_servlets] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.filtering = hs.get_filtering() self.store = hs.get_datastores().main - def test_add_filter(self): + def test_add_filter(self) -> None: channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % (self.user_id), @@ -43,11 +45,13 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.json_body, {"filter_id": "0"}) - filter = self.store.get_user_filter(user_localpart="apple", filter_id=0) + filter = self.get_success( + self.store.get_user_filter(user_localpart="apple", filter_id=0) + ) self.pump() - self.assertEqual(filter.result, self.EXAMPLE_FILTER) + self.assertEqual(filter, self.EXAMPLE_FILTER) - def test_add_filter_for_other_user(self): + def test_add_filter_for_other_user(self) -> None: channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), @@ -57,7 +61,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) - def test_add_filter_non_local_user(self): + def test_add_filter_non_local_user(self) -> None: _is_mine = self.hs.is_mine self.hs.is_mine = lambda target_user: False channel = self.make_request( @@ -70,14 +74,13 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) - def test_get_filter(self): - filter_id = defer.ensureDeferred( + def test_get_filter(self) -> None: + filter_id = self.get_success( self.filtering.add_user_filter( user_localpart="apple", user_filter=self.EXAMPLE_FILTER ) ) self.reactor.advance(1) - filter_id = filter_id.result channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id) ) @@ -85,7 +88,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.json_body, self.EXAMPLE_FILTER) - def test_get_filter_non_existant(self): + def test_get_filter_non_existant(self) -> None: channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id) ) @@ -95,7 +98,7 @@ class FilterTestCase(unittest.HomeserverTestCase): # Currently invalid params do not have an appropriate errcode # in errors.py - def test_get_filter_invalid_id(self): + def test_get_filter_invalid_id(self) -> None: channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id) ) @@ -103,7 +106,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"400") # No ID also returns an invalid_id error - def test_get_filter_no_id(self): + def test_get_filter_no_id(self) -> None: channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id) ) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index a087cd7b21..709f851a38 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -15,7 +15,7 @@ import itertools import urllib.parse -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from unittest.mock import patch from twisted.test.proto_helpers import MemoryReactor @@ -45,7 +45,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): ] hijack_auth = False - def default_config(self) -> dict: + def default_config(self) -> Dict[str, Any]: # We need to enable msc1849 support for aggregations config = super().default_config() diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py index ee6b0b9ebf..20a259fc43 100644 --- a/tests/rest/client/test_report_event.py +++ b/tests/rest/client/test_report_event.py @@ -14,8 +14,13 @@ import json +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.rest.client import login, report_event, room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest @@ -28,7 +33,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase): report_event.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") self.other_user = self.register_user("user", "pass") @@ -42,35 +47,35 @@ class ReportEventTestCase(unittest.HomeserverTestCase): self.event_id = resp["event_id"] self.report_path = f"rooms/{self.room_id}/report/{self.event_id}" - def test_reason_str_and_score_int(self): + def test_reason_str_and_score_int(self) -> None: data = {"reason": "this makes me sad", "score": -100} self._assert_status(200, data) - def test_no_reason(self): + def test_no_reason(self) -> None: data = {"score": 0} self._assert_status(200, data) - def test_no_score(self): + def test_no_score(self) -> None: data = {"reason": "this makes me sad"} self._assert_status(200, data) - def test_no_reason_and_no_score(self): - data = {} + def test_no_reason_and_no_score(self) -> None: + data: JsonDict = {} self._assert_status(200, data) - def test_reason_int_and_score_str(self): + def test_reason_int_and_score_str(self) -> None: data = {"reason": 10, "score": "string"} self._assert_status(400, data) - def test_reason_zero_and_score_blank(self): + def test_reason_zero_and_score_blank(self) -> None: data = {"reason": 0, "score": ""} self._assert_status(400, data) - def test_reason_and_score_null(self): + def test_reason_and_score_null(self) -> None: data = {"reason": None, "score": None} self._assert_status(400, data) - def _assert_status(self, response_status, data): + def _assert_status(self, response_status: int, data: JsonDict) -> None: channel = self.make_request( "POST", self.report_path, diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index e0b11e7264..37866ee330 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -18,11 +18,12 @@ """Tests REST events for /rooms paths.""" import json -from typing import Iterable, List +from typing import Any, Dict, Iterable, List, Optional from unittest.mock import Mock, call from urllib import parse as urlparse from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.constants import ( @@ -35,7 +36,9 @@ from synapse.api.errors import Codes, HttpResponseException from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client import account, directory, login, profile, room, sync +from synapse.server import HomeServer from synapse.types import JsonDict, RoomAlias, UserID, create_requester +from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest @@ -45,11 +48,11 @@ PATH_PREFIX = b"/_matrix/client/api/v1" class RoomBase(unittest.HomeserverTestCase): - rmcreator_id = None + rmcreator_id: Optional[str] = None servlets = [room.register_servlets, room.register_deprecated_servlets] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver( "red", @@ -57,15 +60,15 @@ class RoomBase(unittest.HomeserverTestCase): federation_client=Mock(), ) - self.hs.get_federation_handler = Mock() + self.hs.get_federation_handler = Mock() # type: ignore[assignment] self.hs.get_federation_handler.return_value.maybe_backfill = Mock( return_value=make_awaitable(None) ) - async def _insert_client_ip(*args, **kwargs): + async def _insert_client_ip(*args: Any, **kwargs: Any) -> None: return None - self.hs.get_datastores().main.insert_client_ip = _insert_client_ip + self.hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[assignment] return self.hs @@ -76,7 +79,7 @@ class RoomPermissionsTestCase(RoomBase): user_id = "@sid1:red" rmcreator_id = "@notme:red" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.helper.auth_user_id = self.rmcreator_id # create some rooms under the name rmcreator_id @@ -108,12 +111,12 @@ class RoomPermissionsTestCase(RoomBase): # auth as user_id now self.helper.auth_user_id = self.user_id - def test_can_do_action(self): + def test_can_do_action(self) -> None: msg_content = b'{"msgtype":"m.text","body":"hello"}' seq = iter(range(100)) - def send_msg_path(): + def send_msg_path() -> str: return "/rooms/%s/send/m.room.message/mid%s" % ( self.created_rmid, str(next(seq)), @@ -148,7 +151,7 @@ class RoomPermissionsTestCase(RoomBase): channel = self.make_request("PUT", send_msg_path(), msg_content) self.assertEqual(403, channel.code, msg=channel.result["body"]) - def test_topic_perms(self): + def test_topic_perms(self) -> None: topic_content = b'{"topic":"My Topic Name"}' topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid @@ -214,14 +217,14 @@ class RoomPermissionsTestCase(RoomBase): self.assertEqual(403, channel.code, msg=channel.result["body"]) def _test_get_membership( - self, room=None, members: Iterable = frozenset(), expect_code=None - ): + self, room: str, members: Iterable = frozenset(), expect_code: int = 200 + ) -> None: for member in members: path = "/rooms/%s/state/m.room.member/%s" % (room, member) channel = self.make_request("GET", path) self.assertEqual(expect_code, channel.code) - def test_membership_basic_room_perms(self): + def test_membership_basic_room_perms(self) -> None: # === room does not exist === room = self.uncreated_rmid # get membership of self, get membership of other, uncreated room @@ -241,7 +244,7 @@ class RoomPermissionsTestCase(RoomBase): self.helper.join(room=room, user=usr, expect_code=404) self.helper.leave(room=room, user=usr, expect_code=404) - def test_membership_private_room_perms(self): + def test_membership_private_room_perms(self) -> None: room = self.created_rmid # get membership of self, get membership of other, private room + invite # expect all 403s @@ -264,7 +267,7 @@ class RoomPermissionsTestCase(RoomBase): members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 ) - def test_membership_public_room_perms(self): + def test_membership_public_room_perms(self) -> None: room = self.created_public_rmid # get membership of self, get membership of other, public room + invite # expect 403 @@ -287,7 +290,7 @@ class RoomPermissionsTestCase(RoomBase): members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 ) - def test_invited_permissions(self): + def test_invited_permissions(self) -> None: room = self.created_rmid self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) @@ -310,7 +313,7 @@ class RoomPermissionsTestCase(RoomBase): expect_code=403, ) - def test_joined_permissions(self): + def test_joined_permissions(self) -> None: room = self.created_rmid self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) self.helper.join(room=room, user=self.user_id) @@ -348,7 +351,7 @@ class RoomPermissionsTestCase(RoomBase): # set left of self, expect 200 self.helper.leave(room=room, user=self.user_id) - def test_leave_permissions(self): + def test_leave_permissions(self) -> None: room = self.created_rmid self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) self.helper.join(room=room, user=self.user_id) @@ -383,7 +386,7 @@ class RoomPermissionsTestCase(RoomBase): ) # tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember - def test_member_event_from_ban(self): + def test_member_event_from_ban(self) -> None: room = self.created_rmid self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) self.helper.join(room=room, user=self.user_id) @@ -475,21 +478,21 @@ class RoomsMemberListTestCase(RoomBase): user_id = "@sid1:red" - def test_get_member_list(self): + def test_get_member_list(self) -> None: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request("GET", "/rooms/%s/members" % room_id) self.assertEqual(200, channel.code, msg=channel.result["body"]) - def test_get_member_list_no_room(self): + def test_get_member_list_no_room(self) -> None: channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") self.assertEqual(403, channel.code, msg=channel.result["body"]) - def test_get_member_list_no_permission(self): + def test_get_member_list_no_permission(self) -> None: room_id = self.helper.create_room_as("@some_other_guy:red") channel = self.make_request("GET", "/rooms/%s/members" % room_id) self.assertEqual(403, channel.code, msg=channel.result["body"]) - def test_get_member_list_no_permission_with_at_token(self): + def test_get_member_list_no_permission_with_at_token(self) -> None: """ Tests that a stranger to the room cannot get the member list (in the case that they use an at token). @@ -509,7 +512,7 @@ class RoomsMemberListTestCase(RoomBase): ) self.assertEqual(403, channel.code, msg=channel.result["body"]) - def test_get_member_list_no_permission_former_member(self): + def test_get_member_list_no_permission_former_member(self) -> None: """ Tests that a former member of the room can not get the member list. """ @@ -529,7 +532,7 @@ class RoomsMemberListTestCase(RoomBase): channel = self.make_request("GET", "/rooms/%s/members" % room_id) self.assertEqual(403, channel.code, msg=channel.result["body"]) - def test_get_member_list_no_permission_former_member_with_at_token(self): + def test_get_member_list_no_permission_former_member_with_at_token(self) -> None: """ Tests that a former member of the room can not get the member list (in the case that they use an at token). @@ -569,7 +572,7 @@ class RoomsMemberListTestCase(RoomBase): ) self.assertEqual(403, channel.code, msg=channel.result["body"]) - def test_get_member_list_mixed_memberships(self): + def test_get_member_list_mixed_memberships(self) -> None: room_creator = "@some_other_guy:red" room_id = self.helper.create_room_as(room_creator) room_path = "/rooms/%s/members" % room_id @@ -594,26 +597,26 @@ class RoomsCreateTestCase(RoomBase): user_id = "@sid1:red" - def test_post_room_no_keys(self): + def test_post_room_no_keys(self) -> None: # POST with no config keys, expect new room id channel = self.make_request("POST", "/createRoom", "{}") self.assertEqual(200, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) - def test_post_room_visibility_key(self): + def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') self.assertEqual(200, channel.code) self.assertTrue("room_id" in channel.json_body) - def test_post_room_custom_key(self): + def test_post_room_custom_key(self) -> None: # POST with custom config keys, expect new room id channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') self.assertEqual(200, channel.code) self.assertTrue("room_id" in channel.json_body) - def test_post_room_known_and_unknown_keys(self): + def test_post_room_known_and_unknown_keys(self) -> None: # POST with custom + known config keys, expect new room id channel = self.make_request( "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' @@ -621,7 +624,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(200, channel.code) self.assertTrue("room_id" in channel.json_body) - def test_post_room_invalid_content(self): + def test_post_room_invalid_content(self) -> None: # POST with invalid content / paths, expect 400 channel = self.make_request("POST", "/createRoom", b'{"visibili') self.assertEqual(400, channel.code) @@ -629,7 +632,7 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request("POST", "/createRoom", b'["hello"]') self.assertEqual(400, channel.code) - def test_post_room_invitees_invalid_mxid(self): + def test_post_room_invitees_invalid_mxid(self) -> None: # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 # Note the trailing space in the MXID here! channel = self.make_request( @@ -638,7 +641,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(400, channel.code) @unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}}) - def test_post_room_invitees_ratelimit(self): + def test_post_room_invitees_ratelimit(self) -> None: """Test that invites sent when creating a room are ratelimited by a RateLimiter, which ratelimits them correctly, including by not limiting when the requester is exempt from ratelimiting. @@ -674,7 +677,7 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request("POST", "/createRoom", content) self.assertEqual(200, channel.code) - def test_spam_checker_may_join_room(self): + def test_spam_checker_may_join_room(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly bypassed when creating a new room. """ @@ -704,12 +707,12 @@ class RoomTopicTestCase(RoomBase): user_id = "@sid1:red" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # create the room self.room_id = self.helper.create_room_as(self.user_id) self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,) - def test_invalid_puts(self): + def test_invalid_puts(self) -> None: # missing keys or invalid json channel = self.make_request("PUT", self.path, "{}") self.assertEqual(400, channel.code, msg=channel.result["body"]) @@ -736,7 +739,7 @@ class RoomTopicTestCase(RoomBase): channel = self.make_request("PUT", self.path, content) self.assertEqual(400, channel.code, msg=channel.result["body"]) - def test_rooms_topic(self): + def test_rooms_topic(self) -> None: # nothing should be there channel = self.make_request("GET", self.path) self.assertEqual(404, channel.code, msg=channel.result["body"]) @@ -751,7 +754,7 @@ class RoomTopicTestCase(RoomBase): self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) - def test_rooms_topic_with_extra_keys(self): + def test_rooms_topic_with_extra_keys(self) -> None: # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' channel = self.make_request("PUT", self.path, content) @@ -768,10 +771,10 @@ class RoomMemberStateTestCase(RoomBase): user_id = "@sid1:red" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) - def test_invalid_puts(self): + def test_invalid_puts(self) -> None: path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json channel = self.make_request("PUT", path, "{}") @@ -801,7 +804,7 @@ class RoomMemberStateTestCase(RoomBase): channel = self.make_request("PUT", path, content.encode("ascii")) self.assertEqual(400, channel.code, msg=channel.result["body"]) - def test_rooms_members_self(self): + def test_rooms_members_self(self) -> None: path = "/rooms/%s/state/m.room.member/%s" % ( urlparse.quote(self.room_id), self.user_id, @@ -812,13 +815,13 @@ class RoomMemberStateTestCase(RoomBase): channel = self.make_request("PUT", path, content.encode("ascii")) self.assertEqual(200, channel.code, msg=channel.result["body"]) - channel = self.make_request("GET", path, None) + channel = self.make_request("GET", path, content=b"") self.assertEqual(200, channel.code, msg=channel.result["body"]) expected_response = {"membership": Membership.JOIN} self.assertEqual(expected_response, channel.json_body) - def test_rooms_members_other(self): + def test_rooms_members_other(self) -> None: self.other_id = "@zzsid1:red" path = "/rooms/%s/state/m.room.member/%s" % ( urlparse.quote(self.room_id), @@ -830,11 +833,11 @@ class RoomMemberStateTestCase(RoomBase): channel = self.make_request("PUT", path, content) self.assertEqual(200, channel.code, msg=channel.result["body"]) - channel = self.make_request("GET", path, None) + channel = self.make_request("GET", path, content=b"") self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(json.loads(content), channel.json_body) - def test_rooms_members_other_custom_keys(self): + def test_rooms_members_other_custom_keys(self) -> None: self.other_id = "@zzsid1:red" path = "/rooms/%s/state/m.room.member/%s" % ( urlparse.quote(self.room_id), @@ -849,7 +852,7 @@ class RoomMemberStateTestCase(RoomBase): channel = self.make_request("PUT", path, content) self.assertEqual(200, channel.code, msg=channel.result["body"]) - channel = self.make_request("GET", path, None) + channel = self.make_request("GET", path, content=b"") self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(json.loads(content), channel.json_body) @@ -866,7 +869,7 @@ class RoomInviteRatelimitTestCase(RoomBase): @unittest.override_config( {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}} ) - def test_invites_by_rooms_ratelimit(self): + def test_invites_by_rooms_ratelimit(self) -> None: """Tests that invites in a room are actually rate-limited.""" room_id = self.helper.create_room_as(self.user_id) @@ -878,7 +881,7 @@ class RoomInviteRatelimitTestCase(RoomBase): @unittest.override_config( {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} ) - def test_invites_by_users_ratelimit(self): + def test_invites_by_users_ratelimit(self) -> None: """Tests that invites to a specific user are actually rate-limited.""" for _ in range(3): @@ -897,7 +900,7 @@ class RoomJoinTestCase(RoomBase): room.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user1 = self.register_user("thomas", "hackme") self.tok1 = self.login("thomas", "hackme") @@ -908,7 +911,7 @@ class RoomJoinTestCase(RoomBase): self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) - def test_spam_checker_may_join_room(self): + def test_spam_checker_may_join_room(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly called and blocks room joins when needed. """ @@ -975,8 +978,8 @@ class RoomJoinRatelimitTestCase(RoomBase): room.register_servlets, ] - def prepare(self, reactor, clock, homeserver): - super().prepare(reactor, clock, homeserver) + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) # profile changes expect that the user is actually registered user = UserID.from_string(self.user_id) self.get_success(self.register_user(user.localpart, "supersecretpassword")) @@ -984,7 +987,7 @@ class RoomJoinRatelimitTestCase(RoomBase): @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} ) - def test_join_local_ratelimit(self): + def test_join_local_ratelimit(self) -> None: """Tests that local joins are actually rate-limited.""" for _ in range(3): self.helper.create_room_as(self.user_id) @@ -994,7 +997,7 @@ class RoomJoinRatelimitTestCase(RoomBase): @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} ) - def test_join_local_ratelimit_profile_change(self): + def test_join_local_ratelimit_profile_change(self) -> None: """Tests that sending a profile update into all of the user's joined rooms isn't rate-limited by the rate-limiter on joins.""" @@ -1031,7 +1034,7 @@ class RoomJoinRatelimitTestCase(RoomBase): @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} ) - def test_join_local_ratelimit_idempotent(self): + def test_join_local_ratelimit_idempotent(self) -> None: """Tests that the room join endpoints remain idempotent despite rate-limiting on room joins.""" room_id = self.helper.create_room_as(self.user_id) @@ -1056,7 +1059,7 @@ class RoomJoinRatelimitTestCase(RoomBase): "autocreate_auto_join_rooms": True, }, ) - def test_autojoin_rooms(self): + def test_autojoin_rooms(self) -> None: user_id = self.register_user("testuser", "password") # Check that the new user successfully joined the four rooms @@ -1071,10 +1074,10 @@ class RoomMessagesTestCase(RoomBase): user_id = "@sid1:red" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) - def test_invalid_puts(self): + def test_invalid_puts(self) -> None: path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json channel = self.make_request("PUT", path, b"{}") @@ -1095,7 +1098,7 @@ class RoomMessagesTestCase(RoomBase): channel = self.make_request("PUT", path, b"") self.assertEqual(400, channel.code, msg=channel.result["body"]) - def test_rooms_messages_sent(self): + def test_rooms_messages_sent(self) -> None: path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) content = b'{"body":"test","msgtype":{"type":"a"}}' @@ -1119,11 +1122,11 @@ class RoomInitialSyncTestCase(RoomBase): user_id = "@sid1:red" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # create the room self.room_id = self.helper.create_room_as(self.user_id) - def test_initial_sync(self): + def test_initial_sync(self) -> None: channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) self.assertEqual(200, channel.code) @@ -1131,7 +1134,7 @@ class RoomInitialSyncTestCase(RoomBase): self.assertEqual("join", channel.json_body["membership"]) # Room state is easier to assert on if we unpack it into a dict - state = {} + state: JsonDict = {} for event in channel.json_body["state"]: if "state_key" not in event: continue @@ -1160,10 +1163,10 @@ class RoomMessageListTestCase(RoomBase): user_id = "@sid1:red" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) - def test_topo_token_is_accepted(self): + def test_topo_token_is_accepted(self) -> None: token = "t1-0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) @@ -1174,7 +1177,7 @@ class RoomMessageListTestCase(RoomBase): self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) - def test_stream_token_is_accepted_for_fwd_pagianation(self): + def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None: token = "s0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) @@ -1185,7 +1188,7 @@ class RoomMessageListTestCase(RoomBase): self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) - def test_room_messages_purge(self): + def test_room_messages_purge(self) -> None: store = self.hs.get_datastores().main pagination_handler = self.hs.get_pagination_handler() @@ -1278,10 +1281,10 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): user_id = True hijack_auth = False - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Register the user who does the searching - self.user_id = self.register_user("user", "pass") + self.user_id2 = self.register_user("user", "pass") self.access_token = self.login("user", "pass") # Register the user who sends the message @@ -1289,12 +1292,12 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): self.other_access_token = self.login("otheruser", "pass") # Create a room - self.room = self.helper.create_room_as(self.user_id, tok=self.access_token) + self.room = self.helper.create_room_as(self.user_id2, tok=self.access_token) # Invite the other person self.helper.invite( room=self.room, - src=self.user_id, + src=self.user_id2, tok=self.access_token, targ=self.other_user_id, ) @@ -1304,7 +1307,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): room=self.room, user=self.other_user_id, tok=self.other_access_token ) - def test_finds_message(self): + def test_finds_message(self) -> None: """ The search functionality will search for content in messages if asked to do so. @@ -1333,7 +1336,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): # No context was requested, so we should get none. self.assertEqual(results["results"][0]["context"], {}) - def test_include_context(self): + def test_include_context(self) -> None: """ When event_context includes include_profile, profile information will be included in the search response. @@ -1379,7 +1382,7 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.url = b"/_matrix/client/r0/publicRooms" @@ -1389,11 +1392,11 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): return self.hs - def test_restricted_no_auth(self): + def test_restricted_no_auth(self) -> None: channel = self.make_request("GET", self.url) self.assertEqual(channel.code, 401, channel.result) - def test_restricted_auth(self): + def test_restricted_auth(self) -> None: self.register_user("user", "pass") tok = self.login("user", "pass") @@ -1412,19 +1415,19 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: return self.setup_test_homeserver(federation_client=Mock()) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.register_user("user", "pass") self.token = self.login("user", "pass") self.federation_client = hs.get_federation_client() - def test_simple(self): + def test_simple(self) -> None: "Simple test for searching rooms over federation" - self.federation_client.get_public_rooms.side_effect = ( - lambda *a, **k: defer.succeed({}) + self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined] + {} ) search_filter = {"generic_search_term": "foobar"} @@ -1437,7 +1440,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 200, channel.result) - self.federation_client.get_public_rooms.assert_called_once_with( + self.federation_client.get_public_rooms.assert_called_once_with( # type: ignore[attr-defined] "testserv", limit=100, since_token=None, @@ -1446,12 +1449,12 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): third_party_instance_id=None, ) - def test_fallback(self): + def test_fallback(self) -> None: "Test that searching public rooms over federation falls back if it gets a 404" # The `get_public_rooms` should be called again if the first call fails # with a 404, when using search filters. - self.federation_client.get_public_rooms.side_effect = ( + self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined] HttpResponseException(404, "Not Found", b""), defer.succeed({}), ) @@ -1466,7 +1469,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 200, channel.result) - self.federation_client.get_public_rooms.assert_has_calls( + self.federation_client.get_public_rooms.assert_has_calls( # type: ignore[attr-defined] [ call( "testserv", @@ -1497,14 +1500,14 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): profile.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["allow_per_room_profiles"] = False self.hs = self.setup_test_homeserver(config=config) return self.hs - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("test", "test") self.tok = self.login("test", "test") @@ -1522,7 +1525,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - def test_per_room_profile_forbidden(self): + def test_per_room_profile_forbidden(self) -> None: data = {"membership": "join", "displayname": "other test user"} request_data = json.dumps(data) channel = self.make_request( @@ -1557,7 +1560,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.creator = self.register_user("creator", "test") self.creator_tok = self.login("creator", "test") @@ -1566,7 +1569,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) - def test_join_reason(self): + def test_join_reason(self) -> None: reason = "hello" channel = self.make_request( "POST", @@ -1578,7 +1581,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self._check_for_reason(reason) - def test_leave_reason(self): + def test_leave_reason(self) -> None: self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) reason = "hello" @@ -1592,7 +1595,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self._check_for_reason(reason) - def test_kick_reason(self): + def test_kick_reason(self) -> None: self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) reason = "hello" @@ -1606,7 +1609,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self._check_for_reason(reason) - def test_ban_reason(self): + def test_ban_reason(self) -> None: self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) reason = "hello" @@ -1620,7 +1623,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self._check_for_reason(reason) - def test_unban_reason(self): + def test_unban_reason(self) -> None: reason = "hello" channel = self.make_request( "POST", @@ -1632,7 +1635,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self._check_for_reason(reason) - def test_invite_reason(self): + def test_invite_reason(self) -> None: reason = "hello" channel = self.make_request( "POST", @@ -1644,7 +1647,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self._check_for_reason(reason) - def test_reject_invite_reason(self): + def test_reject_invite_reason(self) -> None: self.helper.invite( self.room_id, src=self.creator, @@ -1663,7 +1666,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self._check_for_reason(reason) - def _check_for_reason(self, reason): + def _check_for_reason(self, reason: str) -> None: channel = self.make_request( "GET", "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format( @@ -1704,12 +1707,12 @@ class LabelsTestCase(unittest.HomeserverTestCase): "org.matrix.not_labels": ["#notfun"], } - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("test", "test") self.tok = self.login("test", "test") self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - def test_context_filter_labels(self): + def test_context_filter_labels(self) -> None: """Test that we can filter by a label on a /context request.""" event_id = self._send_labelled_messages_in_room() @@ -1739,7 +1742,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): events_after[0]["content"]["body"], "with right label", events_after[0] ) - def test_context_filter_not_labels(self): + def test_context_filter_not_labels(self) -> None: """Test that we can filter by the absence of a label on a /context request.""" event_id = self._send_labelled_messages_in_room() @@ -1772,7 +1775,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): events_after[1]["content"]["body"], "with two wrong labels", events_after[1] ) - def test_context_filter_labels_not_labels(self): + def test_context_filter_labels_not_labels(self) -> None: """Test that we can filter by both a label and the absence of another label on a /context request. """ @@ -1801,7 +1804,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): events_after[0]["content"]["body"], "with wrong label", events_after[0] ) - def test_messages_filter_labels(self): + def test_messages_filter_labels(self) -> None: """Test that we can filter by a label on a /messages request.""" self._send_labelled_messages_in_room() @@ -1818,7 +1821,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) - def test_messages_filter_not_labels(self): + def test_messages_filter_not_labels(self) -> None: """Test that we can filter by the absence of a label on a /messages request.""" self._send_labelled_messages_in_room() @@ -1839,7 +1842,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): events[3]["content"]["body"], "with two wrong labels", events[3] ) - def test_messages_filter_labels_not_labels(self): + def test_messages_filter_labels_not_labels(self) -> None: """Test that we can filter by both a label and the absence of another label on a /messages request. """ @@ -1862,7 +1865,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): self.assertEqual(len(events), 1, [event["content"] for event in events]) self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) - def test_search_filter_labels(self): + def test_search_filter_labels(self) -> None: """Test that we can filter by a label on a /search request.""" request_data = json.dumps( { @@ -1899,7 +1902,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): results[1]["result"]["content"]["body"], ) - def test_search_filter_not_labels(self): + def test_search_filter_not_labels(self) -> None: """Test that we can filter by the absence of a label on a /search request.""" request_data = json.dumps( { @@ -1946,7 +1949,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): results[3]["result"]["content"]["body"], ) - def test_search_filter_labels_not_labels(self): + def test_search_filter_labels_not_labels(self) -> None: """Test that we can filter by both a label and the absence of another label on a /search request. """ @@ -1980,7 +1983,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): results[0]["result"]["content"]["body"], ) - def _send_labelled_messages_in_room(self): + def _send_labelled_messages_in_room(self) -> str: """Sends several messages to a room with different labels (or without any) to test filtering by label. Returns: @@ -2056,12 +2059,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["experimental_features"] = {"msc3440_enabled": True} return config - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("test", "test") self.tok = self.login("test", "test") self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) @@ -2136,7 +2139,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): return channel.json_body["chunk"] - def test_filter_relation_senders(self): + def test_filter_relation_senders(self) -> None: # Messages which second user reacted to. filter = {"io.element.relation_senders": [self.second_user_id]} chunk = self._filter_messages(filter) @@ -2159,7 +2162,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] ) - def test_filter_relation_type(self): + def test_filter_relation_type(self) -> None: # Messages which have annotations. filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) @@ -2185,7 +2188,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] ) - def test_filter_relation_senders_and_type(self): + def test_filter_relation_senders_and_type(self) -> None: # Messages which second user reacted to. filter = { "io.element.relation_senders": [self.second_user_id], @@ -2205,7 +2208,7 @@ class ContextTestCase(unittest.HomeserverTestCase): account.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("user", "password") self.tok = self.login("user", "password") self.room_id = self.helper.create_room_as( @@ -2218,7 +2221,7 @@ class ContextTestCase(unittest.HomeserverTestCase): self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok) self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok) - def test_erased_sender(self): + def test_erased_sender(self) -> None: """Test that an erasure request results in the requester's events being hidden from any new member of the room. """ @@ -2332,7 +2335,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_owner = self.register_user("room_owner", "test") self.room_owner_tok = self.login("room_owner", "test") @@ -2340,17 +2343,17 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): self.room_owner, tok=self.room_owner_tok ) - def test_no_aliases(self): + def test_no_aliases(self) -> None: res = self._get_aliases(self.room_owner_tok) self.assertEqual(res["aliases"], []) - def test_not_in_room(self): + def test_not_in_room(self) -> None: self.register_user("user", "test") user_tok = self.login("user", "test") res = self._get_aliases(user_tok, expected_code=403) self.assertEqual(res["errcode"], "M_FORBIDDEN") - def test_admin_user(self): + def test_admin_user(self) -> None: alias1 = self._random_alias() self._set_alias_via_directory(alias1) @@ -2360,7 +2363,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): res = self._get_aliases(user_tok) self.assertEqual(res["aliases"], [alias1]) - def test_with_aliases(self): + def test_with_aliases(self) -> None: alias1 = self._random_alias() alias2 = self._random_alias() @@ -2370,7 +2373,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): res = self._get_aliases(self.room_owner_tok) self.assertEqual(set(res["aliases"]), {alias1, alias2}) - def test_peekable_room(self): + def test_peekable_room(self) -> None: alias1 = self._random_alias() self._set_alias_via_directory(alias1) @@ -2404,7 +2407,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): def _random_alias(self) -> str: return RoomAlias(random_string(5), self.hs.hostname).to_string() - def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None: url = "/_matrix/client/r0/directory/room/" + alias data = {"room_id": self.room_id} request_data = json.dumps(data) @@ -2423,7 +2426,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_owner = self.register_user("room_owner", "test") self.room_owner_tok = self.login("room_owner", "test") @@ -2434,7 +2437,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): self.alias = "#alias:test" self._set_alias_via_directory(self.alias) - def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None: url = "/_matrix/client/r0/directory/room/" + alias data = {"room_id": self.room_id} request_data = json.dumps(data) @@ -2456,7 +2459,9 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): self.assertIsInstance(res, dict) return res - def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict: + def _set_canonical_alias( + self, content: JsonDict, expected_code: int = 200 + ) -> JsonDict: """Calls the endpoint under test. returns the json response object.""" channel = self.make_request( "PUT", @@ -2469,7 +2474,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): self.assertIsInstance(res, dict) return res - def test_canonical_alias(self): + def test_canonical_alias(self) -> None: """Test a basic alias message.""" # There is no canonical alias to start with. self._get_canonical_alias(expected_code=404) @@ -2488,7 +2493,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): res = self._get_canonical_alias() self.assertEqual(res, {}) - def test_alt_aliases(self): + def test_alt_aliases(self) -> None: """Test a canonical alias message with alt_aliases.""" # Create an alias. self._set_canonical_alias({"alt_aliases": [self.alias]}) @@ -2504,7 +2509,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): res = self._get_canonical_alias() self.assertEqual(res, {}) - def test_alias_alt_aliases(self): + def test_alias_alt_aliases(self) -> None: """Test a canonical alias message with an alias and alt_aliases.""" # Create an alias. self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) @@ -2520,7 +2525,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): res = self._get_canonical_alias() self.assertEqual(res, {}) - def test_partial_modify(self): + def test_partial_modify(self) -> None: """Test removing only the alt_aliases.""" # Create an alias. self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) @@ -2536,7 +2541,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): res = self._get_canonical_alias() self.assertEqual(res, {"alias": self.alias}) - def test_add_alias(self): + def test_add_alias(self) -> None: """Test removing only the alt_aliases.""" # Create an additional alias. second_alias = "#second:test" @@ -2556,7 +2561,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} ) - def test_bad_data(self): + def test_bad_data(self) -> None: """Invalid data for alt_aliases should cause errors.""" self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400) self._set_canonical_alias({"alt_aliases": None}, expected_code=400) @@ -2566,7 +2571,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): self._set_canonical_alias({"alt_aliases": True}, expected_code=400) self._set_canonical_alias({"alt_aliases": {}}, expected_code=400) - def test_bad_alias(self): + def test_bad_alias(self) -> None: """An alias which does not point to the room raises a SynapseError.""" self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) @@ -2580,13 +2585,13 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("thomas", "hackme") self.tok = self.login("thomas", "hackme") self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - def test_threepid_invite_spamcheck(self): + def test_threepid_invite_spamcheck(self) -> None: # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we # can check its call_count later on during the test. diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index bfc04785b7..58f1ea11b7 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -12,16 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. import threading -from typing import TYPE_CHECKING, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes, LoginType, Membership from synapse.api.errors import SynapseError +from synapse.api.room_versions import RoomVersion from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.rest import admin from synapse.rest.client import account, login, profile, room +from synapse.server import HomeServer from synapse.types import JsonDict, Requester, StateMap +from synapse.util import Clock from synapse.util.frozenutils import unfreeze from tests import unittest @@ -34,7 +40,7 @@ thread_local = threading.local() class LegacyThirdPartyRulesTestModule: - def __init__(self, config: Dict, module_api: "ModuleApi"): + def __init__(self, config: Dict, module_api: "ModuleApi") -> None: # keep a record of the "current" rules module, so that the test can patch # it if desired. thread_local.rules_module = self @@ -42,32 +48,36 @@ class LegacyThirdPartyRulesTestModule: async def on_create_room( self, requester: Requester, config: dict, is_requester_admin: bool - ): + ) -> bool: return True - async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]): + async def check_event_allowed( + self, event: EventBase, state: StateMap[EventBase] + ) -> Union[bool, dict]: return True @staticmethod - def parse_config(config): + def parse_config(config: Dict[str, Any]) -> Dict[str, Any]: return config class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule): - def __init__(self, config: Dict, module_api: "ModuleApi"): + def __init__(self, config: Dict, module_api: "ModuleApi") -> None: super().__init__(config, module_api) - def on_create_room( + async def on_create_room( self, requester: Requester, config: dict, is_requester_admin: bool - ): + ) -> bool: return False class LegacyChangeEvents(LegacyThirdPartyRulesTestModule): - def __init__(self, config: Dict, module_api: "ModuleApi"): + def __init__(self, config: Dict, module_api: "ModuleApi") -> None: super().__init__(config, module_api) - async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]): + async def check_event_allowed( + self, event: EventBase, state: StateMap[EventBase] + ) -> JsonDict: d = event.get_dict() content = unfreeze(event.content) content["foo"] = "bar" @@ -84,7 +94,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): account.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver() load_legacy_third_party_event_rules(hs) @@ -94,22 +104,30 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): # Note that these checks are not relevant to this test case. # Have this homeserver auto-approve all event signature checking. - async def approve_all_signature_checking(_, pdu): + async def approve_all_signature_checking( + _: RoomVersion, pdu: EventBase + ) -> EventBase: return pdu - hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking + hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking # type: ignore[assignment] # Have this homeserver skip event auth checks. This is necessary due to # event auth checks ensuring that events were signed by the sender's homeserver. - async def _check_event_auth(origin, event, context, *args, **kwargs): + async def _check_event_auth( + origin: str, + event: EventBase, + context: EventContext, + *args: Any, + **kwargs: Any, + ) -> EventContext: return context - hs.get_federation_event_handler()._check_event_auth = _check_event_auth + hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment] return hs - def prepare(self, reactor, clock, homeserver): - super().prepare(reactor, clock, homeserver) + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) # Create some users and a room to play with during the tests self.user_id = self.register_user("kermit", "monkey") self.invitee = self.register_user("invitee", "hackme") @@ -121,13 +139,15 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): except Exception: pass - def test_third_party_rules(self): + def test_third_party_rules(self) -> None: """Tests that a forbidden event is forbidden from being sent, but an allowed one can be sent. """ # patch the rules module with a Mock which will return False for some event # types - async def check(ev, state): + async def check( + ev: EventBase, state: StateMap[EventBase] + ) -> Tuple[bool, Optional[JsonDict]]: return ev.type != "foo.bar.forbidden", None callback = Mock(spec=[], side_effect=check) @@ -161,7 +181,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): ) self.assertEqual(channel.result["code"], b"403", channel.result) - def test_third_party_rules_workaround_synapse_errors_pass_through(self): + def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None: """ Tests that the workaround introduced by https://github.com/matrix-org/synapse/pull/11042 is functional: that SynapseErrors are passed through from check_event_allowed @@ -172,7 +192,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): """ class NastyHackException(SynapseError): - def error_dict(self): + def error_dict(self) -> JsonDict: """ This overrides SynapseError's `error_dict` to nastily inject JSON into the error response. @@ -182,7 +202,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): return result # add a callback that will raise our hacky exception - async def check(ev, state) -> Tuple[bool, Optional[JsonDict]]: + async def check( + ev: EventBase, state: StateMap[EventBase] + ) -> Tuple[bool, Optional[JsonDict]]: raise NastyHackException(429, "message") self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check] @@ -202,11 +224,13 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): {"errcode": "M_UNKNOWN", "error": "message", "nasty": "very"}, ) - def test_cannot_modify_event(self): + def test_cannot_modify_event(self) -> None: """cannot accidentally modify an event before it is persisted""" # first patch the event checker so that it will try to modify the event - async def check(ev: EventBase, state): + async def check( + ev: EventBase, state: StateMap[EventBase] + ) -> Tuple[bool, Optional[JsonDict]]: ev.content = {"x": "y"} return True, None @@ -223,10 +247,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): # 500 Internal Server Error self.assertEqual(channel.code, 500, channel.result) - def test_modify_event(self): + def test_modify_event(self) -> None: """The module can return a modified version of the event""" # first patch the event checker so that it will modify the event - async def check(ev: EventBase, state): + async def check( + ev: EventBase, state: StateMap[EventBase] + ) -> Tuple[bool, Optional[JsonDict]]: d = ev.get_dict() d["content"] = {"x": "y"} return True, d @@ -253,10 +279,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): ev = channel.json_body self.assertEqual(ev["content"]["x"], "y") - def test_message_edit(self): + def test_message_edit(self) -> None: """Ensure that the module doesn't cause issues with edited messages.""" # first patch the event checker so that it will modify the event - async def check(ev: EventBase, state): + async def check( + ev: EventBase, state: StateMap[EventBase] + ) -> Tuple[bool, Optional[JsonDict]]: d = ev.get_dict() d["content"] = { "msgtype": "m.text", @@ -315,7 +343,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): ev = channel.json_body self.assertEqual(ev["content"]["body"], "EDITED BODY") - def test_send_event(self): + def test_send_event(self) -> None: """Tests that a module can send an event into a room via the module api""" content = { "msgtype": "m.text", @@ -344,7 +372,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): } } ) - def test_legacy_check_event_allowed(self): + def test_legacy_check_event_allowed(self) -> None: """Tests that the wrapper for legacy check_event_allowed callbacks works correctly. """ @@ -379,13 +407,13 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): } } ) - def test_legacy_on_create_room(self): + def test_legacy_on_create_room(self) -> None: """Tests that the wrapper for legacy on_create_room callbacks works correctly. """ self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403) - def test_sent_event_end_up_in_room_state(self): + def test_sent_event_end_up_in_room_state(self) -> None: """Tests that a state event sent by a module while processing another state event doesn't get dropped from the state of the room. This is to guard against a bug where Synapse has been observed doing so, see https://github.com/matrix-org/synapse/issues/10830 @@ -400,7 +428,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): api = self.hs.get_module_api() # Define a callback that sends a custom event on power levels update. - async def test_fn(event: EventBase, state_events): + async def test_fn( + event: EventBase, state_events: StateMap[EventBase] + ) -> Tuple[bool, Optional[JsonDict]]: if event.is_state and event.type == EventTypes.PowerLevels: await api.create_and_send_event_into_room( { @@ -436,7 +466,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["i"], i) - def test_on_new_event(self): + def test_on_new_event(self) -> None: """Test that the on_new_event callback is called on new events""" on_new_event = Mock(make_awaitable(None)) self.hs.get_third_party_event_rules()._on_new_event_callbacks.append( @@ -501,7 +531,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) - def _update_power_levels(self, event_default: int = 0): + def _update_power_levels(self, event_default: int = 0) -> None: """Updates the room's power levels. Args: @@ -533,7 +563,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): tok=self.tok, ) - def test_on_profile_update(self): + def test_on_profile_update(self) -> None: """Tests that the on_profile_update module callback is correctly called on profile updates. """ @@ -592,7 +622,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual(profile_info.display_name, displayname) self.assertEqual(profile_info.avatar_url, avatar_url) - def test_on_profile_update_admin(self): + def test_on_profile_update_admin(self) -> None: """Tests that the on_profile_update module callback is correctly called on profile updates triggered by a server admin. """ @@ -634,7 +664,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual(profile_info.display_name, displayname) self.assertEqual(profile_info.avatar_url, avatar_url) - def test_on_user_deactivation_status_changed(self): + def test_on_user_deactivation_status_changed(self) -> None: """Tests that the on_user_deactivation_status_changed module callback is called correctly when processing a user's deactivation. """ @@ -691,7 +721,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): args = profile_mock.call_args[0] self.assertTrue(args[3]) - def test_on_user_deactivation_status_changed_admin(self): + def test_on_user_deactivation_status_changed_admin(self) -> None: """Tests that the on_user_deactivation_status_changed module callback is called correctly when processing a user's deactivation triggered by a server admin as well as a reactivation. diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py index 8b2da88e8a..43be711a64 100644 --- a/tests/rest/client/test_typing.py +++ b/tests/rest/client/test_typing.py @@ -14,11 +14,16 @@ # limitations under the License. """Tests REST events for /rooms paths.""" - +from typing import Any from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.rest.client import room +from synapse.server import HomeServer +from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import UserID +from synapse.util import Clock from tests import unittest @@ -33,7 +38,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): user = UserID.from_string(user_id) servlets = [room.register_servlets] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver( "red", @@ -43,30 +48,34 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): self.event_source = hs.get_event_sources().sources.typing - hs.get_federation_handler = Mock() + hs.get_federation_handler = Mock() # type: ignore[assignment] - async def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } + async def get_user_by_access_token( + token: str, + rights: str = "access", + allow_expired: bool = False, + ) -> TokenLookupResult: + return TokenLookupResult( + user_id=self.user_id, + is_guest=False, + token_id=1, + ) - hs.get_auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment] - async def _insert_client_ip(*args, **kwargs): + async def _insert_client_ip(*args: Any, **kwargs: Any) -> None: return None - hs.get_datastores().main.insert_client_ip = _insert_client_ip + hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[assignment] return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) # Need another user to make notifications actually work self.helper.join(self.room_id, user="@jim:red") - def test_set_typing(self): + def test_set_typing(self) -> None: channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), @@ -95,7 +104,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): ], ) - def test_set_not_typing(self): + def test_set_not_typing(self) -> None: channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), @@ -103,7 +112,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): ) self.assertEqual(200, channel.code) - def test_typing_timeout(self): + def test_typing_timeout(self) -> None: channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), -- cgit 1.4.1 From 106959b3cf1a59ab5469db639223b6a5b84fb7d7 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 2 Mar 2022 17:24:52 +0000 Subject: Remove unused mocks from `test_typing` (#12136) * Remove unused mocks from `test_typing` It's not clear what these do. `get_user_by_access_token` has the wrong signature, including the return type. Tests all pass without these. I think we should nuke them. * Changelog * Fixup imports --- changelog.d/12136.misc | 1 + tests/rest/client/test_typing.py | 32 +------------------------------- 2 files changed, 2 insertions(+), 31 deletions(-) create mode 100644 changelog.d/12136.misc diff --git a/changelog.d/12136.misc b/changelog.d/12136.misc new file mode 100644 index 0000000000..98b1c1c9d8 --- /dev/null +++ b/changelog.d/12136.misc @@ -0,0 +1 @@ +Remove unused mocks from `test_typing`. \ No newline at end of file diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py index 43be711a64..d6da510773 100644 --- a/tests/rest/client/test_typing.py +++ b/tests/rest/client/test_typing.py @@ -14,14 +14,11 @@ # limitations under the License. """Tests REST events for /rooms paths.""" -from typing import Any -from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor from synapse.rest.client import room from synapse.server import HomeServer -from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import UserID from synapse.util import Clock @@ -39,35 +36,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): servlets = [room.register_servlets] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - - hs = self.setup_test_homeserver( - "red", - federation_http_client=None, - federation_client=Mock(), - ) - + hs = self.setup_test_homeserver("red") self.event_source = hs.get_event_sources().sources.typing - - hs.get_federation_handler = Mock() # type: ignore[assignment] - - async def get_user_by_access_token( - token: str, - rights: str = "access", - allow_expired: bool = False, - ) -> TokenLookupResult: - return TokenLookupResult( - user_id=self.user_id, - is_guest=False, - token_id=1, - ) - - hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment] - - async def _insert_client_ip(*args: Any, **kwargs: Any) -> None: - return None - - hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[assignment] - return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: -- cgit 1.4.1 From 1fbe0316a991e77289d4577b16ff3fcd27c26dc8 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 2 Mar 2022 18:00:26 +0000 Subject: Add suffices to scripts in scripts-dev (#12137) * Rename scripts-dev to have suffices * Update references to `scripts-dev` * Changelog * These scripts don't pass mypy --- .github/workflows/release-artifacts.yml | 4 +- .github/workflows/tests.yml | 4 +- changelog.d/12137.misc | 1 + docs/code_style.md | 2 +- mypy.ini | 12 +- scripts-dev/build_debian_packages | 217 -------------------------------- scripts-dev/build_debian_packages.py | 217 ++++++++++++++++++++++++++++++++ scripts-dev/check-newsfragment | 62 --------- scripts-dev/check-newsfragment.sh | 62 +++++++++ scripts-dev/generate_sample_config | 28 ----- scripts-dev/generate_sample_config.sh | 28 +++++ scripts-dev/lint.sh | 2 - scripts-dev/sign_json | 166 ------------------------ scripts-dev/sign_json.py | 166 ++++++++++++++++++++++++ tox.ini | 2 - 15 files changed, 490 insertions(+), 483 deletions(-) create mode 100644 changelog.d/12137.misc delete mode 100755 scripts-dev/build_debian_packages create mode 100755 scripts-dev/build_debian_packages.py delete mode 100755 scripts-dev/check-newsfragment create mode 100755 scripts-dev/check-newsfragment.sh delete mode 100755 scripts-dev/generate_sample_config create mode 100755 scripts-dev/generate_sample_config.sh delete mode 100755 scripts-dev/sign_json create mode 100755 scripts-dev/sign_json.py diff --git a/.github/workflows/release-artifacts.yml b/.github/workflows/release-artifacts.yml index eee3633d50..65ea761ad7 100644 --- a/.github/workflows/release-artifacts.yml +++ b/.github/workflows/release-artifacts.yml @@ -31,7 +31,7 @@ jobs: # if we're running from a tag, get the full list of distros; otherwise just use debian:sid dists='["debian:sid"]' if [[ $GITHUB_REF == refs/tags/* ]]; then - dists=$(scripts-dev/build_debian_packages --show-dists-json) + dists=$(scripts-dev/build_debian_packages.py --show-dists-json) fi echo "::set-output name=distros::$dists" # map the step outputs to job outputs @@ -74,7 +74,7 @@ jobs: # see https://github.com/docker/build-push-action/issues/252 # for the cache magic here run: | - ./src/scripts-dev/build_debian_packages \ + ./src/scripts-dev/build_debian_packages.py \ --docker-build-arg=--cache-from=type=local,src=/tmp/.buildx-cache \ --docker-build-arg=--cache-to=type=local,mode=max,dest=/tmp/.buildx-cache-new \ --docker-build-arg=--progress=plain \ diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e9e4277322..3f4e44ca59 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,7 +16,7 @@ jobs: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 - run: pip install -e . - - run: scripts-dev/generate_sample_config --check + - run: scripts-dev/generate_sample_config.sh --check lint: runs-on: ubuntu-latest @@ -51,7 +51,7 @@ jobs: fetch-depth: 0 - uses: actions/setup-python@v2 - run: "pip install 'towncrier>=18.6.0rc1'" - - run: scripts-dev/check-newsfragment + - run: scripts-dev/check-newsfragment.sh env: PULL_REQUEST_NUMBER: ${{ github.event.number }} diff --git a/changelog.d/12137.misc b/changelog.d/12137.misc new file mode 100644 index 0000000000..118ff77a91 --- /dev/null +++ b/changelog.d/12137.misc @@ -0,0 +1 @@ +Give `scripts-dev` scripts suffixes for neater CI config. \ No newline at end of file diff --git a/docs/code_style.md b/docs/code_style.md index 4d8e7c973d..e7c9cd1a5e 100644 --- a/docs/code_style.md +++ b/docs/code_style.md @@ -172,6 +172,6 @@ frobber: ``` Note that the sample configuration is generated from the synapse code -and is maintained by a script, `scripts-dev/generate_sample_config`. +and is maintained by a script, `scripts-dev/generate_sample_config.sh`. Making sure that the output from this script matches the desired format is left as an exercise for the reader! diff --git a/mypy.ini b/mypy.ini index 23ca4eaa5a..10971b7225 100644 --- a/mypy.ini +++ b/mypy.ini @@ -11,7 +11,7 @@ local_partial_types = True no_implicit_optional = True files = - scripts-dev/sign_json, + scripts-dev/, setup.py, synapse/, tests/ @@ -23,10 +23,20 @@ files = # https://docs.python.org/3/library/re.html#re.X exclude = (?x) ^( + |scripts-dev/build_debian_packages.py + |scripts-dev/check_signature.py + |scripts-dev/definitions.py + |scripts-dev/federation_client.py + |scripts-dev/hash_history.py + |scripts-dev/list_url_patterns.py + |scripts-dev/release.py + |scripts-dev/tail-synapse.py + |synapse/_scripts/export_signing_key.py |synapse/_scripts/move_remote_media_to_new_store.py |synapse/_scripts/synapse_port_db.py |synapse/_scripts/update_synapse_database.py + |synapse/storage/databases/__init__.py |synapse/storage/databases/main/__init__.py |synapse/storage/databases/main/cache.py diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages deleted file mode 100755 index 7ff96a1ee6..0000000000 --- a/scripts-dev/build_debian_packages +++ /dev/null @@ -1,217 +0,0 @@ -#!/usr/bin/env python3 - -# Build the Debian packages using Docker images. -# -# This script builds the Docker images and then executes them sequentially, each -# one building a Debian package for the targeted operating system. It is -# designed to be a "single command" to produce all the images. -# -# By default, builds for all known distributions, but a list of distributions -# can be passed on the commandline for debugging. - -import argparse -import json -import os -import signal -import subprocess -import sys -import threading -from concurrent.futures import ThreadPoolExecutor -from typing import Optional, Sequence - -DISTS = ( - "debian:buster", # oldstable: EOL 2022-08 - "debian:bullseye", - "debian:bookworm", - "debian:sid", - "ubuntu:focal", # 20.04 LTS (our EOL forced by Py38 on 2024-10-14) - "ubuntu:impish", # 21.10 (EOL 2022-07) -) - -DESC = """\ -Builds .debs for synapse, using a Docker image for the build environment. - -By default, builds for all known distributions, but a list of distributions -can be passed on the commandline for debugging. -""" - -projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) - - -class Builder(object): - def __init__( - self, redirect_stdout=False, docker_build_args: Optional[Sequence[str]] = None - ): - self.redirect_stdout = redirect_stdout - self._docker_build_args = tuple(docker_build_args or ()) - self.active_containers = set() - self._lock = threading.Lock() - self._failed = False - - def run_build(self, dist, skip_tests=False): - """Build deb for a single distribution""" - - if self._failed: - print("not building %s due to earlier failure" % (dist,)) - raise Exception("failed") - - try: - self._inner_build(dist, skip_tests) - except Exception as e: - print("build of %s failed: %s" % (dist, e), file=sys.stderr) - self._failed = True - raise - - def _inner_build(self, dist, skip_tests=False): - tag = dist.split(":", 1)[1] - - # Make the dir where the debs will live. - # - # Note that we deliberately put this outside the source tree, otherwise - # we tend to get source packages which are full of debs. (We could hack - # around that with more magic in the build_debian.sh script, but that - # doesn't solve the problem for natively-run dpkg-buildpakage). - debsdir = os.path.join(projdir, "../debs") - os.makedirs(debsdir, exist_ok=True) - - if self.redirect_stdout: - logfile = os.path.join(debsdir, "%s.buildlog" % (tag,)) - print("building %s: directing output to %s" % (dist, logfile)) - stdout = open(logfile, "w") - else: - stdout = None - - # first build a docker image for the build environment - build_args = ( - ( - "docker", - "build", - "--tag", - "dh-venv-builder:" + tag, - "--build-arg", - "distro=" + dist, - "-f", - "docker/Dockerfile-dhvirtualenv", - ) - + self._docker_build_args - + ("docker",) - ) - - subprocess.check_call( - build_args, - stdout=stdout, - stderr=subprocess.STDOUT, - cwd=projdir, - ) - - container_name = "synapse_build_" + tag - with self._lock: - self.active_containers.add(container_name) - - # then run the build itself - subprocess.check_call( - [ - "docker", - "run", - "--rm", - "--name", - container_name, - "--volume=" + projdir + ":/synapse/source:ro", - "--volume=" + debsdir + ":/debs", - "-e", - "TARGET_USERID=%i" % (os.getuid(),), - "-e", - "TARGET_GROUPID=%i" % (os.getgid(),), - "-e", - "DEB_BUILD_OPTIONS=%s" % ("nocheck" if skip_tests else ""), - "dh-venv-builder:" + tag, - ], - stdout=stdout, - stderr=subprocess.STDOUT, - ) - - with self._lock: - self.active_containers.remove(container_name) - - if stdout is not None: - stdout.close() - print("Completed build of %s" % (dist,)) - - def kill_containers(self): - with self._lock: - active = list(self.active_containers) - - for c in active: - print("killing container %s" % (c,)) - subprocess.run( - [ - "docker", - "kill", - c, - ], - stdout=subprocess.DEVNULL, - ) - with self._lock: - self.active_containers.remove(c) - - -def run_builds(builder, dists, jobs=1, skip_tests=False): - def sig(signum, _frame): - print("Caught SIGINT") - builder.kill_containers() - - signal.signal(signal.SIGINT, sig) - - with ThreadPoolExecutor(max_workers=jobs) as e: - res = e.map(lambda dist: builder.run_build(dist, skip_tests), dists) - - # make sure we consume the iterable so that exceptions are raised. - for _ in res: - pass - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description=DESC, - ) - parser.add_argument( - "-j", - "--jobs", - type=int, - default=1, - help="specify the number of builds to run in parallel", - ) - parser.add_argument( - "--no-check", - action="store_true", - help="skip running tests after building", - ) - parser.add_argument( - "--docker-build-arg", - action="append", - help="specify an argument to pass to docker build", - ) - parser.add_argument( - "--show-dists-json", - action="store_true", - help="instead of building the packages, just list the dists to build for, as a json array", - ) - parser.add_argument( - "dist", - nargs="*", - default=DISTS, - help="a list of distributions to build for. Default: %(default)s", - ) - args = parser.parse_args() - if args.show_dists_json: - print(json.dumps(DISTS)) - else: - builder = Builder( - redirect_stdout=(args.jobs > 1), docker_build_args=args.docker_build_arg - ) - run_builds( - builder, - dists=args.dist, - jobs=args.jobs, - skip_tests=args.no_check, - ) diff --git a/scripts-dev/build_debian_packages.py b/scripts-dev/build_debian_packages.py new file mode 100755 index 0000000000..7ff96a1ee6 --- /dev/null +++ b/scripts-dev/build_debian_packages.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 + +# Build the Debian packages using Docker images. +# +# This script builds the Docker images and then executes them sequentially, each +# one building a Debian package for the targeted operating system. It is +# designed to be a "single command" to produce all the images. +# +# By default, builds for all known distributions, but a list of distributions +# can be passed on the commandline for debugging. + +import argparse +import json +import os +import signal +import subprocess +import sys +import threading +from concurrent.futures import ThreadPoolExecutor +from typing import Optional, Sequence + +DISTS = ( + "debian:buster", # oldstable: EOL 2022-08 + "debian:bullseye", + "debian:bookworm", + "debian:sid", + "ubuntu:focal", # 20.04 LTS (our EOL forced by Py38 on 2024-10-14) + "ubuntu:impish", # 21.10 (EOL 2022-07) +) + +DESC = """\ +Builds .debs for synapse, using a Docker image for the build environment. + +By default, builds for all known distributions, but a list of distributions +can be passed on the commandline for debugging. +""" + +projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + + +class Builder(object): + def __init__( + self, redirect_stdout=False, docker_build_args: Optional[Sequence[str]] = None + ): + self.redirect_stdout = redirect_stdout + self._docker_build_args = tuple(docker_build_args or ()) + self.active_containers = set() + self._lock = threading.Lock() + self._failed = False + + def run_build(self, dist, skip_tests=False): + """Build deb for a single distribution""" + + if self._failed: + print("not building %s due to earlier failure" % (dist,)) + raise Exception("failed") + + try: + self._inner_build(dist, skip_tests) + except Exception as e: + print("build of %s failed: %s" % (dist, e), file=sys.stderr) + self._failed = True + raise + + def _inner_build(self, dist, skip_tests=False): + tag = dist.split(":", 1)[1] + + # Make the dir where the debs will live. + # + # Note that we deliberately put this outside the source tree, otherwise + # we tend to get source packages which are full of debs. (We could hack + # around that with more magic in the build_debian.sh script, but that + # doesn't solve the problem for natively-run dpkg-buildpakage). + debsdir = os.path.join(projdir, "../debs") + os.makedirs(debsdir, exist_ok=True) + + if self.redirect_stdout: + logfile = os.path.join(debsdir, "%s.buildlog" % (tag,)) + print("building %s: directing output to %s" % (dist, logfile)) + stdout = open(logfile, "w") + else: + stdout = None + + # first build a docker image for the build environment + build_args = ( + ( + "docker", + "build", + "--tag", + "dh-venv-builder:" + tag, + "--build-arg", + "distro=" + dist, + "-f", + "docker/Dockerfile-dhvirtualenv", + ) + + self._docker_build_args + + ("docker",) + ) + + subprocess.check_call( + build_args, + stdout=stdout, + stderr=subprocess.STDOUT, + cwd=projdir, + ) + + container_name = "synapse_build_" + tag + with self._lock: + self.active_containers.add(container_name) + + # then run the build itself + subprocess.check_call( + [ + "docker", + "run", + "--rm", + "--name", + container_name, + "--volume=" + projdir + ":/synapse/source:ro", + "--volume=" + debsdir + ":/debs", + "-e", + "TARGET_USERID=%i" % (os.getuid(),), + "-e", + "TARGET_GROUPID=%i" % (os.getgid(),), + "-e", + "DEB_BUILD_OPTIONS=%s" % ("nocheck" if skip_tests else ""), + "dh-venv-builder:" + tag, + ], + stdout=stdout, + stderr=subprocess.STDOUT, + ) + + with self._lock: + self.active_containers.remove(container_name) + + if stdout is not None: + stdout.close() + print("Completed build of %s" % (dist,)) + + def kill_containers(self): + with self._lock: + active = list(self.active_containers) + + for c in active: + print("killing container %s" % (c,)) + subprocess.run( + [ + "docker", + "kill", + c, + ], + stdout=subprocess.DEVNULL, + ) + with self._lock: + self.active_containers.remove(c) + + +def run_builds(builder, dists, jobs=1, skip_tests=False): + def sig(signum, _frame): + print("Caught SIGINT") + builder.kill_containers() + + signal.signal(signal.SIGINT, sig) + + with ThreadPoolExecutor(max_workers=jobs) as e: + res = e.map(lambda dist: builder.run_build(dist, skip_tests), dists) + + # make sure we consume the iterable so that exceptions are raised. + for _ in res: + pass + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=DESC, + ) + parser.add_argument( + "-j", + "--jobs", + type=int, + default=1, + help="specify the number of builds to run in parallel", + ) + parser.add_argument( + "--no-check", + action="store_true", + help="skip running tests after building", + ) + parser.add_argument( + "--docker-build-arg", + action="append", + help="specify an argument to pass to docker build", + ) + parser.add_argument( + "--show-dists-json", + action="store_true", + help="instead of building the packages, just list the dists to build for, as a json array", + ) + parser.add_argument( + "dist", + nargs="*", + default=DISTS, + help="a list of distributions to build for. Default: %(default)s", + ) + args = parser.parse_args() + if args.show_dists_json: + print(json.dumps(DISTS)) + else: + builder = Builder( + redirect_stdout=(args.jobs > 1), docker_build_args=args.docker_build_arg + ) + run_builds( + builder, + dists=args.dist, + jobs=args.jobs, + skip_tests=args.no_check, + ) diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment deleted file mode 100755 index 493558ad65..0000000000 --- a/scripts-dev/check-newsfragment +++ /dev/null @@ -1,62 +0,0 @@ -#!/usr/bin/env bash -# -# A script which checks that an appropriate news file has been added on this -# branch. - -echo -e "+++ \033[32mChecking newsfragment\033[m" - -set -e - -# make sure that origin/develop is up to date -git remote set-branches --add origin develop -git fetch -q origin develop - -pr="$PULL_REQUEST_NUMBER" - -# if there are changes in the debian directory, check that the debian changelog -# has been updated -if ! git diff --quiet FETCH_HEAD... -- debian; then - if git diff --quiet FETCH_HEAD... -- debian/changelog; then - echo "Updates to debian directory, but no update to the changelog." >&2 - echo "!! Please see the contributing guide for help writing your changelog entry:" >&2 - echo "https://github.com/matrix-org/synapse/blob/develop/CONTRIBUTING.md#debian-changelog" >&2 - exit 1 - fi -fi - -# if there are changes *outside* the debian directory, check that the -# newsfragments have been updated. -if ! git diff --name-only FETCH_HEAD... | grep -qv '^debian/'; then - exit 0 -fi - -# Print a link to the contributing guide if the user makes a mistake -CONTRIBUTING_GUIDE_TEXT="!! Please see the contributing guide for help writing your changelog entry: -https://github.com/matrix-org/synapse/blob/develop/CONTRIBUTING.md#changelog" - -# If check-newsfragment returns a non-zero exit code, print the contributing guide and exit -python -m towncrier.check --compare-with=origin/develop || (echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 && exit 1) - -echo -echo "--------------------------" -echo - -matched=0 -for f in $(git diff --diff-filter=d --name-only FETCH_HEAD... -- changelog.d); do - # check that any added newsfiles on this branch end with a full stop. - lastchar=$(tr -d '\n' < "$f" | tail -c 1) - if [ "$lastchar" != '.' ] && [ "$lastchar" != '!' ]; then - echo -e "\e[31mERROR: newsfragment $f does not end with a '.' or '!'\e[39m" >&2 - echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 - exit 1 - fi - - # see if this newsfile corresponds to the right PR - [[ -n "$pr" && "$f" == changelog.d/"$pr".* ]] && matched=1 -done - -if [[ -n "$pr" && "$matched" -eq 0 ]]; then - echo -e "\e[31mERROR: Did not find a news fragment with the right number: expected changelog.d/$pr.*.\e[39m" >&2 - echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 - exit 1 -fi diff --git a/scripts-dev/check-newsfragment.sh b/scripts-dev/check-newsfragment.sh new file mode 100755 index 0000000000..493558ad65 --- /dev/null +++ b/scripts-dev/check-newsfragment.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash +# +# A script which checks that an appropriate news file has been added on this +# branch. + +echo -e "+++ \033[32mChecking newsfragment\033[m" + +set -e + +# make sure that origin/develop is up to date +git remote set-branches --add origin develop +git fetch -q origin develop + +pr="$PULL_REQUEST_NUMBER" + +# if there are changes in the debian directory, check that the debian changelog +# has been updated +if ! git diff --quiet FETCH_HEAD... -- debian; then + if git diff --quiet FETCH_HEAD... -- debian/changelog; then + echo "Updates to debian directory, but no update to the changelog." >&2 + echo "!! Please see the contributing guide for help writing your changelog entry:" >&2 + echo "https://github.com/matrix-org/synapse/blob/develop/CONTRIBUTING.md#debian-changelog" >&2 + exit 1 + fi +fi + +# if there are changes *outside* the debian directory, check that the +# newsfragments have been updated. +if ! git diff --name-only FETCH_HEAD... | grep -qv '^debian/'; then + exit 0 +fi + +# Print a link to the contributing guide if the user makes a mistake +CONTRIBUTING_GUIDE_TEXT="!! Please see the contributing guide for help writing your changelog entry: +https://github.com/matrix-org/synapse/blob/develop/CONTRIBUTING.md#changelog" + +# If check-newsfragment returns a non-zero exit code, print the contributing guide and exit +python -m towncrier.check --compare-with=origin/develop || (echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 && exit 1) + +echo +echo "--------------------------" +echo + +matched=0 +for f in $(git diff --diff-filter=d --name-only FETCH_HEAD... -- changelog.d); do + # check that any added newsfiles on this branch end with a full stop. + lastchar=$(tr -d '\n' < "$f" | tail -c 1) + if [ "$lastchar" != '.' ] && [ "$lastchar" != '!' ]; then + echo -e "\e[31mERROR: newsfragment $f does not end with a '.' or '!'\e[39m" >&2 + echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 + exit 1 + fi + + # see if this newsfile corresponds to the right PR + [[ -n "$pr" && "$f" == changelog.d/"$pr".* ]] && matched=1 +done + +if [[ -n "$pr" && "$matched" -eq 0 ]]; then + echo -e "\e[31mERROR: Did not find a news fragment with the right number: expected changelog.d/$pr.*.\e[39m" >&2 + echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 + exit 1 +fi diff --git a/scripts-dev/generate_sample_config b/scripts-dev/generate_sample_config deleted file mode 100755 index 185e277933..0000000000 --- a/scripts-dev/generate_sample_config +++ /dev/null @@ -1,28 +0,0 @@ -#!/usr/bin/env bash -# -# Update/check the docs/sample_config.yaml - -set -e - -cd "$(dirname "$0")/.." - -SAMPLE_CONFIG="docs/sample_config.yaml" -SAMPLE_LOG_CONFIG="docs/sample_log_config.yaml" - -check() { - diff -u "$SAMPLE_LOG_CONFIG" <(synapse/_scripts/generate_log_config.py) >/dev/null || return 1 -} - -if [ "$1" == "--check" ]; then - diff -u "$SAMPLE_CONFIG" <(synapse/_scripts/generate_config.py --header-file docs/.sample_config_header.yaml) >/dev/null || { - echo -e "\e[1m\e[31m$SAMPLE_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config\`.\e[0m" >&2 - exit 1 - } - diff -u "$SAMPLE_LOG_CONFIG" <(synapse/_scripts/generate_log_config.py) >/dev/null || { - echo -e "\e[1m\e[31m$SAMPLE_LOG_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config\`.\e[0m" >&2 - exit 1 - } -else - synapse/_scripts/generate_config.py --header-file docs/.sample_config_header.yaml -o "$SAMPLE_CONFIG" - synapse/_scripts/generate_log_config.py -o "$SAMPLE_LOG_CONFIG" -fi diff --git a/scripts-dev/generate_sample_config.sh b/scripts-dev/generate_sample_config.sh new file mode 100755 index 0000000000..375897eacb --- /dev/null +++ b/scripts-dev/generate_sample_config.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +# +# Update/check the docs/sample_config.yaml + +set -e + +cd "$(dirname "$0")/.." + +SAMPLE_CONFIG="docs/sample_config.yaml" +SAMPLE_LOG_CONFIG="docs/sample_log_config.yaml" + +check() { + diff -u "$SAMPLE_LOG_CONFIG" <(synapse/_scripts/generate_log_config.py) >/dev/null || return 1 +} + +if [ "$1" == "--check" ]; then + diff -u "$SAMPLE_CONFIG" <(synapse/_scripts/generate_config.py --header-file docs/.sample_config_header.yaml) >/dev/null || { + echo -e "\e[1m\e[31m$SAMPLE_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config.sh\`.\e[0m" >&2 + exit 1 + } + diff -u "$SAMPLE_LOG_CONFIG" <(synapse/_scripts/generate_log_config.py) >/dev/null || { + echo -e "\e[1m\e[31m$SAMPLE_LOG_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config.sh\`.\e[0m" >&2 + exit 1 + } +else + synapse/_scripts/generate_config.py --header-file docs/.sample_config_header.yaml -o "$SAMPLE_CONFIG" + synapse/_scripts/generate_log_config.py -o "$SAMPLE_LOG_CONFIG" +fi diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index df4d4934d0..2f5f2c3566 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -85,8 +85,6 @@ else "synapse" "docker" "tests" # annoyingly, black doesn't find these so we have to list them "scripts-dev" - "scripts-dev/build_debian_packages" - "scripts-dev/sign_json" "contrib" "synctl" "setup.py" "synmark" "stubs" ".ci" ) fi diff --git a/scripts-dev/sign_json b/scripts-dev/sign_json deleted file mode 100755 index 9459543106..0000000000 --- a/scripts-dev/sign_json +++ /dev/null @@ -1,166 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2020 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import json -import sys -from json import JSONDecodeError - -import yaml -from signedjson.key import read_signing_keys -from signedjson.sign import sign_json - -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.crypto.event_signing import add_hashes_and_signatures -from synapse.util import json_encoder - - -def main(): - parser = argparse.ArgumentParser( - description="""Adds a signature to a JSON object. - -Example usage: - - $ scripts-dev/sign_json.py -N test -k localhost.signing.key "{}" - {"signatures":{"test":{"ed25519:a_ZnZh":"LmPnml6iM0iR..."}}} -""", - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - - parser.add_argument( - "-N", - "--server-name", - help="Name to give as the local homeserver. If unspecified, will be " - "read from the config file.", - ) - - parser.add_argument( - "-k", - "--signing-key-path", - help="Path to the file containing the private ed25519 key to sign the " - "request with.", - ) - - parser.add_argument( - "-K", - "--signing-key", - help="The private ed25519 key to sign the request with.", - ) - - parser.add_argument( - "-c", - "--config", - default="homeserver.yaml", - help=( - "Path to synapse config file, from which the server name and/or signing " - "key path will be read. Ignored if --server-name and --signing-key(-path) " - "are both given." - ), - ) - - parser.add_argument( - "--sign-event-room-version", - type=str, - help=( - "Sign the JSON as an event for the given room version, rather than raw JSON. " - "This means that we will add a 'hashes' object, and redact the event before " - "signing." - ), - ) - - input_args = parser.add_mutually_exclusive_group() - - input_args.add_argument("input_data", nargs="?", help="Raw JSON to be signed.") - - input_args.add_argument( - "-i", - "--input", - type=argparse.FileType("r"), - default=sys.stdin, - help=( - "A file from which to read the JSON to be signed. If neither --input nor " - "input_data are given, JSON will be read from stdin." - ), - ) - - parser.add_argument( - "-o", - "--output", - type=argparse.FileType("w"), - default=sys.stdout, - help="Where to write the signed JSON. Defaults to stdout.", - ) - - args = parser.parse_args() - - if not args.server_name or not (args.signing_key_path or args.signing_key): - read_args_from_config(args) - - if args.signing_key: - keys = read_signing_keys([args.signing_key]) - else: - with open(args.signing_key_path) as f: - keys = read_signing_keys(f) - - json_to_sign = args.input_data - if json_to_sign is None: - json_to_sign = args.input.read() - - try: - obj = json.loads(json_to_sign) - except JSONDecodeError as e: - print("Unable to parse input as JSON: %s" % e, file=sys.stderr) - sys.exit(1) - - if not isinstance(obj, dict): - print("Input json was not an object", file=sys.stderr) - sys.exit(1) - - if args.sign_event_room_version: - room_version = KNOWN_ROOM_VERSIONS.get(args.sign_event_room_version) - if not room_version: - print( - f"Unknown room version {args.sign_event_room_version}", file=sys.stderr - ) - sys.exit(1) - add_hashes_and_signatures(room_version, obj, args.server_name, keys[0]) - else: - sign_json(obj, args.server_name, keys[0]) - - for c in json_encoder.iterencode(obj): - args.output.write(c) - args.output.write("\n") - - -def read_args_from_config(args: argparse.Namespace) -> None: - with open(args.config, "r") as fh: - config = yaml.safe_load(fh) - if not args.server_name: - args.server_name = config["server_name"] - if not args.signing_key_path and not args.signing_key: - if "signing_key" in config: - args.signing_key = config["signing_key"] - elif "signing_key_path" in config: - args.signing_key_path = config["signing_key_path"] - else: - print( - "A signing key must be given on the commandline or in the config file.", - file=sys.stderr, - ) - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/scripts-dev/sign_json.py b/scripts-dev/sign_json.py new file mode 100755 index 0000000000..9459543106 --- /dev/null +++ b/scripts-dev/sign_json.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python +# +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import sys +from json import JSONDecodeError + +import yaml +from signedjson.key import read_signing_keys +from signedjson.sign import sign_json + +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.crypto.event_signing import add_hashes_and_signatures +from synapse.util import json_encoder + + +def main(): + parser = argparse.ArgumentParser( + description="""Adds a signature to a JSON object. + +Example usage: + + $ scripts-dev/sign_json.py -N test -k localhost.signing.key "{}" + {"signatures":{"test":{"ed25519:a_ZnZh":"LmPnml6iM0iR..."}}} +""", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "-N", + "--server-name", + help="Name to give as the local homeserver. If unspecified, will be " + "read from the config file.", + ) + + parser.add_argument( + "-k", + "--signing-key-path", + help="Path to the file containing the private ed25519 key to sign the " + "request with.", + ) + + parser.add_argument( + "-K", + "--signing-key", + help="The private ed25519 key to sign the request with.", + ) + + parser.add_argument( + "-c", + "--config", + default="homeserver.yaml", + help=( + "Path to synapse config file, from which the server name and/or signing " + "key path will be read. Ignored if --server-name and --signing-key(-path) " + "are both given." + ), + ) + + parser.add_argument( + "--sign-event-room-version", + type=str, + help=( + "Sign the JSON as an event for the given room version, rather than raw JSON. " + "This means that we will add a 'hashes' object, and redact the event before " + "signing." + ), + ) + + input_args = parser.add_mutually_exclusive_group() + + input_args.add_argument("input_data", nargs="?", help="Raw JSON to be signed.") + + input_args.add_argument( + "-i", + "--input", + type=argparse.FileType("r"), + default=sys.stdin, + help=( + "A file from which to read the JSON to be signed. If neither --input nor " + "input_data are given, JSON will be read from stdin." + ), + ) + + parser.add_argument( + "-o", + "--output", + type=argparse.FileType("w"), + default=sys.stdout, + help="Where to write the signed JSON. Defaults to stdout.", + ) + + args = parser.parse_args() + + if not args.server_name or not (args.signing_key_path or args.signing_key): + read_args_from_config(args) + + if args.signing_key: + keys = read_signing_keys([args.signing_key]) + else: + with open(args.signing_key_path) as f: + keys = read_signing_keys(f) + + json_to_sign = args.input_data + if json_to_sign is None: + json_to_sign = args.input.read() + + try: + obj = json.loads(json_to_sign) + except JSONDecodeError as e: + print("Unable to parse input as JSON: %s" % e, file=sys.stderr) + sys.exit(1) + + if not isinstance(obj, dict): + print("Input json was not an object", file=sys.stderr) + sys.exit(1) + + if args.sign_event_room_version: + room_version = KNOWN_ROOM_VERSIONS.get(args.sign_event_room_version) + if not room_version: + print( + f"Unknown room version {args.sign_event_room_version}", file=sys.stderr + ) + sys.exit(1) + add_hashes_and_signatures(room_version, obj, args.server_name, keys[0]) + else: + sign_json(obj, args.server_name, keys[0]) + + for c in json_encoder.iterencode(obj): + args.output.write(c) + args.output.write("\n") + + +def read_args_from_config(args: argparse.Namespace) -> None: + with open(args.config, "r") as fh: + config = yaml.safe_load(fh) + if not args.server_name: + args.server_name = config["server_name"] + if not args.signing_key_path and not args.signing_key: + if "signing_key" in config: + args.signing_key = config["signing_key"] + elif "signing_key_path" in config: + args.signing_key_path = config["signing_key_path"] + else: + print( + "A signing key must be given on the commandline or in the config file.", + file=sys.stderr, + ) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tox.ini b/tox.ini index 8d6aa7580b..f4829200cc 100644 --- a/tox.ini +++ b/tox.ini @@ -40,8 +40,6 @@ lint_targets = tests # annoyingly, black doesn't find these so we have to list them scripts-dev - scripts-dev/build_debian_packages - scripts-dev/sign_json stubs contrib synctl -- cgit 1.4.1 From 11282ade1d8deeafa042a639e2685472d6347e69 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 2 Mar 2022 19:22:44 +0000 Subject: Move the `snapcraft` configuration to `contrib`. (#12142) * Move the `snapcraft` configuration to `contrib`. We're happy for people to package this as a snap image if it's useful, but we don't support or maintain it. I'd like to move the config to `contrib` to reflect this state of affairs. * Changelog --- MANIFEST.in | 1 - changelog.d/12142.misc | 1 + contrib/snap/snapcraft.yaml | 57 +++++++++++++++++++++++++++++++++++++++++++++ snap/snapcraft.yaml | 57 --------------------------------------------- 4 files changed, 58 insertions(+), 58 deletions(-) create mode 100644 changelog.d/12142.misc create mode 100644 contrib/snap/snapcraft.yaml delete mode 100644 snap/snapcraft.yaml diff --git a/MANIFEST.in b/MANIFEST.in index 7e903518e1..f1e295e583 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -52,5 +52,4 @@ prune contrib prune debian prune demo/etc prune docker -prune snap prune stubs diff --git a/changelog.d/12142.misc b/changelog.d/12142.misc new file mode 100644 index 0000000000..5d09f90b52 --- /dev/null +++ b/changelog.d/12142.misc @@ -0,0 +1 @@ +Move the snapcraft configuration file to `contrib`. \ No newline at end of file diff --git a/contrib/snap/snapcraft.yaml b/contrib/snap/snapcraft.yaml new file mode 100644 index 0000000000..dd4c8478d5 --- /dev/null +++ b/contrib/snap/snapcraft.yaml @@ -0,0 +1,57 @@ +name: matrix-synapse +base: core18 +version: git +summary: Reference Matrix homeserver +description: | + Synapse is the reference Matrix homeserver. + Matrix is a federated and decentralised instant messaging and VoIP system. + +grade: stable +confinement: strict + +apps: + matrix-synapse: + command: synctl --no-daemonize start $SNAP_COMMON/homeserver.yaml + stop-command: synctl -c $SNAP_COMMON stop + plugs: [network-bind, network] + daemon: simple + hash-password: + command: hash_password + generate-config: + command: generate_config + generate-signing-key: + command: generate_signing_key + register-new-matrix-user: + command: register_new_matrix_user + plugs: [network] + synctl: + command: synctl +parts: + matrix-synapse: + source: . + plugin: python + python-version: python3 + python-packages: + - '.[all]' + - pip + - setuptools + - setuptools-scm + - wheel + build-packages: + - libffi-dev + - libturbojpeg0-dev + - libssl-dev + - libxslt1-dev + - libpq-dev + - zlib1g-dev + stage-packages: + - libasn1-8-heimdal + - libgssapi3-heimdal + - libhcrypto4-heimdal + - libheimbase1-heimdal + - libheimntlm0-heimdal + - libhx509-5-heimdal + - libkrb5-26-heimdal + - libldap-2.4-2 + - libpq5 + - libsasl2-2 diff --git a/snap/snapcraft.yaml b/snap/snapcraft.yaml deleted file mode 100644 index dd4c8478d5..0000000000 --- a/snap/snapcraft.yaml +++ /dev/null @@ -1,57 +0,0 @@ -name: matrix-synapse -base: core18 -version: git -summary: Reference Matrix homeserver -description: | - Synapse is the reference Matrix homeserver. - Matrix is a federated and decentralised instant messaging and VoIP system. - -grade: stable -confinement: strict - -apps: - matrix-synapse: - command: synctl --no-daemonize start $SNAP_COMMON/homeserver.yaml - stop-command: synctl -c $SNAP_COMMON stop - plugs: [network-bind, network] - daemon: simple - hash-password: - command: hash_password - generate-config: - command: generate_config - generate-signing-key: - command: generate_signing_key - register-new-matrix-user: - command: register_new_matrix_user - plugs: [network] - synctl: - command: synctl -parts: - matrix-synapse: - source: . - plugin: python - python-version: python3 - python-packages: - - '.[all]' - - pip - - setuptools - - setuptools-scm - - wheel - build-packages: - - libffi-dev - - libturbojpeg0-dev - - libssl-dev - - libxslt1-dev - - libpq-dev - - zlib1g-dev - stage-packages: - - libasn1-8-heimdal - - libgssapi3-heimdal - - libhcrypto4-heimdal - - libheimbase1-heimdal - - libheimntlm0-heimdal - - libhx509-5-heimdal - - libkrb5-26-heimdal - - libldap-2.4-2 - - libpq5 - - libsasl2-2 -- cgit 1.4.1 From ae8a616b4926a005cd73db835a1ebcea4f5cbee0 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 3 Mar 2022 11:39:58 +0100 Subject: Correctly register deactivation and profile update module callbacks (#12141) --- changelog.d/12141.bugfix | 1 + synapse/events/third_party_rules.py | 10 +++++++--- synapse/module_api/__init__.py | 8 ++++++++ 3 files changed, 16 insertions(+), 3 deletions(-) create mode 100644 changelog.d/12141.bugfix diff --git a/changelog.d/12141.bugfix b/changelog.d/12141.bugfix new file mode 100644 index 0000000000..03a2507e2c --- /dev/null +++ b/changelog.d/12141.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.54.0rc1 preventing the new module callbacks introduced in this release from being registered by modules. diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index dd3104faf3..ede72ee876 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -174,7 +174,9 @@ class ThirdPartyEventRules: ] = None, on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None, on_profile_update: Optional[ON_PROFILE_UPDATE_CALLBACK] = None, - on_deactivation: Optional[ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK] = None, + on_user_deactivation_status_changed: Optional[ + ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK + ] = None, ) -> None: """Register callbacks from modules for each hook.""" if check_event_allowed is not None: @@ -199,8 +201,10 @@ class ThirdPartyEventRules: if on_profile_update is not None: self._on_profile_update_callbacks.append(on_profile_update) - if on_deactivation is not None: - self._on_user_deactivation_status_changed_callbacks.append(on_deactivation) + if on_user_deactivation_status_changed is not None: + self._on_user_deactivation_status_changed_callbacks.append( + on_user_deactivation_status_changed, + ) async def check_event_allowed( self, event: EventBase, context: EventContext diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 7e46931869..c42eeedd87 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -59,6 +59,8 @@ from synapse.events.third_party_rules import ( CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK, ON_CREATE_ROOM_CALLBACK, ON_NEW_EVENT_CALLBACK, + ON_PROFILE_UPDATE_CALLBACK, + ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK, ) from synapse.handlers.account_validity import ( IS_USER_EXPIRED_CALLBACK, @@ -281,6 +283,10 @@ class ModuleApi: CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = None, on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None, + on_profile_update: Optional[ON_PROFILE_UPDATE_CALLBACK] = None, + on_user_deactivation_status_changed: Optional[ + ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK + ] = None, ) -> None: """Registers callbacks for third party event rules capabilities. @@ -292,6 +298,8 @@ class ModuleApi: check_threepid_can_be_invited=check_threepid_can_be_invited, check_visibility_can_be_modified=check_visibility_can_be_modified, on_new_event=on_new_event, + on_profile_update=on_profile_update, + on_user_deactivation_status_changed=on_user_deactivation_status_changed, ) def register_presence_router_callbacks( -- cgit 1.4.1 From 31b125ccec75e708b09f40205c8cfe692edfa6b4 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 3 Mar 2022 04:45:23 -0600 Subject: Enable MSC3030 Complement tests in Synapse (#12144) The Complement tests for MSC3030 are now merged, https://github.com/matrix-org/complement/pull/178 Synapse implmentation: https://github.com/matrix-org/synapse/pull/9445 --- .github/workflows/tests.yml | 2 +- changelog.d/12144.misc | 1 + scripts-dev/complement.sh | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12144.misc diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3f4e44ca59..fa9611d42b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -376,7 +376,7 @@ jobs: # Run Complement - run: | set -o pipefail - go test -v -json -p 1 -tags synapse_blacklist,msc2403 ./tests/... 2>&1 | gotestfmt + go test -v -json -p 1 -tags synapse_blacklist,msc2403,msc3030 ./tests/... 2>&1 | gotestfmt shell: bash name: Run Complement Tests env: diff --git a/changelog.d/12144.misc b/changelog.d/12144.misc new file mode 100644 index 0000000000..d8f71bb203 --- /dev/null +++ b/changelog.d/12144.misc @@ -0,0 +1 @@ +Enable [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) Complement tests in CI. diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index 0aecb3daf1..e3d3e0f293 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -71,4 +71,4 @@ fi # Run the tests! echo "Images built; running complement" -go test -v -tags synapse_blacklist,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... +go test -v -tags synapse_blacklist,msc2403,msc3030 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... -- cgit 1.4.1 From 61fd2a8f591f20fe9d1cffe659336664bf44e742 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 3 Mar 2022 10:52:35 +0000 Subject: Limit the size of the aggregation_key (#12101) There's no reason to let people use long keys. --- changelog.d/12101.misc | 1 + synapse/handlers/message.py | 3 +++ 2 files changed, 4 insertions(+) create mode 100644 changelog.d/12101.misc diff --git a/changelog.d/12101.misc b/changelog.d/12101.misc new file mode 100644 index 0000000000..d165f73d13 --- /dev/null +++ b/changelog.d/12101.misc @@ -0,0 +1 @@ +Limit the size of `aggregation_key` on annotations. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 61cb133ef2..0799ec9a84 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1069,6 +1069,9 @@ class EventCreationHandler: if relation_type == RelationTypes.ANNOTATION: aggregation_key = relation["key"] + if len(aggregation_key) > 500: + raise SynapseError(400, "Aggregation key is too long") + already_exists = await self.store.has_user_annotated_event( relates_to, event.type, aggregation_key, event.sender ) -- cgit 1.4.1 From a511a890d7c556ad357d27443e5665e6cc25e0b5 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 3 Mar 2022 05:19:20 -0600 Subject: Enable MSC2716 Complement tests in Synapse (#12145) Co-authored-by: Brendan Abolivier --- .github/workflows/tests.yml | 2 +- changelog.d/12145.misc | 1 + scripts-dev/complement.sh | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12145.misc diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fa9611d42b..c89c50cd07 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -376,7 +376,7 @@ jobs: # Run Complement - run: | set -o pipefail - go test -v -json -p 1 -tags synapse_blacklist,msc2403,msc3030 ./tests/... 2>&1 | gotestfmt + go test -v -json -p 1 -tags synapse_blacklist,msc2403,msc2716,msc3030 ./tests/... 2>&1 | gotestfmt shell: bash name: Run Complement Tests env: diff --git a/changelog.d/12145.misc b/changelog.d/12145.misc new file mode 100644 index 0000000000..4092a2d66e --- /dev/null +++ b/changelog.d/12145.misc @@ -0,0 +1 @@ +Enable [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) Complement tests in CI. diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index e3d3e0f293..0a79a4063f 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -71,4 +71,4 @@ fi # Run the tests! echo "Images built; running complement" -go test -v -tags synapse_blacklist,msc2403,msc3030 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... +go test -v -tags synapse_blacklist,msc2403,msc2716,msc3030 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... -- cgit 1.4.1 From cea1b58c4a5850a59e14808324ce1d9f222f52e5 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 3 Mar 2022 12:47:55 +0000 Subject: Don't impose version checks on dev extras at runtime (#12129) * Fix incorrect argument in test case * Add copyright header * Docstring and __all__ * Exclude dev depenencies * Use changelog from #12088 * Include version in error messages This will hopefully distinguish between the version of the source code and the version of the distribution package that is installed. * Linter script is your friend --- changelog.d/12129.misc | 1 + synapse/util/check_dependencies.py | 61 +++++++++++++++++++++++++++++++---- tests/util/test_check_dependencies.py | 21 ++++++++++-- 3 files changed, 74 insertions(+), 9 deletions(-) create mode 100644 changelog.d/12129.misc diff --git a/changelog.d/12129.misc b/changelog.d/12129.misc new file mode 100644 index 0000000000..ce4213650c --- /dev/null +++ b/changelog.d/12129.misc @@ -0,0 +1 @@ +Inspect application dependencies using `importlib.metadata` or its backport. \ No newline at end of file diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py index 3a1f6b3c75..39b0a91db3 100644 --- a/synapse/util/check_dependencies.py +++ b/synapse/util/check_dependencies.py @@ -1,3 +1,25 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This module exposes a single function which checks synapse's dependencies are present +and correctly versioned. It makes use of `importlib.metadata` to do so. The details +are a bit murky: there's no easy way to get a map from "extras" to the packages they +require. But this is probably just symptomatic of Python's package management. +""" + import logging from typing import Iterable, NamedTuple, Optional @@ -10,6 +32,8 @@ try: except ImportError: import importlib_metadata as metadata # type: ignore[no-redef] +__all__ = ["check_requirements"] + class DependencyException(Exception): @property @@ -29,7 +53,17 @@ class DependencyException(Exception): yield '"' + i + '"' -EXTRAS = set(metadata.metadata(DISTRIBUTION_NAME).get_all("Provides-Extra")) +DEV_EXTRAS = {"lint", "mypy", "test", "dev"} +RUNTIME_EXTRAS = ( + set(metadata.metadata(DISTRIBUTION_NAME).get_all("Provides-Extra")) - DEV_EXTRAS +) +VERSION = metadata.version(DISTRIBUTION_NAME) + + +def _is_dev_dependency(req: Requirement) -> bool: + return req.marker is not None and any( + req.marker.evaluate({"extra": e}) for e in DEV_EXTRAS + ) class Dependency(NamedTuple): @@ -43,6 +77,9 @@ def _generic_dependencies() -> Iterable[Dependency]: assert requirements is not None for raw_requirement in requirements: req = Requirement(raw_requirement) + if _is_dev_dependency(req): + continue + # https://packaging.pypa.io/en/latest/markers.html#usage notes that # > Evaluating an extra marker with no environment is an error # so we pass in a dummy empty extra value here. @@ -56,6 +93,8 @@ def _dependencies_for_extra(extra: str) -> Iterable[Dependency]: assert requirements is not None for raw_requirement in requirements: req = Requirement(raw_requirement) + if _is_dev_dependency(req): + continue # Exclude mandatory deps by only selecting deps needed with this extra. if ( req.marker is not None @@ -67,18 +106,26 @@ def _dependencies_for_extra(extra: str) -> Iterable[Dependency]: def _not_installed(requirement: Requirement, extra: Optional[str] = None) -> str: if extra: - return f"Need {requirement.name} for {extra}, but it is not installed" + return ( + f"Synapse {VERSION} needs {requirement.name} for {extra}, " + f"but it is not installed" + ) else: - return f"Need {requirement.name}, but it is not installed" + return f"Synapse {VERSION} needs {requirement.name}, but it is not installed" def _incorrect_version( requirement: Requirement, got: str, extra: Optional[str] = None ) -> str: if extra: - return f"Need {requirement} for {extra}, but got {requirement.name}=={got}" + return ( + f"Synapse {VERSION} needs {requirement} for {extra}, " + f"but got {requirement.name}=={got}" + ) else: - return f"Need {requirement}, but got {requirement.name}=={got}" + return ( + f"Synapse {VERSION} needs {requirement}, but got {requirement.name}=={got}" + ) def check_requirements(extra: Optional[str] = None) -> None: @@ -100,10 +147,10 @@ def check_requirements(extra: Optional[str] = None) -> None: # First work out which dependencies are required, and which are optional. if extra is None: dependencies = _generic_dependencies() - elif extra in EXTRAS: + elif extra in RUNTIME_EXTRAS: dependencies = _dependencies_for_extra(extra) else: - raise ValueError(f"Synapse does not provide the feature '{extra}'") + raise ValueError(f"Synapse {VERSION} does not provide the feature '{extra}'") deps_unfulfilled = [] errors = [] diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py index 3c07252252..a91c33272f 100644 --- a/tests/util/test_check_dependencies.py +++ b/tests/util/test_check_dependencies.py @@ -65,6 +65,23 @@ class TestDependencyChecker(TestCase): # should not raise check_requirements() + def test_checks_ignore_dev_dependencies(self) -> None: + """Bot generic and per-extra checks should ignore dev dependencies.""" + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1; extra == 'mypy'"], + ), patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}): + # We're testing that none of these calls raise. + with self.mock_installed_package(None): + check_requirements() + check_requirements("cool-extra") + with self.mock_installed_package(old): + check_requirements() + check_requirements("cool-extra") + with self.mock_installed_package(new): + check_requirements() + check_requirements("cool-extra") + def test_generic_check_of_optional_dependency(self) -> None: """Complain if an optional package is old.""" with patch( @@ -85,11 +102,11 @@ class TestDependencyChecker(TestCase): with patch( "synapse.util.check_dependencies.metadata.requires", return_value=["dummypkg >= 1; extra == 'cool-extra'"], - ), patch("synapse.util.check_dependencies.EXTRAS", {"cool-extra"}): + ), patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}): with self.mock_installed_package(None): self.assertRaises(DependencyException, check_requirements, "cool-extra") with self.mock_installed_package(old): self.assertRaises(DependencyException, check_requirements, "cool-extra") with self.mock_installed_package(new): # should not raise - check_requirements() + check_requirements("cool-extra") -- cgit 1.4.1 From 1d11b452b70c768e4919bd9cf6bcaeda2050a3d4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 3 Mar 2022 10:43:06 -0500 Subject: Use the proper serialization format when bundling aggregations. (#12090) This ensures that the `latest_event` field of the bundled aggregation for threads uses the same format as the other events in the response. --- changelog.d/12090.bugfix | 1 + synapse/appservice/api.py | 24 ++++--- synapse/events/utils.py | 81 ++++++++++++++------- synapse/handlers/events.py | 3 +- synapse/handlers/initial_sync.py | 9 ++- synapse/handlers/pagination.py | 7 +- synapse/rest/client/notifications.py | 9 ++- synapse/rest/client/sync.py | 132 ++++++++++------------------------- tests/events/test_utils.py | 5 +- tests/rest/client/test_relations.py | 2 - 10 files changed, 130 insertions(+), 143 deletions(-) create mode 100644 changelog.d/12090.bugfix diff --git a/changelog.d/12090.bugfix b/changelog.d/12090.bugfix new file mode 100644 index 0000000000..087065dcb1 --- /dev/null +++ b/changelog.d/12090.bugfix @@ -0,0 +1 @@ +Use the proper serialization format for bundled thread aggregations. The bug has existed since Synapse v1.48.0. diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index a0ea958af6..98fe354014 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -25,7 +25,7 @@ from synapse.appservice import ( TransactionUnusedFallbackKeys, ) from synapse.events import EventBase -from synapse.events.utils import serialize_event +from synapse.events.utils import SerializeEventConfig, serialize_event from synapse.http.client import SimpleHttpClient from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache @@ -321,16 +321,18 @@ class ApplicationServiceApi(SimpleHttpClient): serialize_event( e, time_now, - as_client_event=True, - # If this is an invite or a knock membership event, and we're interested - # in this user, then include any stripped state alongside the event. - include_stripped_room_state=( - e.type == EventTypes.Member - and ( - e.membership == Membership.INVITE - or e.membership == Membership.KNOCK - ) - and service.is_interested_in_user(e.state_key) + config=SerializeEventConfig( + as_client_event=True, + # If this is an invite or a knock membership event, and we're interested + # in this user, then include any stripped state alongside the event. + include_stripped_room_state=( + e.type == EventTypes.Member + and ( + e.membership == Membership.INVITE + or e.membership == Membership.KNOCK + ) + and service.is_interested_in_user(e.state_key) + ), ), ) for e in events diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 9386fa29dd..ee34cb46e4 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -26,6 +26,7 @@ from typing import ( Union, ) +import attr from frozendict import frozendict from synapse.api.constants import EventContentFields, EventTypes, RelationTypes @@ -303,29 +304,37 @@ def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict: return d +@attr.s(slots=True, frozen=True, auto_attribs=True) +class SerializeEventConfig: + as_client_event: bool = True + # Function to convert from federation format to client format + event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1 + # ID of the user's auth token - used for namespacing of transaction IDs + token_id: Optional[int] = None + # List of event fields to include. If empty, all fields will be returned. + only_event_fields: Optional[List[str]] = None + # Some events can have stripped room state stored in the `unsigned` field. + # This is required for invite and knock functionality. If this option is + # False, that state will be removed from the event before it is returned. + # Otherwise, it will be kept. + include_stripped_room_state: bool = False + + +_DEFAULT_SERIALIZE_EVENT_CONFIG = SerializeEventConfig() + + def serialize_event( e: Union[JsonDict, EventBase], time_now_ms: int, *, - as_client_event: bool = True, - event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1, - token_id: Optional[str] = None, - only_event_fields: Optional[List[str]] = None, - include_stripped_room_state: bool = False, + config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, ) -> JsonDict: """Serialize event for clients Args: e time_now_ms - as_client_event - event_format - token_id - only_event_fields - include_stripped_room_state: Some events can have stripped room state - stored in the `unsigned` field. This is required for invite and knock - functionality. If this option is False, that state will be removed from the - event before it is returned. Otherwise, it will be kept. + config: Event serialization config Returns: The serialized event dictionary. @@ -348,11 +357,11 @@ def serialize_event( if "redacted_because" in e.unsigned: d["unsigned"]["redacted_because"] = serialize_event( - e.unsigned["redacted_because"], time_now_ms, event_format=event_format + e.unsigned["redacted_because"], time_now_ms, config=config ) - if token_id is not None: - if token_id == getattr(e.internal_metadata, "token_id", None): + if config.token_id is not None: + if config.token_id == getattr(e.internal_metadata, "token_id", None): txn_id = getattr(e.internal_metadata, "txn_id", None) if txn_id is not None: d["unsigned"]["transaction_id"] = txn_id @@ -361,13 +370,14 @@ def serialize_event( # that are meant to provide metadata about a room to an invitee/knocker. They are # intended to only be included in specific circumstances, such as down sync, and # should not be included in any other case. - if not include_stripped_room_state: + if not config.include_stripped_room_state: d["unsigned"].pop("invite_room_state", None) d["unsigned"].pop("knock_room_state", None) - if as_client_event: - d = event_format(d) + if config.as_client_event: + d = config.event_format(d) + only_event_fields = config.only_event_fields if only_event_fields: if not isinstance(only_event_fields, list) or not all( isinstance(f, str) for f in only_event_fields @@ -390,18 +400,18 @@ class EventClientSerializer: event: Union[JsonDict, EventBase], time_now: int, *, + config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None, - **kwargs: Any, ) -> JsonDict: """Serializes a single event. Args: event: The event being serialized. time_now: The current time in milliseconds + config: Event serialization config bundle_aggregations: Whether to include the bundled aggregations for this event. Only applies to non-state events. (State events never include bundled aggregations.) - **kwargs: Arguments to pass to `serialize_event` Returns: The serialized event @@ -410,7 +420,7 @@ class EventClientSerializer: if not isinstance(event, EventBase): return event - serialized_event = serialize_event(event, time_now, **kwargs) + serialized_event = serialize_event(event, time_now, config=config) # Check if there are any bundled aggregations to include with the event. if bundle_aggregations: @@ -419,6 +429,7 @@ class EventClientSerializer: self._inject_bundled_aggregations( event, time_now, + config, bundle_aggregations[event.event_id], serialized_event, ) @@ -456,6 +467,7 @@ class EventClientSerializer: self, event: EventBase, time_now: int, + config: SerializeEventConfig, aggregations: "BundledAggregations", serialized_event: JsonDict, ) -> None: @@ -466,6 +478,7 @@ class EventClientSerializer: time_now: The current time in milliseconds aggregations: The bundled aggregation to serialize. serialized_event: The serialized event which may be modified. + config: Event serialization config """ serialized_aggregations = {} @@ -493,8 +506,8 @@ class EventClientSerializer: thread = aggregations.thread # Don't bundle aggregations as this could recurse forever. - serialized_latest_event = self.serialize_event( - thread.latest_event, time_now, bundle_aggregations=None + serialized_latest_event = serialize_event( + thread.latest_event, time_now, config=config ) # Manually apply an edit, if one exists. if thread.latest_edit: @@ -515,20 +528,34 @@ class EventClientSerializer: ) def serialize_events( - self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any + self, + events: Iterable[Union[JsonDict, EventBase]], + time_now: int, + *, + config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, + bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None, ) -> List[JsonDict]: """Serializes multiple events. Args: event time_now: The current time in milliseconds - **kwargs: Arguments to pass to `serialize_event` + config: Event serialization config + bundle_aggregations: Whether to include the bundled aggregations for this + event. Only applies to non-state events. (State events never include + bundled aggregations.) Returns: The list of serialized events """ return [ - self.serialize_event(event, time_now=time_now, **kwargs) for event in events + self.serialize_event( + event, + time_now, + config=config, + bundle_aggregations=bundle_aggregations, + ) + for event in events ] diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 97e75e60c3..d2ccb5c5d3 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import AuthError, SynapseError from synapse.events import EventBase +from synapse.events.utils import SerializeEventConfig from synapse.handlers.presence import format_user_presence_state from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, UserID @@ -120,7 +121,7 @@ class EventStreamHandler: chunks = self._event_serializer.serialize_events( events, time_now, - as_client_event=as_client_event, + config=SerializeEventConfig(as_client_event=as_client_event), ) chunk = { diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 344f20f37c..316cfae24f 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, cast from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import SynapseError from synapse.events import EventBase +from synapse.events.utils import SerializeEventConfig from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state from synapse.handlers.receipts import ReceiptEventSource @@ -156,6 +157,8 @@ class InitialSyncHandler: if limit is None: limit = 10 + serializer_options = SerializeEventConfig(as_client_event=as_client_event) + async def handle_room(event: RoomsForUser) -> None: d: JsonDict = { "room_id": event.room_id, @@ -173,7 +176,7 @@ class InitialSyncHandler: d["invite"] = self._event_serializer.serialize_event( invite_event, time_now, - as_client_event=as_client_event, + config=serializer_options, ) rooms_ret.append(d) @@ -225,7 +228,7 @@ class InitialSyncHandler: self._event_serializer.serialize_events( messages, time_now=time_now, - as_client_event=as_client_event, + config=serializer_options, ) ), "start": await start_token.to_string(self.store), @@ -235,7 +238,7 @@ class InitialSyncHandler: d["state"] = self._event_serializer.serialize_events( current_state.values(), time_now=time_now, - as_client_event=as_client_event, + config=serializer_options, ) account_data_events = [] diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 5c01a426ff..183fabcfc0 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -22,6 +22,7 @@ from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError from synapse.api.filtering import Filter +from synapse.events.utils import SerializeEventConfig from synapse.handlers.room import ShutdownRoomResponse from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter @@ -541,13 +542,15 @@ class PaginationHandler: time_now = self.clock.time_msec() + serialize_options = SerializeEventConfig(as_client_event=as_client_event) + chunk = { "chunk": ( self._event_serializer.serialize_events( events, time_now, + config=serialize_options, bundle_aggregations=aggregations, - as_client_event=as_client_event, ) ), "start": await from_token.to_string(self.store), @@ -556,7 +559,7 @@ class PaginationHandler: if state: chunk["state"] = self._event_serializer.serialize_events( - state, time_now, as_client_event=as_client_event + state, time_now, config=serialize_options ) return chunk diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index 20377a9ac6..ff040de6b8 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -16,7 +16,10 @@ import logging from typing import TYPE_CHECKING, Tuple from synapse.api.constants import ReceiptTypes -from synapse.events.utils import format_event_for_client_v2_without_room_id +from synapse.events.utils import ( + SerializeEventConfig, + format_event_for_client_v2_without_room_id, +) from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest @@ -75,7 +78,9 @@ class NotificationsServlet(RestServlet): self._event_serializer.serialize_event( notif_events[pa.event_id], self.clock.time_msec(), - event_format=format_event_for_client_v2_without_room_id, + config=SerializeEventConfig( + event_format=format_event_for_client_v2_without_room_id + ), ) ), } diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index f3018ff690..53c385a86c 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -14,24 +14,14 @@ import itertools import logging from collections import defaultdict -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - List, - Optional, - Tuple, - Union, -) +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from synapse.api.constants import Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState -from synapse.events import EventBase from synapse.events.utils import ( + SerializeEventConfig, format_event_for_client_v2_without_room_id, format_event_raw, ) @@ -48,7 +38,6 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.logging.opentracing import trace -from synapse.storage.databases.main.relations import BundledAggregations from synapse.types import JsonDict, StreamToken from synapse.util import json_decoder @@ -239,28 +228,31 @@ class SyncRestServlet(RestServlet): else: raise Exception("Unknown event format %s" % (filter.event_format,)) + serialize_options = SerializeEventConfig( + event_format=event_formatter, + token_id=access_token_id, + only_event_fields=filter.event_fields, + ) + stripped_serialize_options = SerializeEventConfig( + event_format=event_formatter, + token_id=access_token_id, + include_stripped_room_state=True, + ) + joined = await self.encode_joined( - sync_result.joined, - time_now, - access_token_id, - filter.event_fields, - event_formatter, + sync_result.joined, time_now, serialize_options ) invited = await self.encode_invited( - sync_result.invited, time_now, access_token_id, event_formatter + sync_result.invited, time_now, stripped_serialize_options ) knocked = await self.encode_knocked( - sync_result.knocked, time_now, access_token_id, event_formatter + sync_result.knocked, time_now, stripped_serialize_options ) archived = await self.encode_archived( - sync_result.archived, - time_now, - access_token_id, - filter.event_fields, - event_formatter, + sync_result.archived, time_now, serialize_options ) logger.debug("building sync response dict") @@ -339,9 +331,7 @@ class SyncRestServlet(RestServlet): self, rooms: List[JoinedSyncResult], time_now: int, - token_id: Optional[int], - event_fields: List[str], - event_formatter: Callable[[JsonDict], JsonDict], + serialize_options: SerializeEventConfig, ) -> JsonDict: """ Encode the joined rooms in a sync result @@ -349,24 +339,14 @@ class SyncRestServlet(RestServlet): Args: rooms: list of sync results for rooms this user is joined to time_now: current time - used as a baseline for age calculations - token_id: ID of the user's auth token - used for namespacing - of transaction IDs - event_fields: List of event fields to include. If empty, - all fields will be returned. - event_formatter: function to convert from federation format - to client format + serialize_options: Event serializer options Returns: The joined rooms list, in our response format """ joined = {} for room in rooms: joined[room.room_id] = await self.encode_room( - room, - time_now, - token_id, - joined=True, - only_fields=event_fields, - event_formatter=event_formatter, + room, time_now, joined=True, serialize_options=serialize_options ) return joined @@ -376,8 +356,7 @@ class SyncRestServlet(RestServlet): self, rooms: List[InvitedSyncResult], time_now: int, - token_id: Optional[int], - event_formatter: Callable[[JsonDict], JsonDict], + serialize_options: SerializeEventConfig, ) -> JsonDict: """ Encode the invited rooms in a sync result @@ -385,10 +364,7 @@ class SyncRestServlet(RestServlet): Args: rooms: list of sync results for rooms this user is invited to time_now: current time - used as a baseline for age calculations - token_id: ID of the user's auth token - used for namespacing - of transaction IDs - event_formatter: function to convert from federation format - to client format + serialize_options: Event serializer options Returns: The invited rooms list, in our response format @@ -396,11 +372,7 @@ class SyncRestServlet(RestServlet): invited = {} for room in rooms: invite = self._event_serializer.serialize_event( - room.invite, - time_now, - token_id=token_id, - event_format=event_formatter, - include_stripped_room_state=True, + room.invite, time_now, config=serialize_options ) unsigned = dict(invite.get("unsigned", {})) invite["unsigned"] = unsigned @@ -415,8 +387,7 @@ class SyncRestServlet(RestServlet): self, rooms: List[KnockedSyncResult], time_now: int, - token_id: Optional[int], - event_formatter: Callable[[Dict], Dict], + serialize_options: SerializeEventConfig, ) -> Dict[str, Dict[str, Any]]: """ Encode the rooms we've knocked on in a sync result. @@ -424,8 +395,7 @@ class SyncRestServlet(RestServlet): Args: rooms: list of sync results for rooms this user is knocking on time_now: current time - used as a baseline for age calculations - token_id: ID of the user's auth token - used for namespacing of transaction IDs - event_formatter: function to convert from federation format to client format + serialize_options: Event serializer options Returns: The list of rooms the user has knocked on, in our response format. @@ -433,11 +403,7 @@ class SyncRestServlet(RestServlet): knocked = {} for room in rooms: knock = self._event_serializer.serialize_event( - room.knock, - time_now, - token_id=token_id, - event_format=event_formatter, - include_stripped_room_state=True, + room.knock, time_now, config=serialize_options ) # Extract the `unsigned` key from the knock event. @@ -470,9 +436,7 @@ class SyncRestServlet(RestServlet): self, rooms: List[ArchivedSyncResult], time_now: int, - token_id: Optional[int], - event_fields: List[str], - event_formatter: Callable[[JsonDict], JsonDict], + serialize_options: SerializeEventConfig, ) -> JsonDict: """ Encode the archived rooms in a sync result @@ -480,23 +444,14 @@ class SyncRestServlet(RestServlet): Args: rooms: list of sync results for rooms this user is joined to time_now: current time - used as a baseline for age calculations - token_id: ID of the user's auth token - used for namespacing - of transaction IDs - event_fields: List of event fields to include. If empty, - all fields will be returned. - event_formatter: function to convert from federation format to client format + serialize_options: Event serializer options Returns: The archived rooms list, in our response format """ joined = {} for room in rooms: joined[room.room_id] = await self.encode_room( - room, - time_now, - token_id, - joined=False, - only_fields=event_fields, - event_formatter=event_formatter, + room, time_now, joined=False, serialize_options=serialize_options ) return joined @@ -505,10 +460,8 @@ class SyncRestServlet(RestServlet): self, room: Union[JoinedSyncResult, ArchivedSyncResult], time_now: int, - token_id: Optional[int], joined: bool, - only_fields: Optional[List[str]], - event_formatter: Callable[[JsonDict], JsonDict], + serialize_options: SerializeEventConfig, ) -> JsonDict: """ Args: @@ -524,20 +477,6 @@ class SyncRestServlet(RestServlet): Returns: The room, encoded in our response format """ - - def serialize( - events: Iterable[EventBase], - aggregations: Optional[Dict[str, BundledAggregations]] = None, - ) -> List[JsonDict]: - return self._event_serializer.serialize_events( - events, - time_now=time_now, - bundle_aggregations=aggregations, - token_id=token_id, - event_format=event_formatter, - only_event_fields=only_fields, - ) - state_dict = room.state timeline_events = room.timeline.events @@ -554,9 +493,14 @@ class SyncRestServlet(RestServlet): event.room_id, ) - serialized_state = serialize(state_events) - serialized_timeline = serialize( - timeline_events, room.timeline.bundled_aggregations + serialized_state = self._event_serializer.serialize_events( + state_events, time_now, config=serialize_options + ) + serialized_timeline = self._event_serializer.serialize_events( + timeline_events, + time_now, + config=serialize_options, + bundle_aggregations=room.timeline.bundled_aggregations, ) account_data = room.account_data diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 45e3395b33..00ad19e446 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -16,6 +16,7 @@ from synapse.api.constants import EventContentFields from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict from synapse.events.utils import ( + SerializeEventConfig, copy_power_levels_contents, prune_event, serialize_event, @@ -392,7 +393,9 @@ class PruneEventTestCase(unittest.TestCase): class SerializeEventTestCase(unittest.TestCase): def serialize(self, ev, fields): - return serialize_event(ev, 1479807801915, only_event_fields=fields) + return serialize_event( + ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields) + ) def test_event_fields_works_with_keys(self): self.assertEqual( diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 709f851a38..53062b41de 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -704,10 +704,8 @@ class RelationsTestCase(BaseRelationsTestCase): } }, "event_id": thread_2, - "room_id": self.room, "sender": self.user_id, "type": "m.room.test", - "user_id": self.user_id, }, relations_dict[RelationTypes.THREAD].get("latest_event"), ) -- cgit 1.4.1 From 7e91107be1a4287873266e588a3c5b415279f4c8 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 3 Mar 2022 17:05:44 +0100 Subject: Add type hints to `tests/rest` (#12146) * Add type hints to `tests/rest` * newsfile * change import from `SigningKey` --- changelog.d/12146.misc | 1 + mypy.ini | 7 +--- tests/rest/key/v2/test_remote_key_resource.py | 44 ++++++++++++++-------- tests/rest/media/v1/test_base.py | 4 +- tests/rest/media/v1/test_filepath.py | 48 ++++++++++++------------ tests/rest/media/v1/test_html_preview.py | 54 +++++++++++++-------------- tests/rest/media/v1/test_oembed.py | 10 ++--- tests/rest/test_health.py | 8 ++-- tests/rest/test_well_known.py | 20 +++++----- 9 files changed, 104 insertions(+), 92 deletions(-) create mode 100644 changelog.d/12146.misc diff --git a/changelog.d/12146.misc b/changelog.d/12146.misc new file mode 100644 index 0000000000..3ca7c47212 --- /dev/null +++ b/changelog.d/12146.misc @@ -0,0 +1 @@ +Add type hints to `tests/rest`. diff --git a/mypy.ini b/mypy.ini index 10971b7225..481e8a5366 100644 --- a/mypy.ini +++ b/mypy.ini @@ -89,8 +89,6 @@ exclude = (?x) |tests/push/test_presentable_names.py |tests/push/test_push_rule_evaluator.py |tests/rest/client/test_transactions.py - |tests/rest/key/v2/test_remote_key_resource.py - |tests/rest/media/v1/test_base.py |tests/rest/media/v1/test_media_storage.py |tests/rest/media/v1/test_url_preview.py |tests/scripts/test_new_matrix_user.py @@ -254,10 +252,7 @@ disallow_untyped_defs = True [mypy-tests.storage.test_user_directory] disallow_untyped_defs = True -[mypy-tests.rest.admin.*] -disallow_untyped_defs = True - -[mypy-tests.rest.client.*] +[mypy-tests.rest.*] disallow_untyped_defs = True [mypy-tests.federation.transport.test_client] diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 4672a68596..978c252f84 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -13,19 +13,24 @@ # limitations under the License. import urllib.parse from io import BytesIO, StringIO +from typing import Any, Dict, Optional, Union from unittest.mock import Mock import signedjson.key from canonicaljson import encode_canonical_json -from nacl.signing import SigningKey from signedjson.sign import sign_json +from signedjson.types import SigningKey -from twisted.web.resource import NoResource +from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import NoResource, Resource from synapse.crypto.keyring import PerspectivesKeyFetcher from synapse.http.site import SynapseRequest from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.server import HomeServer from synapse.storage.keys import FetchKeyResult +from synapse.types import JsonDict +from synapse.util import Clock from synapse.util.httpresourcetree import create_resource_tree from synapse.util.stringutils import random_string @@ -35,11 +40,11 @@ from tests.utils import default_config class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.http_client = Mock() return self.setup_test_homeserver(federation_http_client=self.http_client) - def create_test_resource(self): + def create_test_resource(self) -> Resource: return create_resource_tree( {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource() ) @@ -51,7 +56,12 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): Tell the mock http client to expect an outgoing GET request for the given key """ - async def get_json(destination, path, ignore_backoff=False, **kwargs): + async def get_json( + destination: str, + path: str, + ignore_backoff: bool = False, + **kwargs: Any, + ) -> Union[JsonDict, list]: self.assertTrue(ignore_backoff) self.assertEqual(destination, server_name) key_id = "%s:%s" % (signing_key.alg, signing_key.version) @@ -84,7 +94,8 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase): Checks that the response is a 200 and returns the decoded json body. """ channel = FakeChannel(self.site, self.reactor) - req = SynapseRequest(channel, self.site) + # channel is a `FakeChannel` but `HTTPChannel` is expected + req = SynapseRequest(channel, self.site) # type: ignore[arg-type] req.content = BytesIO(b"") req.requestReceived( b"GET", @@ -97,7 +108,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase): resp = channel.json_body return resp - def test_get_key(self): + def test_get_key(self) -> None: """Fetch a remote key""" SERVER_NAME = "remote.server" testkey = signedjson.key.generate_signing_key("ver1") @@ -114,7 +125,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase): self.assertIn(SERVER_NAME, keys[0]["signatures"]) self.assertIn(self.hs.hostname, keys[0]["signatures"]) - def test_get_own_key(self): + def test_get_own_key(self) -> None: """Fetch our own key""" testkey = signedjson.key.generate_signing_key("ver1") self.expect_outgoing_key_request(self.hs.hostname, testkey) @@ -141,7 +152,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): endpoint, to check that the two implementations are compatible. """ - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() # replace the signing key with our own @@ -152,7 +163,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): return config - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # make a second homeserver, configured to use the first one as a key notary self.http_client2 = Mock() config = default_config(name="keyclient") @@ -175,7 +186,9 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): # wire up outbound POST /key/v2/query requests from hs2 so that they # will be forwarded to hs1 - async def post_json(destination, path, data): + async def post_json( + destination: str, path: str, data: Optional[JsonDict] = None + ) -> Union[JsonDict, list]: self.assertEqual(destination, self.hs.hostname) self.assertEqual( path, @@ -183,7 +196,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): ) channel = FakeChannel(self.site, self.reactor) - req = SynapseRequest(channel, self.site) + # channel is a `FakeChannel` but `HTTPChannel` is expected + req = SynapseRequest(channel, self.site) # type: ignore[arg-type] req.content = BytesIO(encode_canonical_json(data)) req.requestReceived( @@ -198,7 +212,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): self.http_client2.post_json.side_effect = post_json - def test_get_key(self): + def test_get_key(self) -> None: """Fetch a key belonging to a random server""" # make up a key to be fetched. testkey = signedjson.key.generate_signing_key("abc") @@ -218,7 +232,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): signedjson.key.encode_verify_key_base64(testkey.verify_key), ) - def test_get_notary_key(self): + def test_get_notary_key(self) -> None: """Fetch a key belonging to the notary server""" # make up a key to be fetched. We randomise the keyid to try to get it to # appear before the key server signing key sometimes (otherwise we bail out @@ -240,7 +254,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): signedjson.key.encode_verify_key_base64(testkey.verify_key), ) - def test_get_notary_keyserver_key(self): + def test_get_notary_keyserver_key(self) -> None: """Fetch the notary's keyserver key""" # we expect hs1 to make a regular key request to itself self.expect_outgoing_key_request(self.hs.hostname, self.hs_signing_key) diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py index f761e23f1b..c73179151a 100644 --- a/tests/rest/media/v1/test_base.py +++ b/tests/rest/media/v1/test_base.py @@ -28,11 +28,11 @@ class GetFileNameFromHeadersTests(unittest.TestCase): b"inline; filename*=utf-8''foo%C2%A3bar": "fooΒ£bar", } - def tests(self): + def tests(self) -> None: for hdr, expected in self.TEST_CASES.items(): res = get_filename_from_headers({b"Content-Disposition": [hdr]}) self.assertEqual( res, expected, - "expected output for %s to be %s but was %s" % (hdr, expected, res), + f"expected output for {hdr!r} to be {expected} but was {res}", ) diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py index 913bc530aa..43e6f0f70a 100644 --- a/tests/rest/media/v1/test_filepath.py +++ b/tests/rest/media/v1/test_filepath.py @@ -21,12 +21,12 @@ from tests import unittest class MediaFilePathsTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.filepaths = MediaFilePaths("/media_store") - def test_local_media_filepath(self): + def test_local_media_filepath(self) -> None: """Test local media paths""" self.assertEqual( self.filepaths.local_media_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"), @@ -37,7 +37,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_local_media_thumbnail(self): + def test_local_media_thumbnail(self) -> None: """Test local media thumbnail paths""" self.assertEqual( self.filepaths.local_media_thumbnail_rel( @@ -52,14 +52,14 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", ) - def test_local_media_thumbnail_dir(self): + def test_local_media_thumbnail_dir(self) -> None: """Test local media thumbnail directory paths""" self.assertEqual( self.filepaths.local_media_thumbnail_dir("GerZNDnDZVjsOtardLuwfIBg"), "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_remote_media_filepath(self): + def test_remote_media_filepath(self) -> None: """Test remote media paths""" self.assertEqual( self.filepaths.remote_media_filepath_rel( @@ -74,7 +74,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_remote_media_thumbnail(self): + def test_remote_media_thumbnail(self) -> None: """Test remote media thumbnail paths""" self.assertEqual( self.filepaths.remote_media_thumbnail_rel( @@ -99,7 +99,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", ) - def test_remote_media_thumbnail_legacy(self): + def test_remote_media_thumbnail_legacy(self) -> None: """Test old-style remote media thumbnail paths""" self.assertEqual( self.filepaths.remote_media_thumbnail_rel_legacy( @@ -108,7 +108,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg", ) - def test_remote_media_thumbnail_dir(self): + def test_remote_media_thumbnail_dir(self) -> None: """Test remote media thumbnail directory paths""" self.assertEqual( self.filepaths.remote_media_thumbnail_dir( @@ -117,7 +117,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_url_cache_filepath(self): + def test_url_cache_filepath(self) -> None: """Test URL cache paths""" self.assertEqual( self.filepaths.url_cache_filepath_rel("2020-01-02_GerZNDnDZVjsOtar"), @@ -128,7 +128,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/url_cache/2020-01-02/GerZNDnDZVjsOtar", ) - def test_url_cache_filepath_legacy(self): + def test_url_cache_filepath_legacy(self) -> None: """Test old-style URL cache paths""" self.assertEqual( self.filepaths.url_cache_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"), @@ -139,7 +139,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_url_cache_filepath_dirs_to_delete(self): + def test_url_cache_filepath_dirs_to_delete(self) -> None: """Test URL cache cleanup paths""" self.assertEqual( self.filepaths.url_cache_filepath_dirs_to_delete( @@ -148,7 +148,7 @@ class MediaFilePathsTestCase(unittest.TestCase): ["/media_store/url_cache/2020-01-02"], ) - def test_url_cache_filepath_dirs_to_delete_legacy(self): + def test_url_cache_filepath_dirs_to_delete_legacy(self) -> None: """Test old-style URL cache cleanup paths""" self.assertEqual( self.filepaths.url_cache_filepath_dirs_to_delete( @@ -160,7 +160,7 @@ class MediaFilePathsTestCase(unittest.TestCase): ], ) - def test_url_cache_thumbnail(self): + def test_url_cache_thumbnail(self) -> None: """Test URL cache thumbnail paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_rel( @@ -175,7 +175,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale", ) - def test_url_cache_thumbnail_legacy(self): + def test_url_cache_thumbnail_legacy(self) -> None: """Test old-style URL cache thumbnail paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_rel( @@ -190,7 +190,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", ) - def test_url_cache_thumbnail_directory(self): + def test_url_cache_thumbnail_directory(self) -> None: """Test URL cache thumbnail directory paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_directory_rel( @@ -203,7 +203,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar", ) - def test_url_cache_thumbnail_directory_legacy(self): + def test_url_cache_thumbnail_directory_legacy(self) -> None: """Test old-style URL cache thumbnail directory paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_directory_rel( @@ -216,7 +216,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_url_cache_thumbnail_dirs_to_delete(self): + def test_url_cache_thumbnail_dirs_to_delete(self) -> None: """Test URL cache thumbnail cleanup paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_dirs_to_delete( @@ -228,7 +228,7 @@ class MediaFilePathsTestCase(unittest.TestCase): ], ) - def test_url_cache_thumbnail_dirs_to_delete_legacy(self): + def test_url_cache_thumbnail_dirs_to_delete_legacy(self) -> None: """Test old-style URL cache thumbnail cleanup paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_dirs_to_delete( @@ -241,7 +241,7 @@ class MediaFilePathsTestCase(unittest.TestCase): ], ) - def test_server_name_validation(self): + def test_server_name_validation(self) -> None: """Test validation of server names""" self._test_path_validation( [ @@ -274,7 +274,7 @@ class MediaFilePathsTestCase(unittest.TestCase): ], ) - def test_file_id_validation(self): + def test_file_id_validation(self) -> None: """Test validation of local, remote and legacy URL cache file / media IDs""" # File / media IDs get split into three parts to form paths, consisting of the # first two characters, next two characters and rest of the ID. @@ -357,7 +357,7 @@ class MediaFilePathsTestCase(unittest.TestCase): invalid_values=invalid_file_ids, ) - def test_url_cache_media_id_validation(self): + def test_url_cache_media_id_validation(self) -> None: """Test validation of URL cache media IDs""" self._test_path_validation( [ @@ -387,7 +387,7 @@ class MediaFilePathsTestCase(unittest.TestCase): ], ) - def test_content_type_validation(self): + def test_content_type_validation(self) -> None: """Test validation of thumbnail content types""" self._test_path_validation( [ @@ -410,7 +410,7 @@ class MediaFilePathsTestCase(unittest.TestCase): ], ) - def test_thumbnail_method_validation(self): + def test_thumbnail_method_validation(self) -> None: """Test validation of thumbnail methods""" self._test_path_validation( [ @@ -440,7 +440,7 @@ class MediaFilePathsTestCase(unittest.TestCase): parameter: str, valid_values: Iterable[str], invalid_values: Iterable[str], - ): + ) -> None: """Test that the specified methods validate the named parameter as expected Args: diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py index a4b57e3d1f..3fb37a2a59 100644 --- a/tests/rest/media/v1/test_html_preview.py +++ b/tests/rest/media/v1/test_html_preview.py @@ -32,7 +32,7 @@ class SummarizeTestCase(unittest.TestCase): if not lxml: skip = "url preview feature requires lxml" - def test_long_summarize(self): + def test_long_summarize(self) -> None: example_paras = [ """TromsΓΈ (Norwegian pronunciation: [ˈtrʊmsΕ“] ( listen); Northern Sami: Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in @@ -90,7 +90,7 @@ class SummarizeTestCase(unittest.TestCase): " TromsΓΈya had a population of 36,088. Substantial parts of the urban…", ) - def test_short_summarize(self): + def test_short_summarize(self) -> None: example_paras = [ "TromsΓΈ (Norwegian pronunciation: [ˈtrʊmsΕ“] ( listen); Northern Sami:" " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" @@ -117,7 +117,7 @@ class SummarizeTestCase(unittest.TestCase): " most of the year.", ) - def test_small_then_large_summarize(self): + def test_small_then_large_summarize(self) -> None: example_paras = [ "TromsΓΈ (Norwegian pronunciation: [ˈtrʊmsΕ“] ( listen); Northern Sami:" " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" @@ -150,7 +150,7 @@ class CalcOgTestCase(unittest.TestCase): if not lxml: skip = "url preview feature requires lxml" - def test_simple(self): + def test_simple(self) -> None: html = b""" Foo @@ -165,7 +165,7 @@ class CalcOgTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - def test_comment(self): + def test_comment(self) -> None: html = b""" Foo @@ -181,7 +181,7 @@ class CalcOgTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - def test_comment2(self): + def test_comment2(self) -> None: html = b""" Foo @@ -206,7 +206,7 @@ class CalcOgTestCase(unittest.TestCase): }, ) - def test_script(self): + def test_script(self) -> None: html = b""" Foo @@ -222,7 +222,7 @@ class CalcOgTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - def test_missing_title(self): + def test_missing_title(self) -> None: html = b""" @@ -236,7 +236,7 @@ class CalcOgTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) - def test_h1_as_title(self): + def test_h1_as_title(self) -> None: html = b""" @@ -251,7 +251,7 @@ class CalcOgTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) - def test_missing_title_and_broken_h1(self): + def test_missing_title_and_broken_h1(self) -> None: html = b""" @@ -266,19 +266,19 @@ class CalcOgTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) - def test_empty(self): + def test_empty(self) -> None: """Test a body with no data in it.""" html = b"" tree = decode_body(html, "http://example.com/test.html") self.assertIsNone(tree) - def test_no_tree(self): + def test_no_tree(self) -> None: """A valid body with no tree in it.""" html = b"\x00" tree = decode_body(html, "http://example.com/test.html") self.assertIsNone(tree) - def test_xml(self): + def test_xml(self) -> None: """Test decoding XML and ensure it works properly.""" # Note that the strip() call is important to ensure the xml tag starts # at the initial byte. @@ -293,7 +293,7 @@ class CalcOgTestCase(unittest.TestCase): og = parse_html_to_open_graph(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - def test_invalid_encoding(self): + def test_invalid_encoding(self) -> None: """An invalid character encoding should be ignored and treated as UTF-8, if possible.""" html = b""" @@ -307,7 +307,7 @@ class CalcOgTestCase(unittest.TestCase): og = parse_html_to_open_graph(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - def test_invalid_encoding2(self): + def test_invalid_encoding2(self) -> None: """A body which doesn't match the sent character encoding.""" # Note that this contains an invalid UTF-8 sequence in the title. html = b""" @@ -322,7 +322,7 @@ class CalcOgTestCase(unittest.TestCase): og = parse_html_to_open_graph(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "ΓΏΓΏ Foo", "og:description": "Some text."}) - def test_windows_1252(self): + def test_windows_1252(self) -> None: """A body which uses cp1252, but doesn't declare that.""" html = b""" @@ -338,7 +338,7 @@ class CalcOgTestCase(unittest.TestCase): class MediaEncodingTestCase(unittest.TestCase): - def test_meta_charset(self): + def test_meta_charset(self) -> None: """A character encoding is found via the meta tag.""" encodings = _get_html_media_encodings( b""" @@ -363,7 +363,7 @@ class MediaEncodingTestCase(unittest.TestCase): ) self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - def test_meta_charset_underscores(self): + def test_meta_charset_underscores(self) -> None: """A character encoding contains underscore.""" encodings = _get_html_media_encodings( b""" @@ -376,7 +376,7 @@ class MediaEncodingTestCase(unittest.TestCase): ) self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"]) - def test_xml_encoding(self): + def test_xml_encoding(self) -> None: """A character encoding is found via the meta tag.""" encodings = _get_html_media_encodings( b""" @@ -388,7 +388,7 @@ class MediaEncodingTestCase(unittest.TestCase): ) self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - def test_meta_xml_encoding(self): + def test_meta_xml_encoding(self) -> None: """Meta tags take precedence over XML encoding.""" encodings = _get_html_media_encodings( b""" @@ -402,7 +402,7 @@ class MediaEncodingTestCase(unittest.TestCase): ) self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"]) - def test_content_type(self): + def test_content_type(self) -> None: """A character encoding is found via the Content-Type header.""" # Test a few variations of the header. headers = ( @@ -417,12 +417,12 @@ class MediaEncodingTestCase(unittest.TestCase): encodings = _get_html_media_encodings(b"", header) self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - def test_fallback(self): + def test_fallback(self) -> None: """A character encoding cannot be found in the body or header.""" encodings = _get_html_media_encodings(b"", "text/html") self.assertEqual(list(encodings), ["utf-8", "cp1252"]) - def test_duplicates(self): + def test_duplicates(self) -> None: """Ensure each encoding is only attempted once.""" encodings = _get_html_media_encodings( b""" @@ -436,7 +436,7 @@ class MediaEncodingTestCase(unittest.TestCase): ) self.assertEqual(list(encodings), ["utf-8", "cp1252"]) - def test_unknown_invalid(self): + def test_unknown_invalid(self) -> None: """A character encoding should be ignored if it is unknown or invalid.""" encodings = _get_html_media_encodings( b""" @@ -451,7 +451,7 @@ class MediaEncodingTestCase(unittest.TestCase): class RebaseUrlTestCase(unittest.TestCase): - def test_relative(self): + def test_relative(self) -> None: """Relative URLs should be resolved based on the context of the base URL.""" self.assertEqual( rebase_url("subpage", "https://example.com/foo/"), @@ -466,14 +466,14 @@ class RebaseUrlTestCase(unittest.TestCase): "https://example.com/bar", ) - def test_absolute(self): + def test_absolute(self) -> None: """Absolute URLs should not be modified.""" self.assertEqual( rebase_url("https://alice.com/a/", "https://example.com/foo/"), "https://alice.com/a/", ) - def test_data(self): + def test_data(self) -> None: """Data URLs should not be modified.""" self.assertEqual( rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"), diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py index 048d0ca44a..f38d7225f8 100644 --- a/tests/rest/media/v1/test_oembed.py +++ b/tests/rest/media/v1/test_oembed.py @@ -16,7 +16,7 @@ import json from twisted.test.proto_helpers import MemoryReactor -from synapse.rest.media.v1.oembed import OEmbedProvider +from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -25,15 +25,15 @@ from tests.unittest import HomeserverTestCase class OEmbedTests(HomeserverTestCase): - def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): - self.oembed = OEmbedProvider(homeserver) + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.oembed = OEmbedProvider(hs) - def parse_response(self, response: JsonDict): + def parse_response(self, response: JsonDict) -> OEmbedResult: return self.oembed.parse_oembed_response( "https://test", json.dumps(response).encode("utf-8") ) - def test_version(self): + def test_version(self) -> None: """Accept versions that are similar to 1.0 as a string or int (or missing).""" for version in ("1.0", 1.0, 1): result = self.parse_response({"version": version, "type": "link"}) diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py index 01d48c3860..da325955f8 100644 --- a/tests/rest/test_health.py +++ b/tests/rest/test_health.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from http import HTTPStatus from synapse.rest.health import HealthResource @@ -19,12 +19,12 @@ from tests import unittest class HealthCheckTests(unittest.HomeserverTestCase): - def create_test_resource(self): + def create_test_resource(self) -> HealthResource: # replace the JsonResource with a HealthResource. return HealthResource() - def test_health(self): + def test_health(self) -> None: channel = self.make_request("GET", "/health", shorthand=False) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(channel.result["body"], b"OK") diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index 118aa93a32..11f78f52b8 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus + from twisted.web.resource import Resource from synapse.rest.well_known import well_known_resource @@ -19,7 +21,7 @@ from tests import unittest class WellKnownTests(unittest.HomeserverTestCase): - def create_test_resource(self): + def create_test_resource(self) -> Resource: # replace the JsonResource with a Resource wrapping the WellKnownResource res = Resource() res.putChild(b".well-known", well_known_resource(self.hs)) @@ -31,12 +33,12 @@ class WellKnownTests(unittest.HomeserverTestCase): "default_identity_server": "https://testis", } ) - def test_client_well_known(self): + def test_client_well_known(self) -> None: channel = self.make_request( "GET", "/.well-known/matrix/client", shorthand=False ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual( channel.json_body, { @@ -50,27 +52,27 @@ class WellKnownTests(unittest.HomeserverTestCase): "public_baseurl": None, } ) - def test_client_well_known_no_public_baseurl(self): + def test_client_well_known_no_public_baseurl(self) -> None: channel = self.make_request( "GET", "/.well-known/matrix/client", shorthand=False ) - self.assertEqual(channel.code, 404) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) @unittest.override_config({"serve_server_wellknown": True}) - def test_server_well_known(self): + def test_server_well_known(self) -> None: channel = self.make_request( "GET", "/.well-known/matrix/server", shorthand=False ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual( channel.json_body, {"m.server": "test:443"}, ) - def test_server_well_known_disabled(self): + def test_server_well_known_disabled(self) -> None: channel = self.make_request( "GET", "/.well-known/matrix/server", shorthand=False ) - self.assertEqual(channel.code, 404) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) -- cgit 1.4.1 From 9297d040a72b70c7cc0ec15319afdd99b01ba885 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 3 Mar 2022 17:14:09 +0000 Subject: Detox, part 2 of N (#12152) I've argued in #11537 that poetry and tox don't cooperate well at the moment. (See also #12119.) Therefore I'm pruning away bits of tox to make the transition to poetry easier. This change removes the commands for coverage. We don't use coverage in anger at the moment. It shouldn't be too hard to add coverage as a dev-dependency and reintroduce this if we really want it. --- changelog.d/12152.misc | 1 + tox.ini | 26 -------------------------- 2 files changed, 1 insertion(+), 26 deletions(-) create mode 100644 changelog.d/12152.misc diff --git a/changelog.d/12152.misc b/changelog.d/12152.misc new file mode 100644 index 0000000000..b9877eaccb --- /dev/null +++ b/changelog.d/12152.misc @@ -0,0 +1 @@ +Prune unused jobs from `tox` config. \ No newline at end of file diff --git a/tox.ini b/tox.ini index f4829200cc..04d282a705 100644 --- a/tox.ini +++ b/tox.ini @@ -158,32 +158,6 @@ commands = extras = lint commands = isort -c --df {[base]lint_targets} -[testenv:combine] -skip_install = true -usedevelop = false -deps = - coverage - pip>=10 -commands= - coverage combine - coverage report - -[testenv:cov-erase] -skip_install = true -usedevelop = false -deps = - coverage -commands= - coverage erase - -[testenv:cov-html] -skip_install = true -usedevelop = false -deps = - coverage -commands= - coverage html - [testenv:mypy] deps = {[base]deps} -- cgit 1.4.1 From fb0ffa96766a4b6f298f53af2d212e4c4d09d9e9 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 3 Mar 2022 18:14:09 +0000 Subject: Rename various ApplicationServices interested methods (#11915) --- changelog.d/11915.misc | 1 + synapse/appservice/__init__.py | 133 ++++++++++++++++++++++++------------ synapse/handlers/appservice.py | 4 +- synapse/handlers/directory.py | 6 +- synapse/handlers/receipts.py | 2 +- synapse/handlers/typing.py | 4 +- tests/appservice/test_appservice.py | 45 +++++++++--- tests/handlers/test_appservice.py | 56 +++++++++++---- 8 files changed, 175 insertions(+), 76 deletions(-) create mode 100644 changelog.d/11915.misc diff --git a/changelog.d/11915.misc b/changelog.d/11915.misc new file mode 100644 index 0000000000..e3cef1511e --- /dev/null +++ b/changelog.d/11915.misc @@ -0,0 +1 @@ +Simplify the `ApplicationService` class' set of public methods related to interest checking. \ No newline at end of file diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 4d3f8e4923..07ec95f1d6 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -175,27 +175,14 @@ class ApplicationService: return namespace.exclusive return False - async def _matches_user(self, event: EventBase, store: "DataStore") -> bool: - if self.is_interested_in_user(event.sender): - return True - - # also check m.room.member state key - if event.type == EventTypes.Member and self.is_interested_in_user( - event.state_key - ): - return True - - does_match = await self.matches_user_in_member_list(event.room_id, store) - return does_match - @cached(num_args=1, cache_context=True) - async def matches_user_in_member_list( + async def _matches_user_in_member_list( self, room_id: str, store: "DataStore", cache_context: _CacheContext, ) -> bool: - """Check if this service is interested a room based upon it's membership + """Check if this service is interested a room based upon its membership Args: room_id: The room to check. @@ -214,47 +201,110 @@ class ApplicationService: return True return False - def _matches_room_id(self, event: EventBase) -> bool: - if hasattr(event, "room_id"): - return self.is_interested_in_room(event.room_id) - return False + def is_interested_in_user( + self, + user_id: str, + ) -> bool: + """ + Returns whether the application is interested in a given user ID. + + The appservice is considered to be interested in a user if either: the + user ID is in the appservice's user namespace, or if the user is the + appservice's configured sender_localpart. + + Args: + user_id: The ID of the user to check. + + Returns: + True if the application service is interested in the user, False if not. + """ + return ( + # User is the appservice's sender_localpart user + user_id == self.sender + # User is in the appservice's user namespace + or self.is_user_in_namespace(user_id) + ) + + @cached(num_args=1, cache_context=True) + async def is_interested_in_room( + self, + room_id: str, + store: "DataStore", + cache_context: _CacheContext, + ) -> bool: + """ + Returns whether the application service is interested in a given room ID. + + The appservice is considered to be interested in the room if either: the ID or one + of the aliases of the room is in the appservice's room ID or alias namespace + respectively, or if one of the members of the room fall into the appservice's user + namespace. - async def _matches_aliases(self, event: EventBase, store: "DataStore") -> bool: - alias_list = await store.get_aliases_for_room(event.room_id) + Args: + room_id: The ID of the room to check. + store: The homeserver's datastore class. + + Returns: + True if the application service is interested in the room, False if not. + """ + # Check if we have interest in this room ID + if self.is_room_id_in_namespace(room_id): + return True + + # likewise with the room's aliases (if it has any) + alias_list = await store.get_aliases_for_room(room_id) for alias in alias_list: - if self.is_interested_in_alias(alias): + if self.is_room_alias_in_namespace(alias): return True - return False + # And finally, perform an expensive check on whether any of the + # users in the room match the appservice's user namespace + return await self._matches_user_in_member_list( + room_id, store, on_invalidate=cache_context.invalidate + ) - async def is_interested(self, event: EventBase, store: "DataStore") -> bool: + @cached(num_args=1, cache_context=True) + async def is_interested_in_event( + self, + event_id: str, + event: EventBase, + store: "DataStore", + cache_context: _CacheContext, + ) -> bool: """Check if this service is interested in this event. Args: + event_id: The ID of the event to check. This is purely used for simplifying the + caching of calls to this method. event: The event to check. store: The datastore to query. Returns: - True if this service would like to know about this event. + True if this service would like to know about this event, otherwise False. """ - # Do cheap checks first - if self._matches_room_id(event): + # Check if we're interested in this event's sender by namespace (or if they're the + # sender_localpart user) + if self.is_interested_in_user(event.sender): return True - # This will check the namespaces first before - # checking the store, so should be run before _matches_aliases - if await self._matches_user(event, store): + # additionally, if this is a membership event, perform the same checks on + # the user it references + if event.type == EventTypes.Member and self.is_interested_in_user( + event.state_key + ): return True - # This will check the store, so should be run last - if await self._matches_aliases(event, store): + # This will check the datastore, so should be run last + if await self.is_interested_in_room( + event.room_id, store, on_invalidate=cache_context.invalidate + ): return True return False - @cached(num_args=1) + @cached(num_args=1, cache_context=True) async def is_interested_in_presence( - self, user_id: UserID, store: "DataStore" + self, user_id: UserID, store: "DataStore", cache_context: _CacheContext ) -> bool: """Check if this service is interested a user's presence @@ -272,20 +322,19 @@ class ApplicationService: # Then find out if the appservice is interested in any of those rooms for room_id in room_ids: - if await self.matches_user_in_member_list(room_id, store): + if await self.is_interested_in_room( + room_id, store, on_invalidate=cache_context.invalidate + ): return True return False - def is_interested_in_user(self, user_id: str) -> bool: - return ( - bool(self._matches_regex(ApplicationService.NS_USERS, user_id)) - or user_id == self.sender - ) + def is_user_in_namespace(self, user_id: str) -> bool: + return bool(self._matches_regex(ApplicationService.NS_USERS, user_id)) - def is_interested_in_alias(self, alias: str) -> bool: + def is_room_alias_in_namespace(self, alias: str) -> bool: return bool(self._matches_regex(ApplicationService.NS_ALIASES, alias)) - def is_interested_in_room(self, room_id: str) -> bool: + def is_room_id_in_namespace(self, room_id: str) -> bool: return bool(self._matches_regex(ApplicationService.NS_ROOMS, room_id)) def is_exclusive_user(self, user_id: str) -> bool: diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index e6461cc3c9..bd913e524e 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -571,7 +571,7 @@ class ApplicationServicesHandler: room_alias_str = room_alias.to_string() services = self.store.get_app_services() alias_query_services = [ - s for s in services if (s.is_interested_in_alias(room_alias_str)) + s for s in services if (s.is_room_alias_in_namespace(room_alias_str)) ] for alias_service in alias_query_services: is_known_alias = await self.appservice_api.query_alias( @@ -660,7 +660,7 @@ class ApplicationServicesHandler: # inside of a list comprehension anymore. interested_list = [] for s in services: - if await s.is_interested(event, self.store): + if await s.is_interested_in_event(event.event_id, event, self.store): interested_list.append(s) return interested_list diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index b7064c6624..33d827a45b 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -119,7 +119,7 @@ class DirectoryHandler: service = requester.app_service if service: - if not service.is_interested_in_alias(room_alias_str): + if not service.is_room_alias_in_namespace(room_alias_str): raise SynapseError( 400, "This application service has not reserved this kind of alias.", @@ -221,7 +221,7 @@ class DirectoryHandler: async def delete_appservice_association( self, service: ApplicationService, room_alias: RoomAlias ) -> None: - if not service.is_interested_in_alias(room_alias.to_string()): + if not service.is_room_alias_in_namespace(room_alias.to_string()): raise SynapseError( 400, "This application service has not reserved this kind of alias", @@ -376,7 +376,7 @@ class DirectoryHandler: # non-exclusive locks on the alias (or there are no interested services) services = self.store.get_app_services() interested_services = [ - s for s in services if s.is_interested_in_alias(alias.to_string()) + s for s in services if s.is_room_alias_in_namespace(alias.to_string()) ] for service in interested_services: diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index b4132c353a..6250bb3bdf 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -269,7 +269,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): # Then filter down to rooms that the AS can read events = [] for room_id, event in rooms_to_events.items(): - if not await service.matches_user_in_member_list(room_id, self.store): + if not await service.is_interested_in_room(room_id, self.store): continue events.append(event) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 843c68eb0f..3b89126528 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -486,9 +486,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): if handler._room_serials[room_id] <= from_key: continue - if not await service.matches_user_in_member_list( - room_id, self._main_store - ): + if not await service.is_interested_in_room(room_id, self._main_store): continue events.append(self._make_event_for(room_id)) diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 9bd6275e92..edc584d0cf 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -36,7 +36,10 @@ class ApplicationServiceTestCase(unittest.TestCase): hostname="matrix.org", # only used by get_groups_for_user ) self.event = Mock( - type="m.something", room_id="!foo:bar", sender="@someone:somewhere" + event_id="$abc:xyz", + type="m.something", + room_id="!foo:bar", + sender="@someone:somewhere", ) self.store = Mock() @@ -50,7 +53,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -62,7 +67,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertFalse( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -76,7 +83,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -90,7 +99,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -104,7 +115,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertFalse( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -121,7 +134,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -174,7 +189,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertFalse( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -191,7 +208,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -207,7 +226,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -225,7 +246,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(event=self.event, store=self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 072e6bbcdd..cead9f90df 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -59,11 +59,11 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.event_source = hs.get_event_sources() def test_notify_interested_services(self): - interested_service = self._mkservice(is_interested=True) + interested_service = self._mkservice(is_interested_in_event=True) services = [ - self._mkservice(is_interested=False), + self._mkservice(is_interested_in_event=False), interested_service, - self._mkservice(is_interested=False), + self._mkservice(is_interested_in_event=False), ] self.mock_as_api.query_user.return_value = make_awaitable(True) @@ -85,7 +85,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): def test_query_user_exists_unknown_user(self): user_id = "@someone:anywhere" - services = [self._mkservice(is_interested=True)] + services = [self._mkservice(is_interested_in_event=True)] services[0].is_interested_in_user.return_value = True self.mock_store.get_app_services.return_value = services self.mock_store.get_user_by_id.return_value = make_awaitable(None) @@ -102,7 +102,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): def test_query_user_exists_known_user(self): user_id = "@someone:anywhere" - services = [self._mkservice(is_interested=True)] + services = [self._mkservice(is_interested_in_event=True)] services[0].is_interested_in_user.return_value = True self.mock_store.get_app_services.return_value = services self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id}) @@ -127,11 +127,11 @@ class AppServiceHandlerTestCase(unittest.TestCase): room_id = "!alpha:bet" servers = ["aperture"] - interested_service = self._mkservice_alias(is_interested_in_alias=True) + interested_service = self._mkservice_alias(is_room_alias_in_namespace=True) services = [ - self._mkservice_alias(is_interested_in_alias=False), + self._mkservice_alias(is_room_alias_in_namespace=False), interested_service, - self._mkservice_alias(is_interested_in_alias=False), + self._mkservice_alias(is_room_alias_in_namespace=False), ] self.mock_as_api.query_alias.return_value = make_awaitable(True) @@ -275,7 +275,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): to be pushed out to interested appservices, and that the stream ID is updated accordingly. """ - interested_service = self._mkservice(is_interested=True) + interested_service = self._mkservice(is_interested_in_event=True) services = [interested_service] self.mock_store.get_app_services.return_value = services self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable( @@ -304,7 +304,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): Test sending out of order ephemeral events to the appservice handler are ignored. """ - interested_service = self._mkservice(is_interested=True) + interested_service = self._mkservice(is_interested_in_event=True) services = [interested_service] self.mock_store.get_app_services.return_value = services @@ -325,17 +325,45 @@ class AppServiceHandlerTestCase(unittest.TestCase): interested_service, ephemeral=[] ) - def _mkservice(self, is_interested, protocols=None): + def _mkservice( + self, is_interested_in_event: bool, protocols: Optional[Iterable] = None + ) -> Mock: + """ + Create a new mock representing an ApplicationService. + + Args: + is_interested_in_event: Whether this application service will be considered + interested in all events. + protocols: The third-party protocols that this application service claims to + support. + + Returns: + A mock representing the ApplicationService. + """ service = Mock() - service.is_interested.return_value = make_awaitable(is_interested) + service.is_interested_in_event.return_value = make_awaitable( + is_interested_in_event + ) service.token = "mock_service_token" service.url = "mock_service_url" service.protocols = protocols return service - def _mkservice_alias(self, is_interested_in_alias): + def _mkservice_alias(self, is_room_alias_in_namespace: bool) -> Mock: + """ + Create a new mock representing an ApplicationService that is or is not interested + any given room aliase. + + Args: + is_room_alias_in_namespace: If true, the application service will be interested + in all room aliases that are queried against it. If false, the application + service will not be interested in any room aliases. + + Returns: + A mock representing the ApplicationService. + """ service = Mock() - service.is_interested_in_alias.return_value = is_interested_in_alias + service.is_room_alias_in_namespace.return_value = is_room_alias_in_namespace service.token = "mock_service_token" service.url = "mock_service_url" return service -- cgit 1.4.1 From 8533c8b03d8916e3805c7d0e0020226017680147 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 3 Mar 2022 19:58:08 +0000 Subject: Avoid generating state groups for local out-of-band leaves (#12154) If we locally generate a rejection for an invite received over federation, it is stored as an outlier (because we probably don't have the state for the room). However, currently we still generate a state group for it (even though the state in that state group will be nonsense). By setting the `outlier` param on `create_event`, we avoid the nonsensical state. --- changelog.d/12154.misc | 1 + synapse/handlers/room_member.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12154.misc diff --git a/changelog.d/12154.misc b/changelog.d/12154.misc new file mode 100644 index 0000000000..18d2a4728b --- /dev/null +++ b/changelog.d/12154.misc @@ -0,0 +1 @@ +Avoid generating state groups for local out-of-band leaves. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a582837cf0..7cbc484b06 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1736,8 +1736,8 @@ class RoomMemberMasterHandler(RoomMemberHandler): txn_id=txn_id, prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, + outlier=True, ) - event.internal_metadata.outlier = True event.internal_metadata.out_of_band_membership = True result_event = await self.event_creation_handler.handle_new_client_event( -- cgit 1.4.1 From d56202b0383627fdb4e04404d62922dce16868f8 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 4 Mar 2022 10:25:18 +0000 Subject: Fix type of `events` in `StateGroupStorage` and `StateHandler` (#12156) We make multiple passes over this, so a regular iterable won't do. --- changelog.d/12156.misc | 1 + synapse/state/__init__.py | 6 +++--- synapse/storage/state.py | 8 ++++---- 3 files changed, 8 insertions(+), 7 deletions(-) create mode 100644 changelog.d/12156.misc diff --git a/changelog.d/12156.misc b/changelog.d/12156.misc new file mode 100644 index 0000000000..4818d988d7 --- /dev/null +++ b/changelog.d/12156.misc @@ -0,0 +1 @@ +Fix some type annotations. diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 6babd5963c..21888cc8c5 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -194,7 +194,7 @@ class StateHandler: } async def get_current_state_ids( - self, room_id: str, latest_event_ids: Optional[Iterable[str]] = None + self, room_id: str, latest_event_ids: Optional[Collection[str]] = None ) -> StateMap[str]: """Get the current state, or the state at a set of events, for a room @@ -243,7 +243,7 @@ class StateHandler: return await self.get_hosts_in_room_at_events(room_id, event_ids) async def get_hosts_in_room_at_events( - self, room_id: str, event_ids: Iterable[str] + self, room_id: str, event_ids: Collection[str] ) -> Set[str]: """Get the hosts that were in a room at the given event ids @@ -404,7 +404,7 @@ class StateHandler: @measure_func() async def resolve_state_groups_for_events( - self, room_id: str, event_ids: Iterable[str] + self, room_id: str, event_ids: Collection[str] ) -> _StateCacheEntry: """Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. diff --git a/synapse/storage/state.py b/synapse/storage/state.py index e79ecf64a0..86f1a5373b 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -561,7 +561,7 @@ class StateGroupStorage: return state_group_delta.prev_group, state_group_delta.delta_ids async def get_state_groups_ids( - self, _room_id: str, event_ids: Iterable[str] + self, _room_id: str, event_ids: Collection[str] ) -> Dict[int, MutableStateMap[str]]: """Get the event IDs of all the state for the state groups for the given events @@ -596,7 +596,7 @@ class StateGroupStorage: return group_to_state[state_group] async def get_state_groups( - self, room_id: str, event_ids: Iterable[str] + self, room_id: str, event_ids: Collection[str] ) -> Dict[int, List[EventBase]]: """Get the state groups for the given list of event_ids @@ -648,7 +648,7 @@ class StateGroupStorage: return self.stores.state._get_state_groups_from_groups(groups, state_filter) async def get_state_for_events( - self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None + self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None ) -> Dict[str, StateMap[EventBase]]: """Given a list of event_ids and type tuples, return a list of state dicts for each event. @@ -684,7 +684,7 @@ class StateGroupStorage: return {event: event_to_state[event] for event in event_ids} async def get_state_ids_for_events( - self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None + self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None ) -> Dict[str, StateMap[str]]: """ Get the state dicts corresponding to a list of events, containing the event_ids -- cgit 1.4.1 From 87c230c27cdeb7e421f61f1271a500c760f1f63b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 4 Mar 2022 10:31:19 +0000 Subject: Update client-visibility filtering for outlier events (#12155) Avoid trying to get the state for outliers, which isn't a sensible thing to do. --- changelog.d/12155.misc | 1 + synapse/visibility.py | 17 ++++++++++- tests/test_visibility.py | 76 ++++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 90 insertions(+), 4 deletions(-) create mode 100644 changelog.d/12155.misc diff --git a/changelog.d/12155.misc b/changelog.d/12155.misc new file mode 100644 index 0000000000..9f333e718a --- /dev/null +++ b/changelog.d/12155.misc @@ -0,0 +1 @@ +Avoid trying to calculate the state at outlier events. diff --git a/synapse/visibility.py b/synapse/visibility.py index 1b970ce479..281cbe4d88 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -81,8 +81,9 @@ async def filter_events_for_client( types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) + # we exclude outliers at this point, and then handle them separately later event_id_to_state = await storage.state.get_state_for_events( - frozenset(e.event_id for e in events), + frozenset(e.event_id for e in events if not e.internal_metadata.outlier), state_filter=StateFilter.from_types(types), ) @@ -154,6 +155,17 @@ async def filter_events_for_client( if event.event_id in always_include_ids: return event + # we need to handle outliers separately, since we don't have the room state. + if event.internal_metadata.outlier: + # Normally these can't be seen by clients, but we make an exception for + # for out-of-band membership events (eg, incoming invites, or rejections of + # said invite) for the user themselves. + if event.type == EventTypes.Member and event.state_key == user_id: + logger.debug("Returning out-of-band-membership event %s", event) + return event + + return None + state = event_id_to_state[event.event_id] # get the room_visibility at the time of the event. @@ -198,6 +210,9 @@ async def filter_events_for_client( # Always allow the user to see their own leave events, otherwise # they won't see the room disappear if they reject the invite + # + # (Note this doesn't work for out-of-band invite rejections, which don't + # have prev_state populated. They are handled above in the outlier code.) if membership == "leave" and ( prev_membership == "join" or prev_membership == "invite" ): diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 219b5660b1..532e3fe9cd 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -13,11 +13,12 @@ # limitations under the License. import logging from typing import Optional +from unittest.mock import patch from synapse.api.room_versions import RoomVersions -from synapse.events import EventBase -from synapse.types import JsonDict -from synapse.visibility import filter_events_for_server +from synapse.events import EventBase, make_event_from_dict +from synapse.types import JsonDict, create_requester +from synapse.visibility import filter_events_for_client, filter_events_for_server from tests import unittest from tests.utils import create_room @@ -185,3 +186,72 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): self.get_success(self.storage.persistence.persist_event(event, context)) return event + + +class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): + def test_out_of_band_invite_rejection(self): + # this is where we have received an invite event over federation, and then + # rejected it. + invite_pdu = { + "room_id": "!room:id", + "depth": 1, + "auth_events": [], + "prev_events": [], + "origin_server_ts": 1, + "sender": "@someone:" + self.OTHER_SERVER_NAME, + "type": "m.room.member", + "state_key": "@user:test", + "content": {"membership": "invite"}, + } + self.add_hashes_and_signatures(invite_pdu) + invite_event_id = make_event_from_dict(invite_pdu, RoomVersions.V9).event_id + + self.get_success( + self.hs.get_federation_server().on_invite_request( + self.OTHER_SERVER_NAME, + invite_pdu, + "9", + ) + ) + + # stub out do_remotely_reject_invite so that we fall back to a locally- + # generated rejection + with patch.object( + self.hs.get_federation_handler(), + "do_remotely_reject_invite", + side_effect=Exception(), + ): + reject_event_id, _ = self.get_success( + self.hs.get_room_member_handler().remote_reject_invite( + invite_event_id, + txn_id=None, + requester=create_requester("@user:test"), + content={}, + ) + ) + + invite_event, reject_event = self.get_success( + self.hs.get_datastores().main.get_events_as_list( + [invite_event_id, reject_event_id] + ) + ) + + # the invited user should be able to see both the invite and the rejection + self.assertEqual( + self.get_success( + filter_events_for_client( + self.hs.get_storage(), "@user:test", [invite_event, reject_event] + ) + ), + [invite_event, reject_event], + ) + + # other users should see neither + self.assertEqual( + self.get_success( + filter_events_for_client( + self.hs.get_storage(), "@other:test", [invite_event, reject_event] + ) + ), + [], + ) -- cgit 1.4.1 From 423cca9efe06d78aaca5f62fb74ee7e5bceebe49 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 4 Mar 2022 11:48:15 +0000 Subject: Spread out sending device lists to remote hosts (#12132) --- changelog.d/12132.feature | 1 + synapse/federation/send_queue.py | 2 +- synapse/federation/sender/__init__.py | 26 +++++++---- synapse/federation/sender/per_destination_queue.py | 10 +++++ synapse/handlers/device.py | 2 +- synapse/replication/tcp/client.py | 2 +- tests/federation/test_federation_sender.py | 52 ++++++++++++++++++++-- 7 files changed, 79 insertions(+), 16 deletions(-) create mode 100644 changelog.d/12132.feature diff --git a/changelog.d/12132.feature b/changelog.d/12132.feature new file mode 100644 index 0000000000..3b8362ad35 --- /dev/null +++ b/changelog.d/12132.feature @@ -0,0 +1 @@ +Improve performance of logging in for large accounts. diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 0d7c4f5067..d720b5fd3f 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -244,7 +244,7 @@ class FederationRemoteSendQueue(AbstractFederationSender): self.notifier.on_new_replication_data() - def send_device_messages(self, destination: str) -> None: + def send_device_messages(self, destination: str, immediate: bool = False) -> None: """As per FederationSender""" # We don't need to replicate this as it gets sent down a different # stream. diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 6106a486d1..30e2421efc 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -118,7 +118,12 @@ class AbstractFederationSender(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod - def send_device_messages(self, destination: str) -> None: + def send_device_messages(self, destination: str, immediate: bool = True) -> None: + """Tells the sender that a new device message is ready to be sent to the + destination. The `immediate` flag specifies whether the messages should + be tried to be sent immediately, or whether it can be delayed for a + short while (to aid performance). + """ raise NotImplementedError() @abc.abstractmethod @@ -146,9 +151,8 @@ class AbstractFederationSender(metaclass=abc.ABCMeta): @attr.s -class _PresenceQueue: - """A queue of destinations that need to be woken up due to new presence - updates. +class _DestinationWakeupQueue: + """A queue of destinations that need to be woken up due to new updates. Staggers waking up of per destination queues to ensure that we don't attempt to start TLS connections with many hosts all at once, leading to pinned CPU. @@ -175,7 +179,7 @@ class _PresenceQueue: if not self.processing: self._handle() - @wrap_as_background_process("_PresenceQueue.handle") + @wrap_as_background_process("_DestinationWakeupQueue.handle") async def _handle(self) -> None: """Background process to drain the queue.""" @@ -297,7 +301,7 @@ class FederationSender(AbstractFederationSender): self._external_cache = hs.get_external_cache() - self._presence_queue = _PresenceQueue(self, self.clock) + self._destination_wakeup_queue = _DestinationWakeupQueue(self, self.clock) def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue: """Get or create a PerDestinationQueue for the given destination @@ -614,7 +618,7 @@ class FederationSender(AbstractFederationSender): states, start_loop=False ) - self._presence_queue.add_to_queue(destination) + self._destination_wakeup_queue.add_to_queue(destination) def build_and_send_edu( self, @@ -667,7 +671,7 @@ class FederationSender(AbstractFederationSender): else: queue.send_edu(edu) - def send_device_messages(self, destination: str) -> None: + def send_device_messages(self, destination: str, immediate: bool = False) -> None: if destination == self.server_name: logger.warning("Not sending device update to ourselves") return @@ -677,7 +681,11 @@ class FederationSender(AbstractFederationSender): ): return - self._get_per_destination_queue(destination).attempt_new_transaction() + if immediate: + self._get_per_destination_queue(destination).attempt_new_transaction() + else: + self._get_per_destination_queue(destination).mark_new_data() + self._destination_wakeup_queue.add_to_queue(destination) def wake_destination(self, destination: str) -> None: """Called when we want to retry sending transactions to a remote. diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index c8768f22bc..d80f0ac5e8 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -219,6 +219,16 @@ class PerDestinationQueue: self._pending_edus.append(edu) self.attempt_new_transaction() + def mark_new_data(self) -> None: + """Marks that the destination has new data to send, without starting a + new transaction. + + If a transaction loop is already in progress then a new transcation will + be attempted when the current one finishes. + """ + + self._new_data_to_send = True + def attempt_new_transaction(self) -> None: """Try to start a new transaction to this destination diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 934b5bd734..d90cb259a6 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -506,7 +506,7 @@ class DeviceHandler(DeviceWorkerHandler): "Sending device list update notif for %r to: %r", user_id, hosts ) for host in hosts: - self.federation_sender.send_device_messages(host) + self.federation_sender.send_device_messages(host, immediate=False) log_kv({"message": "sent device update to host", "host": host}) async def notify_user_signature_update( diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 1b8479b0b4..b8fc1d4db9 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -380,7 +380,7 @@ class FederationSenderHandler: # changes. hosts = {row.entity for row in rows if not row.entity.startswith("@")} for host in hosts: - self.federation_sender.send_device_messages(host) + self.federation_sender.send_device_messages(host, immediate=False) elif stream_name == ToDeviceStream.NAME: # The to_device stream includes stuff to be pushed to both local diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 60e0c31f43..e90592855a 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -201,9 +201,12 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.assertEqual(len(self.edus), 1) stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + # a second call should produce no new device EDUs self.hs.get_federation_sender().send_device_messages("host2") - self.pump() self.assertEqual(self.edus, []) # a second device @@ -232,6 +235,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): device1_signing_key = self.generate_and_upload_device_signing_key(u1, "D1") device2_signing_key = self.generate_and_upload_device_signing_key(u1, "D2") + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + # expect two more edus self.assertEqual(len(self.edus), 2) stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id) @@ -265,6 +272,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): e2e_handler.upload_signing_keys_for_user(u1, cross_signing_keys) ) + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + # expect signing key update edu self.assertEqual(len(self.edus), 2) self.assertEqual(self.edus.pop(0)["edu_type"], "m.signing_key_update") @@ -284,6 +295,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): ) self.assertEqual(ret["failures"], {}) + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + # expect two edus, in one or two transactions. We don't know what order the # devices will be updated. self.assertEqual(len(self.edus), 2) @@ -307,6 +322,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.login("user", "pass", device_id="D2") self.login("user", "pass", device_id="D3") + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + # expect three edus self.assertEqual(len(self.edus), 3) stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) @@ -318,6 +337,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) ) + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + # expect three edus, in an unknown order self.assertEqual(len(self.edus), 3) for edu in self.edus: @@ -350,12 +373,19 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) ) + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + self.assertGreaterEqual(mock_send_txn.call_count, 4) # recover the server mock_send_txn.side_effect = self.record_transaction self.hs.get_federation_sender().send_device_messages("host2") - self.pump() + + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) # for each device, there should be a single update self.assertEqual(len(self.edus), 3) @@ -390,6 +420,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) ) + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + self.assertGreaterEqual(mock_send_txn.call_count, 4) # run the prune job @@ -401,7 +435,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # recover the server mock_send_txn.side_effect = self.record_transaction self.hs.get_federation_sender().send_device_messages("host2") - self.pump() + + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) # there should be a single update for this user. self.assertEqual(len(self.edus), 1) @@ -435,6 +472,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.login("user", "pass", device_id="D2") self.login("user", "pass", device_id="D3") + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + # delete them again self.get_success( self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) @@ -451,7 +492,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # recover the server mock_send_txn.side_effect = self.record_transaction self.hs.get_federation_sender().send_device_messages("host2") - self.pump() + + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) # ... and we should get a single update for this user. self.assertEqual(len(self.edus), 1) -- cgit 1.4.1 From 4aeb00ca20a0d9dbb2a104591aca081c723eb6d9 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 4 Mar 2022 11:58:49 +0000 Subject: Move synctl into `synapse._scripts` and expose as an entrypoint (#12140) --- .dockerignore | 1 - MANIFEST.in | 1 - changelog.d/12140.misc | 1 + docker/Dockerfile | 2 +- docs/postgres.md | 8 +- docs/turn-howto.md | 5 +- docs/upgrade.md | 23 ++- scripts-dev/lint.sh | 2 +- setup.py | 2 +- synapse/_scripts/synctl.py | 360 +++++++++++++++++++++++++++++++++++++++++++++ synctl | 360 --------------------------------------------- tox.ini | 1 - 12 files changed, 393 insertions(+), 373 deletions(-) create mode 100644 changelog.d/12140.misc create mode 100755 synapse/_scripts/synctl.py delete mode 100755 synctl diff --git a/.dockerignore b/.dockerignore index 617f701597..434231fce9 100644 --- a/.dockerignore +++ b/.dockerignore @@ -7,6 +7,5 @@ !MANIFEST.in !README.rst !setup.py -!synctl **/__pycache__ diff --git a/MANIFEST.in b/MANIFEST.in index f1e295e583..d744c090ac 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,3 @@ -include synctl include LICENSE include VERSION include *.rst diff --git a/changelog.d/12140.misc b/changelog.d/12140.misc new file mode 100644 index 0000000000..33a21a29f0 --- /dev/null +++ b/changelog.d/12140.misc @@ -0,0 +1 @@ +Move `synctl` into `synapse._scripts` and expose as an entry point. \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile index 327275a9ca..24b5515eb9 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -46,7 +46,7 @@ RUN \ && rm -rf /var/lib/apt/lists/* # Copy just what we need to pip install -COPY MANIFEST.in README.rst setup.py synctl /synapse/ +COPY MANIFEST.in README.rst setup.py /synapse/ COPY synapse/__init__.py /synapse/synapse/__init__.py COPY synapse/python_dependencies.py /synapse/synapse/python_dependencies.py diff --git a/docs/postgres.md b/docs/postgres.md index 0562021da5..de4e2ba4b7 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -153,9 +153,9 @@ database file (typically `homeserver.db`) to another location. Once the copy is complete, restart synapse. For instance: ```sh -./synctl stop +synctl stop cp homeserver.db homeserver.db.snapshot -./synctl start +synctl start ``` Copy the old config file into a new config file: @@ -192,10 +192,10 @@ Once that has completed, change the synapse config to point at the PostgreSQL database configuration file `homeserver-postgres.yaml`: ```sh -./synctl stop +synctl stop mv homeserver.yaml homeserver-old-sqlite.yaml mv homeserver-postgres.yaml homeserver.yaml -./synctl start +synctl start ``` Synapse should now be running against PostgreSQL. diff --git a/docs/turn-howto.md b/docs/turn-howto.md index eba7ca6124..3a2cd04e36 100644 --- a/docs/turn-howto.md +++ b/docs/turn-howto.md @@ -238,8 +238,9 @@ After updating the homeserver configuration, you must restart synapse: * If you use synctl: ```sh - cd /where/you/run/synapse - ./synctl restart + # Depending on how Synapse is installed, synctl may already be on + # your PATH. If not, you may need to activate a virtual environment. + synctl restart ``` * If you use systemd: ```sh diff --git a/docs/upgrade.md b/docs/upgrade.md index f9be3ac6bc..0d0bb066ee 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -47,7 +47,7 @@ this document. 3. Restart Synapse: ```bash - ./synctl restart + synctl restart ``` To check whether your update was successful, you can check the running @@ -85,6 +85,27 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.55.0 + +## `synctl` script has been moved + +The `synctl` script +[has been made](https://github.com/matrix-org/synapse/pull/12140) an +[entry point](https://packaging.python.org/en/latest/specifications/entry-points/) +and no longer exists at the root of Synapse's source tree. If you wish to use +`synctl` to manage your homeserver, you should invoke `synctl` directly, e.g. +`synctl start` instead of `./synctl start` or `/path/to/synctl start`. + +You will need to ensure `synctl` is on your `PATH`. + - This is automatically the case when using + [Debian packages](https://packages.matrix.org/debian/) or + [docker images](https://hub.docker.com/r/matrixdotorg/synapse) + provided by Matrix.org. + - When installing from a wheel, sdist, or PyPI, a `synctl` executable is added + to your Python installation's `bin`. This should be on your `PATH` + automatically, though you might need to activate a virtual environment + depending on how you installed Synapse. + # Upgrading to v1.54.0 ## Legacy structured logging configuration removal diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index 2f5f2c3566..c063fafa97 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -85,7 +85,7 @@ else "synapse" "docker" "tests" # annoyingly, black doesn't find these so we have to list them "scripts-dev" - "contrib" "synctl" "setup.py" "synmark" "stubs" ".ci" + "contrib" "setup.py" "synmark" "stubs" ".ci" ) fi fi diff --git a/setup.py b/setup.py index 318df16766..439ed75d72 100755 --- a/setup.py +++ b/setup.py @@ -155,6 +155,7 @@ setup( # Application "synapse_homeserver = synapse.app.homeserver:main", "synapse_worker = synapse.app.generic_worker:main", + "synctl = synapse._scripts.synctl:main", # Scripts "export_signing_key = synapse._scripts.export_signing_key:main", "generate_config = synapse._scripts.generate_config:main", @@ -177,6 +178,5 @@ setup( "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", ], - scripts=["synctl"], cmdclass={"test": TestCommand}, ) diff --git a/synapse/_scripts/synctl.py b/synapse/_scripts/synctl.py new file mode 100755 index 0000000000..1ab36949c7 --- /dev/null +++ b/synapse/_scripts/synctl.py @@ -0,0 +1,360 @@ +#!/usr/bin/env python +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import collections +import errno +import glob +import os +import os.path +import signal +import subprocess +import sys +import time +from typing import Iterable, Optional + +import yaml + +from synapse.config import find_config_files + +MAIN_PROCESS = "synapse.app.homeserver" + +GREEN = "\x1b[1;32m" +YELLOW = "\x1b[1;33m" +RED = "\x1b[1;31m" +NORMAL = "\x1b[m" + +SYNCTL_CACHE_FACTOR_WARNING = """\ +Setting 'synctl_cache_factor' in the config is deprecated. Instead, please do +one of the following: + - Either set the environment variable 'SYNAPSE_CACHE_FACTOR' + - or set 'caches.global_factor' in the homeserver config. +--------------------------------------------------------------------------------""" + + +def pid_running(pid): + try: + os.kill(pid, 0) + except OSError as err: + if err.errno == errno.EPERM: + pass # process exists + else: + return False + + # When running in a container, orphan processes may not get reaped and their + # PIDs may remain valid. Try to work around the issue. + try: + with open(f"/proc/{pid}/status") as status_file: + if "zombie" in status_file.read(): + return False + except Exception: + # This isn't Linux or `/proc/` is unavailable. + # Assume that the process is still running. + pass + + return True + + +def write(message, colour=NORMAL, stream=sys.stdout): + # Lets check if we're writing to a TTY before colouring + should_colour = False + try: + should_colour = stream.isatty() + except AttributeError: + # Just in case `isatty` isn't defined on everything. The python + # docs are incredibly vague. + pass + + if not should_colour: + stream.write(message + "\n") + else: + stream.write(colour + message + NORMAL + "\n") + + +def abort(message, colour=RED, stream=sys.stderr): + write(message, colour, stream) + sys.exit(1) + + +def start(pidfile: str, app: str, config_files: Iterable[str], daemonize: bool) -> bool: + """Attempts to start a synapse main or worker process. + Args: + pidfile: the pidfile we expect the process to create + app: the python module to run + config_files: config files to pass to synapse + daemonize: if True, will include a --daemonize argument to synapse + + Returns: + True if the process started successfully or was already running + False if there was an error starting the process + """ + + if os.path.exists(pidfile) and pid_running(int(open(pidfile).read())): + print(app + " already running") + return True + + args = [sys.executable, "-m", app] + for c in config_files: + args += ["-c", c] + if daemonize: + args.append("--daemonize") + + try: + subprocess.check_call(args) + write("started %s(%s)" % (app, ",".join(config_files)), colour=GREEN) + return True + except subprocess.CalledProcessError as e: + err = "%s(%s) failed to start (exit code: %d). Check the Synapse logfile" % ( + app, + ",".join(config_files), + e.returncode, + ) + if daemonize: + err += ", or run synctl with --no-daemonize" + err += "." + write(err, colour=RED, stream=sys.stderr) + return False + + +def stop(pidfile: str, app: str) -> Optional[int]: + """Attempts to kill a synapse worker from the pidfile. + Args: + pidfile: path to file containing worker's pid + app: name of the worker's appservice + + Returns: + process id, or None if the process was not running + """ + + if os.path.exists(pidfile): + pid = int(open(pidfile).read()) + try: + os.kill(pid, signal.SIGTERM) + write("stopped %s" % (app,), colour=GREEN) + return pid + except OSError as err: + if err.errno == errno.ESRCH: + write("%s not running" % (app,), colour=YELLOW) + elif err.errno == errno.EPERM: + abort("Cannot stop %s: Operation not permitted" % (app,)) + else: + abort("Cannot stop %s: Unknown error" % (app,)) + else: + write( + "No running worker of %s found (from %s)\nThe process might be managed by another controller (e.g. systemd)" + % (app, pidfile), + colour=YELLOW, + ) + return None + + +Worker = collections.namedtuple( + "Worker", ["app", "configfile", "pidfile", "cache_factor", "cache_factors"] +) + + +def main(): + + parser = argparse.ArgumentParser() + + parser.add_argument( + "action", + choices=["start", "stop", "restart"], + help="whether to start, stop or restart the synapse", + ) + parser.add_argument( + "configfile", + nargs="?", + default="homeserver.yaml", + help="the homeserver config file. Defaults to homeserver.yaml. May also be" + " a directory with *.yaml files", + ) + parser.add_argument( + "-w", "--worker", metavar="WORKERCONFIG", help="start or stop a single worker" + ) + parser.add_argument( + "-a", + "--all-processes", + metavar="WORKERCONFIGDIR", + help="start or stop all the workers in the given directory" + " and the main synapse process", + ) + parser.add_argument( + "--no-daemonize", + action="store_false", + dest="daemonize", + help="Run synapse in the foreground for debugging. " + "Will work only if the daemonize option is not set in the config.", + ) + + options = parser.parse_args() + + if options.worker and options.all_processes: + write('Cannot use "--worker" with "--all-processes"', stream=sys.stderr) + sys.exit(1) + if not options.daemonize and options.all_processes: + write('Cannot use "--no-daemonize" with "--all-processes"', stream=sys.stderr) + sys.exit(1) + + configfile = options.configfile + + if not os.path.exists(configfile): + write( + f"Config file {configfile} does not exist.\n" + f"To generate a config file, run:\n" + f" {sys.executable} -m {MAIN_PROCESS}" + f" -c {configfile} --generate-config" + f" --server-name= --report-stats=\n", + stream=sys.stderr, + ) + sys.exit(1) + + config_files = find_config_files([configfile]) + config = {} + for config_file in config_files: + with open(config_file) as file_stream: + yaml_config = yaml.safe_load(file_stream) + if yaml_config is not None: + config.update(yaml_config) + + pidfile = config["pid_file"] + cache_factor = config.get("synctl_cache_factor") + start_stop_synapse = True + + if cache_factor: + write(SYNCTL_CACHE_FACTOR_WARNING) + os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) + + cache_factors = config.get("synctl_cache_factors", {}) + for cache_name, factor in cache_factors.items(): + os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor) + + worker_configfiles = [] + if options.worker: + start_stop_synapse = False + worker_configfile = options.worker + if not os.path.exists(worker_configfile): + write( + "No worker config found at %r" % (worker_configfile,), stream=sys.stderr + ) + sys.exit(1) + worker_configfiles.append(worker_configfile) + + if options.all_processes: + # To start the main synapse with -a you need to add a worker file + # with worker_app == "synapse.app.homeserver" + start_stop_synapse = False + worker_configdir = options.all_processes + if not os.path.isdir(worker_configdir): + write( + "No worker config directory found at %r" % (worker_configdir,), + stream=sys.stderr, + ) + sys.exit(1) + worker_configfiles.extend( + sorted(glob.glob(os.path.join(worker_configdir, "*.yaml"))) + ) + + workers = [] + for worker_configfile in worker_configfiles: + with open(worker_configfile) as stream: + worker_config = yaml.safe_load(stream) + worker_app = worker_config["worker_app"] + if worker_app == "synapse.app.homeserver": + # We need to special case all of this to pick up options that may + # be set in the main config file or in this worker config file. + worker_pidfile = worker_config.get("pid_file") or pidfile + worker_cache_factor = ( + worker_config.get("synctl_cache_factor") or cache_factor + ) + worker_cache_factors = ( + worker_config.get("synctl_cache_factors") or cache_factors + ) + # The master process doesn't support using worker_* config. + for key in worker_config: + if key == "worker_app": # But we allow worker_app + continue + assert not key.startswith( + "worker_" + ), "Main process cannot use worker_* config" + else: + worker_pidfile = worker_config["worker_pid_file"] + worker_cache_factor = worker_config.get("synctl_cache_factor") + worker_cache_factors = worker_config.get("synctl_cache_factors", {}) + workers.append( + Worker( + worker_app, + worker_configfile, + worker_pidfile, + worker_cache_factor, + worker_cache_factors, + ) + ) + + action = options.action + + if action == "stop" or action == "restart": + running_pids = [] + for worker in workers: + pid = stop(worker.pidfile, worker.app) + if pid is not None: + running_pids.append(pid) + + if start_stop_synapse: + pid = stop(pidfile, MAIN_PROCESS) + if pid is not None: + running_pids.append(pid) + + if len(running_pids) > 0: + write("Waiting for processes to exit...") + for running_pid in running_pids: + while pid_running(running_pid): + time.sleep(0.2) + write("All processes exited") + + if action == "start" or action == "restart": + error = False + if start_stop_synapse: + if not start(pidfile, MAIN_PROCESS, (configfile,), options.daemonize): + error = True + + for worker in workers: + env = os.environ.copy() + + if worker.cache_factor: + os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor) + + for cache_name, factor in worker.cache_factors.items(): + os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor) + + if not start( + worker.pidfile, + worker.app, + (configfile, worker.configfile), + options.daemonize, + ): + error = True + + # Reset env back to the original + os.environ.clear() + os.environ.update(env) + + if error: + exit(1) + + +if __name__ == "__main__": + main() diff --git a/synctl b/synctl deleted file mode 100755 index 1ab36949c7..0000000000 --- a/synctl +++ /dev/null @@ -1,360 +0,0 @@ -#!/usr/bin/env python -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import collections -import errno -import glob -import os -import os.path -import signal -import subprocess -import sys -import time -from typing import Iterable, Optional - -import yaml - -from synapse.config import find_config_files - -MAIN_PROCESS = "synapse.app.homeserver" - -GREEN = "\x1b[1;32m" -YELLOW = "\x1b[1;33m" -RED = "\x1b[1;31m" -NORMAL = "\x1b[m" - -SYNCTL_CACHE_FACTOR_WARNING = """\ -Setting 'synctl_cache_factor' in the config is deprecated. Instead, please do -one of the following: - - Either set the environment variable 'SYNAPSE_CACHE_FACTOR' - - or set 'caches.global_factor' in the homeserver config. ---------------------------------------------------------------------------------""" - - -def pid_running(pid): - try: - os.kill(pid, 0) - except OSError as err: - if err.errno == errno.EPERM: - pass # process exists - else: - return False - - # When running in a container, orphan processes may not get reaped and their - # PIDs may remain valid. Try to work around the issue. - try: - with open(f"/proc/{pid}/status") as status_file: - if "zombie" in status_file.read(): - return False - except Exception: - # This isn't Linux or `/proc/` is unavailable. - # Assume that the process is still running. - pass - - return True - - -def write(message, colour=NORMAL, stream=sys.stdout): - # Lets check if we're writing to a TTY before colouring - should_colour = False - try: - should_colour = stream.isatty() - except AttributeError: - # Just in case `isatty` isn't defined on everything. The python - # docs are incredibly vague. - pass - - if not should_colour: - stream.write(message + "\n") - else: - stream.write(colour + message + NORMAL + "\n") - - -def abort(message, colour=RED, stream=sys.stderr): - write(message, colour, stream) - sys.exit(1) - - -def start(pidfile: str, app: str, config_files: Iterable[str], daemonize: bool) -> bool: - """Attempts to start a synapse main or worker process. - Args: - pidfile: the pidfile we expect the process to create - app: the python module to run - config_files: config files to pass to synapse - daemonize: if True, will include a --daemonize argument to synapse - - Returns: - True if the process started successfully or was already running - False if there was an error starting the process - """ - - if os.path.exists(pidfile) and pid_running(int(open(pidfile).read())): - print(app + " already running") - return True - - args = [sys.executable, "-m", app] - for c in config_files: - args += ["-c", c] - if daemonize: - args.append("--daemonize") - - try: - subprocess.check_call(args) - write("started %s(%s)" % (app, ",".join(config_files)), colour=GREEN) - return True - except subprocess.CalledProcessError as e: - err = "%s(%s) failed to start (exit code: %d). Check the Synapse logfile" % ( - app, - ",".join(config_files), - e.returncode, - ) - if daemonize: - err += ", or run synctl with --no-daemonize" - err += "." - write(err, colour=RED, stream=sys.stderr) - return False - - -def stop(pidfile: str, app: str) -> Optional[int]: - """Attempts to kill a synapse worker from the pidfile. - Args: - pidfile: path to file containing worker's pid - app: name of the worker's appservice - - Returns: - process id, or None if the process was not running - """ - - if os.path.exists(pidfile): - pid = int(open(pidfile).read()) - try: - os.kill(pid, signal.SIGTERM) - write("stopped %s" % (app,), colour=GREEN) - return pid - except OSError as err: - if err.errno == errno.ESRCH: - write("%s not running" % (app,), colour=YELLOW) - elif err.errno == errno.EPERM: - abort("Cannot stop %s: Operation not permitted" % (app,)) - else: - abort("Cannot stop %s: Unknown error" % (app,)) - else: - write( - "No running worker of %s found (from %s)\nThe process might be managed by another controller (e.g. systemd)" - % (app, pidfile), - colour=YELLOW, - ) - return None - - -Worker = collections.namedtuple( - "Worker", ["app", "configfile", "pidfile", "cache_factor", "cache_factors"] -) - - -def main(): - - parser = argparse.ArgumentParser() - - parser.add_argument( - "action", - choices=["start", "stop", "restart"], - help="whether to start, stop or restart the synapse", - ) - parser.add_argument( - "configfile", - nargs="?", - default="homeserver.yaml", - help="the homeserver config file. Defaults to homeserver.yaml. May also be" - " a directory with *.yaml files", - ) - parser.add_argument( - "-w", "--worker", metavar="WORKERCONFIG", help="start or stop a single worker" - ) - parser.add_argument( - "-a", - "--all-processes", - metavar="WORKERCONFIGDIR", - help="start or stop all the workers in the given directory" - " and the main synapse process", - ) - parser.add_argument( - "--no-daemonize", - action="store_false", - dest="daemonize", - help="Run synapse in the foreground for debugging. " - "Will work only if the daemonize option is not set in the config.", - ) - - options = parser.parse_args() - - if options.worker and options.all_processes: - write('Cannot use "--worker" with "--all-processes"', stream=sys.stderr) - sys.exit(1) - if not options.daemonize and options.all_processes: - write('Cannot use "--no-daemonize" with "--all-processes"', stream=sys.stderr) - sys.exit(1) - - configfile = options.configfile - - if not os.path.exists(configfile): - write( - f"Config file {configfile} does not exist.\n" - f"To generate a config file, run:\n" - f" {sys.executable} -m {MAIN_PROCESS}" - f" -c {configfile} --generate-config" - f" --server-name= --report-stats=\n", - stream=sys.stderr, - ) - sys.exit(1) - - config_files = find_config_files([configfile]) - config = {} - for config_file in config_files: - with open(config_file) as file_stream: - yaml_config = yaml.safe_load(file_stream) - if yaml_config is not None: - config.update(yaml_config) - - pidfile = config["pid_file"] - cache_factor = config.get("synctl_cache_factor") - start_stop_synapse = True - - if cache_factor: - write(SYNCTL_CACHE_FACTOR_WARNING) - os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) - - cache_factors = config.get("synctl_cache_factors", {}) - for cache_name, factor in cache_factors.items(): - os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor) - - worker_configfiles = [] - if options.worker: - start_stop_synapse = False - worker_configfile = options.worker - if not os.path.exists(worker_configfile): - write( - "No worker config found at %r" % (worker_configfile,), stream=sys.stderr - ) - sys.exit(1) - worker_configfiles.append(worker_configfile) - - if options.all_processes: - # To start the main synapse with -a you need to add a worker file - # with worker_app == "synapse.app.homeserver" - start_stop_synapse = False - worker_configdir = options.all_processes - if not os.path.isdir(worker_configdir): - write( - "No worker config directory found at %r" % (worker_configdir,), - stream=sys.stderr, - ) - sys.exit(1) - worker_configfiles.extend( - sorted(glob.glob(os.path.join(worker_configdir, "*.yaml"))) - ) - - workers = [] - for worker_configfile in worker_configfiles: - with open(worker_configfile) as stream: - worker_config = yaml.safe_load(stream) - worker_app = worker_config["worker_app"] - if worker_app == "synapse.app.homeserver": - # We need to special case all of this to pick up options that may - # be set in the main config file or in this worker config file. - worker_pidfile = worker_config.get("pid_file") or pidfile - worker_cache_factor = ( - worker_config.get("synctl_cache_factor") or cache_factor - ) - worker_cache_factors = ( - worker_config.get("synctl_cache_factors") or cache_factors - ) - # The master process doesn't support using worker_* config. - for key in worker_config: - if key == "worker_app": # But we allow worker_app - continue - assert not key.startswith( - "worker_" - ), "Main process cannot use worker_* config" - else: - worker_pidfile = worker_config["worker_pid_file"] - worker_cache_factor = worker_config.get("synctl_cache_factor") - worker_cache_factors = worker_config.get("synctl_cache_factors", {}) - workers.append( - Worker( - worker_app, - worker_configfile, - worker_pidfile, - worker_cache_factor, - worker_cache_factors, - ) - ) - - action = options.action - - if action == "stop" or action == "restart": - running_pids = [] - for worker in workers: - pid = stop(worker.pidfile, worker.app) - if pid is not None: - running_pids.append(pid) - - if start_stop_synapse: - pid = stop(pidfile, MAIN_PROCESS) - if pid is not None: - running_pids.append(pid) - - if len(running_pids) > 0: - write("Waiting for processes to exit...") - for running_pid in running_pids: - while pid_running(running_pid): - time.sleep(0.2) - write("All processes exited") - - if action == "start" or action == "restart": - error = False - if start_stop_synapse: - if not start(pidfile, MAIN_PROCESS, (configfile,), options.daemonize): - error = True - - for worker in workers: - env = os.environ.copy() - - if worker.cache_factor: - os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor) - - for cache_name, factor in worker.cache_factors.items(): - os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor) - - if not start( - worker.pidfile, - worker.app, - (configfile, worker.configfile), - options.daemonize, - ): - error = True - - # Reset env back to the original - os.environ.clear() - os.environ.update(env) - - if error: - exit(1) - - -if __name__ == "__main__": - main() diff --git a/tox.ini b/tox.ini index 04d282a705..f1f96b27ea 100644 --- a/tox.ini +++ b/tox.ini @@ -42,7 +42,6 @@ lint_targets = scripts-dev stubs contrib - synctl synmark .ci docker -- cgit 1.4.1 From 36071d39f784ddc2271c91a72a55d9ee4f8689bb Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 4 Mar 2022 12:01:51 +0000 Subject: Changelog (#12153) --- .github/workflows/tests.yml | 1 + changelog.d/12153.misc | 1 + tox.ini | 1 - 3 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12153.misc diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c89c50cd07..3bce95b0e0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,6 +17,7 @@ jobs: - uses: actions/setup-python@v2 - run: pip install -e . - run: scripts-dev/generate_sample_config.sh --check + - run: scripts-dev/config-lint.sh lint: runs-on: ubuntu-latest diff --git a/changelog.d/12153.misc b/changelog.d/12153.misc new file mode 100644 index 0000000000..f02d140f38 --- /dev/null +++ b/changelog.d/12153.misc @@ -0,0 +1 @@ +Move CI checks out of tox, to facilitate a move to using poetry. \ No newline at end of file diff --git a/tox.ini b/tox.ini index f1f96b27ea..3ffd2c3e97 100644 --- a/tox.ini +++ b/tox.ini @@ -151,7 +151,6 @@ extras = lint commands = python -m black --check --diff {[base]lint_targets} flake8 {[base]lint_targets} {env:PEP8SUFFIX:} - {toxinidir}/scripts-dev/config-lint.sh [testenv:check_isort] extras = lint -- cgit 1.4.1 From cd1ae3d0b438ff453b7d4750c4fe901f266fcbb6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 4 Mar 2022 07:10:10 -0500 Subject: Remove backwards compatibility with RelationPaginationToken. (#12138) --- changelog.d/12138.removal | 1 + synapse/rest/client/relations.py | 55 +++++++--------------------- synapse/storage/relations.py | 31 ---------------- tests/rest/client/test_relations.py | 73 +------------------------------------ 4 files changed, 16 insertions(+), 144 deletions(-) create mode 100644 changelog.d/12138.removal diff --git a/changelog.d/12138.removal b/changelog.d/12138.removal new file mode 100644 index 0000000000..6ed84d476c --- /dev/null +++ b/changelog.d/12138.removal @@ -0,0 +1 @@ +Remove backwards compatibilty with pagination tokens from the `/relations` and `/aggregations` endpoints generated from Synapse < v1.52.0. diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 487ea38b55..07fa1cdd4c 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -27,50 +27,15 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.storage.relations import ( - AggregationPaginationToken, - PaginationChunk, - RelationPaginationToken, -) -from synapse.types import JsonDict, RoomStreamToken, StreamToken +from synapse.storage.relations import AggregationPaginationToken, PaginationChunk +from synapse.types import JsonDict, StreamToken if TYPE_CHECKING: from synapse.server import HomeServer - from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) -async def _parse_token( - store: "DataStore", token: Optional[str] -) -> Optional[StreamToken]: - """ - For backwards compatibility support RelationPaginationToken, but new pagination - tokens are generated as full StreamTokens, to be compatible with /sync and /messages. - """ - if not token: - return None - # Luckily the format for StreamToken and RelationPaginationToken differ enough - # that they can easily be separated. An "_" appears in the serialization of - # RoomStreamToken (as part of StreamToken), but RelationPaginationToken uses - # "-" only for separators. - if "_" in token: - return await StreamToken.from_string(store, token) - else: - relation_token = RelationPaginationToken.from_string(token) - return StreamToken( - room_key=RoomStreamToken(relation_token.topological, relation_token.stream), - presence_key=0, - typing_key=0, - receipt_key=0, - account_data_key=0, - push_rules_key=0, - to_device_key=0, - device_list_key=0, - groups_key=0, - ) - - class RelationPaginationServlet(RestServlet): """API to paginate relations on an event by topological ordering, optionally filtered by relation type and event type. @@ -122,8 +87,12 @@ class RelationPaginationServlet(RestServlet): pagination_chunk = PaginationChunk(chunk=[]) else: # Return the relations - from_token = await _parse_token(self.store, from_token_str) - to_token = await _parse_token(self.store, to_token_str) + from_token = None + if from_token_str: + from_token = await StreamToken.from_string(self.store, from_token_str) + to_token = None + if to_token_str: + to_token = await StreamToken.from_string(self.store, to_token_str) pagination_chunk = await self.store.get_relations_for_event( event_id=parent_id, @@ -317,8 +286,12 @@ class RelationAggregationGroupPaginationServlet(RestServlet): from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") - from_token = await _parse_token(self.store, from_token_str) - to_token = await _parse_token(self.store, to_token_str) + from_token = None + if from_token_str: + from_token = await StreamToken.from_string(self.store, from_token_str) + to_token = None + if to_token_str: + to_token = await StreamToken.from_string(self.store, to_token_str) result = await self.store.get_relations_for_event( event_id=parent_id, diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 36ca2b8273..fba270150b 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -54,37 +54,6 @@ class PaginationChunk: return d -@attr.s(frozen=True, slots=True, auto_attribs=True) -class RelationPaginationToken: - """Pagination token for relation pagination API. - - As the results are in topological order, we can use the - `topological_ordering` and `stream_ordering` fields of the events at the - boundaries of the chunk as pagination tokens. - - Attributes: - topological: The topological ordering of the boundary event - stream: The stream ordering of the boundary event. - """ - - topological: int - stream: int - - @staticmethod - def from_string(string: str) -> "RelationPaginationToken": - try: - t, s = string.split("-") - return RelationPaginationToken(int(t), int(s)) - except ValueError: - raise SynapseError(400, "Invalid relation pagination token") - - async def to_string(self, store: "DataStore") -> str: - return "%d-%d" % (self.topological, self.stream) - - def as_tuple(self) -> Tuple[Any, ...]: - return attr.astuple(self) - - @attr.s(frozen=True, slots=True, auto_attribs=True) class AggregationPaginationToken: """Pagination token for relation aggregation pagination API. diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 53062b41de..274f9c44c1 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -24,8 +24,7 @@ from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, register, relations, room, sync from synapse.server import HomeServer -from synapse.storage.relations import RelationPaginationToken -from synapse.types import JsonDict, StreamToken +from synapse.types import JsonDict from synapse.util import Clock from tests import unittest @@ -281,15 +280,6 @@ class RelationsTestCase(BaseRelationsTestCase): channel.json_body["chunk"][0], ) - def _stream_token_to_relation_token(self, token: str) -> str: - """Convert a StreamToken into a legacy token (RelationPaginationToken).""" - room_key = self.get_success(StreamToken.from_string(self.store, token)).room_key - return self.get_success( - RelationPaginationToken( - topological=room_key.topological, stream=room_key.stream - ).to_string(self.store) - ) - def test_repeated_paginate_relations(self) -> None: """Test that if we paginate using a limit and tokens then we get the expected events. @@ -330,34 +320,6 @@ class RelationsTestCase(BaseRelationsTestCase): found_event_ids.reverse() self.assertEqual(found_event_ids, expected_event_ids) - # Reset and try again, but convert the tokens to the legacy format. - prev_token = "" - found_event_ids = [] - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + self._stream_token_to_relation_token(prev_token) - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) - def test_pagination_from_sync_and_messages(self) -> None: """Pagination tokens from /sync and /messages can be used to paginate /relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") @@ -543,39 +505,6 @@ class RelationsTestCase(BaseRelationsTestCase): found_event_ids.reverse() self.assertEqual(found_event_ids, expected_event_ids) - # Reset and try again, but convert the tokens to the legacy format. - prev_token = "" - found_event_ids = [] - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + self._stream_token_to_relation_token(prev_token) - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}" - f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" - f"/m.reaction/{encoded_key}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) - def test_aggregation(self) -> None: """Test that annotations get correctly aggregated.""" -- cgit 1.4.1 From 158e0937eb56f14dd851549939037de499526fd2 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 4 Mar 2022 13:10:05 +0000 Subject: Add test for `ObservableDeferred`'s cancellation behaviour (#12149) Signed-off-by: Sean Quah --- changelog.d/12149.misc | 1 + tests/util/test_async_helpers.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 changelog.d/12149.misc diff --git a/changelog.d/12149.misc b/changelog.d/12149.misc new file mode 100644 index 0000000000..d39af96723 --- /dev/null +++ b/changelog.d/12149.misc @@ -0,0 +1 @@ +Add test for `ObservableDeferred`'s cancellation behaviour. diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index 362014f4cb..ff53ce114b 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -100,6 +100,34 @@ class ObservableDeferredTest(TestCase): self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result") self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result") + def test_cancellation(self): + """Test that cancelling an observer does not affect other observers.""" + origin_d: "Deferred[int]" = Deferred() + observable = ObservableDeferred(origin_d, consumeErrors=True) + + observer1 = observable.observe() + observer2 = observable.observe() + observer3 = observable.observe() + + self.assertFalse(observer1.called) + self.assertFalse(observer2.called) + self.assertFalse(observer3.called) + + # cancel the second observer + observer2.cancel() + self.assertFalse(observer1.called) + self.failureResultOf(observer2, CancelledError) + self.assertFalse(observer3.called) + + # other observers resolve as normal + origin_d.callback(123) + self.assertEqual(observer1.result, 123, "observer 1 callback result") + self.assertEqual(observer3.result, 123, "observer 3 callback result") + + # additional observers resolve as normal + observer4 = observable.observe() + self.assertEqual(observer4.result, 123, "observer 4 callback result") + class TimeoutDeferredTest(TestCase): def setUp(self): -- cgit 1.4.1 From 75574726a766f09d955c05672d400c65cb341810 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 4 Mar 2022 15:37:02 +0000 Subject: Add type hints for `ObservableDeferred` attributes (#12159) Signed-off-by: Sean Quah --- changelog.d/12159.misc | 1 + synapse/util/async_helpers.py | 14 +++++++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) create mode 100644 changelog.d/12159.misc diff --git a/changelog.d/12159.misc b/changelog.d/12159.misc new file mode 100644 index 0000000000..30500f2fd9 --- /dev/null +++ b/changelog.d/12159.misc @@ -0,0 +1 @@ +Add type hints for `ObservableDeferred` attributes. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 60c03a66fd..a9f67dcbac 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -40,7 +40,7 @@ from typing import ( ) import attr -from typing_extensions import ContextManager +from typing_extensions import ContextManager, Literal from twisted.internet import defer from twisted.internet.defer import CancelledError @@ -96,6 +96,10 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): __slots__ = ["_deferred", "_observers", "_result"] + _deferred: "defer.Deferred[_T]" + _observers: Union[List["defer.Deferred[_T]"], Tuple[()]] + _result: Union[None, Tuple[Literal[True], _T], Tuple[Literal[False], Failure]] + def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False): object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_result", None) @@ -158,12 +162,14 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): effect the underlying deferred. """ if not self._result: + assert isinstance(self._observers, list) d: "defer.Deferred[_T]" = defer.Deferred() self._observers.append(d) return d + elif self._result[0]: + return defer.succeed(self._result[1]) else: - success, res = self._result - return defer.succeed(res) if success else defer.fail(res) + return defer.fail(self._result[1]) def observers(self) -> "Collection[defer.Deferred[_T]]": return self._observers @@ -175,6 +181,8 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): return self._result is not None and self._result[0] is True def get_result(self) -> Union[_T, Failure]: + if self._result is None: + raise ValueError(f"{self!r} has no result yet") return self._result[1] def __getattr__(self, name: str) -> Any: -- cgit 1.4.1 From 0752ab7a3621b90073f9332fbfdc8afe16a3be01 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 4 Mar 2022 17:57:27 +0000 Subject: Reduce to-device queries for /sync. (#12163) --- changelog.d/12163.misc | 1 + synapse/storage/databases/main/deviceinbox.py | 3 +++ 2 files changed, 4 insertions(+) create mode 100644 changelog.d/12163.misc diff --git a/changelog.d/12163.misc b/changelog.d/12163.misc new file mode 100644 index 0000000000..13de0895f5 --- /dev/null +++ b/changelog.d/12163.misc @@ -0,0 +1 @@ +Reduce number of DB queries made during processing of `/sync`. diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 1392363de1..b4a1b041b1 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -298,6 +298,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): # This user has new messages sent to them. Query messages for them user_ids_to_query.add(user_id) + if not user_ids_to_query: + return {}, to_stream_id + def get_device_messages_txn(txn: LoggingTransaction): # Build a query to select messages from any of the given devices that # are between the given stream id bounds. -- cgit 1.4.1 From d2ef1a79cf942b2f8c1a4724eb0c4bca0b60edbe Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 4 Mar 2022 22:40:24 +0000 Subject: Relax version guard for packaging (#12166) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit It’s just occurred to me that #12088 pulled in the β€œpackaging” package (~=21.3). I pulled in the newest version I had at the time. I only use it for packaging.requirements.Requirements. Which was added in packaging 16.1: https://github.com/pypa/packaging/releases/tag/16.1 https://pkgs.org/download/python3-packaging suggests that the oldest version we care about is 17.1 in Ubuntu Bionic. So I think with this bound we're hunky dory. --- changelog.d/12166.misc | 1 + synapse/python_dependencies.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12166.misc diff --git a/changelog.d/12166.misc b/changelog.d/12166.misc new file mode 100644 index 0000000000..24b4a7c7de --- /dev/null +++ b/changelog.d/12166.misc @@ -0,0 +1 @@ +Relax the version guard for "packaging" added in #12088. diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 8f48a33936..b40a7bbb76 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -83,8 +83,8 @@ REQUIREMENTS = [ # ijson 3.1.4 fixes a bug with "." in property names "ijson>=3.1.4", "matrix-common~=1.1.0", - # For runtime introspection of our dependencies - "packaging~=21.3", + # We need packaging.requirements.Requirement, added in 16.1. + "packaging>=16.1", ] CONDITIONAL_REQUIREMENTS = { -- cgit 1.4.1 From 0211f18d65b20c9bd77e64f296f8790f4267cf28 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 7 Mar 2022 12:24:06 +0000 Subject: Switch the `tests-done` job to an Action (#12161) I've factored it out for easier use in other workflows. --- .github/workflows/tests.yml | 30 +++++++++--------------------- changelog.d/12161.misc | 1 + 2 files changed, 10 insertions(+), 21 deletions(-) create mode 100644 changelog.d/12161.misc diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3bce95b0e0..613a773775 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -388,34 +388,22 @@ jobs: tests-done: if: ${{ always() }} needs: + - check-sampleconfig - lint - lint-crlf - lint-newsfile - trial - trial-olddeps - sytest + - export-data - portdb - complement runs-on: ubuntu-latest steps: - - name: Set build result - env: - NEEDS_CONTEXT: ${{ toJSON(needs) }} - # the `jq` incantation dumps out a series of " " lines. - # we set it to an intermediate variable to avoid a pipe, which makes it - # hard to set $rc. - run: | - rc=0 - results=$(jq -r 'to_entries[] | [.key,.value.result] | join(" ")' <<< $NEEDS_CONTEXT) - while read job result ; do - # The newsfile lint may be skipped on non PR builds - if [ $result == "skipped" ] && [ $job == "lint-newsfile" ]; then - continue - fi - - if [ "$result" != "success" ]; then - echo "::set-failed ::Job $job returned $result" - rc=1 - fi - done <<< $results - exit $rc + - uses: matrix-org/done-action@v2 + with: + needs: ${{ toJSON(needs) }} + + # The newsfile lint may be skipped on non PR builds + skippable: + lint-newsfile diff --git a/changelog.d/12161.misc b/changelog.d/12161.misc new file mode 100644 index 0000000000..43eff08d46 --- /dev/null +++ b/changelog.d/12161.misc @@ -0,0 +1 @@ +Use a prebuilt Action for the `tests-done` CI job. -- cgit 1.4.1 From f63bedef07360216a8de71dc38f00f1aea503903 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 7 Mar 2022 09:00:05 -0500 Subject: Invalidate caches when an event with a relation is redacted. (#12121) The caches for the target of the relation must be cleared so that the bundled aggregations are re-calculated after the redaction is processed. --- changelog.d/12113.bugfix | 1 + changelog.d/12113.misc | 1 - changelog.d/12121.bugfix | 1 + synapse/storage/databases/main/cache.py | 2 + synapse/storage/databases/main/events.py | 38 +++++- tests/rest/client/test_relations.py | 207 ++++++++++++++++++++++++------- 6 files changed, 202 insertions(+), 48 deletions(-) create mode 100644 changelog.d/12113.bugfix delete mode 100644 changelog.d/12113.misc create mode 100644 changelog.d/12121.bugfix diff --git a/changelog.d/12113.bugfix b/changelog.d/12113.bugfix new file mode 100644 index 0000000000..df9b0dc413 --- /dev/null +++ b/changelog.d/12113.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when redacting events with relations. diff --git a/changelog.d/12113.misc b/changelog.d/12113.misc deleted file mode 100644 index 102e064053..0000000000 --- a/changelog.d/12113.misc +++ /dev/null @@ -1 +0,0 @@ -Refactor the tests for event relations. diff --git a/changelog.d/12121.bugfix b/changelog.d/12121.bugfix new file mode 100644 index 0000000000..df9b0dc413 --- /dev/null +++ b/changelog.d/12121.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when redacting events with relations. diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index c428dd5596..abd54c7dc7 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -200,6 +200,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_relations_for_event.invalidate((relates_to,)) self.get_aggregation_groups_for_event.invalidate((relates_to,)) self.get_applicable_edit.invalidate((relates_to,)) + self.get_thread_summary.invalidate((relates_to,)) + self.get_thread_participated.invalidate((relates_to,)) async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): """Invalidates the cache and adds it to the cache stream so slaves diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index ca2a9ba9d1..1dc83aa5e3 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1518,7 +1518,7 @@ class PersistEventsStore: ) # Remove from relations table. - self._handle_redaction(txn, event.redacts) + self._handle_redact_relations(txn, event.redacts) # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. @@ -1943,15 +1943,43 @@ class PersistEventsStore: txn.execute(sql, (batch_id,)) - def _handle_redaction(self, txn, redacted_event_id): - """Handles receiving a redaction and checking whether we need to remove - any redacted relations from the database. + def _handle_redact_relations( + self, txn: LoggingTransaction, redacted_event_id: str + ) -> None: + """Handles receiving a redaction and checking whether the redacted event + has any relations which must be removed from the database. Args: txn - redacted_event_id (str): The event that was redacted. + redacted_event_id: The event that was redacted. """ + # Fetch the current relation of the event being redacted. + redacted_relates_to = self.db_pool.simple_select_one_onecol_txn( + txn, + table="event_relations", + keyvalues={"event_id": redacted_event_id}, + retcol="relates_to_id", + allow_none=True, + ) + # Any relation information for the related event must be cleared. + if redacted_relates_to is not None: + self.store._invalidate_cache_and_stream( + txn, self.store.get_relations_for_event, (redacted_relates_to,) + ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,) + ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_applicable_edit, (redacted_relates_to,) + ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_thread_summary, (redacted_relates_to,) + ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_thread_participated, (redacted_relates_to,) + ) + self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 274f9c44c1..a40a5de399 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1273,7 +1273,21 @@ class RelationsTestCase(BaseRelationsTestCase): class RelationRedactionTestCase(BaseRelationsTestCase): - """Test the behaviour of relations when the parent or child event is redacted.""" + """ + Test the behaviour of relations when the parent or child event is redacted. + + The behaviour of each relation type is subtly different which causes the tests + to be a bit repetitive, they follow a naming scheme of: + + test_redact_(relation|parent)_{relation_type} + + The first bit of "relation" means that the event with the relation defined + on it (the child event) is to be redacted. A "parent" means that the target + of the relation (the parent event) is to be redacted. + + The relation_type describes which type of relation is under test (i.e. it is + related to the value of rel_type in the event content). + """ def _redact(self, event_id: str) -> None: channel = self.make_request( @@ -1284,9 +1298,53 @@ class RelationRedactionTestCase(BaseRelationsTestCase): ) self.assertEqual(200, channel.code, channel.json_body) + def _make_relation_requests(self) -> Tuple[List[str], JsonDict]: + """ + Makes requests and ensures they result in a 200 response, returns a + tuple of results: + + 1. `/relations` -> Returns a list of event IDs. + 2. `/event` -> Returns the response's m.relations field (from unsigned), + if it exists. + """ + + # Request the relations of the event. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]] + + # Fetch the bundled aggregations of the event. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + bundled_relations = channel.json_body["unsigned"].get("m.relations", {}) + + return event_ids, bundled_relations + + def _get_aggregations(self) -> List[JsonDict]: + """Request /aggregations on the parent ID and includes the returned chunk.""" + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + return channel.json_body["chunk"] + def test_redact_relation_annotation(self) -> None: - """Test that annotations of an event are properly handled after the + """ + Test that annotations of an event are properly handled after the annotation is redacted. + + The redacted relation should not be included in bundled aggregations or + the response to relations. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEqual(200, channel.code, channel.json_body) @@ -1296,24 +1354,97 @@ class RelationRedactionTestCase(BaseRelationsTestCase): RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) self.assertEqual(200, channel.code, channel.json_body) + unredacted_event_id = channel.json_body["event_id"] + + # Both relations should exist. + event_ids, relations = self._make_relation_requests() + self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id]) + self.assertEquals( + relations["m.annotation"], + {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]}, + ) + + # Both relations appear in the aggregation. + chunk = self._get_aggregations() + self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 2}]) # Redact one of the reactions. self._redact(to_redact_event_id) - # Ensure that the aggregations are correct. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", - access_token=self.user_token, + # The unredacted relation should still exist. + event_ids, relations = self._make_relation_requests() + self.assertEquals(event_ids, [unredacted_event_id]) + self.assertEquals( + relations["m.annotation"], + {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, + ) + + # The unredacted aggregation should still exist. + chunk = self._get_aggregations() + self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}]) + + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_redact_relation_thread(self) -> None: + """ + Test that thread replies are properly handled after the thread reply redacted. + + The redacted event should not be included in bundled aggregations or + the response to relations. + """ + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + content={"body": "reply 1", "msgtype": "m.text"}, ) self.assertEqual(200, channel.code, channel.json_body) + unredacted_event_id = channel.json_body["event_id"] + # Note that the *last* event in the thread is redacted, as that gets + # included in the bundled aggregation. + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + content={"body": "reply 2", "msgtype": "m.text"}, + ) + self.assertEqual(200, channel.code, channel.json_body) + to_redact_event_id = channel.json_body["event_id"] + + # Both relations exist. + event_ids, relations = self._make_relation_requests() + self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id]) + self.assertDictContainsSubset( + { + "count": 2, + "current_user_participated": True, + }, + relations[RelationTypes.THREAD], + ) + # And the latest event returned is the event that will be redacted. self.assertEqual( - channel.json_body, - {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, + relations[RelationTypes.THREAD]["latest_event"]["event_id"], + to_redact_event_id, ) - def test_redact_relation_edit(self) -> None: + # Redact one of the reactions. + self._redact(to_redact_event_id) + + # The unredacted relation should still exist. + event_ids, relations = self._make_relation_requests() + self.assertEquals(event_ids, [unredacted_event_id]) + self.assertDictContainsSubset( + { + "count": 1, + "current_user_participated": True, + }, + relations[RelationTypes.THREAD], + ) + # And the latest event is now the unredacted event. + self.assertEqual( + relations[RelationTypes.THREAD]["latest_event"]["event_id"], + unredacted_event_id, + ) + + def test_redact_parent_edit(self) -> None: """Test that edits of an event are redacted when the original event is redacted. """ @@ -1331,34 +1462,19 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self.assertEqual(200, channel.code, channel.json_body) # Check the relation is returned - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations" - f"/{self.parent_id}/m.replace/m.room.message", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertIn("chunk", channel.json_body) - self.assertEqual(len(channel.json_body["chunk"]), 1) + event_ids, relations = self._make_relation_requests() + self.assertEqual(len(event_ids), 1) + self.assertIn(RelationTypes.REPLACE, relations) # Redact the original event self._redact(self.parent_id) - # Try to check for remaining m.replace relations - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations" - f"/{self.parent_id}/m.replace/m.room.message", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - # Check that no relations are returned - self.assertIn("chunk", channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) + # The relations are not returned. + event_ids, relations = self._make_relation_requests() + self.assertEqual(len(event_ids), 0) + self.assertEqual(relations, {}) - def test_redact_parent(self) -> None: + def test_redact_parent_annotation(self) -> None: """Test that annotations of an event are redacted when the original event is redacted. """ @@ -1366,16 +1482,23 @@ class RelationRedactionTestCase(BaseRelationsTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="πŸ‘") self.assertEqual(200, channel.code, channel.json_body) + # The relations should exist. + event_ids, relations = self._make_relation_requests() + self.assertEqual(len(event_ids), 1) + self.assertIn(RelationTypes.ANNOTATION, relations) + + # The aggregation should exist. + chunk = self._get_aggregations() + self.assertEqual(chunk, [{"type": "m.reaction", "key": "πŸ‘", "count": 1}]) + # Redact the original event. self._redact(self.parent_id) - # Check that aggregations returns zero - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}/m.annotation/m.reaction", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) + # The relations are not returned. + event_ids, relations = self._make_relation_requests() + self.assertEqual(event_ids, []) + self.assertEqual(relations, {}) - self.assertIn("chunk", channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) + # There's nothing to aggregate. + chunk = self._get_aggregations() + self.assertEqual(chunk, []) -- cgit 1.4.1 From 26211fec24d8d0a967de33147e148166359ec8cb Mon Sep 17 00:00:00 2001 From: Shay Date: Mon, 7 Mar 2022 09:44:33 -0800 Subject: Fix a bug in background updates wherein background updates are never run using the default batch size (#12157) --- changelog.d/12157.bugfix | 1 + synapse/storage/background_updates.py | 8 +++++--- tests/rest/admin/test_background_updates.py | 18 ++++++++---------- tests/storage/test_background_update.py | 4 ++-- 4 files changed, 16 insertions(+), 15 deletions(-) create mode 100644 changelog.d/12157.bugfix diff --git a/changelog.d/12157.bugfix b/changelog.d/12157.bugfix new file mode 100644 index 0000000000..c3d2e700bb --- /dev/null +++ b/changelog.d/12157.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in #4864 whereby background updates are never run with the default background batch size. diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index d64910aded..4acc2c997d 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -102,10 +102,12 @@ class BackgroundUpdatePerformance: Returns: A duration in ms as a float """ - if self.avg_duration_ms == 0: - return 0 - elif self.total_item_count == 0: + # We want to return None if this is the first background update item + if self.total_item_count == 0: return None + # Avoid dividing by zero + elif self.avg_duration_ms == 0: + return 0 else: # Use the exponential moving average so that we can adapt to # changes in how long the update process takes. diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index fb36aa9940..becec84524 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -155,10 +155,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.001, + "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE + BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE ), } }, @@ -210,10 +210,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.001, + "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE + BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE ), } }, @@ -239,10 +239,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.001, + "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE + BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE ), } }, @@ -278,11 +278,9 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.001, + "average_items_per_ms": 0.05263157894736842, "total_duration_ms": 2000.0, - "total_item_count": ( - 2 * BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE - ), + "total_item_count": (110), } }, "enabled": True, diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 39dcc094bd..9fdf54ea31 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -66,13 +66,13 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): self.update_handler.reset_mock() res = self.get_success( self.updates.do_next_background_update(False), - by=0.01, + by=0.02, ) self.assertFalse(res) # on the first call, we should get run with the default background update size self.update_handler.assert_called_once_with( - {"my_key": 1}, self.updates.MINIMUM_BACKGROUND_BATCH_SIZE + {"my_key": 1}, self.updates.DEFAULT_BACKGROUND_BATCH_SIZE ) # second step: complete the update -- cgit 1.4.1 From 2eef234ae367657d4fe5cb0bef6bda67e97b7e4d Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 8 Mar 2022 10:47:28 +0000 Subject: Fix a bug introduced in 1.54.0rc1 which meant that Synapse would refuse to start if pre-release versions of dependencies were installed. (#12177) * Add failing test to characterise the regression #12176 * Permit pre-release versions of specified packages * Newsfile (bugfix) Signed-off-by: Olivier Wilkinson (reivilibre) --- changelog.d/12177.bugfix | 1 + synapse/util/check_dependencies.py | 3 ++- tests/util/test_check_dependencies.py | 19 +++++++++++++++++++ 3 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12177.bugfix diff --git a/changelog.d/12177.bugfix b/changelog.d/12177.bugfix new file mode 100644 index 0000000000..3f2852f345 --- /dev/null +++ b/changelog.d/12177.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in 1.54.0rc1 which meant that Synapse would refuse to start if pre-release versions of dependencies were installed. \ No newline at end of file diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py index 39b0a91db3..12cd804939 100644 --- a/synapse/util/check_dependencies.py +++ b/synapse/util/check_dependencies.py @@ -163,7 +163,8 @@ def check_requirements(extra: Optional[str] = None) -> None: deps_unfulfilled.append(requirement.name) errors.append(_not_installed(requirement, extra)) else: - if not requirement.specifier.contains(dist.version): + # We specify prereleases=True to allow prereleases such as RCs. + if not requirement.specifier.contains(dist.version, prereleases=True): deps_unfulfilled.append(requirement.name) errors.append(_incorrect_version(requirement, dist.version, extra)) diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py index a91c33272f..38e9f58ac6 100644 --- a/tests/util/test_check_dependencies.py +++ b/tests/util/test_check_dependencies.py @@ -27,7 +27,9 @@ class DummyDistribution(metadata.Distribution): old = DummyDistribution("0.1.2") +old_release_candidate = DummyDistribution("0.1.2rc3") new = DummyDistribution("1.2.3") +new_release_candidate = DummyDistribution("1.2.3rc4") # could probably use stdlib TestCase --- no need for twisted here @@ -110,3 +112,20 @@ class TestDependencyChecker(TestCase): with self.mock_installed_package(new): # should not raise check_requirements("cool-extra") + + def test_release_candidates_satisfy_dependency(self) -> None: + """ + Tests that release candidates count as far as satisfying a dependency + is concerned. + (Regression test, see #12176.) + """ + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1"], + ): + with self.mock_installed_package(old_release_candidate): + self.assertRaises(DependencyException, check_requirements) + + with self.mock_installed_package(new_release_candidate): + # should not raise + check_requirements() -- cgit 1.4.1 From ea992adf867c0c74dccfd6a40e0f73933ccf2899 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Tue, 8 Mar 2022 10:55:26 +0000 Subject: 1.54.0 --- CHANGES.md | 18 ++++++++++++++++++ changelog.d/12127.misc | 1 - changelog.d/12129.misc | 1 - changelog.d/12141.bugfix | 1 - changelog.d/12166.misc | 1 - changelog.d/12177.bugfix | 1 - debian/changelog | 6 ++++++ synapse/__init__.py | 2 +- 8 files changed, 25 insertions(+), 6 deletions(-) delete mode 100644 changelog.d/12127.misc delete mode 100644 changelog.d/12129.misc delete mode 100644 changelog.d/12141.bugfix delete mode 100644 changelog.d/12166.misc delete mode 100644 changelog.d/12177.bugfix diff --git a/CHANGES.md b/CHANGES.md index 0a87f5cd42..9d27cd35aa 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,21 @@ +Synapse 1.54.0 (2022-03-08) +=========================== + +Bugfixes +-------- + +- Fix a bug introduced in Synapse 1.54.0rc1 preventing the new module callbacks introduced in this release from being registered by modules. ([\#12141](https://github.com/matrix-org/synapse/issues/12141)) +- Fix a bug introduced in Synapse 1.54.0rc1 which meant that Synapse would refuse to start if pre-release versions of dependencies were installed. ([\#12177](https://github.com/matrix-org/synapse/issues/12177)) + + +Internal Changes +---------------- + +- Update release script to insert the previous version when writing "No significant changes" line in the changelog. ([\#12127](https://github.com/matrix-org/synapse/issues/12127)) +- Inspect application dependencies using `importlib.metadata` or its backport. ([\#12129](https://github.com/matrix-org/synapse/issues/12129)) +- Relax the version guard for "packaging" added in #12088. ([\#12166](https://github.com/matrix-org/synapse/issues/12166)) + + Synapse 1.54.0rc1 (2022-03-02) ============================== diff --git a/changelog.d/12127.misc b/changelog.d/12127.misc deleted file mode 100644 index e42eca63a8..0000000000 --- a/changelog.d/12127.misc +++ /dev/null @@ -1 +0,0 @@ -Update release script to insert the previous version when writing "No significant changes" line in the changelog. diff --git a/changelog.d/12129.misc b/changelog.d/12129.misc deleted file mode 100644 index ce4213650c..0000000000 --- a/changelog.d/12129.misc +++ /dev/null @@ -1 +0,0 @@ -Inspect application dependencies using `importlib.metadata` or its backport. \ No newline at end of file diff --git a/changelog.d/12141.bugfix b/changelog.d/12141.bugfix deleted file mode 100644 index 03a2507e2c..0000000000 --- a/changelog.d/12141.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in Synapse 1.54.0rc1 preventing the new module callbacks introduced in this release from being registered by modules. diff --git a/changelog.d/12166.misc b/changelog.d/12166.misc deleted file mode 100644 index 24b4a7c7de..0000000000 --- a/changelog.d/12166.misc +++ /dev/null @@ -1 +0,0 @@ -Relax the version guard for "packaging" added in #12088. diff --git a/changelog.d/12177.bugfix b/changelog.d/12177.bugfix deleted file mode 100644 index 3f2852f345..0000000000 --- a/changelog.d/12177.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in 1.54.0rc1 which meant that Synapse would refuse to start if pre-release versions of dependencies were installed. \ No newline at end of file diff --git a/debian/changelog b/debian/changelog index df3db85b8e..02136a0d60 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.54.0) stable; urgency=medium + + * New synapse release 1.54.0. + + -- Synapse Packaging team Tue, 08 Mar 2022 10:54:52 +0000 + matrix-synapse-py3 (1.54.0~rc1) stable; urgency=medium * New synapse release 1.54.0~rc1. diff --git a/synapse/__init__.py b/synapse/__init__.py index b21e1ed0f3..c6727024f0 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.54.0rc1" +__version__ = "1.54.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when -- cgit 1.4.1