From 1bc4feb6c9762216e930daf0ddbdb86c77bf7724 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 21 Mar 2023 14:19:54 -0400 Subject: Apply & bundle edits for non-message events. (#15295) --- synapse/storage/databases/main/relations.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) (limited to 'synapse/storage/databases') diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index bc3a83919c..3955a8a9a5 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -472,12 +472,11 @@ class RelationsWorkerStore(SQLBaseStore): the event will map to None. """ - # We only allow edits for `m.room.message` events that have the same sender - # and event type. We can't assert these things during regular event auth so - # we have to do the checks post hoc. + # We only allow edits for events that have the same sender and event type. + # We can't assert these things during regular event auth so we have to do + # the checks post hoc. - # Fetches latest edit that has the same type and sender as the - # original, and is an `m.room.message`. + # Fetches latest edit that has the same type and sender as the original. if isinstance(self.database_engine, PostgresEngine): # The `DISTINCT ON` clause will pick the *first* row it encounters, # so ordering by origin server ts + event ID desc will ensure we get @@ -493,7 +492,6 @@ class RelationsWorkerStore(SQLBaseStore): WHERE %s AND relation_type = ? - AND edit.type = 'm.room.message' ORDER by original.event_id DESC, edit.origin_server_ts DESC, edit.event_id DESC """ else: @@ -512,7 +510,6 @@ class RelationsWorkerStore(SQLBaseStore): WHERE %s AND relation_type = ? - AND edit.type = 'm.room.message' ORDER by edit.origin_server_ts, edit.event_id """ -- cgit 1.5.1 From 3b0083c92adf76daf4161908565de9e5efc08074 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 22 Mar 2023 17:15:34 +0000 Subject: Use immutabledict instead of frozendict (#15113) Additionally: * Consistently use `freeze()` in test --------- Co-authored-by: Patrick Cloke Co-authored-by: 6543 <6543@obermui.de> --- changelog.d/15113.misc | 1 + poetry.lock | 125 ++++--------------------------- pyproject.toml | 12 ++- stubs/frozendict.pyi | 39 ---------- synapse/__init__.py | 19 +++-- synapse/crypto/event_signing.py | 2 +- synapse/events/snapshot.py | 4 +- synapse/events/utils.py | 2 +- synapse/events/validator.py | 2 +- synapse/state/__init__.py | 10 ++- synapse/storage/databases/main/stream.py | 4 +- synapse/types/__init__.py | 12 +-- synapse/types/state.py | 26 ++++--- synapse/util/__init__.py | 20 ++--- synapse/util/frozenutils.py | 6 +- tests/api/test_filtering.py | 6 +- tests/config/test_workers.py | 6 +- tests/push/test_push_rule_evaluator.py | 18 ++--- tests/storage/test_state.py | 40 +++++----- tests/types/test_state.py | 14 ++-- 20 files changed, 124 insertions(+), 244 deletions(-) create mode 100644 changelog.d/15113.misc delete mode 100644 stubs/frozendict.pyi (limited to 'synapse/storage/databases') diff --git a/changelog.d/15113.misc b/changelog.d/15113.misc new file mode 100644 index 0000000000..6917dd5652 --- /dev/null +++ b/changelog.d/15113.misc @@ -0,0 +1 @@ +Use `immutabledict` instead of `frozendict`. diff --git a/poetry.lock b/poetry.lock index ff8b43bac7..76fbfafcf9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -160,23 +160,16 @@ css = ["tinycss2 (>=1.1.0,<1.2)"] [[package]] name = "canonicaljson" -version = "1.6.5" +version = "2.0.0" description = "Canonical JSON" category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "canonicaljson-1.6.5-py3-none-any.whl", hash = "sha256:806ea6f2cbb7405d20259e1c36dd1214ba5c242fa9165f5bd0bf2081f82c23fb"}, - {file = "canonicaljson-1.6.5.tar.gz", hash = "sha256:68dfc157b011e07d94bf74b5d4ccc01958584ed942d9dfd5fdd706609e81cd4b"}, + {file = "canonicaljson-2.0.0-py3-none-any.whl", hash = "sha256:c38a315de3b5a0532f1ec1f9153cd3d716abfc565a558d00a4835428a34fca5b"}, + {file = "canonicaljson-2.0.0.tar.gz", hash = "sha256:e2fdaef1d7fadc5d9cb59bd3d0d41b064ddda697809ac4325dced721d12f113f"}, ] -[package.dependencies] -simplejson = ">=3.14.0" -typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.8\""} - -[package.extras] -frozendict = ["frozendict (>=1.0)"] - [[package]] name = "certifi" version = "2022.12.7" @@ -453,33 +446,6 @@ files = [ [package.extras] dev = ["Sphinx", "coverage", "flake8", "lxml", "memory-profiler", "mypy (==0.910)", "tox", "xmlschema (>=1.8.0)"] -[[package]] -name = "frozendict" -version = "2.3.4" -description = "A simple immutable dictionary" -category = "main" -optional = false -python-versions = ">=3.6" -files = [ - {file = "frozendict-2.3.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4a3b32d47282ae0098b9239a6d53ec539da720258bd762d62191b46f2f87c5fc"}, - {file = "frozendict-2.3.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84c9887179a245a66a50f52afa08d4d92ae0f269839fab82285c70a0fa0dd782"}, - {file = "frozendict-2.3.4-cp310-cp310-win_amd64.whl", hash = "sha256:b98a0d65a59af6da03f794f90b0c3085a7ee14e7bf8f0ef36b079ee8aa992439"}, - {file = "frozendict-2.3.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:3d8042b7dab5e992e30889c9b71b781d5feef19b372d47d735e4d7d45846fd4a"}, - {file = "frozendict-2.3.4-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:25a6d2e8b7cf6b6e5677a1a4b53b4073e5d9ec640d1db30dc679627668d25e90"}, - {file = "frozendict-2.3.4-cp36-cp36m-win_amd64.whl", hash = "sha256:dbbe1339ac2646523e0bb00d1896085d1f70de23780e4927ca82b36ab8a044d3"}, - {file = "frozendict-2.3.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:95bac22f7f09d81f378f2b3f672b7a50a974ca180feae1507f5e21bc147e8bc8"}, - {file = "frozendict-2.3.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dae686722c144b333c4dbdc16323a5de11406d26b76d2be1cc175f90afacb5ba"}, - {file = "frozendict-2.3.4-cp37-cp37m-win_amd64.whl", hash = "sha256:389f395a74eb16992217ac1521e689c1dea2d70113bcb18714669ace1ed623b9"}, - {file = "frozendict-2.3.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ccb6450a416c9cc9acef7683e637e28356e3ceeabf83521f74cc2718883076b7"}, - {file = "frozendict-2.3.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aca59108b77cadc13ba7dfea7e8f50811208c7652a13dc6c7f92d7782a24d299"}, - {file = "frozendict-2.3.4-cp38-cp38-win_amd64.whl", hash = "sha256:3ec86ebf143dd685184215c27ec416c36e0ba1b80d81b1b9482f7d380c049b4e"}, - {file = "frozendict-2.3.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5809e6ff6b7257043a486f7a3b73a7da71cf69a38980b4171e4741291d0d9eb3"}, - {file = "frozendict-2.3.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7c550ed7fdf1962984bec21630c584d722b3ee5d5f57a0ae2527a0121dc0414a"}, - {file = "frozendict-2.3.4-cp39-cp39-win_amd64.whl", hash = "sha256:3e93aebc6e69a8ef329bbe9afb8342bd33c7b5c7a0c480cb9f7e60b0cbe48072"}, - {file = "frozendict-2.3.4-py3-none-any.whl", hash = "sha256:d722f3d89db6ae35ef35ecc243c40c800eb344848c83dba4798353312cd37b15"}, - {file = "frozendict-2.3.4.tar.gz", hash = "sha256:15b4b18346259392b0d27598f240e9390fafbff882137a9c48a1e0104fb17f78"}, -] - [[package]] name = "gitdb" version = "4.0.9" @@ -725,6 +691,18 @@ files = [ {file = "ijson-3.2.0.post0.tar.gz", hash = "sha256:80a5bd7e9923cab200701f67ad2372104328b99ddf249dbbe8834102c852d316"}, ] +[[package]] +name = "immutabledict" +version = "2.2.3" +description = "Immutable wrapper around dictionaries (a fork of frozendict)" +category = "main" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "immutabledict-2.2.3-py3-none-any.whl", hash = "sha256:a7b078ebcc4a58ddc73b55f808b26e7c8c2d5183fad325615112689e1a63e714"}, + {file = "immutabledict-2.2.3.tar.gz", hash = "sha256:0e1e8a3f2b3ff062daa19795f947e9ec7a58add269d44e34d3ab4319e1343853"}, +] + [[package]] name = "importlib-metadata" version = "6.0.0" @@ -2174,77 +2152,6 @@ unpaddedbase64 = ">=1.0.1" [package.extras] dev = ["typing-extensions (>=3.5)"] -[[package]] -name = "simplejson" -version = "3.17.6" -description = "Simple, fast, extensible JSON encoder/decoder for Python" -category = "main" -optional = false -python-versions = ">=2.5, !=3.0.*, !=3.1.*, !=3.2.*" -files = [ - {file = "simplejson-3.17.6-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:a89acae02b2975b1f8e4974cb8cdf9bf9f6c91162fb8dec50c259ce700f2770a"}, - {file = "simplejson-3.17.6-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:82ff356ff91be0ab2293fc6d8d262451eb6ac4fd999244c4b5f863e049ba219c"}, - {file = "simplejson-3.17.6-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:0de783e9c2b87bdd75b57efa2b6260c24b94605b5c9843517577d40ee0c3cc8a"}, - {file = "simplejson-3.17.6-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:d24a9e61df7a7787b338a58abfba975414937b609eb6b18973e25f573bc0eeeb"}, - {file = "simplejson-3.17.6-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:e8603e691580487f11306ecb066c76f1f4a8b54fb3bdb23fa40643a059509366"}, - {file = "simplejson-3.17.6-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:9b01e7b00654115965a206e3015f0166674ec1e575198a62a977355597c0bef5"}, - {file = "simplejson-3.17.6-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:37bc0cf0e5599f36072077e56e248f3336917ded1d33d2688624d8ed3cefd7d2"}, - {file = "simplejson-3.17.6-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:cf6e7d5fe2aeb54898df18db1baf479863eae581cce05410f61f6b4188c8ada1"}, - {file = "simplejson-3.17.6-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:bdfc54b4468ed4cd7415928cbe782f4d782722a81aeb0f81e2ddca9932632211"}, - {file = "simplejson-3.17.6-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:dd16302d39c4d6f4afde80edd0c97d4db643327d355a312762ccd9bd2ca515ed"}, - {file = "simplejson-3.17.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:deac4bdafa19bbb89edfb73b19f7f69a52d0b5bd3bb0c4ad404c1bbfd7b4b7fd"}, - {file = "simplejson-3.17.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a8bbdb166e2fb816e43ab034c865147edafe28e1b19c72433147789ac83e2dda"}, - {file = "simplejson-3.17.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a7854326920d41c3b5d468154318fe6ba4390cb2410480976787c640707e0180"}, - {file = "simplejson-3.17.6-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:04e31fa6ac8e326480703fb6ded1488bfa6f1d3f760d32e29dbf66d0838982ce"}, - {file = "simplejson-3.17.6-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f63600ec06982cdf480899026f4fda622776f5fabed9a869fdb32d72bc17e99a"}, - {file = "simplejson-3.17.6-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e03c3b8cc7883a54c3f34a6a135c4a17bc9088a33f36796acdb47162791b02f6"}, - {file = "simplejson-3.17.6-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a2d30d6c1652140181dc6861f564449ad71a45e4f165a6868c27d36745b65d40"}, - {file = "simplejson-3.17.6-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a1aa6e4cae8e3b8d5321be4f51c5ce77188faf7baa9fe1e78611f93a8eed2882"}, - {file = "simplejson-3.17.6-cp310-cp310-win32.whl", hash = "sha256:97202f939c3ff341fc3fa84d15db86156b1edc669424ba20b0a1fcd4a796a045"}, - {file = "simplejson-3.17.6-cp310-cp310-win_amd64.whl", hash = "sha256:80d3bc9944be1d73e5b1726c3bbfd2628d3d7fe2880711b1eb90b617b9b8ac70"}, - {file = "simplejson-3.17.6-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:9fa621b3c0c05d965882c920347b6593751b7ab20d8fa81e426f1735ca1a9fc7"}, - {file = "simplejson-3.17.6-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd2fb11922f58df8528adfca123f6a84748ad17d066007e7ac977720063556bd"}, - {file = "simplejson-3.17.6-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:724c1fe135aa437d5126138d977004d165a3b5e2ee98fc4eb3e7c0ef645e7e27"}, - {file = "simplejson-3.17.6-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:4ff4ac6ff3aa8f814ac0f50bf218a2e1a434a17aafad4f0400a57a8cc62ef17f"}, - {file = "simplejson-3.17.6-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:67093a526e42981fdd954868062e56c9b67fdd7e712616cc3265ad0c210ecb51"}, - {file = "simplejson-3.17.6-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:5d6b4af7ad7e4ac515bc6e602e7b79e2204e25dbd10ab3aa2beef3c5a9cad2c7"}, - {file = "simplejson-3.17.6-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:1c9b1ed7ed282b36571638297525f8ef80f34b3e2d600a56f962c6044f24200d"}, - {file = "simplejson-3.17.6-cp36-cp36m-win32.whl", hash = "sha256:632ecbbd2228575e6860c9e49ea3cc5423764d5aa70b92acc4e74096fb434044"}, - {file = "simplejson-3.17.6-cp36-cp36m-win_amd64.whl", hash = "sha256:4c09868ddb86bf79b1feb4e3e7e4a35cd6e61ddb3452b54e20cf296313622566"}, - {file = "simplejson-3.17.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4b6bd8144f15a491c662f06814bd8eaa54b17f26095bb775411f39bacaf66837"}, - {file = "simplejson-3.17.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5decdc78849617917c206b01e9fc1d694fd58caa961be816cb37d3150d613d9a"}, - {file = "simplejson-3.17.6-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:521877c7bd060470806eb6335926e27453d740ac1958eaf0d8c00911bc5e1802"}, - {file = "simplejson-3.17.6-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:65b998193bd7b0c7ecdfffbc825d808eac66279313cb67d8892bb259c9d91494"}, - {file = "simplejson-3.17.6-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:ac786f6cb7aa10d44e9641c7a7d16d7f6e095b138795cd43503769d4154e0dc2"}, - {file = "simplejson-3.17.6-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:3ff5b3464e1ce86a8de8c88e61d4836927d5595c2162cab22e96ff551b916e81"}, - {file = "simplejson-3.17.6-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:69bd56b1d257a91e763256d63606937ae4eb890b18a789b66951c00062afec33"}, - {file = "simplejson-3.17.6-cp37-cp37m-win32.whl", hash = "sha256:b81076552d34c27e5149a40187a8f7e2abb2d3185576a317aaf14aeeedad862a"}, - {file = "simplejson-3.17.6-cp37-cp37m-win_amd64.whl", hash = "sha256:07ecaafc1b1501f275bf5acdee34a4ad33c7c24ede287183ea77a02dc071e0c0"}, - {file = "simplejson-3.17.6-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:068670af975247acbb9fc3d5393293368cda17026db467bf7a51548ee8f17ee1"}, - {file = "simplejson-3.17.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:4d1c135af0c72cb28dd259cf7ba218338f4dc027061262e46fe058b4e6a4c6a3"}, - {file = "simplejson-3.17.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:23fe704da910ff45e72543cbba152821685a889cf00fc58d5c8ee96a9bad5f94"}, - {file = "simplejson-3.17.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f444762fed1bc1fd75187ef14a20ed900c1fbb245d45be9e834b822a0223bc81"}, - {file = "simplejson-3.17.6-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:681eb4d37c9a9a6eb9b3245a5e89d7f7b2b9895590bb08a20aa598c1eb0a1d9d"}, - {file = "simplejson-3.17.6-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8e8607d8f6b4f9d46fee11447e334d6ab50e993dd4dbfb22f674616ce20907ab"}, - {file = "simplejson-3.17.6-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b10556817f09d46d420edd982dd0653940b90151d0576f09143a8e773459f6fe"}, - {file = "simplejson-3.17.6-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:e1ec8a9ee0987d4524ffd6299e778c16cc35fef6d1a2764e609f90962f0b293a"}, - {file = "simplejson-3.17.6-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:0b4126cac7d69ac06ff22efd3e0b3328a4a70624fcd6bca4fc1b4e6d9e2e12bf"}, - {file = "simplejson-3.17.6-cp38-cp38-win32.whl", hash = "sha256:35a49ebef25f1ebdef54262e54ae80904d8692367a9f208cdfbc38dbf649e00a"}, - {file = "simplejson-3.17.6-cp38-cp38-win_amd64.whl", hash = "sha256:743cd768affaa508a21499f4858c5b824ffa2e1394ed94eb85caf47ac0732198"}, - {file = "simplejson-3.17.6-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:fb62d517a516128bacf08cb6a86ecd39fb06d08e7c4980251f5d5601d29989ba"}, - {file = "simplejson-3.17.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:12133863178a8080a3dccbf5cb2edfab0001bc41e5d6d2446af2a1131105adfe"}, - {file = "simplejson-3.17.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5540fba2d437edaf4aa4fbb80f43f42a8334206ad1ad3b27aef577fd989f20d9"}, - {file = "simplejson-3.17.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d74ee72b5071818a1a5dab47338e87f08a738cb938a3b0653b9e4d959ddd1fd9"}, - {file = "simplejson-3.17.6-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:28221620f4dcabdeac310846629b976e599a13f59abb21616356a85231ebd6ad"}, - {file = "simplejson-3.17.6-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:b09bc62e5193e31d7f9876220fb429ec13a6a181a24d897b9edfbbdbcd678851"}, - {file = "simplejson-3.17.6-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7255a37ff50593c9b2f1afa8fafd6ef5763213c1ed5a9e2c6f5b9cc925ab979f"}, - {file = "simplejson-3.17.6-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:401d40969cee3df7bda211e57b903a534561b77a7ade0dd622a8d1a31eaa8ba7"}, - {file = "simplejson-3.17.6-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a649d0f66029c7eb67042b15374bd93a26aae202591d9afd71e111dd0006b198"}, - {file = "simplejson-3.17.6-cp39-cp39-win32.whl", hash = "sha256:522fad7be85de57430d6d287c4b635813932946ebf41b913fe7e880d154ade2e"}, - {file = "simplejson-3.17.6-cp39-cp39-win_amd64.whl", hash = "sha256:3fe87570168b2ae018391e2b43fbf66e8593a86feccb4b0500d134c998983ccc"}, - {file = "simplejson-3.17.6.tar.gz", hash = "sha256:cf98038d2abf63a1ada5730e91e84c642ba6c225b0198c3684151b1f80c5f8a6"}, -] - [[package]] name = "six" version = "1.16.0" @@ -3013,4 +2920,4 @@ user-search = ["pyicu"] [metadata] lock-version = "2.0" python-versions = "^3.7.1" -content-hash = "de2c4c8de336593478ce02581a5336afe2544db93ea82f3955b34c3653c29a26" +content-hash = "0ca92e52a1952f9485172efe25a039351280c28f0a158869557dc2f8855786fe" diff --git a/pyproject.toml b/pyproject.toml index 19dc7c1536..c0111dd796 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -153,15 +153,13 @@ python = "^3.7.1" # ---------------------- # we use the TYPE_CHECKER.redefine method added in jsonschema 3.0.0 jsonschema = ">=3.0.0" -# frozendict 2.1.2 is broken on Debian 10: https://github.com/Marco-Sulla/python-frozendict/issues/41 -# We cannot test our wheels against the 2.3.5 release in CI. Putting in an upper bound for this -# because frozendict has been more trouble than it's worth; we would like to move to immutabledict. -frozendict = ">=1,!=2.1.2,<2.3.5" +# We choose 2.0 as a lower bound: the most recent backwards incompatible release. +# It seems generally available, judging by https://pkgs.org/search/?q=immutabledict +immutabledict = ">=2.0" # We require 2.1.0 or higher for type hints. Previous guard was >= 1.1.0 unpaddedbase64 = ">=2.1.0" -# We require 1.5.0 to work around an issue when running against the C implementation of -# frozendict: https://github.com/matrix-org/python-canonicaljson/issues/36 -canonicaljson = "^1.5.0" +# We require 2.0.0 for immutabledict support. +canonicaljson = "^2.0.0" # we use the type definitions added in signedjson 1.1. signedjson = "^1.1.0" # validating SSL certs for IP addresses requires service_identity 18.1. diff --git a/stubs/frozendict.pyi b/stubs/frozendict.pyi deleted file mode 100644 index 196dee4461..0000000000 --- a/stubs/frozendict.pyi +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2020 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Stub for frozendict. - -from __future__ import annotations - -from typing import Any, Hashable, Iterable, Iterator, Mapping, Tuple, TypeVar, overload - -_KT = TypeVar("_KT", bound=Hashable) # Key type. -_VT = TypeVar("_VT") # Value type. - -class frozendict(Mapping[_KT, _VT]): - @overload - def __init__(self, **kwargs: _VT) -> None: ... - @overload - def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ... - @overload - def __init__( - self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT - ) -> None: ... - def __getitem__(self, key: _KT) -> _VT: ... - def __contains__(self, key: Any) -> bool: ... - def copy(self, **add_or_replace: Any) -> frozendict: ... - def __iter__(self) -> Iterator[_KT]: ... - def __len__(self) -> int: ... - def __repr__(self) -> str: ... - def __hash__(self) -> int: ... diff --git a/synapse/__init__.py b/synapse/__init__.py index a203ed533a..b97ee59f15 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -17,9 +17,9 @@ """ This is an implementation of a Matrix homeserver. """ -import json import os import sys +from typing import Any, Dict from synapse.util.rust import check_rust_lib_up_to_date from synapse.util.stringutils import strtobool @@ -61,11 +61,20 @@ try: except ImportError: pass -# Use the standard library json implementation instead of simplejson. +# Teach canonicaljson how to serialise immutabledicts. try: - from canonicaljson import set_json_library - - set_json_library(json) + from canonicaljson import register_preserialisation_callback + from immutabledict import immutabledict + + def _immutabledict_cb(d: immutabledict) -> Dict[str, Any]: + try: + return d._dict + except Exception: + # Paranoia: fall back to a `dict()` call, in case a future version of + # immutabledict removes `_dict` from the implementation. + return dict(d) + + register_preserialisation_callback(immutabledict, _immutabledict_cb) except ImportError: pass diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 23b799ac32..1a293f1df0 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -51,7 +51,7 @@ def check_event_content_hash( # some malformed events lack a 'hashes'. Protect against it being missing # or a weird type by basically treating it the same as an unhashed event. hashes = event.get("hashes") - # nb it might be a frozendict or a dict + # nb it might be a immutabledict or a dict if not isinstance(hashes, collections.abc.Mapping): raise SynapseError( 400, "Malformed 'hashes': %s" % (type(hashes),), Codes.UNAUTHORIZED diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index c04ad08cbb..9b4d692cf4 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -15,7 +15,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, List, Optional, Tuple import attr -from frozendict import frozendict +from immutabledict import immutabledict from synapse.appservice import ApplicationService from synapse.events import EventBase @@ -489,4 +489,4 @@ def _decode_state_dict( if input is None: return None - return frozendict({(etype, state_key): v for etype, state_key, v in input}) + return immutabledict({(etype, state_key): v for etype, state_key, v in input}) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index b9c15ffcdb..e41c7a4b83 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -567,7 +567,7 @@ PowerLevelsContent = Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]] def copy_and_fixup_power_levels_contents( old_power_levels: PowerLevelsContent, ) -> Dict[str, Union[int, Dict[str, int]]]: - """Copy the content of a power_levels event, unfreezing frozendicts along the way. + """Copy the content of a power_levels event, unfreezing immutabledicts along the way. We accept as input power level values which are strings, provided they represent an integer, e.g. `"`100"` instead of 100. Such strings are converted to integers diff --git a/synapse/events/validator.py b/synapse/events/validator.py index fb1737b910..6f0e4386d3 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -258,7 +258,7 @@ POWER_LEVELS_SCHEMA = { def _create_power_level_validator() -> Type[jsonschema.Draft7Validator]: validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA) - # by default jsonschema does not consider a frozendict to be an object so + # by default jsonschema does not consider a immutabledict to be an object so # we need to use a custom type checker # https://python-jsonschema.readthedocs.io/en/stable/validate/?highlight=object#validating-with-additional-types type_checker = validator.TYPE_CHECKER.redefine( diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 4dc25df67e..6031095249 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -33,7 +33,7 @@ from typing import ( ) import attr -from frozendict import frozendict +from immutabledict import immutabledict from prometheus_client import Counter, Histogram from synapse.api.constants import EventTypes @@ -105,14 +105,18 @@ class _StateCacheEntry: # # This can be None if we have a `state_group` (as then we can fetch the # state from the DB.) - self._state = frozendict(state) if state is not None else None + self._state: Optional[StateMap[str]] = ( + immutabledict(state) if state is not None else None + ) # the ID of a state group if one and only one is involved. # otherwise, None otherwise? self.state_group = state_group self.prev_group = prev_group - self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None + self.delta_ids: Optional[StateMap[str]] = ( + immutabledict(delta_ids) if delta_ids is not None else None + ) async def get_state( self, diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index ac5fbf6b86..2b8779bbb8 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -50,7 +50,7 @@ from typing import ( ) import attr -from frozendict import frozendict +from immutabledict import immutabledict from typing_extensions import Literal from twisted.internet import defer @@ -557,7 +557,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if p > min_pos } - return RoomStreamToken(None, min_pos, frozendict(positions)) + return RoomStreamToken(None, min_pos, immutabledict(positions)) async def get_room_events_stream_for_rooms( self, diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 33363867c4..c09b9cf87d 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -35,7 +35,7 @@ from typing import ( ) import attr -from frozendict import frozendict +from immutabledict import immutabledict from signedjson.key import decode_verify_key_bytes from signedjson.types import VerifyKey from typing_extensions import Final, TypedDict @@ -490,12 +490,12 @@ class RoomStreamToken: ) stream: int = attr.ib(validator=attr.validators.instance_of(int)) - instance_map: "frozendict[str, int]" = attr.ib( - factory=frozendict, + instance_map: "immutabledict[str, int]" = attr.ib( + factory=immutabledict, validator=attr.validators.deep_mapping( key_validator=attr.validators.instance_of(str), value_validator=attr.validators.instance_of(int), - mapping_validator=attr.validators.instance_of(frozendict), + mapping_validator=attr.validators.instance_of(immutabledict), ), ) @@ -531,7 +531,7 @@ class RoomStreamToken: return cls( topological=None, stream=stream, - instance_map=frozendict(instance_map), + instance_map=immutabledict(instance_map), ) except CancelledError: raise @@ -566,7 +566,7 @@ class RoomStreamToken: for instance in set(self.instance_map).union(other.instance_map) } - return RoomStreamToken(None, max_stream, frozendict(instance_map)) + return RoomStreamToken(None, max_stream, immutabledict(instance_map)) def as_historical_tuple(self) -> Tuple[int, int]: """Returns a tuple of `(topological, stream)` for historical tokens. diff --git a/synapse/types/state.py b/synapse/types/state.py index 4b3071acce..1e78a74047 100644 --- a/synapse/types/state.py +++ b/synapse/types/state.py @@ -28,7 +28,7 @@ from typing import ( ) import attr -from frozendict import frozendict +from immutabledict import immutabledict from synapse.api.constants import EventTypes from synapse.types import MutableStateMap, StateKey, StateMap @@ -56,7 +56,7 @@ class StateFilter: appear in `types`. """ - types: "frozendict[str, Optional[FrozenSet[str]]]" + types: "immutabledict[str, Optional[FrozenSet[str]]]" include_others: bool = False def __attrs_post_init__(self) -> None: @@ -67,7 +67,7 @@ class StateFilter: object.__setattr__( self, "types", - frozendict({k: v for k, v in self.types.items() if v is not None}), + immutabledict({k: v for k, v in self.types.items() if v is not None}), ) @staticmethod @@ -112,7 +112,7 @@ class StateFilter: type_dict.setdefault(typ, set()).add(s) # type: ignore return StateFilter( - types=frozendict( + types=immutabledict( (k, frozenset(v) if v is not None else None) for k, v in type_dict.items() ) @@ -139,7 +139,7 @@ class StateFilter: The new state filter """ return StateFilter( - types=frozendict({EventTypes.Member: frozenset(members)}), + types=immutabledict({EventTypes.Member: frozenset(members)}), include_others=True, ) @@ -159,7 +159,7 @@ class StateFilter: types_with_frozen_values[state_types] = None return StateFilter( - frozendict(types_with_frozen_values), include_others=include_others + immutabledict(types_with_frozen_values), include_others=include_others ) def return_expanded(self) -> "StateFilter": @@ -217,7 +217,7 @@ class StateFilter: # We want to return all non-members, but only particular # memberships return StateFilter( - types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}), + types=immutabledict({EventTypes.Member: self.types[EventTypes.Member]}), include_others=True, ) else: @@ -381,14 +381,16 @@ class StateFilter: if state_keys is None: member_filter = StateFilter.all() else: - member_filter = StateFilter(frozendict({EventTypes.Member: state_keys})) + member_filter = StateFilter( + immutabledict({EventTypes.Member: state_keys}) + ) elif self.include_others: member_filter = StateFilter.all() else: member_filter = StateFilter.none() non_member_filter = StateFilter( - types=frozendict( + types=immutabledict( {k: v for k, v in self.types.items() if k != EventTypes.Member} ), include_others=self.include_others, @@ -578,8 +580,8 @@ class StateFilter: return False -_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True) +_ALL_STATE_FILTER = StateFilter(types=immutabledict(), include_others=True) _ALL_NON_MEMBER_STATE_FILTER = StateFilter( - types=frozendict({EventTypes.Member: frozenset()}), include_others=True + types=immutabledict({EventTypes.Member: frozenset()}), include_others=True ) -_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) +_NONE_STATE_FILTER = StateFilter(types=immutabledict(), include_others=False) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 7be9d5f113..9ddd26ccaa 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -18,7 +18,7 @@ import typing from typing import Any, Callable, Dict, Generator, Optional, Sequence import attr -from frozendict import frozendict +from immutabledict import immutabledict from matrix_common.versionstring import get_distribution_version_string from typing_extensions import ParamSpec @@ -41,22 +41,18 @@ def _reject_invalid_json(val: Any) -> None: raise ValueError("Invalid JSON value: '%s'" % val) -def _handle_frozendict(obj: Any) -> Dict[Any, Any]: - """Helper for json_encoder. Makes frozendicts serializable by returning +def _handle_immutabledict(obj: Any) -> Dict[Any, Any]: + """Helper for json_encoder. Makes immutabledicts serializable by returning the underlying dict """ - if type(obj) is frozendict: + if type(obj) is immutabledict: # fishing the protected dict out of the object is a bit nasty, # but we don't really want the overhead of copying the dict. try: # Safety: we catch the AttributeError immediately below. - # See https://github.com/matrix-org/python-canonicaljson/issues/36#issuecomment-927816293 - # for discussion on how frozendict's internals have changed over time. - return obj._dict # type: ignore[attr-defined] + return obj._dict except AttributeError: - # When the C implementation of frozendict is used, - # there isn't a `_dict` attribute with a dict - # so we resort to making a copy of the frozendict + # If all else fails, resort to making a copy of the immutabledict return dict(obj) raise TypeError( "Object of type %s is not JSON serializable" % obj.__class__.__name__ @@ -64,11 +60,11 @@ def _handle_frozendict(obj: Any) -> Dict[Any, Any]: # A custom JSON encoder which: -# * handles frozendicts +# * handles immutabledicts # * produces valid JSON (no NaNs etc) # * reduces redundant whitespace json_encoder = json.JSONEncoder( - allow_nan=False, separators=(",", ":"), default=_handle_frozendict + allow_nan=False, separators=(",", ":"), default=_handle_immutabledict ) # Create a custom decoder to reject Python extensions to JSON. diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 7223af1a36..889caa2601 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -14,14 +14,14 @@ import collections.abc from typing import Any -from frozendict import frozendict +from immutabledict import immutabledict def freeze(o: Any) -> Any: if isinstance(o, dict): - return frozendict({k: freeze(v) for k, v in o.items()}) + return immutabledict({k: freeze(v) for k, v in o.items()}) - if isinstance(o, frozendict): + if isinstance(o, immutabledict): return o if isinstance(o, (bytes, str)): diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 0f45615160..6c6a9ab4b4 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -18,7 +18,6 @@ from typing import List from unittest.mock import patch import jsonschema -from frozendict import frozendict from twisted.test.proto_helpers import MemoryReactor @@ -29,6 +28,7 @@ from synapse.api.presence import UserPresenceState from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock +from synapse.util.frozenutils import freeze from tests import unittest from tests.events.test_utils import MockEvent @@ -343,12 +343,12 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.assertFalse(Filter(self.hs, definition)._check(event)) - # check it works with frozendicts too + # check it works with frozen dictionaries too event = MockEvent( sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content=frozendict({EventContentFields.LABELS: ["#fun"]}), + content=freeze({EventContentFields.LABELS: ["#fun"]}), ) self.assertTrue(Filter(self.hs, definition)._check(event)) diff --git a/tests/config/test_workers.py b/tests/config/test_workers.py index ef6294ecb2..49a6bdf408 100644 --- a/tests/config/test_workers.py +++ b/tests/config/test_workers.py @@ -14,14 +14,14 @@ from typing import Any, Mapping, Optional from unittest.mock import Mock -from frozendict import frozendict +from immutabledict import immutabledict from synapse.config import ConfigError from synapse.config.workers import WorkerConfig from tests.unittest import TestCase -_EMPTY_FROZENDICT: Mapping[str, Any] = frozendict() +_EMPTY_IMMUTABLEDICT: Mapping[str, Any] = immutabledict() class WorkerDutyConfigTestCase(TestCase): @@ -29,7 +29,7 @@ class WorkerDutyConfigTestCase(TestCase): self, worker_app: str, worker_name: Optional[str], - extras: Mapping[str, Any] = _EMPTY_FROZENDICT, + extras: Mapping[str, Any] = _EMPTY_IMMUTABLEDICT, ) -> WorkerConfig: root_config = Mock() root_config.worker_app = worker_app diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 52c4aafea6..b2536562e0 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -14,8 +14,6 @@ from typing import Any, Dict, List, Optional, Union, cast -import frozendict - from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -318,11 +316,11 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): "pattern should only match at the start/end of the value", ) - # it should work on frozendicts too + # it should work on frozen dictionaries too self._assert_matches( condition, - frozendict.frozendict({"value": "FoobaZ"}), - "patterns should match on frozendicts", + freeze({"value": "FoobaZ"}), + "patterns should match on frozen dictionaries", ) # wildcards should match @@ -425,11 +423,11 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): "incorrect types should not match", ) - # it should work on frozendicts too + # it should work on frozen dictionaries too self._assert_matches( condition, - frozendict.frozendict({"value": "foobaz"}), - "values should match on frozendicts", + freeze({"value": "foobaz"}), + "values should match on frozen dictionaries", ) def test_exact_event_match_boolean(self) -> None: @@ -546,11 +544,11 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): "does not search in a string", ) - # it should work on frozendicts too + # it should work on frozen dictionaries too self._assert_matches( condition, freeze({"value": ["foobaz"]}), - "values should match on frozendicts", + "values should match on frozen dictionaries", ) def test_no_body(self) -> None: diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 62aed6af0a..0b9446c36c 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -14,7 +14,7 @@ import logging -from frozendict import frozendict +from immutabledict import immutabledict from twisted.test.proto_helpers import MemoryReactor @@ -198,7 +198,7 @@ class StateStoreTestCase(HomeserverTestCase): self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( - types=frozendict( + types=immutabledict( {EventTypes.Member: frozenset({self.u_alice.to_string()})} ), include_others=True, @@ -220,7 +220,7 @@ class StateStoreTestCase(HomeserverTestCase): self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( - types=frozendict({EventTypes.Member: frozenset()}), + types=immutabledict({EventTypes.Member: frozenset()}), include_others=True, ), ) @@ -246,7 +246,8 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_cache, group, state_filter=StateFilter( - types=frozendict({EventTypes.Member: frozenset()}), include_others=True + types=immutabledict({EventTypes.Member: frozenset()}), + include_others=True, ), ) @@ -263,7 +264,8 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types=frozendict({EventTypes.Member: frozenset()}), include_others=True + types=immutabledict({EventTypes.Member: frozenset()}), + include_others=True, ), ) @@ -276,7 +278,7 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_cache, group, state_filter=StateFilter( - types=frozendict({EventTypes.Member: None}), include_others=True + types=immutabledict({EventTypes.Member: None}), include_others=True ), ) @@ -293,7 +295,7 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types=frozendict({EventTypes.Member: None}), include_others=True + types=immutabledict({EventTypes.Member: None}), include_others=True ), ) @@ -313,7 +315,7 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_cache, group, state_filter=StateFilter( - types=frozendict({EventTypes.Member: frozenset({e5.state_key})}), + types=immutabledict({EventTypes.Member: frozenset({e5.state_key})}), include_others=True, ), ) @@ -331,7 +333,7 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types=frozendict({EventTypes.Member: frozenset({e5.state_key})}), + types=immutabledict({EventTypes.Member: frozenset({e5.state_key})}), include_others=True, ), ) @@ -345,7 +347,7 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types=frozendict({EventTypes.Member: frozenset({e5.state_key})}), + types=immutabledict({EventTypes.Member: frozenset({e5.state_key})}), include_others=False, ), ) @@ -396,7 +398,8 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_cache, group, state_filter=StateFilter( - types=frozendict({EventTypes.Member: frozenset()}), include_others=True + types=immutabledict({EventTypes.Member: frozenset()}), + include_others=True, ), ) @@ -408,7 +411,8 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types=frozendict({EventTypes.Member: frozenset()}), include_others=True + types=immutabledict({EventTypes.Member: frozenset()}), + include_others=True, ), ) @@ -421,7 +425,7 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_cache, group, state_filter=StateFilter( - types=frozendict({EventTypes.Member: None}), include_others=True + types=immutabledict({EventTypes.Member: None}), include_others=True ), ) @@ -432,7 +436,7 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types=frozendict({EventTypes.Member: None}), include_others=True + types=immutabledict({EventTypes.Member: None}), include_others=True ), ) @@ -451,7 +455,7 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_cache, group, state_filter=StateFilter( - types=frozendict({EventTypes.Member: frozenset({e5.state_key})}), + types=immutabledict({EventTypes.Member: frozenset({e5.state_key})}), include_others=True, ), ) @@ -463,7 +467,7 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types=frozendict({EventTypes.Member: frozenset({e5.state_key})}), + types=immutabledict({EventTypes.Member: frozenset({e5.state_key})}), include_others=True, ), ) @@ -477,7 +481,7 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_cache, group, state_filter=StateFilter( - types=frozendict({EventTypes.Member: frozenset({e5.state_key})}), + types=immutabledict({EventTypes.Member: frozenset({e5.state_key})}), include_others=False, ), ) @@ -489,7 +493,7 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types=frozendict({EventTypes.Member: frozenset({e5.state_key})}), + types=immutabledict({EventTypes.Member: frozenset({e5.state_key})}), include_others=False, ), ) diff --git a/tests/types/test_state.py b/tests/types/test_state.py index eb809f9fb7..1d89582c44 100644 --- a/tests/types/test_state.py +++ b/tests/types/test_state.py @@ -1,4 +1,4 @@ -from frozendict import frozendict +from immutabledict import immutabledict from synapse.api.constants import EventTypes from synapse.types.state import StateFilter @@ -172,7 +172,7 @@ class StateFilterDifferenceTestCase(TestCase): }, include_others=False, ), - StateFilter(types=frozendict(), include_others=True), + StateFilter(types=immutabledict(), include_others=True), ) # (wildcard on state keys) - (no state keys) @@ -188,7 +188,7 @@ class StateFilterDifferenceTestCase(TestCase): include_others=False, ), StateFilter( - types=frozendict(), + types=immutabledict(), include_others=True, ), ) @@ -279,7 +279,7 @@ class StateFilterDifferenceTestCase(TestCase): {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, include_others=True, ), - StateFilter(types=frozendict(), include_others=False), + StateFilter(types=immutabledict(), include_others=False), ) # (wildcard on state keys) - (specific state keys) @@ -332,7 +332,7 @@ class StateFilterDifferenceTestCase(TestCase): include_others=True, ), StateFilter( - types=frozendict(), + types=immutabledict(), include_others=False, ), ) @@ -403,7 +403,7 @@ class StateFilterDifferenceTestCase(TestCase): {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, include_others=True, ), - StateFilter(types=frozendict(), include_others=False), + StateFilter(types=immutabledict(), include_others=False), ) # (wildcard on state keys) - (specific state keys) @@ -450,7 +450,7 @@ class StateFilterDifferenceTestCase(TestCase): include_others=True, ), StateFilter( - types=frozendict(), + types=immutabledict(), include_others=False, ), ) -- cgit 1.5.1 From e6af49fbea939d9e69ed05e0a0ced5948c722ea4 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Fri, 24 Mar 2023 11:44:01 +0000 Subject: Reintroduce membership tables event stream ordering (#15128) * Add `event_stream_ordering` column to membership state tables Specifically this adds the column to `current_state_events`, `local_current_membership` and `room_memberships`. Each of these tables is regularly joined with the `events` table to get the stream ordering and denormalising this into each table will yield significant query performance improvements once used. * Make denormalised `event_stream_ordering` columns foreign keys * Add comment in schema file explaining new denormalised columns * Add triggers to enforce consistency of `event_stream_ordering` columns * Re-order purge room tables to account for foreign keys * Bump schema version to 75 Co-authored-by: David Robertson Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- changelog.d/15128.misc | 1 + synapse/storage/databases/main/events.py | 23 +++++-- synapse/storage/databases/main/purge_events.py | 6 +- synapse/storage/schema/__init__.py | 14 ++-- .../01membership_tables_event_stream_ordering.sql | 20 ++++++ ...ership_tables_event_stream_ordering_triggers.py | 79 ++++++++++++++++++++++ 6 files changed, 131 insertions(+), 12 deletions(-) create mode 100644 changelog.d/15128.misc create mode 100644 synapse/storage/schema/main/delta/74/01membership_tables_event_stream_ordering.sql create mode 100644 synapse/storage/schema/main/delta/74/02membership_tables_event_stream_ordering_triggers.py (limited to 'synapse/storage/databases') diff --git a/changelog.d/15128.misc b/changelog.d/15128.misc new file mode 100644 index 0000000000..c09911e48d --- /dev/null +++ b/changelog.d/15128.misc @@ -0,0 +1 @@ +Add denormalised event stream ordering column to membership state tables for future use. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index a8a4ed4436..193959b250 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1126,11 +1126,15 @@ class PersistEventsStore: # been inserted into room_memberships. txn.execute_batch( """INSERT INTO current_state_events - (room_id, type, state_key, event_id, membership) - VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) + (room_id, type, state_key, event_id, membership, event_stream_ordering) + VALUES ( + ?, ?, ?, ?, + (SELECT membership FROM room_memberships WHERE event_id = ?), + (SELECT stream_ordering FROM events WHERE event_id = ?) + ) """, [ - (room_id, key[0], key[1], ev_id, ev_id) + (room_id, key[0], key[1], ev_id, ev_id, ev_id) for key, ev_id in to_insert.items() ], ) @@ -1157,11 +1161,15 @@ class PersistEventsStore: if to_insert: txn.execute_batch( """INSERT INTO local_current_membership - (room_id, user_id, event_id, membership) - VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) + (room_id, user_id, event_id, membership, event_stream_ordering) + VALUES ( + ?, ?, ?, + (SELECT membership FROM room_memberships WHERE event_id = ?), + (SELECT stream_ordering FROM events WHERE event_id = ?) + ) """, [ - (room_id, key[1], ev_id, ev_id) + (room_id, key[1], ev_id, ev_id, ev_id) for key, ev_id in to_insert.items() if key[0] == EventTypes.Member and self.is_mine_id(key[1]) ], @@ -1769,6 +1777,7 @@ class PersistEventsStore: table="room_memberships", keys=( "event_id", + "event_stream_ordering", "user_id", "sender", "room_id", @@ -1779,6 +1788,7 @@ class PersistEventsStore: values=[ ( event.event_id, + event.internal_metadata.stream_ordering, event.state_key, event.user_id, event.room_id, @@ -1811,6 +1821,7 @@ class PersistEventsStore: keyvalues={"room_id": event.room_id, "user_id": event.state_key}, values={ "event_id": event.event_id, + "event_stream_ordering": event.internal_metadata.stream_ordering, "membership": event.membership, }, ) diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 7a7c0d9c75..efbd3e75d9 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -428,14 +428,16 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "partial_state_events", "partial_state_rooms_servers", "partial_state_rooms", + # Note: the _membership(s) tables have foreign keys to the `events` table + # so must be deleted first. + "local_current_membership", + "room_memberships", "events", "federation_inbound_events_staging", - "local_current_membership", "receipts_graph", "receipts_linearized", "room_aliases", "room_depth", - "room_memberships", "room_stats_state", "room_stats_current", "room_stats_earliest_token", diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index d3103a6c7a..a28f2b997c 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 74 # remember to update the list below when updating +SCHEMA_VERSION = 75 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -91,13 +91,19 @@ Changes in SCHEMA_VERSION = 74: - A query on `event_stream_ordering` column has now been disambiguated (i.e. the codebase can handle the `current_state_events`, `local_current_memberships` and `room_memberships` tables having an `event_stream_ordering` column). + +Changes in SCHEMA_VERSION = 75: + - The `event_stream_ordering` column in membership tables (`current_state_events`, + `local_current_membership` & `room_memberships`) is now being populated for new + rows. When the background job to populate historical rows lands this will + become the compat schema version. """ SCHEMA_COMPAT_VERSION = ( - # The threads_id column must exist for event_push_actions, event_push_summary, - # receipts_linearized, and receipts_graph. - 73 + # Queries against `event_stream_ordering` columns in membership tables must + # be disambiguated. + 74 ) """Limit on how far the synapse codebase can be rolled back without breaking db compat diff --git a/synapse/storage/schema/main/delta/74/01membership_tables_event_stream_ordering.sql b/synapse/storage/schema/main/delta/74/01membership_tables_event_stream_ordering.sql new file mode 100644 index 0000000000..e2608f3a2e --- /dev/null +++ b/synapse/storage/schema/main/delta/74/01membership_tables_event_stream_ordering.sql @@ -0,0 +1,20 @@ +/* Copyright 2022 Beeper + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Each of these are denormalised copies of `stream_ordering` from the corresponding row in` events` which +-- we use to improve database performance by reduring JOINs. +ALTER TABLE current_state_events ADD COLUMN event_stream_ordering BIGINT REFERENCES events(stream_ordering); +ALTER TABLE local_current_membership ADD COLUMN event_stream_ordering BIGINT REFERENCES events(stream_ordering); +ALTER TABLE room_memberships ADD COLUMN event_stream_ordering BIGINT REFERENCES events(stream_ordering); diff --git a/synapse/storage/schema/main/delta/74/02membership_tables_event_stream_ordering_triggers.py b/synapse/storage/schema/main/delta/74/02membership_tables_event_stream_ordering_triggers.py new file mode 100644 index 0000000000..e32e9083b3 --- /dev/null +++ b/synapse/storage/schema/main/delta/74/02membership_tables_event_stream_ordering_triggers.py @@ -0,0 +1,79 @@ +# Copyright 2022 Beeper +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This migration adds triggers to the room membership tables to enforce consistency. +Triggers cannot be expressed in .sql files, so we have to use a separate file. +""" +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine +from synapse.storage.types import Cursor + + +def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): + # Complain if the `event_stream_ordering` in membership tables doesn't match + # the `stream_ordering` row with the same `event_id` in `events`. + if isinstance(database_engine, Sqlite3Engine): + for table in ( + "current_state_events", + "local_current_membership", + "room_memberships", + ): + cur.execute( + f""" + CREATE TRIGGER IF NOT EXISTS {table}_bad_event_stream_ordering + BEFORE INSERT ON {table} + FOR EACH ROW + BEGIN + SELECT RAISE(ABORT, 'Incorrect event_stream_ordering in {table}') + WHERE EXISTS ( + SELECT 1 FROM events + WHERE events.event_id = NEW.event_id + AND events.stream_ordering != NEW.event_stream_ordering + ); + END; + """ + ) + elif isinstance(database_engine, PostgresEngine): + cur.execute( + """ + CREATE OR REPLACE FUNCTION check_event_stream_ordering() RETURNS trigger AS $BODY$ + BEGIN + IF EXISTS ( + SELECT 1 FROM events + WHERE events.event_id = NEW.event_id + AND events.stream_ordering != NEW.event_stream_ordering + ) THEN + RAISE EXCEPTION 'Incorrect event_stream_ordering'; + END IF; + RETURN NEW; + END; + $BODY$ LANGUAGE plpgsql; + """ + ) + + for table in ( + "current_state_events", + "local_current_membership", + "room_memberships", + ): + cur.execute( + f""" + CREATE TRIGGER check_event_stream_ordering BEFORE INSERT OR UPDATE ON {table} + FOR EACH ROW + EXECUTE PROCEDURE check_event_stream_ordering() + """ + ) + else: + raise NotImplementedError("Unknown database engine") -- cgit 1.5.1 From 5b70f240cf70b390db7e74ab614ace108fc08d70 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 24 Mar 2023 16:09:39 +0100 Subject: Make cleaning up pushers depend on the device_id instead of the token_id (#15280) This makes it so that we rely on the `device_id` to delete pushers on logout, instead of relying on the `access_token_id`. This ensures we're not removing pushers on token refresh, and prepares for a world without access token IDs (also known as the OIDC). This actually runs the `set_device_id_for_pushers` background update, which was forgotten in #13831. Note that for backwards compatibility it still deletes pushers based on the `access_token` until the background update finishes. --- changelog.d/15280.misc | 1 + synapse/_scripts/synapse_port_db.py | 6 ++- synapse/handlers/auth.py | 8 ++- synapse/handlers/device.py | 2 + synapse/handlers/register.py | 4 +- synapse/push/__init__.py | 7 ++- synapse/push/pusherpool.py | 58 ++++++++++++++++------ synapse/rest/admin/users.py | 1 - synapse/rest/client/pusher.py | 1 - synapse/storage/databases/main/pusher.py | 40 +++++++++++---- .../74/02_set_device_id_for_pushers_bg_update.sql | 19 +++++++ tests/push/test_email.py | 6 +-- tests/push/test_http.py | 46 ++++++++--------- tests/replication/test_pusher_shard.py | 4 +- tests/rest/admin/test_user.py | 4 +- 15 files changed, 142 insertions(+), 65 deletions(-) create mode 100644 changelog.d/15280.misc create mode 100644 synapse/storage/schema/main/delta/74/02_set_device_id_for_pushers_bg_update.sql (limited to 'synapse/storage/databases') diff --git a/changelog.d/15280.misc b/changelog.d/15280.misc new file mode 100644 index 0000000000..41d56b0cf0 --- /dev/null +++ b/changelog.d/15280.misc @@ -0,0 +1 @@ +Make the pushers rely on the `device_id` instead of the `access_token_id` for various operations. diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 94b86c1d6f..1dcb397ba4 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -68,7 +68,10 @@ from synapse.storage.databases.main.media_repository import ( MediaRepositoryBackgroundUpdateStore, ) from synapse.storage.databases.main.presence import PresenceBackgroundUpdateStore -from synapse.storage.databases.main.pusher import PusherWorkerStore +from synapse.storage.databases.main.pusher import ( + PusherBackgroundUpdatesStore, + PusherWorkerStore, +) from synapse.storage.databases.main.receipts import ReceiptsBackgroundUpdateStore from synapse.storage.databases.main.registration import ( RegistrationBackgroundUpdateStore, @@ -226,6 +229,7 @@ class Store( AccountDataWorkerStore, PushRuleStore, PusherWorkerStore, + PusherBackgroundUpdatesStore, PresenceBackgroundUpdateStore, ReceiptsBackgroundUpdateStore, RelationsWorkerStore, diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 308e38edea..1e89447044 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1504,8 +1504,10 @@ class AuthHandler: ) # delete pushers associated with this access token + # XXX(quenting): This is only needed until the 'set_device_id_for_pushers' + # background update completes. if token.token_id is not None: - await self.hs.get_pusherpool().remove_pushers_by_access_token( + await self.hs.get_pusherpool().remove_pushers_by_access_tokens( token.user_id, (token.token_id,) ) @@ -1535,7 +1537,9 @@ class AuthHandler: ) # delete pushers associated with the access tokens - await self.hs.get_pusherpool().remove_pushers_by_access_token( + # XXX(quenting): This is only needed until the 'set_device_id_for_pushers' + # background update completes. + await self.hs.get_pusherpool().remove_pushers_by_access_tokens( user_id, (token_id for _, token_id, _ in tokens_and_devices) ) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 6f7963df43..9ded6389ac 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -503,6 +503,8 @@ class DeviceHandler(DeviceWorkerHandler): else: raise + await self.hs.get_pusherpool().remove_pushers_by_devices(user_id, device_ids) + # Delete data specific to each device. Not optimised as it is not # considered as part of a critical path. for device_id in device_ids: diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 6b110dcb6e..c8bf2439af 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -1013,11 +1013,11 @@ class RegistrationHandler: user_tuple = await self.store.get_user_by_access_token(token) # The token better still exist. assert user_tuple - token_id = user_tuple.token_id + device_id = user_tuple.device_id await self.pusher_pool.add_or_update_pusher( user_id=user_id, - access_token=token_id, + device_id=device_id, kind="email", app_id="m.email", app_display_name="Email Notifications", diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index a0c760239d..9e3a98741a 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -103,7 +103,7 @@ class PusherConfig: id: Optional[str] user_name: str - access_token: Optional[int] + profile_tag: str kind: str app_id: str @@ -119,6 +119,11 @@ class PusherConfig: enabled: bool device_id: Optional[str] + # XXX(quenting): The access_token is not persisted anymore for new pushers, but we + # keep it when reading from the database, so that we don't get stale pushers + # while the "set_device_id_for_pushers" background update is running. + access_token: Optional[int] + def as_dict(self) -> Dict[str, Any]: """Information that can be retrieved about a pusher after creation.""" return { diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index e2648cbc93..6517e3566f 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -25,7 +25,7 @@ from synapse.metrics.background_process_metrics import ( from synapse.push import Pusher, PusherConfig, PusherConfigException from synapse.push.pusher import PusherFactory from synapse.replication.http.push import ReplicationRemovePusherRestServlet -from synapse.types import JsonDict, RoomStreamToken +from synapse.types import JsonDict, RoomStreamToken, StrCollection from synapse.util.async_helpers import concurrently_execute from synapse.util.threepids import canonicalise_email @@ -97,7 +97,6 @@ class PusherPool: async def add_or_update_pusher( self, user_id: str, - access_token: Optional[int], kind: str, app_id: str, app_display_name: str, @@ -128,6 +127,22 @@ class PusherPool: # stream ordering, so it will process pushes from this point onwards. last_stream_ordering = self.store.get_room_max_stream_ordering() + # Before we actually persist the pusher, we check if the user already has one + # for this app ID and pushkey. If so, we want to keep the access token and + # device ID in place, since this could be one device modifying + # (e.g. enabling/disabling) another device's pusher. + # XXX(quenting): Even though we're not persisting the access_token_id for new + # pushers anymore, we still need to copy existing access_token_ids over when + # updating a pusher, in case the "set_device_id_for_pushers" background update + # hasn't run yet. + access_token_id = None + existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey( + user_id, app_id, pushkey + ) + if existing_config: + device_id = existing_config.device_id + access_token_id = existing_config.access_token + # we try to create the pusher just to validate the config: it # will then get pulled out of the database, # recreated, added and started: this means we have only one @@ -136,7 +151,6 @@ class PusherPool: PusherConfig( id=None, user_name=user_id, - access_token=access_token, profile_tag=profile_tag, kind=kind, app_id=app_id, @@ -151,23 +165,12 @@ class PusherPool: failing_since=None, enabled=enabled, device_id=device_id, + access_token=access_token_id, ) ) - # Before we actually persist the pusher, we check if the user already has one - # this app ID and pushkey. If so, we want to keep the access token and device ID - # in place, since this could be one device modifying (e.g. enabling/disabling) - # another device's pusher. - existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey( - user_id, app_id, pushkey - ) - if existing_config: - access_token = existing_config.access_token - device_id = existing_config.device_id - await self.store.add_pusher( user_id=user_id, - access_token=access_token, kind=kind, app_id=app_id, app_display_name=app_display_name, @@ -180,6 +183,7 @@ class PusherPool: profile_tag=profile_tag, enabled=enabled, device_id=device_id, + access_token_id=access_token_id, ) pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id) @@ -199,7 +203,7 @@ class PusherPool: ) await self.remove_pusher(p.app_id, p.pushkey, p.user_name) - async def remove_pushers_by_access_token( + async def remove_pushers_by_access_tokens( self, user_id: str, access_tokens: Iterable[int] ) -> None: """Remove the pushers for a given user corresponding to a set of @@ -209,6 +213,8 @@ class PusherPool: user_id: user to remove pushers for access_tokens: access token *ids* to remove pushers for """ + # XXX(quenting): This is only needed until the "set_device_id_for_pushers" + # background update finishes tokens = set(access_tokens) for p in await self.store.get_pushers_by_user_id(user_id): if p.access_token in tokens: @@ -220,6 +226,26 @@ class PusherPool: ) await self.remove_pusher(p.app_id, p.pushkey, p.user_name) + async def remove_pushers_by_devices( + self, user_id: str, devices: StrCollection + ) -> None: + """Remove the pushers for a given user corresponding to a set of devices + + Args: + user_id: user to remove pushers for + devices: device IDs to remove pushers for + """ + device_ids = set(devices) + for p in await self.store.get_pushers_by_user_id(user_id): + if p.device_id in device_ids: + logger.info( + "Removing pusher for app id %s, pushkey %s, user %s", + p.app_id, + p.pushkey, + p.user_name, + ) + await self.remove_pusher(p.app_id, p.pushkey, p.user_name) + def on_new_notifications(self, max_token: RoomStreamToken) -> None: if not self.pushers: # nothing to do here. diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 281e8fd0ad..331f225116 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -425,7 +425,6 @@ class UserRestServletV2(RestServlet): ): await self.pusher_pool.add_or_update_pusher( user_id=user_id, - access_token=None, kind="email", app_id="m.email", app_display_name="Email Notifications", diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py index 975eef2144..1a8f5292ac 100644 --- a/synapse/rest/client/pusher.py +++ b/synapse/rest/client/pusher.py @@ -126,7 +126,6 @@ class PushersSetRestServlet(RestServlet): try: await self.pusher_pool.add_or_update_pusher( user_id=user.to_string(), - access_token=requester.access_token_id, kind=content["kind"], app_id=content["app_id"], app_display_name=content["app_display_name"], diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 9a24f7a655..ab76b754e0 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -509,19 +509,24 @@ class PusherBackgroundUpdatesStore(SQLBaseStore): async def _set_device_id_for_pushers( self, progress: JsonDict, batch_size: int ) -> int: - """Background update to populate the device_id column of the pushers table.""" + """ + Background update to populate the device_id column and clear the access_token + column for the pushers table. + """ last_pusher_id = progress.get("pusher_id", 0) def set_device_id_for_pushers_txn(txn: LoggingTransaction) -> int: txn.execute( """ - SELECT p.id, at.device_id + SELECT + p.id AS pusher_id, + p.device_id AS pusher_device_id, + at.device_id AS token_device_id FROM pushers AS p - INNER JOIN access_tokens AS at + LEFT JOIN access_tokens AS at ON p.access_token = at.id WHERE p.access_token IS NOT NULL - AND at.device_id IS NOT NULL AND p.id > ? ORDER BY p.id LIMIT ? @@ -533,13 +538,27 @@ class PusherBackgroundUpdatesStore(SQLBaseStore): if len(rows) == 0: return 0 + # The reason we're clearing the access_token column here is a bit subtle. + # When a user logs out, we: + # (1) delete the access token + # (2) delete the device + # + # Ideally, we would delete the pushers only via its link to the device + # during (2), but since this background update might not have fully run yet, + # we're still deleting the pushers via the access token during (1). self.db_pool.simple_update_many_txn( txn=txn, table="pushers", key_names=("id",), - key_values=[(row["id"],) for row in rows], - value_names=("device_id",), - value_values=[(row["device_id"],) for row in rows], + key_values=[(row["pusher_id"],) for row in rows], + value_names=("device_id", "access_token"), + # If there was already a device_id on the pusher, we only want to clear + # the access_token column, so we keep the existing device_id. Otherwise, + # we set the device_id we got from joining the access_tokens table. + value_values=[ + (row["pusher_device_id"] or row["token_device_id"], None) + for row in rows + ], ) self.db_pool.updates._background_update_progress_txn( @@ -568,7 +587,6 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore): async def add_pusher( self, user_id: str, - access_token: Optional[int], kind: str, app_id: str, app_display_name: str, @@ -581,13 +599,13 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore): profile_tag: str = "", enabled: bool = True, device_id: Optional[str] = None, + access_token_id: Optional[int] = None, ) -> None: async with self._pushers_id_gen.get_next() as stream_id: await self.db_pool.simple_upsert( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, values={ - "access_token": access_token, "kind": kind, "app_display_name": app_display_name, "device_display_name": device_display_name, @@ -599,6 +617,10 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore): "id": stream_id, "enabled": enabled, "device_id": device_id, + # XXX(quenting): We're only really persisting the access token ID + # when updating an existing pusher. This is in case the + # 'set_device_id_for_pushers' background update hasn't finished yet. + "access_token": access_token_id, }, desc="add_pusher", ) diff --git a/synapse/storage/schema/main/delta/74/02_set_device_id_for_pushers_bg_update.sql b/synapse/storage/schema/main/delta/74/02_set_device_id_for_pushers_bg_update.sql new file mode 100644 index 0000000000..1367fb6267 --- /dev/null +++ b/synapse/storage/schema/main/delta/74/02_set_device_id_for_pushers_bg_update.sql @@ -0,0 +1,19 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Triggers the background update to set the device_id for pushers +-- that don't have one, and clear the access_token column. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7402, 'set_device_id_for_pushers', '{}'); diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 4ea5472eb4..4b5c96aeae 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -105,7 +105,7 @@ class EmailPusherTests(HomeserverTestCase): self.hs.get_datastores().main.get_user_by_access_token(self.access_token) ) assert user_tuple is not None - self.token_id = user_tuple.token_id + self.device_id = user_tuple.device_id # We need to add email to account before we can create a pusher. self.get_success( @@ -117,7 +117,7 @@ class EmailPusherTests(HomeserverTestCase): pusher = self.get_success( self.hs.get_pusherpool().add_or_update_pusher( user_id=self.user_id, - access_token=self.token_id, + device_id=self.device_id, kind="email", app_id="m.email", app_display_name="Email Notifications", @@ -141,7 +141,7 @@ class EmailPusherTests(HomeserverTestCase): self.get_success_or_raise( self.hs.get_pusherpool().add_or_update_pusher( user_id=self.user_id, - access_token=self.token_id, + device_id=self.device_id, kind="email", app_id="m.email", app_display_name="Email Notifications", diff --git a/tests/push/test_http.py b/tests/push/test_http.py index c280ddcdf6..99cec0836b 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -67,13 +67,13 @@ class HTTPPusherTests(HomeserverTestCase): self.hs.get_datastores().main.get_user_by_access_token(access_token) ) assert user_tuple is not None - token_id = user_tuple.token_id + device_id = user_tuple.device_id def test_data(data: Any) -> None: self.get_failure( self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, - access_token=token_id, + device_id=device_id, kind="http", app_id="m.http", app_display_name="HTTP Push Notifications", @@ -114,12 +114,12 @@ class HTTPPusherTests(HomeserverTestCase): self.hs.get_datastores().main.get_user_by_access_token(access_token) ) assert user_tuple is not None - token_id = user_tuple.token_id + device_id = user_tuple.device_id self.get_success( self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, - access_token=token_id, + device_id=device_id, kind="http", app_id="m.http", app_display_name="HTTP Push Notifications", @@ -235,12 +235,12 @@ class HTTPPusherTests(HomeserverTestCase): self.hs.get_datastores().main.get_user_by_access_token(access_token) ) assert user_tuple is not None - token_id = user_tuple.token_id + device_id = user_tuple.device_id self.get_success( self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, - access_token=token_id, + device_id=device_id, kind="http", app_id="m.http", app_display_name="HTTP Push Notifications", @@ -356,12 +356,12 @@ class HTTPPusherTests(HomeserverTestCase): self.hs.get_datastores().main.get_user_by_access_token(access_token) ) assert user_tuple is not None - token_id = user_tuple.token_id + device_id = user_tuple.device_id self.get_success( self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, - access_token=token_id, + device_id=device_id, kind="http", app_id="m.http", app_display_name="HTTP Push Notifications", @@ -443,12 +443,12 @@ class HTTPPusherTests(HomeserverTestCase): self.hs.get_datastores().main.get_user_by_access_token(access_token) ) assert user_tuple is not None - token_id = user_tuple.token_id + device_id = user_tuple.device_id self.get_success( self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, - access_token=token_id, + device_id=device_id, kind="http", app_id="m.http", app_display_name="HTTP Push Notifications", @@ -521,12 +521,12 @@ class HTTPPusherTests(HomeserverTestCase): self.hs.get_datastores().main.get_user_by_access_token(access_token) ) assert user_tuple is not None - token_id = user_tuple.token_id + device_id = user_tuple.device_id self.get_success( self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, - access_token=token_id, + device_id=device_id, kind="http", app_id="m.http", app_display_name="HTTP Push Notifications", @@ -628,12 +628,12 @@ class HTTPPusherTests(HomeserverTestCase): self.hs.get_datastores().main.get_user_by_access_token(access_token) ) assert user_tuple is not None - token_id = user_tuple.token_id + device_id = user_tuple.device_id self.get_success( self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, - access_token=token_id, + device_id=device_id, kind="http", app_id="m.http", app_display_name="HTTP Push Notifications", @@ -764,12 +764,12 @@ class HTTPPusherTests(HomeserverTestCase): self.hs.get_datastores().main.get_user_by_access_token(access_token) ) assert user_tuple is not None - token_id = user_tuple.token_id + device_id = user_tuple.device_id self.get_success( self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, - access_token=token_id, + device_id=device_id, kind="http", app_id="m.http", app_display_name="HTTP Push Notifications", @@ -778,7 +778,6 @@ class HTTPPusherTests(HomeserverTestCase): lang=None, data={"url": "http://example.com/_matrix/push/v1/notify"}, enabled=enabled, - device_id=user_tuple.device_id, ) ) @@ -895,19 +894,17 @@ class HTTPPusherTests(HomeserverTestCase): def test_update_different_device_access_token_device_id(self) -> None: """Tests that if we create a pusher from one device, the update it from another - device, the access token and device ID associated with the pusher stays the - same. + device, the device ID associated with the pusher stays the same. """ # Create a user with a pusher. user_id, access_token = self._make_user_with_pusher("user") - # Get the token ID for the current access token, since that's what we store in - # the pushers table. Also get the device ID from it. + # Get the device ID for the current access token, since that's what we store in + # the pushers table. user_tuple = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(access_token) ) assert user_tuple is not None - token_id = user_tuple.token_id device_id = user_tuple.device_id # Generate a new access token, and update the pusher with it. @@ -920,10 +917,9 @@ class HTTPPusherTests(HomeserverTestCase): ) pushers: List[PusherConfig] = list(ret) - # Check that we still have one pusher, and that the access token and device ID - # associated with it didn't change. + # Check that we still have one pusher, and that the device ID associated with + # it didn't change. self.assertEqual(len(pushers), 1) - self.assertEqual(pushers[0].access_token, token_id) self.assertEqual(pushers[0].device_id, device_id) @override_config({"experimental_features": {"msc3881_enabled": True}}) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 0798b021c3..dcb3e6669b 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -51,12 +51,12 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): self.hs.get_datastores().main.get_user_by_access_token(access_token) ) assert user_dict is not None - token_id = user_dict.token_id + device_id = user_dict.device_id self.get_success( self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, - access_token=token_id, + device_id=device_id, kind="http", app_id="m.http", app_display_name="HTTP Push Notifications", diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 4b8f889a71..b4241ceaf0 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -3047,12 +3047,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase): self.store.get_user_by_access_token(other_user_token) ) assert user_tuple is not None - token_id = user_tuple.token_id + device_id = user_tuple.device_id self.get_success( self.hs.get_pusherpool().add_or_update_pusher( user_id=self.other_user, - access_token=token_id, + device_id=device_id, kind="http", app_id="m.http", app_display_name="HTTP Push Notifications", -- cgit 1.5.1 From 5f7c9082805846cc07bfef2d48c6f6cfc9f723e9 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Fri, 24 Mar 2023 15:31:12 +0000 Subject: As an optimisation, use `TRUNCATE` on Postgres when clearing the user directory tables. (#15316) --- changelog.d/15316.misc | 1 + synapse/storage/databases/main/user_directory.py | 15 +++++++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) create mode 100644 changelog.d/15316.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/15316.misc b/changelog.d/15316.misc new file mode 100644 index 0000000000..1f408739f0 --- /dev/null +++ b/changelog.d/15316.misc @@ -0,0 +1 @@ +As an optimisation, use `TRUNCATE` on Postgres when clearing the user directory tables. \ No newline at end of file diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 97f09b73dd..9fced4b997 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -698,10 +698,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): """Delete the entire user directory""" def _delete_all_from_user_dir_txn(txn: LoggingTransaction) -> None: - txn.execute("DELETE FROM user_directory") - txn.execute("DELETE FROM user_directory_search") - txn.execute("DELETE FROM users_in_public_rooms") - txn.execute("DELETE FROM users_who_share_private_rooms") + # SQLite doesn't support TRUNCATE. + # On Postgres, DELETE FROM does a table scan but TRUNCATE is more efficient. + truncate = ( + "DELETE FROM" + if isinstance(self.database_engine, Sqlite3Engine) + else "TRUNCATE" + ) + txn.execute(f"{truncate} user_directory") + txn.execute(f"{truncate} user_directory_search") + txn.execute(f"{truncate} users_in_public_rooms") + txn.execute(f"{truncate} users_who_share_private_rooms") txn.call_after(self.get_user_in_directory.invalidate_all) await self.db_pool.runInteraction( -- cgit 1.5.1 From bd4d958aaf7c2123abb3665e7a7b199cf8ce27ee Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 28 Mar 2023 09:46:47 +0100 Subject: Bump ruff from 0.0.252 to 0.0.259 (#15328) * Bump ruff from 0.0.252 to 0.0.259 Bumps [ruff](https://github.com/charliermarsh/ruff) from 0.0.252 to 0.0.259. - [Release notes](https://github.com/charliermarsh/ruff/releases) - [Changelog](https://github.com/charliermarsh/ruff/blob/main/BREAKING_CHANGES.md) - [Commits](https://github.com/charliermarsh/ruff/compare/v0.0.252...v0.0.259) --- updated-dependencies: - dependency-name: ruff dependency-type: direct:development update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] * Fix new warnings * Mypy * Newsfile --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Erik Johnston --- changelog.d/15328.misc | 1 + poetry.lock | 38 +++++++++++++------------- pyproject.toml | 2 +- synapse/events/__init__.py | 4 +-- synapse/events/utils.py | 2 +- synapse/storage/database.py | 4 +-- synapse/storage/databases/main/events.py | 5 ++-- synapse/storage/databases/main/pusher.py | 2 +- synapse/storage/databases/main/stats.py | 14 ++++++++-- synapse/storage/databases/main/stream.py | 5 +++- tests/replication/slave/storage/test_events.py | 2 +- tests/server.py | 10 +++++-- 12 files changed, 54 insertions(+), 35 deletions(-) create mode 100644 changelog.d/15328.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/15328.misc b/changelog.d/15328.misc new file mode 100644 index 0000000000..e3e5953332 --- /dev/null +++ b/changelog.d/15328.misc @@ -0,0 +1 @@ +Bump ruff from 0.0.252 to 0.0.259. diff --git a/poetry.lock b/poetry.lock index 294ce49a8d..978a6e1598 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2323,29 +2323,29 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"] [[package]] name = "ruff" -version = "0.0.252" +version = "0.0.259" description = "An extremely fast Python linter, written in Rust." category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.0.252-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:349367a227c4db7abbc3a9993efea8a608b5bea4bb4a1e5fc6f0d56819524f92"}, - {file = "ruff-0.0.252-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:ce77f9106d96b4faf7865860fb5155b9deaf6f699d9c279118c5ad947739ecaf"}, - {file = "ruff-0.0.252-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edadb0b050293b4e60dab979ba6a4e734d9c899cbe316a0ee5b65e3cdd39c750"}, - {file = "ruff-0.0.252-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4efdae98937d1e4d23ab0b7fc7e8e6b6836cc7d2d42238ceeacbc793ef780542"}, - {file = "ruff-0.0.252-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c8546d879f7d3f669379a03e7b103d90e11901976ab508aeda59c03dfd8a359e"}, - {file = "ruff-0.0.252-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:83fdc7169b6c1fb5fe8d1cdf345697f558c1b433ef97df9ca11defa2a8f3ee9e"}, - {file = "ruff-0.0.252-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84ed9be1a17e2a556a571a5b959398633dd10910abd8dcf8b098061e746e892d"}, - {file = "ruff-0.0.252-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f5e77bd9ba4438cf2ee32154e2673afe22f538ef29f5d65ca47e3dc46c42cf8"}, - {file = "ruff-0.0.252-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a5179b94b45c0f8512eaff3ab304c14714a46df2e9ca72a9d96084adc376b71"}, - {file = "ruff-0.0.252-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:92efd8a71157595df5bc46aaaa0613d8a2fbc5cddc53ae7b749c16025c324732"}, - {file = "ruff-0.0.252-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:fd350fc10832cfd28e681d829a8aa83ea3e653326e0ea9d98637dfb8d46177d2"}, - {file = "ruff-0.0.252-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f119240c9631216e846166e06023b1d878e25fbac93bf20da50069e91cfbfaee"}, - {file = "ruff-0.0.252-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5c5a49f89f5ede93d16eddfeeadd7e5739ec703e8f63ac95eac30236b9e49da3"}, - {file = "ruff-0.0.252-py3-none-win32.whl", hash = "sha256:89a897dc743f2fe063483ea666097e72e848f4bbe40493fe0533e61799959f6e"}, - {file = "ruff-0.0.252-py3-none-win_amd64.whl", hash = "sha256:cdc89ad6ff88519b1fb1816ac82a9ad910762c90ff5fd64dda7691b72d36aff7"}, - {file = "ruff-0.0.252-py3-none-win_arm64.whl", hash = "sha256:4b594a17cf53077165429486650658a0e1b2ac6ab88954f5afd50d2b1b5657a9"}, - {file = "ruff-0.0.252.tar.gz", hash = "sha256:6992611ab7bdbe7204e4831c95ddd3febfeece2e6f5e44bbed044454c7db0f63"}, + {file = "ruff-0.0.259-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:f3938dc45e2a3f818e9cbd53007265c22246fbfded8837b2c563bf0ebde1a226"}, + {file = "ruff-0.0.259-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:22e1e35bf5f12072cd644d22afd9203641ccf258bc14ff91aa1c43dc14f6047d"}, + {file = "ruff-0.0.259-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2fb20e89e85d147c85caa807707a1488bccc1f3854dc3d53533e89b52a0c5ff"}, + {file = "ruff-0.0.259-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:49e903bcda19f6bb0725a962c058eb5d61f40d84ef52ed53b61939b69402ab4e"}, + {file = "ruff-0.0.259-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:71f0ef1985e9a6696fa97da8459917fa34bdaa2c16bd33bd5edead585b7d44f7"}, + {file = "ruff-0.0.259-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:7cfef26619cba184d59aa7fa17b48af5891d51fc0b755a9bc533478a10d4d066"}, + {file = "ruff-0.0.259-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:79b02fa17ec1fd8d306ae302cb47fb614b71e1f539997858243769bcbe78c6d9"}, + {file = "ruff-0.0.259-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:428507fb321b386dda70d66cd1a8aa0abf51d7c197983d83bb9e4fa5ee60300b"}, + {file = "ruff-0.0.259-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5fbaea9167f1852757f02133e5daacdb8c75b3431343205395da5b10499927a"}, + {file = "ruff-0.0.259-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:40ae87f2638484b7e8a7567b04a7af719f1c484c5bf132038b702bb32e1f6577"}, + {file = "ruff-0.0.259-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:29e2b77b7d5da6a7dd5cf9b738b511355c5734ece56f78e500d4b5bffd58c1a0"}, + {file = "ruff-0.0.259-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5b3c1beacf6037e7f0781d4699d9a2dd4ba2462f475be5b1f45cf84c4ba3c69d"}, + {file = "ruff-0.0.259-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:daaea322e7e85f4c13d82be9536309e1c4b8b9851bb0cbc7eeb15d490fd46bf9"}, + {file = "ruff-0.0.259-py3-none-win32.whl", hash = "sha256:38704f151323aa5858370a2f792e122cc25e5d1aabe7d42ceeab83da18f0b456"}, + {file = "ruff-0.0.259-py3-none-win_amd64.whl", hash = "sha256:aa9449b898287e621942cc71b9327eceb8f0c357e4065fecefb707ef2d978df8"}, + {file = "ruff-0.0.259-py3-none-win_arm64.whl", hash = "sha256:e4f39e18702de69faaaee3969934b92d7467285627f99a5b6ecd55a7d9f5d086"}, + {file = "ruff-0.0.259.tar.gz", hash = "sha256:8b56496063ab3bfdf72339a5fbebb8bd46e5c5fee25ef11a9f03b208fa0562ec"}, ] [[package]] @@ -3426,4 +3426,4 @@ user-search = ["pyicu"] [metadata] lock-version = "2.0" python-versions = "^3.7.1" -content-hash = "0a1dd4be3dff3c8cc71bd57a4eb48e1d92f155db7230e61fbb54f8af03619509" +content-hash = "102eed4faa13eab195555ea070f235acd1e3f0ff9cf028afcac6c51b3e409071" diff --git a/pyproject.toml b/pyproject.toml index b04edb611d..9a6306ee70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -311,7 +311,7 @@ all = [ # We pin black so that our tests don't start failing on new releases. isort = ">=5.10.1" black = ">=22.3.0" -ruff = "0.0.252" +ruff = "0.0.259" # Typechecking mypy = "*" diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 91118a8d84..d475fe7ae5 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -462,7 +462,7 @@ class FrozenEvent(EventBase): # Signatures is a dict of dicts, and this is faster than doing a # copy.deepcopy signatures = { - name: {sig_id: sig for sig_id, sig in sigs.items()} + name: dict(sigs.items()) for name, sigs in event_dict.pop("signatures", {}).items() } @@ -510,7 +510,7 @@ class FrozenEventV2(EventBase): # Signatures is a dict of dicts, and this is faster than doing a # copy.deepcopy signatures = { - name: {sig_id: sig for sig_id, sig in sigs.items()} + name: dict(sigs.items()) for name, sigs in event_dict.pop("signatures", {}).items() } diff --git a/synapse/events/utils.py b/synapse/events/utils.py index e41c7a4b83..c14c7791db 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -355,7 +355,7 @@ def serialize_event( time_now_ms = int(time_now_ms) # Should this strip out None's? - d = {k: v for k, v in e.get_dict().items()} + d = dict(e.get_dict().items()) d["event_id"] = e.event_id diff --git a/synapse/storage/database.py b/synapse/storage/database.py index fec4ae5b97..226ccc1671 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1504,8 +1504,8 @@ class DatabasePool: self.engine.lock_table(txn, "user_ips") for keyv, valv in zip(key_values, value_values): - _keys = {x: y for x, y in zip(key_names, keyv)} - _vals = {x: y for x, y in zip(value_names, valv)} + _keys = dict(zip(key_names, keyv)) + _vals = dict(zip(value_names, valv)) self.simple_upsert_txn_emulated(txn, table, _keys, _vals, lock=False) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 193959b250..ccd9f9d141 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -27,6 +27,7 @@ from typing import ( Optional, Set, Tuple, + cast, ) import attr @@ -1348,9 +1349,7 @@ class PersistEventsStore: [event.event_id for event, _ in events_and_contexts], ) - have_persisted: Dict[str, bool] = { - event_id: outlier for event_id, outlier in txn - } + have_persisted = dict(cast(Iterable[Tuple[str, bool]], txn)) logger.debug( "_update_outliers_txn: events=%s have_persisted=%s", diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index ab76b754e0..aeb6034f46 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -518,7 +518,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore): def set_device_id_for_pushers_txn(txn: LoggingTransaction) -> int: txn.execute( """ - SELECT + SELECT p.id AS pusher_id, p.device_id AS pusher_device_id, at.device_id AS token_device_id diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index d3393d8e49..97c4dc2603 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -16,7 +16,17 @@ import logging from enum import Enum from itertools import chain -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, + cast, +) from typing_extensions import Counter @@ -523,7 +533,7 @@ class StatsStore(StateDeltasStore): """, (room_id,), ) - membership_counts = {membership: cnt for membership, cnt in txn} + membership_counts = dict(cast(Iterable[Tuple[str, int]], txn)) txn.execute( """ diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 2b8779bbb8..92cbe262a6 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -41,6 +41,7 @@ from typing import ( Any, Collection, Dict, + Iterable, List, Optional, Set, @@ -1343,7 +1344,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): GROUP BY type """ txn.execute(sql) - min_positions = {typ: pos for typ, pos in txn} # Map from type -> min position + min_positions = dict( + cast(Iterable[Tuple[str, int]], txn) + ) # Map from type -> min position # Ensure we do actually have some values here assert set(min_positions) == {"federation", "events"} diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 57c781a0c3..b2125b1fea 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -412,7 +412,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): self.get_success( self.master_store.add_push_actions_to_staging( event.event_id, - {user_id: actions for user_id, actions in push_actions}, + dict(push_actions), False, "main", ) diff --git a/tests/server.py b/tests/server.py index 5de9722766..bb059630fa 100644 --- a/tests/server.py +++ b/tests/server.py @@ -983,7 +983,9 @@ def setup_test_homeserver( dropped = True except psycopg2.OperationalError as e: warnings.warn( - "Couldn't drop old db: " + str(e), category=UserWarning + "Couldn't drop old db: " + str(e), + category=UserWarning, + stacklevel=2, ) time.sleep(0.5) @@ -991,7 +993,11 @@ def setup_test_homeserver( db_conn.close() if not dropped: - warnings.warn("Failed to drop old DB.", category=UserWarning) + warnings.warn( + "Failed to drop old DB.", + category=UserWarning, + stacklevel=2, + ) if not LEAVE_DB: # Register the cleanup hook -- cgit 1.5.1 From 5282ba1e2bbff2635dc09aec45fd42a56c1a4545 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 28 Mar 2023 14:26:27 -0400 Subject: Implement MSC3983 to proxy /keys/claim queries to appservices. (#15314) Experimental support for MSC3983 is behind a configuration flag. If enabled, for users which are exclusively owned by an application service then the appservice will be queried for one-time keys *if* there are none uploaded to Synapse. --- changelog.d/15314.feature | 1 + synapse/appservice/api.py | 56 +++++++++++++++++ synapse/config/experimental.py | 5 ++ synapse/federation/federation_server.py | 20 +++--- synapse/handlers/appservice.py | 74 +++++++++++++++++++++- synapse/handlers/e2e_keys.py | 57 ++++++++++++++--- synapse/storage/databases/main/end_to_end_keys.py | 36 ++++++++--- tests/appservice/test_api.py | 59 ++++++++++++++++++ tests/handlers/test_e2e_keys.py | 76 ++++++++++++++++++++++- 9 files changed, 355 insertions(+), 29 deletions(-) create mode 100644 changelog.d/15314.feature (limited to 'synapse/storage/databases') diff --git a/changelog.d/15314.feature b/changelog.d/15314.feature new file mode 100644 index 0000000000..68b289b0cc --- /dev/null +++ b/changelog.d/15314.feature @@ -0,0 +1 @@ +Experimental support for passing One Time Key requests to application services ([MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983)). diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 4812fb4496..51ee0e79df 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -388,6 +388,62 @@ class ApplicationServiceApi(SimpleHttpClient): failed_transactions_counter.labels(service.id).inc() return False + async def claim_client_keys( + self, service: "ApplicationService", query: List[Tuple[str, str, str]] + ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]: + """Claim one time keys from an application service. + + Args: + query: An iterable of tuples of (user ID, device ID, algorithm). + + Returns: + A tuple of: + A map of user ID -> a map device ID -> a map of key ID -> JSON dict. + + A copy of the input which has not been fulfilled because the + appservice doesn't support this endpoint or has not returned + data for that tuple. + """ + if service.url is None: + return {}, query + + # This is required by the configuration. + assert service.hs_token is not None + + # Create the expected payload shape. + body: Dict[str, Dict[str, List[str]]] = {} + for user_id, device, algorithm in query: + body.setdefault(user_id, {}).setdefault(device, []).append(algorithm) + + uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim" + try: + response = await self.post_json_get_json( + uri, + body, + headers={"Authorization": [f"Bearer {service.hs_token}"]}, + ) + except CodeMessageException as e: + # The appservice doesn't support this endpoint. + if e.code == 404 or e.code == 405: + return {}, query + logger.warning("claim_keys to %s received %s", uri, e.code) + return {}, query + except Exception as ex: + logger.warning("claim_keys to %s threw exception %s", uri, ex) + return {}, query + + # Check if the appservice fulfilled all of the queried user/device/algorithms + # or if some are still missing. + # + # TODO This places a lot of faith in the response shape being correct. + missing = [ + (user_id, device, algorithm) + for user_id, device, algorithm in query + if algorithm not in response.get(user_id, {}).get(device, []) + ] + + return response, missing + def _serialize( self, service: "ApplicationService", events: Iterable[EventBase] ) -> List[JsonDict]: diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 99dcd27c74..53e6fc2b54 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -74,6 +74,11 @@ class ExperimentalConfig(Config): "msc3202_transaction_extensions", False ) + # MSC3983: Proxying OTK claim requests to exclusive ASes. + self.msc3983_appservice_otk_claims: bool = experimental.get( + "msc3983_appservice_otk_claims", False + ) + # MSC3706 (server-side support for partial state in /send_join responses) # Synapse will always serve partial state responses to requests using the stable # query parameter `omit_members`. If this flag is set, Synapse will also serve diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 6d99845de5..64e99292ec 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -86,7 +86,7 @@ from synapse.storage.databases.main.lock import Lock from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary from synapse.storage.roommember import MemberSummary from synapse.types import JsonDict, StateMap, get_domain_from_id -from synapse.util import json_decoder, unwrapFirstError +from synapse.util import unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import parse_server_name @@ -135,6 +135,7 @@ class FederationServer(FederationBase): self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() self._room_member_handler = hs.get_room_member_handler() + self._e2e_keys_handler = hs.get_e2e_keys_handler() self._state_storage_controller = hs.get_storage_controllers().state @@ -1012,15 +1013,14 @@ class FederationServer(FederationBase): query.append((user_id, device_id, algorithm)) log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) - results = await self.store.claim_e2e_one_time_keys(query) - - json_result: Dict[str, Dict[str, dict]] = {} - for user_id, device_keys in results.items(): - for device_id, keys in device_keys.items(): - for key_id, json_str in keys.items(): - json_result.setdefault(user_id, {})[device_id] = { - key_id: json_decoder.decode(json_str) - } + results = await self._e2e_keys_handler.claim_local_one_time_keys(query) + + json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} + for result in results: + for user_id, device_keys in result.items(): + for device_id, keys in device_keys.items(): + for key_id, key in keys.items(): + json_result.setdefault(user_id, {})[device_id] = {key_id: key} logger.info( "Claimed one-time-keys: %s", diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index ec3ab968e9..953df4d9cd 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -12,7 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) from prometheus_client import Counter @@ -829,3 +838,66 @@ class ApplicationServicesHandler: if unknown_user: return await self.query_user_exists(user_id) return True + + async def claim_e2e_one_time_keys( + self, query: Iterable[Tuple[str, str, str]] + ) -> Tuple[ + Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]] + ]: + """Claim one time keys from application services. + + Args: + query: An iterable of tuples of (user ID, device ID, algorithm). + + Returns: + A tuple of: + An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes. + + A copy of the input which has not been fulfilled (either because + they are not appservice users or the appservice does not support + providing OTKs). + """ + services = self.store.get_app_services() + + # Partition the users by appservice. + query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {} + missing = [] + for user_id, device, algorithm in query: + if not self.store.get_if_app_services_interested_in_user(user_id): + missing.append((user_id, device, algorithm)) + continue + + # Find the associated appservice. + for service in services: + if service.is_exclusive_user(user_id): + query_by_appservice.setdefault(service.id, []).append( + (user_id, device, algorithm) + ) + continue + + # Query each service in parallel. + results = await make_deferred_yieldable( + defer.DeferredList( + [ + run_in_background( + self.appservice_api.claim_client_keys, + # We know this must be an app service. + self.store.get_app_service_by_id(service_id), # type: ignore[arg-type] + service_query, + ) + for service_id, service_query in query_by_appservice.items() + ], + consumeErrors=True, + ) + ) + + # Patch together the results -- they are all independent (since they + # require exclusive control over the users). They get returned as a list + # and the caller combines them. + claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = [] + for success, result in results: + if success: + claimed_keys.append(result[0]) + missing.extend(result[1]) + + return claimed_keys, missing diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 4e9c8d8db0..9e7c2c45b5 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple @@ -53,6 +52,7 @@ class E2eKeysHandler: self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.device_handler = hs.get_device_handler() + self._appservice_handler = hs.get_application_service_handler() self.is_mine = hs.is_mine self.clock = hs.get_clock() @@ -88,6 +88,10 @@ class E2eKeysHandler: max_count=10, ) + self._query_appservices_for_otks = ( + hs.config.experimental.msc3983_appservice_otk_claims + ) + @trace @cancellable async def query_devices( @@ -542,6 +546,42 @@ class E2eKeysHandler: return ret + async def claim_local_one_time_keys( + self, local_query: List[Tuple[str, str, str]] + ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]: + """Claim one time keys for local users. + + 1. Attempt to claim OTKs from the database. + 2. Ask application services if they provide OTKs. + 3. Attempt to fetch fallback keys from the database. + + Args: + local_query: An iterable of tuples of (user ID, device ID, algorithm). + + Returns: + An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes. + """ + + otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query) + + # If the application services have not provided any keys via the C-S + # API, query it directly for one-time keys. + if self._query_appservices_for_otks: + ( + appservice_results, + not_found, + ) = await self._appservice_handler.claim_e2e_one_time_keys(not_found) + else: + appservice_results = [] + + # For each user that does not have a one-time keys available, see if + # there is a fallback key. + fallback_results = await self.store.claim_e2e_fallback_keys(not_found) + + # Return the results in order, each item from the input query should + # only appear once in the combined list. + return (otk_results, *appservice_results, fallback_results) + @trace async def claim_one_time_keys( self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int] @@ -561,17 +601,18 @@ class E2eKeysHandler: set_tag("local_key_query", str(local_query)) set_tag("remote_key_query", str(remote_queries)) - results = await self.store.claim_e2e_one_time_keys(local_query) + results = await self.claim_local_one_time_keys(local_query) # A map of user ID -> device ID -> key ID -> key. json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} + for result in results: + for user_id, device_keys in result.items(): + for device_id, keys in device_keys.items(): + for key_id, key in keys.items(): + json_result.setdefault(user_id, {})[device_id] = {key_id: key} + + # Remote failures. failures: Dict[str, JsonDict] = {} - for user_id, device_keys in results.items(): - for device_id, keys in device_keys.items(): - for key_id, json_str in keys.items(): - json_result.setdefault(user_id, {})[device_id] = { - key_id: json_decoder.decode(json_str) - } @trace async def claim_client_keys(destination: str) -> None: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index a3b6c8ae8e..dc7768c50c 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -51,7 +51,7 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict -from synapse.util import json_encoder +from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter @@ -1028,14 +1028,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker async def claim_e2e_one_time_keys( self, query_list: Iterable[Tuple[str, str, str]] - ) -> Dict[str, Dict[str, Dict[str, str]]]: + ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]: """Take a list of one time keys out of the database. Args: query_list: An iterable of tuples of (user ID, device ID, algorithm). Returns: - A map of user ID -> a map device ID -> a map of key ID -> JSON bytes. + A tuple pf: + A map of user ID -> a map device ID -> a map of key ID -> JSON. + + A copy of the input which has not been fulfilled. """ @trace @@ -1115,7 +1118,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker key_id, key_json = otk_row return f"{algorithm}:{key_id}", key_json - results: Dict[str, Dict[str, Dict[str, str]]] = {} + results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} + missing: List[Tuple[str, str, str]] = [] for user_id, device_id, algorithm in query_list: if self.database_engine.supports_returning: # If we support RETURNING clause we can use a single query that @@ -1138,11 +1142,25 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker device_results = results.setdefault(user_id, {}).setdefault( device_id, {} ) - device_results[claim_row[0]] = claim_row[1] - continue + device_results[claim_row[0]] = json_decoder.decode(claim_row[1]) + else: + missing.append((user_id, device_id, algorithm)) + + return results, missing + + async def claim_e2e_fallback_keys( + self, query_list: Iterable[Tuple[str, str, str]] + ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]: + """Take a list of fallback keys out of the database. - # No one-time key available, so see if there's a fallback - # key + Args: + query_list: An iterable of tuples of (user ID, device ID, algorithm). + + Returns: + A map of user ID -> a map device ID -> a map of key ID -> JSON. + """ + results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} + for user_id, device_id, algorithm in query_list: row = await self.db_pool.simple_select_one( table="e2e_fallback_keys_json", keyvalues={ @@ -1179,7 +1197,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ) device_results = results.setdefault(user_id, {}).setdefault(device_id, {}) - device_results[f"{algorithm}:{key_id}"] = key_json + device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json) return results diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py index 9d183b733e..0dd02b7d58 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py @@ -105,3 +105,62 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): ) self.assertEqual(self.request_url, URL_LOCATION) self.assertEqual(result, SUCCESS_RESULT_LOCATION) + + def test_claim_keys(self) -> None: + """ + Tests that the /keys/claim response is properly parsed for missing + keys. + """ + + RESPONSE: JsonDict = { + "@alice:example.org": { + "DEVICE_1": { + "signed_curve25519:AAAAHg": { + # We don't really care about the content of the keys, + # they get passed back transparently. + }, + "signed_curve25519:BBBBHg": {}, + }, + "DEVICE_2": {"signed_curve25519:CCCCHg": {}}, + }, + } + + async def post_json_get_json( + uri: str, + post_json: Any, + headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]], + ) -> JsonDict: + # Ensure the access token is passed as both a header and query arg. + if not headers.get("Authorization"): + raise RuntimeError("Access token not provided") + + self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"]) + return RESPONSE + + # We assign to a method, which mypy doesn't like. + self.api.post_json_get_json = Mock(side_effect=post_json_get_json) # type: ignore[assignment] + + MISSING_KEYS = [ + # Known user, known device, missing algorithm. + ("@alice:example.org", "DEVICE_1", "signed_curve25519:DDDDHg"), + # Known user, missing device. + ("@alice:example.org", "DEVICE_3", "signed_curve25519:EEEEHg"), + # Unknown user. + ("@bob:example.org", "DEVICE_4", "signed_curve25519:FFFFHg"), + ] + + claimed_keys, missing = self.get_success( + self.api.claim_client_keys( + self.service, + [ + # Found devices + ("@alice:example.org", "DEVICE_1", "signed_curve25519:AAAAHg"), + ("@alice:example.org", "DEVICE_1", "signed_curve25519:BBBBHg"), + ("@alice:example.org", "DEVICE_2", "signed_curve25519:CCCCHg"), + ] + + MISSING_KEYS, + ) + ) + + self.assertEqual(claimed_keys, RESPONSE) + self.assertEqual(missing, MISSING_KEYS) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 6b4cba65d0..4ff04fc66b 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -23,18 +23,24 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.errors import Codes, SynapseError +from synapse.appservice import ApplicationService from synapse.handlers.device import DeviceHandler from synapse.server import HomeServer +from synapse.storage.databases.main.appservice import _make_exclusive_regex from synapse.types import JsonDict from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable +from tests.unittest import override_config class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - return self.setup_test_homeserver(federation_client=mock.Mock()) + self.appservice_api = mock.Mock() + return self.setup_test_homeserver( + federation_client=mock.Mock(), application_service_api=self.appservice_api + ) def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_e2e_keys_handler() @@ -941,3 +947,71 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): # The two requests to the local homeserver should be identical. self.assertEqual(response_1, response_2) + + @override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}}) + def test_query_appservice(self) -> None: + local_user = "@boris:" + self.hs.hostname + device_id_1 = "xyz" + fallback_key = {"alg1:k1": "fallback_key1"} + device_id_2 = "abc" + otk = {"alg1:k2": "key2"} + + # Inject an appservice interested in this user. + appservice = ApplicationService( + token="i_am_an_app_service", + id="1234", + namespaces={"users": [{"regex": r"@boris:*", "exclusive": True}]}, + # Note: this user does not have to match the regex above + sender="@as_main:test", + ) + self.hs.get_datastores().main.services_cache = [appservice] + self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex( + [appservice] + ) + + # Setup a response, but only for device 2. + self.appservice_api.claim_client_keys.return_value = make_awaitable( + ({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1")]) + ) + + # we shouldn't have any unused fallback keys yet + res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) + ) + self.assertEqual(res, []) + + self.get_success( + self.handler.upload_keys_for_user( + local_user, + device_id_1, + {"fallback_keys": fallback_key}, + ) + ) + + # we should now have an unused alg1 key + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) + ) + self.assertEqual(fallback_res, ["alg1"]) + + # claiming an OTK when no OTKs are available should ask the appservice, then + # query the fallback keys. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + { + "one_time_keys": { + local_user: {device_id_1: "alg1", device_id_2: "alg1"} + } + }, + timeout=None, + ) + ) + self.assertEqual( + claim_res, + { + "failures": {}, + "one_time_keys": { + local_user: {device_id_1: fallback_key, device_id_2: otk} + }, + }, + ) -- cgit 1.5.1 From 78cdb72cd6b0e007c314d9fed9f629dfc5b937a6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 29 Mar 2023 12:07:14 +0100 Subject: Delete stale non-e2e devices for users, take 3 (#15183) This should help reduce the number of devices e.g. simple bots the repeatedly login rack up. We only delete non-e2e devices as they should be safe to delete, whereas if we delete e2e devices for a user we may accidentally break their ability to receive e2e keys for a message. --- changelog.d/15183.misc | 1 + synapse/handlers/device.py | 2 +- synapse/handlers/register.py | 50 ++++++++++++++++++- synapse/storage/databases/main/devices.py | 80 ++++++++++++++++++++++++++++++- tests/handlers/test_admin.py | 2 +- tests/handlers/test_device.py | 2 +- tests/storage/test_client_ips.py | 4 +- 7 files changed, 134 insertions(+), 7 deletions(-) create mode 100644 changelog.d/15183.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/15183.misc b/changelog.d/15183.misc new file mode 100644 index 0000000000..f9bfc581ad --- /dev/null +++ b/changelog.d/15183.misc @@ -0,0 +1 @@ +Prune user's old devices on login if they have too many. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 9ded6389ac..0fc165a8d6 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -485,7 +485,7 @@ class DeviceHandler(DeviceWorkerHandler): device_ids = [d for d in device_ids if d != except_device_id] await self.delete_devices(user_id, device_ids) - async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: + async def delete_devices(self, user_id: str, device_ids: StrCollection) -> None: """Delete several devices Args: diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c8bf2439af..bb1df1e60f 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -16,7 +16,7 @@ """Contains functions for registering clients.""" import logging -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple from prometheus_client import Counter from typing_extensions import TypedDict @@ -40,6 +40,7 @@ from synapse.appservice import ApplicationService from synapse.config.server import is_threepid_reserved from synapse.handlers.device import DeviceHandler from synapse.http.servlet import assert_params_in_dict +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.login import RegisterDeviceReplicationServlet from synapse.replication.http.register import ( ReplicationPostRegisterActionsServlet, @@ -48,6 +49,7 @@ from synapse.replication.http.register import ( from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import RoomAlias, UserID, create_requester from synapse.types.state import StateFilter +from synapse.util.iterutils import batch_iter if TYPE_CHECKING: from synapse.server import HomeServer @@ -110,6 +112,10 @@ class RegistrationHandler: self._server_notices_mxid = hs.config.servernotices.server_notices_mxid self._server_name = hs.hostname + # The set of users that we're currently pruning devices for. Ensures + # that we don't have two such jobs for the same user running at once. + self._currently_pruning_devices_for_users: Set[str] = set() + self.spam_checker = hs.get_spam_checker() if hs.config.worker.worker_app: @@ -121,7 +127,10 @@ class RegistrationHandler: ReplicationPostRegisterActionsServlet.make_client(hs) ) else: - self.device_handler = hs.get_device_handler() + device_handler = hs.get_device_handler() + assert isinstance(device_handler, DeviceHandler) + self.device_handler = device_handler + self._register_device_client = self.register_device_inner self.pusher_pool = hs.get_pusherpool() @@ -851,6 +860,9 @@ class RegistrationHandler: # This can only run on the main process. assert isinstance(self.device_handler, DeviceHandler) + # Prune the user's device list if they already have a lot of devices. + await self._maybe_prune_too_many_devices(user_id) + registered_device_id = await self.device_handler.check_device_registered( user_id, device_id, @@ -919,6 +931,40 @@ class RegistrationHandler: "refresh_token": refresh_token, } + async def _maybe_prune_too_many_devices(self, user_id: str) -> None: + """Delete any excess old devices this user may have.""" + + if user_id in self._currently_pruning_devices_for_users: + return + + # We also cap the number of users whose devices we prune at the same + # time, to avoid performance problems. + if len(self._currently_pruning_devices_for_users) > 5: + return + + device_ids = await self.store.check_too_many_devices_for_user(user_id) + if not device_ids: + return + + # Now spawn a background loop that deletes said devices. + async def _prune_too_many_devices_loop() -> None: + if user_id in self._currently_pruning_devices_for_users: + return + + self._currently_pruning_devices_for_users.add(user_id) + + try: + for batch in batch_iter(device_ids, 10): + await self.device_handler.delete_devices(user_id, batch) + + await self.clock.sleep(60) + finally: + self._currently_pruning_devices_for_users.discard(user_id) + + run_as_background_process( + "_prune_too_many_devices_loop", _prune_too_many_devices_loop + ) + async def post_registration_actions( self, user_id: str, auth_result: dict, access_token: Optional[str] ) -> None: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 5503621ad6..7647cda2c6 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1599,6 +1599,73 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): return rows + async def check_too_many_devices_for_user(self, user_id: str) -> List[str]: + """Check if the user has a lot of devices, and if so return the set of + devices we can prune. + + This does *not* return hidden devices or devices with E2E keys. + """ + + num_devices = await self.db_pool.simple_select_one_onecol( + table="devices", + keyvalues={"user_id": user_id, "hidden": False}, + retcol="COALESCE(COUNT(*), 0)", + desc="count_devices", + ) + + # We let users have up to ten devices without pruning. + if num_devices <= 10: + return [] + + # We always prune devices not seen in the last 14 days... + max_last_seen = self._clock.time_msec() - 14 * 24 * 60 * 60 * 1000 + + # ... but we also cap the maximum number of devices the user can have to + # 50. + if num_devices > 50: + # Choose a last seen that ensures we keep at most 50 devices. + sql = """ + SELECT last_seen FROM devices + LEFT JOIN e2e_device_keys_json USING (user_id, device_id) + WHERE + user_id = ? + AND NOT hidden + AND last_seen IS NOT NULL + AND key_json IS NULL + ORDER BY last_seen DESC + LIMIT 1 + OFFSET 50 + """ + + rows = await self.db_pool.execute( + "check_too_many_devices_for_user_last_seen", None, sql, (user_id,) + ) + if rows: + max_last_seen = max(rows[0][0], max_last_seen) + + # Fetch the devices to delete. + sql = """ + SELECT DISTINCT device_id FROM devices + LEFT JOIN e2e_device_keys_json USING (user_id, device_id) + WHERE + user_id = ? + AND NOT hidden + AND last_seen < ? + AND key_json IS NULL + ORDER BY last_seen + """ + + def check_too_many_devices_for_user_txn( + txn: LoggingTransaction, + ) -> List[str]: + txn.execute(sql, (user_id, max_last_seen)) + return [device_id for device_id, in txn] + + return await self.db_pool.runInteraction( + "check_too_many_devices_for_user", + check_too_many_devices_for_user_txn, + ) + class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Because we have write access, this will be a StreamIdGenerator @@ -1657,6 +1724,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values={}, insertion_values={ "display_name": initial_device_display_name, + "last_seen": self._clock.time_msec(), "hidden": False, }, desc="store_device", @@ -1702,7 +1770,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) raise StoreError(500, "Problem storing device.") - async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: + @cached(max_entries=0) + async def delete_device(self, user_id: str, device_id: str) -> None: + raise NotImplementedError() + + # Note: sometimes deleting rows out of `device_inbox` can take a long time, + # so we use a cache so that we deduplicate in flight requests to delete + # devices. + @cachedList(cached_method_name="delete_device", list_name="device_ids") + async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> dict: """Deletes several devices. Args: @@ -1739,6 +1815,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) + return {} + async def update_device( self, user_id: str, device_id: str, new_display_name: Optional[str] = None ) -> None: diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py index 5569ccef8a..f0ba3775c8 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py @@ -272,7 +272,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): self.assertIn("device_id", args[0][0]) self.assertIsNone(args[0][0]["display_name"]) self.assertIsNone(args[0][0]["last_seen_user_agent"]) - self.assertIsNone(args[0][0]["last_seen_ts"]) + self.assertEqual(args[0][0]["last_seen_ts"], 600) self.assertIsNone(args[0][0]["last_seen_ip"]) def test_connections(self) -> None: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index ce7525e29c..a456bffd63 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -115,7 +115,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): "device_id": "xyz", "display_name": "display 0", "last_seen_ip": None, - "last_seen_ts": None, + "last_seen_ts": 1000000, }, device_map["xyz"], ) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index cd0079871c..f989986538 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -170,6 +170,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) ) + last_seen = self.clock.time_msec() + if after_persisting: # Trigger the storage loop self.reactor.advance(10) @@ -190,7 +192,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): "device_id": device_id, "ip": None, "user_agent": None, - "last_seen": None, + "last_seen": last_seen, }, ], ) -- cgit 1.5.1