diff --git a/Cargo.lock b/Cargo.lock
index 59d2aec215..6e97fb8fb1 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -323,18 +323,18 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "serde"
-version = "1.0.148"
+version = "1.0.150"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e53f64bb4ba0191d6d0676e1b141ca55047d83b74f5607e6d8eb88126c52c2dc"
+checksum = "e326c9ec8042f1b5da33252c8a37e9ffbd2c9bef0155215b6e6c80c790e05f91"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
-version = "1.0.148"
+version = "1.0.150"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a55492425aa53521babf6137309e7d34c20bbfbbfcfe2c7f3a047fd1f6b92c0c"
+checksum = "42a3df25b0713732468deadad63ab9da1f1fd75a48a15024b50363f128db627e"
dependencies = [
"proc-macro2",
"quote",
diff --git a/changelog.d/14464.feature b/changelog.d/14464.feature
new file mode 100644
index 0000000000..688ea32117
--- /dev/null
+++ b/changelog.d/14464.feature
@@ -0,0 +1 @@
+Improve user search for international display names.
diff --git a/changelog.d/14595.misc b/changelog.d/14595.misc
deleted file mode 100644
index f9bfc581ad..0000000000
--- a/changelog.d/14595.misc
+++ /dev/null
@@ -1 +0,0 @@
-Prune user's old devices on login if they have too many.
diff --git a/changelog.d/14646.misc b/changelog.d/14646.misc
new file mode 100644
index 0000000000..d44571b731
--- /dev/null
+++ b/changelog.d/14646.misc
@@ -0,0 +1 @@
+Add missing type hints.
diff --git a/changelog.d/14649.misc b/changelog.d/14649.misc
deleted file mode 100644
index f9bfc581ad..0000000000
--- a/changelog.d/14649.misc
+++ /dev/null
@@ -1 +0,0 @@
-Prune user's old devices on login if they have too many.
diff --git a/changelog.d/14650.bugfix b/changelog.d/14650.bugfix
new file mode 100644
index 0000000000..5e18641bf7
--- /dev/null
+++ b/changelog.d/14650.bugfix
@@ -0,0 +1,2 @@
+Fix a bug introduced in Synapse 1.72.0 where the background updates to add non-thread unique indexes on receipts would fail if they were previously interrupted.
+
diff --git a/changelog.d/14656.misc b/changelog.d/14656.misc
new file mode 100644
index 0000000000..9725bb6187
--- /dev/null
+++ b/changelog.d/14656.misc
@@ -0,0 +1 @@
+Bump flake8-bugbear from 22.10.27 to 22.12.6.
diff --git a/changelog.d/14657.misc b/changelog.d/14657.misc
new file mode 100644
index 0000000000..3964488f88
--- /dev/null
+++ b/changelog.d/14657.misc
@@ -0,0 +1 @@
+Bump packaging from 21.3 to 22.0.
diff --git a/changelog.d/14658.misc b/changelog.d/14658.misc
new file mode 100644
index 0000000000..9dc62a8ceb
--- /dev/null
+++ b/changelog.d/14658.misc
@@ -0,0 +1 @@
+Bump types-pillow from 9.3.0.1 to 9.3.0.4.
diff --git a/changelog.d/14659.misc b/changelog.d/14659.misc
new file mode 100644
index 0000000000..70cf6c9c4d
--- /dev/null
+++ b/changelog.d/14659.misc
@@ -0,0 +1 @@
+Bump serde from 1.0.148 to 1.0.150.
diff --git a/changelog.d/14660.misc b/changelog.d/14660.misc
new file mode 100644
index 0000000000..541f98bd93
--- /dev/null
+++ b/changelog.d/14660.misc
@@ -0,0 +1 @@
+Bump phonenumbers from 8.13.1 to 8.13.2.
diff --git a/changelog.d/14661.misc b/changelog.d/14661.misc
new file mode 100644
index 0000000000..25d3b6fe61
--- /dev/null
+++ b/changelog.d/14661.misc
@@ -0,0 +1 @@
+Bump authlib from 1.1.0 to 1.2.0.
diff --git a/changelog.d/14662.removal b/changelog.d/14662.removal
new file mode 100644
index 0000000000..19a387bbb4
--- /dev/null
+++ b/changelog.d/14662.removal
@@ -0,0 +1 @@
+(remove from changelog: unreleased) Revert the deletion of stale devices due to performance issues.
\ No newline at end of file
diff --git a/debian/changelog b/debian/changelog
index 163b7210bf..5d3c4f7d6b 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,10 @@
+matrix-synapse-py3 (1.74.0~rc1) UNRELEASED; urgency=medium
+
+ * New dependency on libicu-dev to provide improved results for user
+ search.
+
+ -- Synapse Packaging team <packages@matrix.org> Tue, 06 Dec 2022 15:28:10 +0000
+
matrix-synapse-py3 (1.73.0) stable; urgency=medium
* New Synapse release 1.73.0.
diff --git a/debian/control b/debian/control
index 86f5a66d02..bc628cec08 100644
--- a/debian/control
+++ b/debian/control
@@ -8,6 +8,8 @@ Build-Depends:
dh-virtualenv (>= 1.1),
libsystemd-dev,
libpq-dev,
+ libicu-dev,
+ pkg-config,
lsb-release,
python3-dev,
python3,
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 185d5bc3d4..7e5123210a 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -97,6 +97,8 @@ RUN \
zlib1g-dev \
git \
curl \
+ libicu-dev \
+ pkg-config \
&& rm -rf /var/lib/apt/lists/*
diff --git a/docker/Dockerfile-dhvirtualenv b/docker/Dockerfile-dhvirtualenv
index 73165f6f85..f3b5b00ce6 100644
--- a/docker/Dockerfile-dhvirtualenv
+++ b/docker/Dockerfile-dhvirtualenv
@@ -84,6 +84,8 @@ RUN apt-get update -qq -o Acquire::Languages=none \
python3-venv \
sqlite3 \
libpq-dev \
+ libicu-dev \
+ pkg-config \
xmlsec1
# Install rust and ensure it's in the PATH
diff --git a/mypy.ini b/mypy.ini
index c3fbd1a955..a4a1e4511a 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -88,6 +88,9 @@ disallow_untyped_defs = False
[mypy-tests.*]
disallow_untyped_defs = False
+[mypy-tests.handlers.test_sso]
+disallow_untyped_defs = True
+
[mypy-tests.handlers.test_user_directory]
disallow_untyped_defs = True
@@ -103,16 +106,7 @@ disallow_untyped_defs = True
[mypy-tests.state.test_profile]
disallow_untyped_defs = True
-[mypy-tests.storage.test_id_generators]
-disallow_untyped_defs = True
-
-[mypy-tests.storage.test_profile]
-disallow_untyped_defs = True
-
-[mypy-tests.handlers.test_sso]
-disallow_untyped_defs = True
-
-[mypy-tests.storage.test_user_directory]
+[mypy-tests.storage.*]
disallow_untyped_defs = True
[mypy-tests.rest.*]
diff --git a/poetry.lock b/poetry.lock
index 1c10f0458a..6fd4bd5ba5 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -13,8 +13,8 @@ tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900
tests-no-zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"]
[[package]]
-name = "Authlib"
-version = "1.1.0"
+name = "authlib"
+version = "1.2.0"
description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients."
category = "main"
optional = true
@@ -260,7 +260,7 @@ pyflakes = ">=2.5.0,<2.6.0"
[[package]]
name = "flake8-bugbear"
-version = "22.10.27"
+version = "22.12.6"
description = "A plugin for flake8 finding likely bugs and design problems in your program. Contains warnings that don't belong in pyflakes and pycodestyle."
category = "dev"
optional = false
@@ -633,14 +633,11 @@ tests = ["Sphinx", "doubles", "flake8", "flake8-quotes", "gevent", "mock", "pyte
[[package]]
name = "packaging"
-version = "21.3"
+version = "22.0"
description = "Core utilities for Python packages"
category = "main"
optional = false
-python-versions = ">=3.6"
-
-[package.dependencies]
-pyparsing = ">=2.0.2,<3.0.5 || >3.0.5"
+python-versions = ">=3.7"
[[package]]
name = "parameterized"
@@ -663,7 +660,7 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
[[package]]
name = "phonenumbers"
-version = "8.13.1"
+version = "8.13.2"
description = "Python version of Google's common library for parsing, formatting, storing and validating international phone numbers."
category = "main"
optional = false
@@ -838,6 +835,14 @@ optional = false
python-versions = ">=3.5"
[[package]]
+name = "pyicu"
+version = "2.10.2"
+description = "Python extension wrapping the ICU C++ API"
+category = "main"
+optional = true
+python-versions = "*"
+
+[[package]]
name = "pyjwt"
version = "2.4.0"
description = "JSON Web Token implementation in Python"
@@ -902,17 +907,6 @@ docs = ["sphinx (!=5.2.0,!=5.2.0.post0)", "sphinx-rtd-theme"]
test = ["flaky", "pretend", "pytest (>=3.0.1)"]
[[package]]
-name = "pyparsing"
-version = "3.0.7"
-description = "Python parsing module"
-category = "main"
-optional = false
-python-versions = ">=3.6"
-
-[package.extras]
-diagrams = ["jinja2", "railroad-diagrams"]
-
-[[package]]
name = "pyrsistent"
version = "0.18.1"
description = "Persistent/Functional/Immutable data structures"
@@ -1440,7 +1434,7 @@ python-versions = "*"
[[package]]
name = "types-pillow"
-version = "9.3.0.1"
+version = "9.3.0.4"
description = "Typing stubs for Pillow"
category = "dev"
optional = false
@@ -1622,7 +1616,7 @@ docs = ["Sphinx", "repoze.sphinx.autointerface"]
test = ["zope.i18nmessageid", "zope.testing", "zope.testrunner"]
[extras]
-all = ["matrix-synapse-ldap3", "psycopg2", "psycopg2cffi", "psycopg2cffi-compat", "pysaml2", "authlib", "lxml", "sentry-sdk", "jaeger-client", "opentracing", "txredisapi", "hiredis", "Pympler"]
+all = ["matrix-synapse-ldap3", "psycopg2", "psycopg2cffi", "psycopg2cffi-compat", "pysaml2", "authlib", "lxml", "sentry-sdk", "jaeger-client", "opentracing", "txredisapi", "hiredis", "Pympler", "pyicu"]
cache-memory = ["Pympler"]
jwt = ["authlib"]
matrix-synapse-ldap3 = ["matrix-synapse-ldap3"]
@@ -1635,20 +1629,21 @@ sentry = ["sentry-sdk"]
systemd = ["systemd-python"]
test = ["parameterized", "idna"]
url-preview = ["lxml"]
+user-search = ["pyicu"]
[metadata]
lock-version = "1.1"
python-versions = "^3.7.1"
-content-hash = "8c44ceeb9df5c3ab43040400e0a6b895de49417e61293a1ba027640b34f03263"
+content-hash = "f20007013f33bc35a01e412c48adc62a936030f3074e06286674c5ad7f44d300"
[metadata.files]
attrs = [
{file = "attrs-22.1.0-py2.py3-none-any.whl", hash = "sha256:86efa402f67bf2df34f51a335487cf46b1ec130d02b8d39fd248abfd30da551c"},
{file = "attrs-22.1.0.tar.gz", hash = "sha256:29adc2665447e5191d0e7c568fde78b21f9672d344281d0c6e1ab085429b22b6"},
]
-Authlib = [
- {file = "Authlib-1.1.0-py2.py3-none-any.whl", hash = "sha256:be4b6a1dea51122336c210a6945b27a105b9ac572baffd15b07bcff4376c1523"},
- {file = "Authlib-1.1.0.tar.gz", hash = "sha256:0a270c91409fc2b7b0fbee6996e09f2ee3187358762111a9a4225c874b94e891"},
+authlib = [
+ {file = "Authlib-1.2.0-py2.py3-none-any.whl", hash = "sha256:4ddf4fd6cfa75c9a460b361d4bd9dac71ffda0be879dbe4292a02e92349ad55a"},
+ {file = "Authlib-1.2.0.tar.gz", hash = "sha256:4fa3e80883a5915ef9f5bc28630564bc4ed5b5af39812a3ff130ec76bd631e9d"},
]
automat = [
{file = "Automat-22.10.0-py2.py3-none-any.whl", hash = "sha256:c3164f8742b9dc440f3682482d32aaff7bb53f71740dd018533f9de286b64180"},
@@ -1836,8 +1831,8 @@ flake8 = [
{file = "flake8-5.0.4.tar.gz", hash = "sha256:6fbe320aad8d6b95cec8b8e47bc933004678dc63095be98528b7bdd2a9f510db"},
]
flake8-bugbear = [
- {file = "flake8-bugbear-22.10.27.tar.gz", hash = "sha256:a6708608965c9e0de5fff13904fed82e0ba21ac929fe4896459226a797e11cd5"},
- {file = "flake8_bugbear-22.10.27-py3-none-any.whl", hash = "sha256:6ad0ab754507319060695e2f2be80e6d8977cfcea082293089a9226276bd825d"},
+ {file = "flake8-bugbear-22.12.6.tar.gz", hash = "sha256:4cdb2c06e229971104443ae293e75e64c6107798229202fbe4f4091427a30ac0"},
+ {file = "flake8_bugbear-22.12.6-py3-none-any.whl", hash = "sha256:b69a510634f8a9c298dfda2b18a8036455e6b19ecac4fe582e4d7a0abfa50a30"},
]
flake8-comprehensions = [
{file = "flake8-comprehensions-3.10.1.tar.gz", hash = "sha256:412052ac4a947f36b891143430fef4859705af11b2572fbb689f90d372cf26ab"},
@@ -2246,8 +2241,8 @@ opentracing = [
{file = "opentracing-2.4.0.tar.gz", hash = "sha256:a173117e6ef580d55874734d1fa7ecb6f3655160b8b8974a2a1e98e5ec9c840d"},
]
packaging = [
- {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"},
- {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"},
+ {file = "packaging-22.0-py3-none-any.whl", hash = "sha256:957e2148ba0e1a3b282772e791ef1d8083648bc131c8ab0c1feba110ce1146c3"},
+ {file = "packaging-22.0.tar.gz", hash = "sha256:2198ec20bd4c017b8f9717e00f0c8714076fc2fd93816750ab48e2c41de2cfd3"},
]
parameterized = [
{file = "parameterized-0.8.1-py2.py3-none-any.whl", hash = "sha256:9cbb0b69a03e8695d68b3399a8a5825200976536fe1cb79db60ed6a4c8c9efe9"},
@@ -2258,8 +2253,8 @@ pathspec = [
{file = "pathspec-0.9.0.tar.gz", hash = "sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1"},
]
phonenumbers = [
- {file = "phonenumbers-8.13.1-py2.py3-none-any.whl", hash = "sha256:07a95c2f178687fd1c3f722cf792b3d33e3a225ae71577e500c99b28544cd6d0"},
- {file = "phonenumbers-8.13.1.tar.gz", hash = "sha256:7cadfe900e833857500b7bafa3e5a7eddc3263eb66b66a767870b33e44665f92"},
+ {file = "phonenumbers-8.13.2-py2.py3-none-any.whl", hash = "sha256:884b26f775205261f4dc861371dce217c1661a4942fb3ec3624e290fb51869bf"},
+ {file = "phonenumbers-8.13.2.tar.gz", hash = "sha256:0179f688d48c0e7e161eb7b9d86d587940af1f5174f97c1fdfd893c599c0d94a"},
]
pillow = [
{file = "Pillow-9.3.0-1-cp37-cp37m-win32.whl", hash = "sha256:e6ea6b856a74d560d9326c0f5895ef8050126acfdc7ca08ad703eb0081e82b74"},
@@ -2427,6 +2422,9 @@ pygments = [
{file = "Pygments-2.11.2-py3-none-any.whl", hash = "sha256:44238f1b60a76d78fc8ca0528ee429702aae011c265fe6a8dd8b63049ae41c65"},
{file = "Pygments-2.11.2.tar.gz", hash = "sha256:4e426f72023d88d03b2fa258de560726ce890ff3b630f88c21cbb8b2503b8c6a"},
]
+pyicu = [
+ {file = "PyICU-2.10.2.tar.gz", hash = "sha256:0c3309eea7fab6857507ace62403515b60fe096cbfb4f90d14f55ff75c5441c1"},
+]
pyjwt = [
{file = "PyJWT-2.4.0-py3-none-any.whl", hash = "sha256:72d1d253f32dbd4f5c88eaf1fdc62f3a19f676ccbadb9dbc5d07e951b2b26daf"},
{file = "PyJWT-2.4.0.tar.gz", hash = "sha256:d42908208c699b3b973cbeb01a969ba6a96c821eefb1c5bfe4c390c01d67abba"},
@@ -2455,10 +2453,6 @@ pyopenssl = [
{file = "pyOpenSSL-22.1.0-py3-none-any.whl", hash = "sha256:b28437c9773bb6c6958628cf9c3bebe585de661dba6f63df17111966363dd15e"},
{file = "pyOpenSSL-22.1.0.tar.gz", hash = "sha256:7a83b7b272dd595222d672f5ce29aa030f1fb837630ef229f62e72e395ce8968"},
]
-pyparsing = [
- {file = "pyparsing-3.0.7-py3-none-any.whl", hash = "sha256:a6c06a88f252e6c322f65faf8f418b16213b51bdfaece0524c1c1bc30c63c484"},
- {file = "pyparsing-3.0.7.tar.gz", hash = "sha256:18ee9022775d270c55187733956460083db60b37d0d0fb357445f3094eed3eea"},
-]
pyrsistent = [
{file = "pyrsistent-0.18.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:df46c854f490f81210870e509818b729db4488e1f30f2a1ce1698b2295a878d1"},
{file = "pyrsistent-0.18.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d45866ececf4a5fff8742c25722da6d4c9e180daa7b405dc0a2a2790d668c26"},
@@ -2809,8 +2803,8 @@ types-opentracing = [
{file = "types_opentracing-2.4.10-py3-none-any.whl", hash = "sha256:66d9cfbbdc4a6f8ca8189a15ad26f0fe41cee84c07057759c5d194e2505b84c2"},
]
types-pillow = [
- {file = "types-Pillow-9.3.0.1.tar.gz", hash = "sha256:f3b7cada3fa496c78d75253c6b1f07a843d625f42e5639b320a72acaff6f7cfb"},
- {file = "types_Pillow-9.3.0.1-py3-none-any.whl", hash = "sha256:79837755fe9659f29efd1016e9903ac4a500e0c73260483f07296bd6ca47668b"},
+ {file = "types-Pillow-9.3.0.4.tar.gz", hash = "sha256:c18d466dc18550d96b8b4a279ff94f0cbad696825b5ad55466604f1daf5709de"},
+ {file = "types_Pillow-9.3.0.4-py3-none-any.whl", hash = "sha256:98b8484ff343676f6f7051682a6cfd26896e993e86b3ce9badfa0ec8750f5405"},
]
types-psycopg2 = [
{file = "types-psycopg2-2.9.21.2.tar.gz", hash = "sha256:bff045579642ce00b4a3c8f2e401b7f96dfaa34939f10be64b0dd3b53feca57d"},
diff --git a/pyproject.toml b/pyproject.toml
index df59fa0562..bb383683cc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -208,6 +208,7 @@ hiredis = { version = "*", optional = true }
Pympler = { version = "*", optional = true }
parameterized = { version = ">=0.7.4", optional = true }
idna = { version = ">=2.5", optional = true }
+pyicu = { version = ">=2.10.2", optional = true }
[tool.poetry.extras]
# NB: Packages that should be part of `pip install matrix-synapse[all]` need to be specified
@@ -230,6 +231,10 @@ redis = ["txredisapi", "hiredis"]
# Required to use experimental `caches.track_memory_usage` config option.
cache-memory = ["pympler"]
test = ["parameterized", "idna"]
+# Allows for better search for international characters in the user directory. This
+# requires libicu's development headers installed on the system (e.g. libicu-dev on
+# Debian-based distributions).
+user-search = ["pyicu"]
# The duplication here is awful. I hate hate hate hate hate it. However, for now I want
# to ensure you can still `pip install matrix-synapse[all]` like today. Two motivations:
@@ -261,6 +266,8 @@ all = [
"txredisapi", "hiredis",
# cache-memory
"pympler",
+ # improved user search
+ "pyicu",
# omitted:
# - test: it's useful to have this separate from dev deps in the olddeps job
# - systemd: this is a system-based requirement
diff --git a/stubs/icu.pyi b/stubs/icu.pyi
new file mode 100644
index 0000000000..efeda7938a
--- /dev/null
+++ b/stubs/icu.pyi
@@ -0,0 +1,25 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Stub for PyICU.
+
+class Locale:
+ @staticmethod
+ def getDefault() -> Locale: ...
+
+class BreakIterator:
+ @staticmethod
+ def createWordInstance(locale: Locale) -> BreakIterator: ...
+ def setText(self, text: str) -> None: ...
+ def nextBoundary(self) -> int: ...
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index c935c7be90..d4750a32e6 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -52,7 +52,6 @@ from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.cancellation import cancellable
-from synapse.util.iterutils import batch_iter
from synapse.util.metrics import measure_func
from synapse.util.retryutils import NotRetryingDestination
@@ -422,9 +421,6 @@ class DeviceHandler(DeviceWorkerHandler):
self._check_device_name_length(initial_device_display_name)
- # Prune the user's device list if they already have a lot of devices.
- await self._prune_too_many_devices(user_id)
-
if device_id is not None:
new_device = await self.store.store_device(
user_id=user_id,
@@ -456,33 +452,6 @@ class DeviceHandler(DeviceWorkerHandler):
raise errors.StoreError(500, "Couldn't generate a device ID.")
- async def _prune_too_many_devices(self, user_id: str) -> None:
- """Delete any excess old devices this user may have."""
- device_ids = await self.store.check_too_many_devices_for_user(user_id, 100)
- if not device_ids:
- return
-
- logger.info("Pruning %d old devices for user %s", len(device_ids), user_id)
-
- # We don't want to block and try and delete tonnes of devices at once,
- # so we cap the number of devices we delete synchronously.
- first_batch, remaining_device_ids = device_ids[:10], device_ids[10:]
- await self.delete_devices(user_id, first_batch)
-
- if not remaining_device_ids:
- return
-
- # Now spawn a background loop that deletes the rest.
- async def _prune_too_many_devices_loop() -> None:
- for batch in batch_iter(remaining_device_ids, 10):
- await self.delete_devices(user_id, batch)
-
- await self.clock.sleep(1)
-
- run_as_background_process(
- "_prune_too_many_devices_loop", _prune_too_many_devices_loop
- )
-
async def _delete_stale_devices(self) -> None:
"""Background task that deletes devices which haven't been accessed for more than
a configured time period.
@@ -512,7 +481,7 @@ class DeviceHandler(DeviceWorkerHandler):
device_ids = [d for d in device_ids if d != except_device_id]
await self.delete_devices(user_id, device_ids)
- async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> None:
+ async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
"""Delete several devices
Args:
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 2056ecb2c3..a99aea8926 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -544,6 +544,48 @@ class BackgroundUpdater:
The named index will be dropped upon completion of the new index.
"""
+ async def updater(progress: JsonDict, batch_size: int) -> int:
+ await self.create_index_in_background(
+ index_name=index_name,
+ table=table,
+ columns=columns,
+ where_clause=where_clause,
+ unique=unique,
+ psql_only=psql_only,
+ replaces_index=replaces_index,
+ )
+ await self._end_background_update(update_name)
+ return 1
+
+ self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+ updater, oneshot=True
+ )
+
+ async def create_index_in_background(
+ self,
+ index_name: str,
+ table: str,
+ columns: Iterable[str],
+ where_clause: Optional[str] = None,
+ unique: bool = False,
+ psql_only: bool = False,
+ replaces_index: Optional[str] = None,
+ ) -> None:
+ """Add an index in the background.
+
+ Args:
+ update_name: update_name to register for
+ index_name: name of index to add
+ table: table to add index to
+ columns: columns/expressions to include in index
+ where_clause: A WHERE clause to specify a partial unique index.
+ unique: true to make a UNIQUE index
+ psql_only: true to only create this index on psql databases (useful
+ for virtual sqlite tables)
+ replaces_index: The name of an index that this index replaces.
+ The named index will be dropped upon completion of the new index.
+ """
+
def create_index_psql(conn: Connection) -> None:
conn.rollback()
# postgres insists on autocommit for the index
@@ -618,16 +660,11 @@ class BackgroundUpdater:
else:
runner = create_index_sqlite
- async def updater(progress: JsonDict, batch_size: int) -> int:
- if runner is not None:
- logger.info("Adding index %s to %s", index_name, table)
- await self.db_pool.runWithConnection(runner)
- await self._end_background_update(update_name)
- return 1
+ if runner is None:
+ return
- self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
- updater, oneshot=True
- )
+ logger.info("Adding index %s to %s", index_name, table)
+ await self.db_pool.runWithConnection(runner)
async def _end_background_update(self, update_name: str) -> None:
"""Removes a completed background update task from the queue.
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 95d4c0622d..a5bb4d404e 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1569,77 +1569,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
return rows
- async def check_too_many_devices_for_user(
- self, user_id: str, limit: int
- ) -> List[str]:
- """Check if the user has a lot of devices, and if so return the set of
- devices we can prune.
-
- This does *not* return hidden devices or devices with E2E keys.
-
- Returns at most `limit` number of devices, ordered by last seen.
- """
-
- num_devices = await self.db_pool.simple_select_one_onecol(
- table="devices",
- keyvalues={"user_id": user_id, "hidden": False},
- retcol="COALESCE(COUNT(*), 0)",
- desc="count_devices",
- )
-
- # We let users have up to ten devices without pruning.
- if num_devices <= 10:
- return []
-
- # We prune everything older than N days.
- max_last_seen = self._clock.time_msec() - 14 * 24 * 60 * 60 * 1000
-
- if num_devices > 50:
- # If the user has more than 50 devices, then we chose a last seen
- # that ensures we keep at most 50 devices.
- sql = """
- SELECT last_seen FROM devices
- LEFT JOIN e2e_device_keys_json USING (user_id, device_id)
- WHERE
- user_id = ?
- AND NOT hidden
- AND last_seen IS NOT NULL
- AND key_json IS NULL
- ORDER BY last_seen DESC
- LIMIT 1
- OFFSET 50
- """
-
- rows = await self.db_pool.execute(
- "check_too_many_devices_for_user_last_seen", None, sql, (user_id,)
- )
- if rows:
- max_last_seen = max(rows[0][0], max_last_seen)
-
- # Now fetch the devices to delete.
- sql = """
- SELECT device_id FROM devices
- LEFT JOIN e2e_device_keys_json USING (user_id, device_id)
- WHERE
- user_id = ?
- AND NOT hidden
- AND last_seen < ?
- AND key_json IS NULL
- ORDER BY last_seen
- LIMIT ?
- """
-
- def check_too_many_devices_for_user_txn(
- txn: LoggingTransaction,
- ) -> List[str]:
- txn.execute(sql, (user_id, max_last_seen, limit))
- return [device_id for device_id, in txn]
-
- return await self.db_pool.runInteraction(
- "check_too_many_devices_for_user",
- check_too_many_devices_for_user_txn,
- )
-
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Because we have write access, this will be a StreamIdGenerator
@@ -1698,7 +1627,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values={},
insertion_values={
"display_name": initial_device_display_name,
- "last_seen": self._clock.time_msec(),
"hidden": False,
},
desc="store_device",
@@ -1744,15 +1672,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
raise StoreError(500, "Problem storing device.")
- @cached(max_entries=0)
- async def delete_device(self, user_id: str, device_id: str) -> None:
- raise NotImplementedError()
-
- # Note: sometimes deleting rows out of `device_inbox` can take a long time,
- # so we use a cache so that we deduplicate in flight requests to delete
- # devices.
- @cachedList(cached_method_name="delete_device", list_name="device_ids")
- async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> dict:
+ async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
"""Deletes several devices.
Args:
@@ -1789,8 +1709,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))
- return {}
-
async def update_device(
self, user_id: str, device_id: str, new_display_name: Optional[str] = None
) -> None:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 643c47d608..4c691642e2 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -140,7 +140,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cancellable
async def get_e2e_device_keys_for_cs_api(
self,
- query_list: List[Tuple[str, Optional[str]]],
+ query_list: Collection[Tuple[str, Optional[str]]],
include_displaynames: bool = True,
) -> Dict[str, Dict[str, JsonDict]]:
"""Fetch a list of device keys, formatted suitably for the C/S API.
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index a580e4bdda..e06725f69c 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -924,39 +924,6 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
return batch_size
- async def _create_receipts_index(self, index_name: str, table: str) -> None:
- """Adds a unique index on `(room_id, receipt_type, user_id)` to the given
- receipts table, for non-thread receipts."""
-
- def _create_index(conn: LoggingDatabaseConnection) -> None:
- conn.rollback()
-
- # we have to set autocommit, because postgres refuses to
- # CREATE INDEX CONCURRENTLY without it.
- if isinstance(self.database_engine, PostgresEngine):
- conn.set_session(autocommit=True)
-
- try:
- c = conn.cursor()
-
- # Now that the duplicates are gone, we can create the index.
- concurrently = (
- "CONCURRENTLY"
- if isinstance(self.database_engine, PostgresEngine)
- else ""
- )
- sql = f"""
- CREATE UNIQUE INDEX {concurrently} {index_name}
- ON {table}(room_id, receipt_type, user_id)
- WHERE thread_id IS NULL
- """
- c.execute(sql)
- finally:
- if isinstance(self.database_engine, PostgresEngine):
- conn.set_session(autocommit=False)
-
- await self.db_pool.runWithConnection(_create_index)
-
async def _background_receipts_linearized_unique_index(
self, progress: dict, batch_size: int
) -> int:
@@ -999,9 +966,12 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
_remote_duplicate_receipts_txn,
)
- await self._create_receipts_index(
- "receipts_linearized_unique_index",
- "receipts_linearized",
+ await self.db_pool.updates.create_index_in_background(
+ index_name="receipts_linearized_unique_index",
+ table="receipts_linearized",
+ columns=["room_id", "receipt_type", "user_id"],
+ where_clause="thread_id IS NULL",
+ unique=True,
)
await self.db_pool.updates._end_background_update(
@@ -1050,9 +1020,12 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
_remote_duplicate_receipts_txn,
)
- await self._create_receipts_index(
- "receipts_graph_unique_index",
- "receipts_graph",
+ await self.db_pool.updates.create_index_in_background(
+ index_name="receipts_graph_unique_index",
+ table="receipts_graph",
+ columns=["room_id", "receipt_type", "user_id"],
+ where_clause="thread_id IS NULL",
+ unique=True,
)
await self.db_pool.updates._end_background_update(
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index af9952f513..14ef5b040d 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -26,6 +26,14 @@ from typing import (
cast,
)
+try:
+ # Figure out if ICU support is available for searching users.
+ import icu
+
+ USE_ICU = True
+except ModuleNotFoundError:
+ USE_ICU = False
+
from typing_extensions import TypedDict
from synapse.api.errors import StoreError
@@ -900,7 +908,7 @@ def _parse_query_sqlite(search_term: str) -> str:
"""
# Pull out the individual words, discarding any non-word characters.
- results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+ results = _parse_words(search_term)
return " & ".join("(%s* OR %s)" % (result, result) for result in results)
@@ -910,12 +918,63 @@ def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]:
We use this so that we can add prefix matching, which isn't something
that is supported by default.
"""
-
- # Pull out the individual words, discarding any non-word characters.
- results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+ results = _parse_words(search_term)
both = " & ".join("(%s:* | %s)" % (result, result) for result in results)
exact = " & ".join("%s" % (result,) for result in results)
prefix = " & ".join("%s:*" % (result,) for result in results)
return both, exact, prefix
+
+
+def _parse_words(search_term: str) -> List[str]:
+ """Split the provided search string into a list of its words.
+
+ If support for ICU (International Components for Unicode) is available, use it.
+ Otherwise, fall back to using a regex to detect word boundaries. This latter
+ solution works well enough for most latin-based languages, but doesn't work as well
+ with other languages.
+
+ Args:
+ search_term: The search string.
+
+ Returns:
+ A list of the words in the search string.
+ """
+ if USE_ICU:
+ return _parse_words_with_icu(search_term)
+
+ return re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+
+
+def _parse_words_with_icu(search_term: str) -> List[str]:
+ """Break down the provided search string into its individual words using ICU
+ (International Components for Unicode).
+
+ Args:
+ search_term: The search string.
+
+ Returns:
+ A list of the words in the search string.
+ """
+ results = []
+ breaker = icu.BreakIterator.createWordInstance(icu.Locale.getDefault())
+ breaker.setText(search_term)
+ i = 0
+ while True:
+ j = breaker.nextBoundary()
+ if j < 0:
+ break
+
+ result = search_term[i:j]
+
+ # libicu considers spaces and punctuation between words as words, but we don't
+ # want to include those in results as they would result in syntax errors in SQL
+ # queries (e.g. "foo bar" would result in the search query including "foo & &
+ # bar").
+ if len(re.findall(r"([\w\-]+)", result, re.UNICODE)):
+ results.append(result)
+
+ i = j
+
+ return results
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index e51cac9b33..ce7525e29c 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -20,8 +20,6 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import NotFoundError, SynapseError
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
-from synapse.rest import admin
-from synapse.rest.client import account, login
from synapse.server import HomeServer
from synapse.util import Clock
@@ -32,12 +30,6 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase):
- servlets = [
- login.register_servlets,
- admin.register_servlets,
- account.register_servlets,
- ]
-
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
handler = hs.get_device_handler()
@@ -123,7 +115,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
"device_id": "xyz",
"display_name": "display 0",
"last_seen_ip": None,
- "last_seen_ts": 1000000,
+ "last_seen_ts": None,
},
device_map["xyz"],
)
@@ -237,29 +229,6 @@ class DeviceTestCase(unittest.HomeserverTestCase):
NotFoundError,
)
- def test_login_delete_old_devices(self) -> None:
- """Delete old devices if the user already has too many."""
-
- user_id = self.register_user("user", "pass")
-
- # Create a bunch of devices
- for _ in range(50):
- self.login("user", "pass")
- self.reactor.advance(1)
-
- # Advance the clock for ages (as we only delete old devices)
- self.reactor.advance(60 * 60 * 24 * 300)
-
- # Log in again to start the pruning
- self.login("user", "pass")
-
- # Give the background job time to do its thing
- self.reactor.pump([1.0] * 100)
-
- # We should now only have the most recent device.
- devices = self.get_success(self.handler.get_devices_by_user(user_id))
- self.assertEqual(len(devices), 1)
-
def _record_users(self) -> None:
# check this works for both devices which have a recorded client_ip,
# and those which don't.
diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py
index 50c20c5b92..373707b275 100644
--- a/tests/storage/databases/main/test_deviceinbox.py
+++ b/tests/storage/databases/main/test_deviceinbox.py
@@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.rest import admin
from synapse.rest.client import devices
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -25,11 +29,11 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
devices.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = self.register_user("foo", "pass")
- def test_background_remove_deleted_devices_from_device_inbox(self):
+ def test_background_remove_deleted_devices_from_device_inbox(self) -> None:
"""Test that the background task to delete old device_inboxes works properly."""
# create a valid device
@@ -89,7 +93,7 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
self.assertEqual(1, len(res))
self.assertEqual(res[0], "cur_device")
- def test_background_remove_hidden_devices_from_device_inbox(self):
+ def test_background_remove_hidden_devices_from_device_inbox(self) -> None:
"""Test that the background task to delete hidden devices
from device_inboxes works properly."""
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 5773172ab8..9f33afcca0 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -45,7 +45,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.hs = hs
self.store: EventsWorkerStore = hs.get_datastores().main
@@ -68,7 +68,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
self.event_ids.append(event.event_id)
- def test_simple(self):
+ def test_simple(self) -> None:
with LoggingContext(name="test") as ctx:
res = self.get_success(
self.store.have_seen_events(
@@ -90,7 +90,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
self.assertEqual(res, {self.event_ids[0]})
self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
- def test_persisting_event_invalidates_cache(self):
+ def test_persisting_event_invalidates_cache(self) -> None:
"""
Test to make sure that the `have_seen_event` cache
is invalidated after we persist an event and returns
@@ -138,7 +138,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
# That should result in a single db query to lookup
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
- def test_invalidate_cache_by_room_id(self):
+ def test_invalidate_cache_by_room_id(self) -> None:
"""
Test to make sure that all events associated with the given `(room_id,)`
are invalidated in the `have_seen_event` cache.
@@ -175,7 +175,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store: EventsWorkerStore = hs.get_datastores().main
self.user = self.register_user("user", "pass")
@@ -189,7 +189,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# Reset the event cache so the tests start with it empty
self.get_success(self.store._get_event_cache.clear())
- def test_simple(self):
+ def test_simple(self) -> None:
"""Test that we cache events that we pull from the DB."""
with LoggingContext("test") as ctx:
@@ -198,7 +198,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
- def test_event_ref(self):
+ def test_event_ref(self) -> None:
"""Test that we reuse events that are still in memory but have fallen
out of the cache, rather than requesting them from the DB.
"""
@@ -223,7 +223,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 0)
- def test_dedupe(self):
+ def test_dedupe(self) -> None:
"""Test that if we request the same event multiple times we only pull it
out once.
"""
@@ -241,7 +241,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
class DatabaseOutageTestCase(unittest.HomeserverTestCase):
"""Test event fetching during a database outage."""
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store: EventsWorkerStore = hs.get_datastores().main
self.room_id = f"!room:{hs.hostname}"
@@ -377,7 +377,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store: EventsWorkerStore = hs.get_datastores().main
self.user = self.register_user("user", "pass")
@@ -412,7 +412,8 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
unblock: "Deferred[None]" = Deferred()
original_runWithConnection = self.store.db_pool.runWithConnection
- async def runWithConnection(*args, **kwargs):
+ # Don't bother with the types here, we just pass into the original function.
+ async def runWithConnection(*args, **kwargs): # type: ignore[no-untyped-def]
await unblock
return await original_runWithConnection(*args, **kwargs)
@@ -441,7 +442,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
self.assertEqual(ctx1.get_resource_usage().evt_db_fetch_count, 1)
self.assertEqual(ctx2.get_resource_usage().evt_db_fetch_count, 0)
- def test_first_get_event_cancelled(self):
+ def test_first_get_event_cancelled(self) -> None:
"""Test cancellation of the first `get_event` call sharing a database fetch.
The first `get_event` call is the one which initiates the fetch. We expect the
@@ -467,7 +468,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
# The second `get_event` call should complete successfully.
self.get_success(get_event2)
- def test_second_get_event_cancelled(self):
+ def test_second_get_event_cancelled(self) -> None:
"""Test cancellation of the second `get_event` call sharing a database fetch."""
with self.blocking_get_event_calls() as (unblock, get_event1, get_event2):
# Cancel the second `get_event` call.
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index 3cc2a58d8d..56cb49d9b5 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -15,18 +15,20 @@
from twisted.internet import defer, reactor
from twisted.internet.base import ReactorBase
from twisted.internet.defer import Deferred
+from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS
+from synapse.util import Clock
from tests import unittest
class LockTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- def test_acquire_contention(self):
+ def test_acquire_contention(self) -> None:
# Track the number of tasks holding the lock.
# Should be at most 1.
in_lock = 0
@@ -34,7 +36,7 @@ class LockTestCase(unittest.HomeserverTestCase):
release_lock: "Deferred[None]" = Deferred()
- async def task():
+ async def task() -> None:
nonlocal in_lock
nonlocal max_in_lock
@@ -76,7 +78,7 @@ class LockTestCase(unittest.HomeserverTestCase):
# At most one task should have held the lock at a time.
self.assertEqual(max_in_lock, 1)
- def test_simple_lock(self):
+ def test_simple_lock(self) -> None:
"""Test that we can take out a lock and that while we hold it nobody
else can take it out.
"""
@@ -103,7 +105,7 @@ class LockTestCase(unittest.HomeserverTestCase):
self.get_success(lock3.__aenter__())
self.get_success(lock3.__aexit__(None, None, None))
- def test_maintain_lock(self):
+ def test_maintain_lock(self) -> None:
"""Test that we don't time out locks while they're still active"""
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
@@ -119,7 +121,7 @@ class LockTestCase(unittest.HomeserverTestCase):
self.get_success(lock.__aexit__(None, None, None))
- def test_timeout_lock(self):
+ def test_timeout_lock(self) -> None:
"""Test that we time out locks if they're not updated for ages"""
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
@@ -139,7 +141,7 @@ class LockTestCase(unittest.HomeserverTestCase):
self.assertFalse(self.get_success(lock.is_still_valid()))
- def test_drop(self):
+ def test_drop(self) -> None:
"""Test that dropping the context manager means we stop renewing the lock"""
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
@@ -153,7 +155,7 @@ class LockTestCase(unittest.HomeserverTestCase):
lock2 = self.get_success(self.store.try_acquire_lock("name", "key"))
self.assertIsNotNone(lock2)
- def test_shutdown(self):
+ def test_shutdown(self) -> None:
"""Test that shutting down Synapse releases the locks"""
# Acquire two locks
lock = self.get_success(self.store.try_acquire_lock("name", "key1"))
diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py
index c4f12d81d7..68026e2830 100644
--- a/tests/storage/databases/main/test_receipts.py
+++ b/tests/storage/databases/main/test_receipts.py
@@ -33,7 +33,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = self.register_user("foo", "pass")
self.token = self.login("foo", "pass")
@@ -47,7 +47,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
table: str,
receipts: Dict[Tuple[str, str, str], Sequence[Dict[str, Any]]],
expected_unique_receipts: Dict[Tuple[str, str, str], Optional[Dict[str, Any]]],
- ):
+ ) -> None:
"""Test that the background update to uniqueify non-thread receipts in
the given receipts table works properly.
@@ -154,7 +154,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
f"Background update did not remove all duplicate receipts from {table}",
)
- def test_background_receipts_linearized_unique_index(self):
+ def test_background_receipts_linearized_unique_index(self) -> None:
"""Test that the background update to uniqueify non-thread receipts in
`receipts_linearized` works properly.
"""
@@ -177,7 +177,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
},
)
- def test_background_receipts_graph_unique_index(self):
+ def test_background_receipts_graph_unique_index(self) -> None:
"""Test that the background update to uniqueify non-thread receipts in
`receipts_graph` works properly.
"""
diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py
index 1edb619630..7d961fac64 100644
--- a/tests/storage/databases/main/test_room.py
+++ b/tests/storage/databases/main/test_room.py
@@ -14,10 +14,14 @@
import json
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import RoomTypes
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.storage.databases.main.room import _BackgroundUpdates
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -30,7 +34,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = self.register_user("foo", "pass")
self.token = self.login("foo", "pass")
@@ -40,7 +44,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
return room_id
- def test_background_populate_rooms_creator_column(self):
+ def test_background_populate_rooms_creator_column(self) -> None:
"""Test that the background update to populate the rooms creator column
works properly.
"""
@@ -95,7 +99,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(room_creator_after, self.user_id)
- def test_background_add_room_type_column(self):
+ def test_background_add_room_type_column(self) -> None:
"""Test that the background update to populate the `room_type` column in
`room_stats_state` works properly.
"""
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 09cb06d614..8bbf936ae9 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -106,7 +106,7 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase):
{(1, "user1", "hello"), (2, "user2", "bleb")},
)
- def test_simple_update_many(self):
+ def test_simple_update_many(self) -> None:
"""
simple_update_many performs many updates at once.
"""
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
index 72bf5b3d31..1bfd11ceae 100644
--- a/tests/storage/test_account_data.py
+++ b/tests/storage/test_account_data.py
@@ -14,13 +14,17 @@
from typing import Iterable, Optional, Set
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import AccountDataTypes
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
class IgnoredUsersTestCase(unittest.HomeserverTestCase):
- def prepare(self, hs, reactor, clock):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
self.user = "@user:test"
@@ -55,7 +59,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
expected_ignored_user_ids,
)
- def test_ignoring_users(self):
+ def test_ignoring_users(self) -> None:
"""Basic adding/removing of users from the ignore list."""
self._update_ignore_list("@other:test", "@another:remote")
self.assert_ignored(self.user, {"@other:test", "@another:remote"})
@@ -82,7 +86,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
# Check the removed user.
self.assert_ignorers("@another:remote", {self.user})
- def test_caching(self):
+ def test_caching(self) -> None:
"""Ensure that caching works properly between different users."""
# The first user ignores a user.
self._update_ignore_list("@other:test")
@@ -99,7 +103,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
self.assert_ignored(self.user, set())
self.assert_ignorers("@other:test", {"@second:test"})
- def test_invalid_data(self):
+ def test_invalid_data(self) -> None:
"""Invalid data ends up clearing out the ignored users list."""
# Add some data and ensure it is there.
self._update_ignore_list("@other:test")
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 1047ed09c8..5e1324a169 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -26,7 +26,7 @@ from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.config._base import ConfigError
from synapse.events import EventBase
from synapse.server import HomeServer
-from synapse.storage.database import DatabasePool, make_conn
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection, make_conn
from synapse.storage.databases.main.appservice import (
ApplicationServiceStore,
ApplicationServiceTransactionStore,
@@ -39,7 +39,7 @@ from tests.test_utils import make_awaitable
class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
- def setUp(self):
+ def setUp(self) -> None:
super(ApplicationServiceStoreTestCase, self).setUp()
self.as_yaml_files: List[str] = []
@@ -73,7 +73,9 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
super(ApplicationServiceStoreTestCase, self).tearDown()
- def _add_appservice(self, as_token, id, url, hs_token, sender) -> None:
+ def _add_appservice(
+ self, as_token: str, id: str, url: str, hs_token: str, sender: str
+ ) -> None:
as_yaml = {
"url": url,
"as_token": as_token,
@@ -135,7 +137,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
database, make_conn(db_config, self.engine, "test"), self.hs
)
- def _add_service(self, url, as_token, id) -> None:
+ def _add_service(self, url: str, as_token: str, id: str) -> None:
as_yaml = {
"url": url,
"as_token": as_token,
@@ -149,7 +151,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
outfile.write(yaml.dump(as_yaml))
self.as_yaml_files.append(as_token)
- def _set_state(self, id: str, state: ApplicationServiceState):
+ def _set_state(self, id: str, state: ApplicationServiceState) -> defer.Deferred:
return self.db_pool.runOperation(
self.engine.convert_param_style(
"INSERT INTO application_services_state(as_id, state) VALUES(?,?)"
@@ -157,7 +159,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
(id, state.value),
)
- def _insert_txn(self, as_id, txn_id, events):
+ def _insert_txn(
+ self, as_id: str, txn_id: int, events: List[Mock]
+ ) -> "defer.Deferred[None]":
return self.db_pool.runOperation(
self.engine.convert_param_style(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
@@ -448,12 +452,14 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
# required for ApplicationServiceTransactionStoreTestCase tests
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
- def __init__(self, database: DatabasePool, db_conn, hs) -> None:
+ def __init__(
+ self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: HomeServer
+ ) -> None:
super().__init__(database, db_conn, hs)
class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase):
- def _write_config(self, suffix, **kwargs) -> str:
+ def _write_config(self, suffix: str, **kwargs: str) -> str:
vals = {
"id": "id" + suffix,
"url": "url" + suffix,
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 40e58f8199..256d28e4c9 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from collections import OrderedDict
+from typing import Generator
from unittest.mock import Mock
from twisted.internet import defer
@@ -30,7 +30,7 @@ from tests.utils import default_config
class SQLBaseStoreTestCase(unittest.TestCase):
"""Test the "simple" SQL generating methods in SQLBaseStore."""
- def setUp(self):
+ def setUp(self) -> None:
self.db_pool = Mock(spec=["runInteraction"])
self.mock_txn = Mock()
self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"])
@@ -38,12 +38,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_conn.rollback.return_value = None
# Our fake runInteraction just runs synchronously inline
- def runInteraction(func, *args, **kwargs):
+ def runInteraction(func, *args, **kwargs) -> defer.Deferred: # type: ignore[no-untyped-def]
return defer.succeed(func(self.mock_txn, *args, **kwargs))
self.db_pool.runInteraction = runInteraction
- def runWithConnection(func, *args, **kwargs):
+ def runWithConnection(func, *args, **kwargs): # type: ignore[no-untyped-def]
return defer.succeed(func(self.mock_conn, *args, **kwargs))
self.db_pool.runWithConnection = runWithConnection
@@ -62,7 +62,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.datastore = SQLBaseStore(db, None, hs) # type: ignore[arg-type]
@defer.inlineCallbacks
- def test_insert_1col(self):
+ def test_insert_1col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
yield defer.ensureDeferred(
@@ -76,7 +76,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
- def test_insert_3cols(self):
+ def test_insert_3cols(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
yield defer.ensureDeferred(
@@ -92,7 +92,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
- def test_select_one_1col(self):
+ def test_select_one_1col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
@@ -108,7 +108,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
- def test_select_one_3col(self):
+ def test_select_one_3col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = (1, 2, 3)
@@ -126,7 +126,9 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
- def test_select_one_missing(self):
+ def test_select_one_missing(
+ self,
+ ) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 0
self.mock_txn.fetchone.return_value = None
@@ -142,7 +144,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.assertFalse(ret)
@defer.inlineCallbacks
- def test_select_list(self):
+ def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 3
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
@@ -159,7 +161,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
- def test_update_one_1col(self):
+ def test_update_one_1col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
yield defer.ensureDeferred(
@@ -176,7 +178,9 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
- def test_update_one_4cols(self):
+ def test_update_one_4cols(
+ self,
+ ) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
yield defer.ensureDeferred(
@@ -193,7 +197,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
- def test_delete_one(self):
+ def test_delete_one(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
yield defer.ensureDeferred(
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index b998ad42d9..d570684c99 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -15,11 +15,16 @@
import os.path
from unittest.mock import Mock, patch
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.storage import prepare_database
+from synapse.storage.types import Cursor
from synapse.types import UserID, create_requester
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -29,7 +34,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
Test the background update to clean forward extremities table.
"""
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.store = homeserver.get_datastores().main
self.room_creator = homeserver.get_room_creation_handler()
@@ -39,7 +46,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
- def run_background_update(self):
+ def run_background_update(self) -> None:
"""Re run the background update to clean up the extremities."""
# Make sure we don't clash with in progress updates.
self.assertTrue(
@@ -54,7 +61,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
"delete_forward_extremities.sql",
)
- def run_delta_file(txn):
+ def run_delta_file(txn: Cursor) -> None:
prepare_database.executescript(txn, schema_path)
self.get_success(
@@ -84,7 +91,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
(room_id,)
)
- def test_soft_failed_extremities_handled_correctly(self):
+ def test_soft_failed_extremities_handled_correctly(self) -> None:
"""Test that extremities are correctly calculated in the presence of
soft failed events.
@@ -114,7 +121,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
self.assertEqual(latest_event_ids, [event_id_4])
- def test_basic_cleanup(self):
+ def test_basic_cleanup(self) -> None:
"""Test that extremities are correctly calculated in the presence of
soft failed events.
@@ -149,7 +156,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(latest_event_ids, [event_id_b])
- def test_chain_of_fail_cleanup(self):
+ def test_chain_of_fail_cleanup(self) -> None:
"""Test that extremities are correctly calculated in the presence of
soft failed events.
@@ -187,7 +194,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(latest_event_ids, [event_id_b])
- def test_forked_graph_cleanup(self):
+ def test_forked_graph_cleanup(self) -> None:
r"""Test that extremities are correctly calculated in the presence of
soft failed events.
@@ -252,12 +259,14 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["cleanup_extremities_with_dummy_events"] = True
return self.setup_test_homeserver(config=config)
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.store = homeserver.get_datastores().main
self.room_creator = homeserver.get_room_creation_handler()
self.event_creator_handler = homeserver.get_event_creation_handler()
@@ -273,7 +282,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.event_creator = homeserver.get_event_creation_handler()
homeserver.config.consent.user_consent_version = self.CONSENT_VERSION
- def test_send_dummy_event(self):
+ def test_send_dummy_event(self) -> None:
self._create_extremity_rich_graph()
# Pump the reactor repeatedly so that the background updates have a
@@ -286,7 +295,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0)
- def test_send_dummy_events_when_insufficient_power(self):
+ def test_send_dummy_events_when_insufficient_power(self) -> None:
self._create_extremity_rich_graph()
# Criple power levels
self.helper.send_state(
@@ -317,7 +326,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250)
- def test_expiry_logic(self):
+ def test_expiry_logic(self) -> None:
"""Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion()
expires old entries correctly.
"""
@@ -357,7 +366,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
0,
)
- def _create_extremity_rich_graph(self):
+ def _create_extremity_rich_graph(self) -> None:
"""Helper method to create bushy graph on demand"""
event_id_start = self.create_and_send_event(self.room_id, self.user)
@@ -372,7 +381,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
)
self.assertEqual(len(latest_event_ids), 50)
- def _enable_consent_checking(self):
+ def _enable_consent_checking(self) -> None:
"""Helper method to enable consent checking"""
self.event_creator._block_events_without_consent_error = "No consent from user"
consent_uri_builder = Mock()
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index a9af1babed..7f7f4ef892 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -13,15 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict
from unittest.mock import Mock
from parameterized import parameterized
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.rest.admin
from synapse.http.site import XForwardedForRequest
from synapse.rest.client import login
+from synapse.server import HomeServer
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.types import UserID
+from synapse.util import Clock
from tests import unittest
from tests.server import make_request
@@ -30,14 +35,10 @@ from tests.unittest import override_config
class ClientIpStoreTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver()
- return hs
-
- def prepare(self, hs, reactor, clock):
- self.store = self.hs.get_datastores().main
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
- def test_insert_new_client_ip(self):
+ def test_insert_new_client_ip(self) -> None:
self.reactor.advance(12345678)
user_id = "@user:id"
@@ -76,7 +77,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
r,
)
- def test_insert_new_client_ip_none_device_id(self):
+ def test_insert_new_client_ip_none_device_id(self) -> None:
"""
An insert with a device ID of NULL will not create a new entry, but
update an existing entry in the user_ips table.
@@ -148,7 +149,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
@parameterized.expand([(False,), (True,)])
- def test_get_last_client_ip_by_device(self, after_persisting: bool):
+ def test_get_last_client_ip_by_device(self, after_persisting: bool) -> None:
"""Test `get_last_client_ip_by_device` for persisted and unpersisted data"""
self.reactor.advance(12345678)
@@ -169,8 +170,6 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
)
- last_seen = self.clock.time_msec()
-
if after_persisting:
# Trigger the storage loop
self.reactor.advance(10)
@@ -191,7 +190,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
"device_id": device_id,
"ip": None,
"user_agent": None,
- "last_seen": last_seen,
+ "last_seen": None,
},
],
)
@@ -213,7 +212,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
},
)
- def test_get_last_client_ip_by_device_combined_data(self):
+ def test_get_last_client_ip_by_device_combined_data(self) -> None:
"""Test that `get_last_client_ip_by_device` combines persisted and unpersisted
data together correctly
"""
@@ -312,7 +311,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
@parameterized.expand([(False,), (True,)])
- def test_get_user_ip_and_agents(self, after_persisting: bool):
+ def test_get_user_ip_and_agents(self, after_persisting: bool) -> None:
"""Test `get_user_ip_and_agents` for persisted and unpersisted data"""
self.reactor.advance(12345678)
@@ -352,7 +351,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
],
)
- def test_get_user_ip_and_agents_combined_data(self):
+ def test_get_user_ip_and_agents_combined_data(self) -> None:
"""Test that `get_user_ip_and_agents` combines persisted and unpersisted data
together correctly
"""
@@ -429,7 +428,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
@override_config({"limit_usage_by_mau": False, "max_mau_value": 50})
- def test_disabled_monthly_active_user(self):
+ def test_disabled_monthly_active_user(self) -> None:
user_id = "@user:server"
self.get_success(
self.store.insert_client_ip(
@@ -440,7 +439,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertFalse(active)
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
- def test_adding_monthly_active_user_when_full(self):
+ def test_adding_monthly_active_user_when_full(self) -> None:
lots_of_users = 100
user_id = "@user:server"
@@ -456,7 +455,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertFalse(active)
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
- def test_adding_monthly_active_user_when_space(self):
+ def test_adding_monthly_active_user_when_space(self) -> None:
user_id = "@user:server"
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
@@ -473,7 +472,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertTrue(active)
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
- def test_updating_monthly_active_user_when_space(self):
+ def test_updating_monthly_active_user_when_space(self) -> None:
user_id = "@user:server"
self.get_success(self.store.register_user(user_id=user_id, password_hash=None))
@@ -491,7 +490,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertTrue(active)
- def test_devices_last_seen_bg_update(self):
+ def test_devices_last_seen_bg_update(self) -> None:
# First make sure we have completed all updates.
self.wait_for_background_updates()
@@ -576,7 +575,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
r,
)
- def test_old_user_ips_pruned(self):
+ def test_old_user_ips_pruned(self) -> None:
# First make sure we have completed all updates.
self.wait_for_background_updates()
@@ -639,11 +638,11 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertEqual(result, [])
# But we should still get the correct values for the device
- result = self.get_success(
+ result2 = self.get_success(
self.store.get_last_client_ip_by_device(user_id, device_id)
)
- r = result[(user_id, device_id)]
+ r = result2[(user_id, device_id)]
self.assertDictContainsSubset(
{
"user_id": user_id,
@@ -663,15 +662,11 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver()
- return hs
-
- def prepare(self, hs, reactor, clock):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
self.user_id = self.register_user("bob", "abc123", True)
- def test_request_with_xforwarded(self):
+ def test_request_with_xforwarded(self) -> None:
"""
The IP in X-Forwarded-For is entered into the client IPs table.
"""
@@ -681,14 +676,19 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
{"request": XForwardedForRequest},
)
- def test_request_from_getPeer(self):
+ def test_request_from_getPeer(self) -> None:
"""
The IP returned by getPeer is entered into the client IPs table, if
there's no X-Forwarded-For header.
"""
self._runtest({}, "127.0.0.1", {})
- def _runtest(self, headers, expected_ip, make_request_args):
+ def _runtest(
+ self,
+ headers: Dict[bytes, bytes],
+ expected_ip: str,
+ make_request_args: Dict[str, Any],
+ ) -> None:
device_id = "bleb"
access_token = self.login("bob", "abc123", device_id=device_id)
diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py
index a40fc20ef9..543cce6b3e 100644
--- a/tests/storage/test_database.py
+++ b/tests/storage/test_database.py
@@ -31,7 +31,7 @@ from tests import unittest
class TupleComparisonClauseTestCase(unittest.TestCase):
- def test_native_tuple_comparison(self):
+ def test_native_tuple_comparison(self) -> None:
clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
self.assertEqual(clause, "(a,b) > (?,?)")
self.assertEqual(args, [1, 2])
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 8e7db2c4ec..f03807c8f9 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -12,17 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Collection, List, Tuple
+
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.api.errors
from synapse.api.constants import EduTypes
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class DeviceStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- def add_device_change(self, user_id, device_ids, host):
+ def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None:
"""Add a device list change for the given device to
`device_lists_outbound_pokes` table.
"""
@@ -44,12 +51,13 @@ class DeviceStoreTestCase(HomeserverTestCase):
)
)
- def test_store_new_device(self):
+ def test_store_new_device(self) -> None:
self.get_success(
self.store.store_device("user_id", "device_id", "display_name")
)
res = self.get_success(self.store.get_device("user_id", "device_id"))
+ assert res is not None
self.assertDictContainsSubset(
{
"user_id": "user_id",
@@ -59,7 +67,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
res,
)
- def test_get_devices_by_user(self):
+ def test_get_devices_by_user(self) -> None:
self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
@@ -89,7 +97,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
res["device2"],
)
- def test_count_devices_by_users(self):
+ def test_count_devices_by_users(self) -> None:
self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
@@ -114,7 +122,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
)
self.assertEqual(3, res)
- def test_get_device_updates_by_remote(self):
+ def test_get_device_updates_by_remote(self) -> None:
device_ids = ["device_id1", "device_id2"]
# Add two device updates with sequential `stream_id`s
@@ -128,7 +136,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
# Check original device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates)
- def test_get_device_updates_by_remote_can_limit_properly(self):
+ def test_get_device_updates_by_remote_can_limit_properly(self) -> None:
"""
Tests that `get_device_updates_by_remote` returns an appropriate
stream_id to resume fetching from (without skipping any results).
@@ -280,7 +288,11 @@ class DeviceStoreTestCase(HomeserverTestCase):
)
self.assertEqual(device_updates, [])
- def _check_devices_in_updates(self, expected_device_ids, device_updates):
+ def _check_devices_in_updates(
+ self,
+ expected_device_ids: Collection[str],
+ device_updates: List[Tuple[str, JsonDict]],
+ ) -> None:
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))
@@ -289,17 +301,19 @@ class DeviceStoreTestCase(HomeserverTestCase):
}
self.assertEqual(received_device_ids, set(expected_device_ids))
- def test_update_device(self):
+ def test_update_device(self) -> None:
self.get_success(
self.store.store_device("user_id", "device_id", "display_name 1")
)
res = self.get_success(self.store.get_device("user_id", "device_id"))
+ assert res is not None
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
self.get_success(self.store.update_device("user_id", "device_id"))
res = self.get_success(self.store.get_device("user_id", "device_id"))
+ assert res is not None
self.assertEqual("display_name 1", res["display_name"])
# do the update
@@ -311,9 +325,10 @@ class DeviceStoreTestCase(HomeserverTestCase):
# check it worked
res = self.get_success(self.store.get_device("user_id", "device_id"))
+ assert res is not None
self.assertEqual("display_name 2", res["display_name"])
- def test_update_unknown_device(self):
+ def test_update_unknown_device(self) -> None:
exc = self.get_failure(
self.store.update_device(
"user_id", "unknown_device_id", new_display_name="display_name 2"
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index 20bf3ca17b..8bedc6bdf3 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -12,19 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
from synapse.types import RoomAlias, RoomID
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class DirectoryStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test")
- def test_room_to_alias(self):
+ def test_room_to_alias(self) -> None:
self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
@@ -36,7 +40,7 @@ class DirectoryStoreTestCase(HomeserverTestCase):
(self.get_success(self.store.get_aliases_for_room(self.room.to_string()))),
)
- def test_alias_to_room(self):
+ def test_alias_to_room(self) -> None:
self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
@@ -48,7 +52,7 @@ class DirectoryStoreTestCase(HomeserverTestCase):
(self.get_success(self.store.get_association_from_room_alias(self.alias))),
)
- def test_delete_alias(self):
+ def test_delete_alias(self) -> None:
self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
index fb96ab3a2f..9cb326d90a 100644
--- a/tests/storage/test_e2e_room_keys.py
+++ b/tests/storage/test_e2e_room_keys.py
@@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
from synapse.storage.databases.main.e2e_room_keys import RoomKey
+from synapse.util import Clock
from tests import unittest
@@ -26,12 +30,12 @@ room_key: RoomKey = {
class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
self.store = hs.get_datastores().main
return hs
- def test_room_keys_version_delete(self):
+ def test_room_keys_version_delete(self) -> None:
# test that deleting a room key backup deletes the keys
version1 = self.get_success(
self.store.create_e2e_room_keys_version(
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 0f04493ad0..5fde3b9c78 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.util import Clock
+
from tests.unittest import HomeserverTestCase
class EndToEndKeyStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- def test_key_without_device_name(self):
+ def test_key_without_device_name(self) -> None:
now = 1470174257070
json = {"key": "value"}
@@ -35,7 +40,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
dev = res["user"]["device"]
self.assertDictContainsSubset(json, dev)
- def test_reupload_key(self):
+ def test_reupload_key(self) -> None:
now = 1470174257070
json = {"key": "value"}
@@ -53,7 +58,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
)
self.assertFalse(changed)
- def test_get_key_with_device_name(self):
+ def test_get_key_with_device_name(self) -> None:
now = 1470174257070
json = {"key": "value"}
@@ -70,7 +75,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
{"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
)
- def test_multiple_devices(self):
+ def test_multiple_devices(self) -> None:
now = 1470174257070
self.get_success(self.store.store_device("user1", "device1", None))
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index de9f4af2de..c070278db8 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -14,6 +14,7 @@
from typing import Dict, List, Set, Tuple
+from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest
from synapse.api.constants import EventTypes
@@ -22,18 +23,22 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.events import _LinkMap
+from synapse.storage.types import Cursor
from synapse.types import create_requester
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class EventChainStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self._next_stream_ordering = 1
- def test_simple(self):
+ def test_simple(self) -> None:
"""Test that the example in `docs/auth_chain_difference_algorithm.md`
works.
"""
@@ -232,7 +237,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
),
)
- def test_out_of_order_events(self):
+ def test_out_of_order_events(self) -> None:
"""Test that we handle persisting events that we don't have the full
auth chain for yet (which should only happen for out of band memberships).
"""
@@ -378,7 +383,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
def persist(
self,
events: List[EventBase],
- ):
+ ) -> None:
"""Persist the given events and check that the links generated match
those given.
"""
@@ -389,7 +394,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
e.internal_metadata.stream_ordering = self._next_stream_ordering
self._next_stream_ordering += 1
- def _persist(txn):
+ def _persist(txn: LoggingTransaction) -> None:
# We need to persist the events to the events and state_events
# tables.
persist_events_store._store_event_txn(
@@ -456,7 +461,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
class LinkMapTestCase(unittest.TestCase):
- def test_simple(self):
+ def test_simple(self) -> None:
"""Basic tests for the LinkMap."""
link_map = _LinkMap()
@@ -492,7 +497,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = self.register_user("foo", "pass")
self.token = self.login("foo", "pass")
@@ -559,7 +564,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
# Delete the chain cover info.
- def _delete_tables(txn):
+ def _delete_tables(txn: Cursor) -> None:
txn.execute("DELETE FROM event_auth_chains")
txn.execute("DELETE FROM event_auth_chain_links")
@@ -567,7 +572,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
return room_id, [state1, state2]
- def test_background_update_single_room(self):
+ def test_background_update_single_room(self) -> None:
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
@@ -602,7 +607,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
)
)
- def test_background_update_multiple_rooms(self):
+ def test_background_update_multiple_rooms(self) -> None:
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
@@ -640,7 +645,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
)
)
- def test_background_update_single_large_room(self):
+ def test_background_update_single_large_room(self) -> None:
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
@@ -693,7 +698,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
)
)
- def test_background_update_multiple_large_room(self):
+ def test_background_update_multiple_large_room(self) -> None:
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 853db930d6..7fd3e01364 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,7 +13,7 @@
# limitations under the License.
import datetime
-from typing import Dict, List, Tuple, Union
+from typing import Dict, List, Tuple, Union, cast
import attr
from parameterized import parameterized
@@ -26,11 +26,12 @@ from synapse.api.room_versions import (
EventFormatVersions,
RoomVersion,
)
-from synapse.events import _EventInternalMetadata
+from synapse.events import EventBase, _EventInternalMetadata
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage.database import LoggingTransaction
+from synapse.storage.types import Cursor
from synapse.types import JsonDict
from synapse.util import Clock, json_encoder
@@ -54,11 +55,11 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- def test_get_prev_events_for_room(self):
+ def test_get_prev_events_for_room(self) -> None:
room_id = "@ROOM:local"
# add a bunch of events and hashes to act as forward extremities
- def insert_event(txn, i):
+ def insert_event(txn: Cursor, i: int) -> None:
event_id = "$event_%i:local" % i
txn.execute(
@@ -90,12 +91,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
for i in range(0, 10):
self.assertEqual("$event_%i:local" % (19 - i), r[i])
- def test_get_rooms_with_many_extremities(self):
+ def test_get_rooms_with_many_extremities(self) -> None:
room1 = "#room1"
room2 = "#room2"
room3 = "#room3"
- def insert_event(txn, i, room_id):
+ def insert_event(txn: Cursor, i: int, room_id: str) -> None:
event_id = "$event_%i:local" % i
txn.execute(
(
@@ -155,7 +156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# | |
# K J
- auth_graph = {
+ auth_graph: Dict[str, List[str]] = {
"a": ["e"],
"b": ["e"],
"c": ["g", "i"],
@@ -185,7 +186,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# Mark the room as maybe having a cover index.
- def store_room(txn):
+ def store_room(txn: LoggingTransaction) -> None:
self.store.db_pool.simple_insert_txn(
txn,
"rooms",
@@ -203,7 +204,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# We rudely fiddle with the appropriate tables directly, as that's much
# easier than constructing events properly.
- def insert_event(txn):
+ def insert_event(txn: LoggingTransaction) -> None:
stream_ordering = 0
for event_id in auth_graph:
@@ -228,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn,
[
- FakeEvent(event_id, room_id, auth_graph[event_id])
+ cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
for event_id in auth_graph
],
)
@@ -243,7 +244,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
return room_id
@parameterized.expand([(True,), (False,)])
- def test_auth_chain_ids(self, use_chain_cover_index: bool):
+ def test_auth_chain_ids(self, use_chain_cover_index: bool) -> None:
room_id = self._setup_auth_chain(use_chain_cover_index)
# a and b have the same auth chain.
@@ -308,7 +309,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.assertCountEqual(auth_chain_ids, ["i", "j"])
@parameterized.expand([(True,), (False,)])
- def test_auth_difference(self, use_chain_cover_index: bool):
+ def test_auth_difference(self, use_chain_cover_index: bool) -> None:
room_id = self._setup_auth_chain(use_chain_cover_index)
# Now actually test that various combinations give the right result:
@@ -353,7 +354,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
self.assertSetEqual(difference, set())
- def test_auth_difference_partial_cover(self):
+ def test_auth_difference_partial_cover(self) -> None:
"""Test that we correctly handle rooms where not all events have a chain
cover calculated. This can happen in some obscure edge cases, including
during the background update that calculates the chain cover for old
@@ -377,7 +378,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# | |
# K J
- auth_graph = {
+ auth_graph: Dict[str, List[str]] = {
"a": ["e"],
"b": ["e"],
"c": ["g", "i"],
@@ -408,7 +409,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# We rudely fiddle with the appropriate tables directly, as that's much
# easier than constructing events properly.
- def insert_event(txn):
+ def insert_event(txn: LoggingTransaction) -> None:
# First insert the room and mark it as having a chain cover.
self.store.db_pool.simple_insert_txn(
txn,
@@ -447,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn,
[
- FakeEvent(event_id, room_id, auth_graph[event_id])
+ cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
for event_id in auth_graph
if event_id != "b"
],
@@ -465,7 +466,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn,
- [FakeEvent("b", room_id, auth_graph["b"])],
+ [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
)
self.store.db_pool.simple_update_txn(
@@ -527,7 +528,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
@parameterized.expand(
[(room_version,) for room_version in KNOWN_ROOM_VERSIONS.values()]
)
- def test_prune_inbound_federation_queue(self, room_version: RoomVersion):
+ def test_prune_inbound_federation_queue(self, room_version: RoomVersion) -> None:
"""Test that pruning of inbound federation queues work"""
room_id = "some_room_id"
@@ -686,7 +687,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
stream_ordering += 1
- def populate_db(txn: LoggingTransaction):
+ def populate_db(txn: LoggingTransaction) -> None:
# Insert the room to satisfy the foreign key constraint of
# `event_failed_pull_attempts`
self.store.db_pool.simple_insert_txn(
@@ -760,7 +761,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
- def test_get_backfill_points_in_room(self):
+ def test_get_backfill_points_in_room(self) -> None:
"""
Test to make sure only backfill points that are older and come before
the `current_depth` are returned.
@@ -787,7 +788,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_backfill_points_in_room_excludes_events_we_have_attempted(
self,
- ):
+ ) -> None:
"""
Test to make sure that events we have attempted to backfill (and within
backoff timeout duration) do not show up as an event to backfill again.
@@ -824,7 +825,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration(
self,
- ):
+ ) -> None:
"""
Test to make sure after we fake attempt to backfill event "b3" many times,
we can see retry and see the "b3" again after the backoff timeout duration
@@ -941,7 +942,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
"5": 7,
}
- def populate_db(txn: LoggingTransaction):
+ def populate_db(txn: LoggingTransaction) -> None:
# Insert the room to satisfy the foreign key constraint of
# `event_failed_pull_attempts`
self.store.db_pool.simple_insert_txn(
@@ -996,7 +997,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
- def test_get_insertion_event_backward_extremities_in_room(self):
+ def test_get_insertion_event_backward_extremities_in_room(self) -> None:
"""
Test to make sure only insertion event backward extremities that are
older and come before the `current_depth` are returned.
@@ -1027,7 +1028,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted(
self,
- ):
+ ) -> None:
"""
Test to make sure that insertion events we have attempted to backfill
(and within backoff timeout duration) do not show up as an event to
@@ -1060,7 +1061,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration(
self,
- ):
+ ) -> None:
"""
Test to make sure after we fake attempt to backfill event
"insertion_eventA" many times, we can see retry and see the
@@ -1130,9 +1131,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
self.assertEqual(backfill_event_ids, ["insertion_eventA"])
- def test_get_event_ids_to_not_pull_from_backoff(
- self,
- ):
+ def test_get_event_ids_to_not_pull_from_backoff(self) -> None:
"""
Test to make sure only event IDs we should backoff from are returned.
"""
@@ -1157,7 +1156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration(
self,
- ):
+ ) -> None:
"""
Test to make sure no event IDs are returned after the backoff duration has
elapsed.
@@ -1187,19 +1186,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.assertEqual(event_ids_to_backoff, [])
-@attr.s
+@attr.s(auto_attribs=True)
class FakeEvent:
- event_id = attr.ib()
- room_id = attr.ib()
- auth_events = attr.ib()
+ event_id: str
+ room_id: str
+ auth_events: List[str]
type = "foo"
state_key = "foo"
internal_metadata = _EventInternalMetadata({})
- def auth_event_ids(self):
+ def auth_event_ids(self) -> List[str]:
return self.auth_events
- def is_state(self):
+ def is_state(self) -> bool:
return True
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index 6f1135eef4..a91411168c 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -20,7 +20,7 @@ from tests.unittest import HomeserverTestCase
class ExtremStatisticsTestCase(HomeserverTestCase):
- def test_exposed_to_prometheus(self):
+ def test_exposed_to_prometheus(self) -> None:
"""
Forward extremity counts are exposed via Prometheus.
"""
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 3ce4f35cb7..05661a537d 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -12,12 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Optional
+
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
from synapse.federation.federation_base import event_from_pdu_json
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.types import StateMap
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -29,7 +36,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.state = self.hs.get_state_handler()
self._persistence = self.hs.get_storage_controllers().persistence
self._state_storage_controller = self.hs.get_storage_controllers().state
@@ -67,7 +76,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check that the current extremities is the remote event.
self.assert_extremities([self.remote_event_1.event_id])
- def persist_event(self, event, state=None):
+ def persist_event(
+ self, event: EventBase, state: Optional[StateMap[str]] = None
+ ) -> None:
"""Persist the event, with optional state"""
context = self.get_success(
self.state.compute_event_context(
@@ -78,14 +89,14 @@ class ExtremPruneTestCase(HomeserverTestCase):
)
self.get_success(self._persistence.persist_event(event, context))
- def assert_extremities(self, expected_extremities):
+ def assert_extremities(self, expected_extremities: List[str]) -> None:
"""Assert the current extremities for the room"""
extremities = self.get_success(
self.store.get_prev_events_for_room(self.room_id)
)
self.assertCountEqual(extremities, expected_extremities)
- def test_prune_gap(self):
+ def test_prune_gap(self) -> None:
"""Test that we drop extremities after a gap when we see an event from
the same domain.
"""
@@ -117,7 +128,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
- def test_do_not_prune_gap_if_state_different(self):
+ def test_do_not_prune_gap_if_state_different(self) -> None:
"""Test that we don't prune extremities after a gap if the resolved
state is different.
"""
@@ -161,7 +172,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check that we haven't dropped the old extremity.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
- def test_prune_gap_if_old(self):
+ def test_prune_gap_if_old(self) -> None:
"""Test that we drop extremities after a gap when the previous extremity
is "old"
"""
@@ -197,7 +208,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
- def test_do_not_prune_gap_if_other_server(self):
+ def test_do_not_prune_gap_if_other_server(self) -> None:
"""Test that we do not drop extremities after a gap when we see an event
from a different domain.
"""
@@ -229,7 +240,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
- def test_prune_gap_if_dummy_remote(self):
+ def test_prune_gap_if_dummy_remote(self) -> None:
"""Test that we drop extremities after a gap when the previous extremity
is a local dummy event and only points to remote events.
"""
@@ -271,7 +282,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
- def test_prune_gap_if_dummy_local(self):
+ def test_prune_gap_if_dummy_local(self) -> None:
"""Test that we don't drop extremities after a gap when the previous
extremity is a local dummy event and points to local events.
"""
@@ -315,7 +326,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id, local_message_event_id])
- def test_do_not_prune_gap_if_not_dummy(self):
+ def test_do_not_prune_gap_if_not_dummy(self) -> None:
"""Test that we do not drop extremities after a gap when the previous extremity
is not a dummy event.
"""
@@ -359,12 +370,14 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.state = self.hs.get_state_handler()
self._persistence = self.hs.get_storage_controllers().persistence
self.store = self.hs.get_datastores().main
- def test_remote_user_rooms_cache_invalidated(self):
+ def test_remote_user_rooms_cache_invalidated(self) -> None:
"""Test that if the server leaves a room the `get_rooms_for_user` cache
is invalidated for remote users.
"""
@@ -411,7 +424,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
self.assertEqual(set(rooms), set())
- def test_room_remote_user_cache_invalidated(self):
+ def test_room_remote_user_cache_invalidated(self) -> None:
"""Test that if the server leaves a room the `get_users_in_room` cache
is invalidated for remote users.
"""
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index 9059095525..aa4b5bd3b1 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -13,6 +13,7 @@
# limitations under the License.
import signedjson.key
+import signedjson.types
import unpaddedbase64
from twisted.internet.defer import Deferred
@@ -22,7 +23,9 @@ from synapse.storage.keys import FetchKeyResult
import tests.unittest
-def decode_verify_key_base64(key_id: str, key_base64: str):
+def decode_verify_key_base64(
+ key_id: str, key_base64: str
+) -> signedjson.types.VerifyKey:
key_bytes = unpaddedbase64.decode_base64(key_base64)
return signedjson.key.decode_verify_key_bytes(key_id, key_bytes)
@@ -36,7 +39,7 @@ KEY_2 = decode_verify_key_base64(
class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
- def test_get_server_verify_keys(self):
+ def test_get_server_verify_keys(self) -> None:
store = self.hs.get_datastores().main
key_id_1 = "ed25519:key1"
@@ -71,7 +74,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
# non-existent result gives None
self.assertIsNone(res[("server1", "ed25519:key3")])
- def test_cache(self):
+ def test_cache(self) -> None:
"""Check that updates correctly invalidate the cache."""
store = self.hs.get_datastores().main
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index c55c4db970..2827738379 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -53,7 +53,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.reactor.advance(FORTY_DAYS)
@override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)})
- def test_initialise_reserved_users(self):
+ def test_initialise_reserved_users(self) -> None:
threepids = self.hs.config.server.mau_limits_reserved_threepids
# register three users, of which two have reserved 3pids, and a third
@@ -133,7 +133,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
active_count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(active_count, 3)
- def test_can_insert_and_count_mau(self):
+ def test_can_insert_and_count_mau(self) -> None:
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, 0)
@@ -143,7 +143,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, 1)
- def test_appservice_user_not_counted_in_mau(self):
+ def test_appservice_user_not_counted_in_mau(self) -> None:
self.get_success(
self.store.register_user(
user_id="@appservice_user:server", appservice_id="wibble"
@@ -158,7 +158,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, 0)
- def test_user_last_seen_monthly_active(self):
+ def test_user_last_seen_monthly_active(self) -> None:
user_id1 = "@user1:server"
user_id2 = "@user2:server"
user_id3 = "@user3:server"
@@ -177,7 +177,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.assertIsNone(result)
@override_config({"max_mau_value": 5})
- def test_reap_monthly_active_users(self):
+ def test_reap_monthly_active_users(self) -> None:
initial_users = 10
for i in range(initial_users):
self.get_success(
@@ -204,7 +204,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
# Note that below says mau_limit (no s), this is the name of the config
# value, although it gets stored on the config object as mau_limits.
@override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)})
- def test_reap_monthly_active_users_reserved_users(self):
+ def test_reap_monthly_active_users_reserved_users(self) -> None:
"""Tests that reaping correctly handles reaping where reserved users are
present"""
threepids = self.hs.config.server.mau_limits_reserved_threepids
@@ -244,7 +244,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, self.hs.config.server.max_mau_value)
- def test_populate_monthly_users_is_guest(self):
+ def test_populate_monthly_users_is_guest(self) -> None:
# Test that guest users are not added to mau list
user_id = "@user_id:host"
@@ -260,7 +260,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called()
- def test_populate_monthly_users_should_update(self):
+ def test_populate_monthly_users_should_update(self) -> None:
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
@@ -273,7 +273,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_called_once()
- def test_populate_monthly_users_should_not_update(self):
+ def test_populate_monthly_users_should_not_update(self) -> None:
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
@@ -286,7 +286,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called()
- def test_get_reserved_real_user_account(self):
+ def test_get_reserved_real_user_account(self) -> None:
# Test no reserved users, or reserved threepids
users = self.get_success(self.store.get_registered_reserved_users())
self.assertEqual(len(users), 0)
@@ -326,7 +326,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
users = self.get_success(self.store.get_registered_reserved_users())
self.assertEqual(len(users), len(threepids))
- def test_support_user_not_add_to_mau_limits(self):
+ def test_support_user_not_add_to_mau_limits(self) -> None:
support_user_id = "@support:test"
count = self.get_success(self.store.get_monthly_active_count())
@@ -347,7 +347,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
@override_config(
{"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1}
)
- def test_track_monthly_users_without_cap(self):
+ def test_track_monthly_users_without_cap(self) -> None:
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(0, count)
@@ -358,14 +358,14 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.assertEqual(2, count)
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
- def test_no_users_when_not_tracking(self):
+ def test_no_users_when_not_tracking(self) -> None:
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
self.get_success(self.store.populate_monthly_active_users("@user:sever"))
self.store.upsert_monthly_active_user.assert_not_called()
- def test_get_monthly_active_count_by_service(self):
+ def test_get_monthly_active_count_by_service(self) -> None:
appservice1_user1 = "@appservice1_user1:example.com"
appservice1_user2 = "@appservice1_user2:example.com"
@@ -413,7 +413,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.assertEqual(result[service2], 1)
self.assertEqual(result[native], 1)
- def test_get_monthly_active_users_by_service(self):
+ def test_get_monthly_active_users_by_service(self) -> None:
# (No users, no filtering) -> empty result
result = self.get_success(self.store.get_monthly_active_users_by_service())
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 9c1182ed16..010cc74c31 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.errors import NotFoundError, SynapseError
from synapse.rest.client import room
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -23,17 +27,17 @@ class PurgeTests(HomeserverTestCase):
user_id = "@red:server"
servlets = [room.register_servlets]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
return hs
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id)
self.store = hs.get_datastores().main
self._storage_controllers = self.hs.get_storage_controllers()
- def test_purge_history(self):
+ def test_purge_history(self) -> None:
"""
Purging a room history will delete everything before the topological point.
"""
@@ -63,7 +67,7 @@ class PurgeTests(HomeserverTestCase):
self.get_failure(self.store.get_event(third["event_id"]), NotFoundError)
self.get_success(self.store.get_event(last["event_id"]))
- def test_purge_history_wont_delete_extrems(self):
+ def test_purge_history_wont_delete_extrems(self) -> None:
"""
Purging a room history will delete everything before the topological point.
"""
@@ -77,6 +81,7 @@ class PurgeTests(HomeserverTestCase):
token = self.get_success(
self.store.get_topological_token_for_event(last["event_id"])
)
+ assert token.topological is not None
event = f"t{token.topological + 1}-{token.stream + 1}"
# Purge everything before this topological token
@@ -94,7 +99,7 @@ class PurgeTests(HomeserverTestCase):
self.get_success(self.store.get_event(third["event_id"]))
self.get_success(self.store.get_event(last["event_id"]))
- def test_purge_room(self):
+ def test_purge_room(self) -> None:
"""
Purging a room will delete everything about it.
"""
diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py
index 81253d0361..d8d84152dc 100644
--- a/tests/storage/test_receipts.py
+++ b/tests/storage/test_receipts.py
@@ -14,8 +14,12 @@
from typing import Collection, Optional
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import ReceiptTypes
+from synapse.server import HomeServer
from synapse.types import UserID, create_requester
+from synapse.util import Clock
from tests.test_utils.event_injection import create_event
from tests.unittest import HomeserverTestCase
@@ -25,7 +29,9 @@ OUR_USER_ID = "@our:test"
class ReceiptTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, homeserver) -> None:
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
super().prepare(reactor, clock, homeserver)
self.store = homeserver.get_datastores().main
@@ -135,11 +141,11 @@ class ReceiptTestCase(HomeserverTestCase):
)
self.assertEqual(res, {})
- res = self.get_last_unthreaded_receipt(
+ res2 = self.get_last_unthreaded_receipt(
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
)
- self.assertEqual(res, None)
+ self.assertIsNone(res2)
def test_get_receipts_for_user(self) -> None:
# Send some events into the first room
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 6c4e63b77c..df4740f9d9 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -11,27 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Optional
+from typing import List, Optional, cast
from canonicaljson import json
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
-from synapse.types import RoomID, UserID
+from synapse.events import EventBase, _EventInternalMetadata
+from synapse.events.builder import EventBuilder
+from synapse.server import HomeServer
+from synapse.types import JsonDict, RoomID, UserID
+from synapse.util import Clock
from tests import unittest
from tests.utils import create_room
class RedactionTestCase(unittest.HomeserverTestCase):
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = super().default_config()
config["redaction_retention_period"] = "30d"
return config
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- self._storage = hs.get_storage_controllers()
+ storage = hs.get_storage_controllers()
+ assert storage.persistence is not None
+ self._persistence = storage.persistence
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -46,14 +54,13 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.depth = 1
- def inject_room_member(
+ def inject_room_member( # type: ignore[override]
self,
- room,
- user,
- membership,
- replaces_state=None,
- extra_content: Optional[dict] = None,
- ):
+ room: RoomID,
+ user: UserID,
+ membership: str,
+ extra_content: Optional[JsonDict] = None,
+ ) -> EventBase:
content = {"membership": membership}
content.update(extra_content or {})
builder = self.event_builder_factory.for_room_version(
@@ -71,11 +78,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self._storage.persistence.persist_event(event, context))
+ self.get_success(self._persistence.persist_event(event, context))
return event
- def inject_message(self, room, user, body):
+ def inject_message(self, room: RoomID, user: UserID, body: str) -> EventBase:
self.depth += 1
builder = self.event_builder_factory.for_room_version(
@@ -93,11 +100,13 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self._storage.persistence.persist_event(event, context))
+ self.get_success(self._persistence.persist_event(event, context))
return event
- def inject_redaction(self, room, event_id, user, reason):
+ def inject_redaction(
+ self, room: RoomID, event_id: str, user: UserID, reason: str
+ ) -> EventBase:
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
@@ -114,11 +123,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self._storage.persistence.persist_event(event, context))
+ self.get_success(self._persistence.persist_event(event, context))
return event
- def test_redact(self):
+ def test_redact(self) -> None:
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
msg_event = self.inject_message(self.room1, self.u_alice, "t")
@@ -165,7 +174,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
event.unsigned["redacted_because"],
)
- def test_redact_join(self):
+ def test_redact_join(self) -> None:
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
msg_event = self.inject_room_member(
@@ -213,12 +222,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
event.unsigned["redacted_because"],
)
- def test_circular_redaction(self):
+ def test_circular_redaction(self) -> None:
redaction_event_id1 = "$redaction1_id:test"
redaction_event_id2 = "$redaction2_id:test"
class EventIdManglingBuilder:
- def __init__(self, base_builder, event_id):
+ def __init__(self, base_builder: EventBuilder, event_id: str):
self._base_builder = base_builder
self._event_id = event_id
@@ -227,67 +236,73 @@ class RedactionTestCase(unittest.HomeserverTestCase):
prev_event_ids: List[str],
auth_event_ids: Optional[List[str]],
depth: Optional[int] = None,
- ):
+ ) -> EventBase:
built_event = await self._base_builder.build(
prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids
)
- built_event._event_id = self._event_id
+ built_event._event_id = self._event_id # type: ignore[attr-defined]
built_event._dict["event_id"] = self._event_id
assert built_event.event_id == self._event_id
return built_event
@property
- def room_id(self):
+ def room_id(self) -> str:
return self._base_builder.room_id
@property
- def type(self):
+ def type(self) -> str:
return self._base_builder.type
@property
- def internal_metadata(self):
+ def internal_metadata(self) -> _EventInternalMetadata:
return self._base_builder.internal_metadata
event_1, context_1 = self.get_success(
self.event_creation_handler.create_new_client_event(
- EventIdManglingBuilder(
- self.event_builder_factory.for_room_version(
- RoomVersions.V1,
- {
- "type": EventTypes.Redaction,
- "sender": self.u_alice.to_string(),
- "room_id": self.room1.to_string(),
- "content": {"reason": "test"},
- "redacts": redaction_event_id2,
- },
+ cast(
+ EventBuilder,
+ EventIdManglingBuilder(
+ self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.Redaction,
+ "sender": self.u_alice.to_string(),
+ "room_id": self.room1.to_string(),
+ "content": {"reason": "test"},
+ "redacts": redaction_event_id2,
+ },
+ ),
+ redaction_event_id1,
),
- redaction_event_id1,
)
)
)
- self.get_success(self._storage.persistence.persist_event(event_1, context_1))
+ self.get_success(self._persistence.persist_event(event_1, context_1))
event_2, context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
- EventIdManglingBuilder(
- self.event_builder_factory.for_room_version(
- RoomVersions.V1,
- {
- "type": EventTypes.Redaction,
- "sender": self.u_alice.to_string(),
- "room_id": self.room1.to_string(),
- "content": {"reason": "test"},
- "redacts": redaction_event_id1,
- },
+ cast(
+ EventBuilder,
+ EventIdManglingBuilder(
+ self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.Redaction,
+ "sender": self.u_alice.to_string(),
+ "room_id": self.room1.to_string(),
+ "content": {"reason": "test"},
+ "redacts": redaction_event_id1,
+ },
+ ),
+ redaction_event_id2,
),
- redaction_event_id2,
)
)
)
- self.get_success(self._storage.persistence.persist_event(event_2, context_2))
+ self.get_success(self._persistence.persist_event(event_2, context_2))
# fetch one of the redactions
fetched = self.get_success(self.store.get_event(redaction_event_id1))
@@ -298,7 +313,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
fetched.unsigned["redacted_because"].event_id, redaction_event_id2
)
- def test_redact_censor(self):
+ def test_redact_censor(self) -> None:
"""Test that a redacted event gets censored in the DB after a month"""
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -364,7 +379,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.assert_dict({"content": {}}, json.loads(event_json))
- def test_redact_redaction(self):
+ def test_redact_redaction(self) -> None:
"""Tests that we can redact a redaction and can fetch it again."""
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -391,7 +406,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.store.get_event(first_redact_event.event_id, allow_none=True)
)
- def test_store_redacted_redaction(self):
+ def test_store_redacted_redaction(self) -> None:
"""Tests that we can store a redacted redaction."""
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -410,9 +425,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(
- self._storage.persistence.persist_event(redaction_event, context)
- )
+ self.get_success(self._persistence.persist_event(redaction_event, context))
# Now lets jump to the future where we have censored the redaction event
# in the DB.
diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py
index 0baa54312e..966aafea6f 100644
--- a/tests/storage/test_rollback_worker.py
+++ b/tests/storage/test_rollback_worker.py
@@ -14,10 +14,15 @@
from typing import List
from unittest import mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.app.generic_worker import GenericWorkerServer
+from synapse.server import HomeServer
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.prepare_database import PrepareDatabaseException, prepare_database
from synapse.storage.schema import SCHEMA_VERSION
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -39,13 +44,13 @@ def fake_listdir(filepath: str) -> List[str]:
class WorkerSchemaTests(HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(
federation_http_client=None, homeserver_to_use=GenericWorkerServer
)
return hs
- def default_config(self):
+ def default_config(self) -> JsonDict:
conf = super().default_config()
# Mark this as a worker app.
@@ -53,7 +58,7 @@ class WorkerSchemaTests(HomeserverTestCase):
return conf
- def test_rolling_back(self):
+ def test_rolling_back(self) -> None:
"""Test that workers can start if the DB is a newer schema version"""
db_pool = self.hs.get_datastores().main.db_pool
@@ -70,7 +75,7 @@ class WorkerSchemaTests(HomeserverTestCase):
prepare_database(db_conn, db_pool.engine, self.hs.config)
- def test_not_upgraded_old_schema_version(self):
+ def test_not_upgraded_old_schema_version(self) -> None:
"""Test that workers don't start if the DB has an older schema version"""
db_pool = self.hs.get_datastores().main.db_pool
db_conn = LoggingDatabaseConnection(
@@ -87,7 +92,7 @@ class WorkerSchemaTests(HomeserverTestCase):
with self.assertRaises(PrepareDatabaseException):
prepare_database(db_conn, db_pool.engine, self.hs.config)
- def test_not_upgraded_current_schema_version_with_outstanding_deltas(self):
+ def test_not_upgraded_current_schema_version_with_outstanding_deltas(self) -> None:
"""
Test that workers don't start if the DB is on the current schema version,
but there are still outstanding delta migrations to run.
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 3405efb6a8..71ec74eadc 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -12,14 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.room_versions import RoomVersions
+from synapse.server import HomeServer
from synapse.types import RoomAlias, RoomID, UserID
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class RoomStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# We can't test RoomStore on its own without the DirectoryStore, for
# management of the 'room_aliases' table
self.store = hs.get_datastores().main
@@ -37,30 +41,34 @@ class RoomStoreTestCase(HomeserverTestCase):
)
)
- def test_get_room(self):
+ def test_get_room(self) -> None:
+ res = self.get_success(self.store.get_room(self.room.to_string()))
+ assert res is not None
self.assertDictContainsSubset(
{
"room_id": self.room.to_string(),
"creator": self.u_creator.to_string(),
"is_public": True,
},
- (self.get_success(self.store.get_room(self.room.to_string()))),
+ res,
)
- def test_get_room_unknown_room(self):
+ def test_get_room_unknown_room(self) -> None:
self.assertIsNone(self.get_success(self.store.get_room("!uknown:test")))
- def test_get_room_with_stats(self):
+ def test_get_room_with_stats(self) -> None:
+ res = self.get_success(self.store.get_room_with_stats(self.room.to_string()))
+ assert res is not None
self.assertDictContainsSubset(
{
"room_id": self.room.to_string(),
"creator": self.u_creator.to_string(),
"public": True,
},
- (self.get_success(self.store.get_room_with_stats(self.room.to_string()))),
+ res,
)
- def test_get_room_with_stats_unknown_room(self):
+ def test_get_room_with_stats_unknown_room(self) -> None:
self.assertIsNone(
- (self.get_success(self.store.get_room_with_stats("!uknown:test"))),
+ self.get_success(self.store.get_room_with_stats("!uknown:test"))
)
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index ef850daa73..14d872514d 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -39,7 +39,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
room.register_servlets,
]
- def test_null_byte(self):
+ def test_null_byte(self) -> None:
"""
Postgres/SQLite don't like null bytes going into the search tables. Internally
we replace those with a space.
@@ -86,7 +86,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
if isinstance(store.database_engine, PostgresEngine):
self.assertIn("alice", result.get("highlights"))
- def test_non_string(self):
+ def test_non_string(self) -> None:
"""Test that non-string `value`s are not inserted into `event_search`.
This is particularly important when using sqlite, since a sqlite column can hold
@@ -157,7 +157,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
self.assertEqual(f.value.code, 404)
@skip_unless(not USE_POSTGRES_FOR_TESTS, "requires sqlite")
- def test_sqlite_non_string_deletion_background_update(self):
+ def test_sqlite_non_string_deletion_background_update(self) -> None:
"""Test the background update to delete bad rows from `event_search`."""
store = self.hs.get_datastores().main
@@ -350,7 +350,7 @@ class MessageSearchTest(HomeserverTestCase):
"results array length should match count",
)
- def test_postgres_web_search_for_phrase(self):
+ def test_postgres_web_search_for_phrase(self) -> None:
"""
Test searching for phrases using typical web search syntax, as per postgres' websearch_to_tsquery.
This test is skipped unless the postgres instance supports websearch_to_tsquery.
@@ -364,7 +364,7 @@ class MessageSearchTest(HomeserverTestCase):
self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES)
- def test_sqlite_search(self):
+ def test_sqlite_search(self) -> None:
"""
Test sqlite searching for phrases.
"""
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 5564161750..d4e6d4236c 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -16,10 +16,15 @@ import logging
from frozendict import frozendict
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
+from synapse.server import HomeServer
from synapse.storage.state import StateFilter
-from synapse.types import RoomID, UserID
+from synapse.types import JsonDict, RoomID, StateMap, UserID
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase, TestCase
@@ -27,7 +32,7 @@ logger = logging.getLogger(__name__)
class StateStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.storage = hs.get_storage_controllers()
self.state_datastore = self.storage.state.stores.state
@@ -48,7 +53,9 @@ class StateStoreTestCase(HomeserverTestCase):
)
)
- def inject_state_event(self, room, sender, typ, state_key, content):
+ def inject_state_event(
+ self, room: RoomID, sender: UserID, typ: str, state_key: str, content: JsonDict
+ ) -> EventBase:
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
@@ -64,24 +71,29 @@ class StateStoreTestCase(HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
+ assert self.storage.persistence is not None
self.get_success(self.storage.persistence.persist_event(event, context))
return event
- def assertStateMapEqual(self, s1, s2):
+ def assertStateMapEqual(
+ self, s1: StateMap[EventBase], s2: StateMap[EventBase]
+ ) -> None:
for t in s1:
# just compare event IDs for simplicity
self.assertEqual(s1[t].event_id, s2[t].event_id)
self.assertEqual(len(s1), len(s2))
- def test_get_state_groups_ids(self):
+ def test_get_state_groups_ids(self) -> None:
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
state_group_map = self.get_success(
- self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
+ self.storage.state.get_state_groups_ids(
+ self.room.to_string(), [e2.event_id]
+ )
)
self.assertEqual(len(state_group_map), 1)
state_map = list(state_group_map.values())[0]
@@ -90,21 +102,21 @@ class StateStoreTestCase(HomeserverTestCase):
{(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
)
- def test_get_state_groups(self):
+ def test_get_state_groups(self) -> None:
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
state_group_map = self.get_success(
- self.storage.state.get_state_groups(self.room, [e2.event_id])
+ self.storage.state.get_state_groups(self.room.to_string(), [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0]
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
- def test_get_state_for_event(self):
+ def test_get_state_for_event(self) -> None:
# this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room.
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
@@ -487,14 +499,16 @@ class StateStoreTestCase(HomeserverTestCase):
class StateFilterDifferenceTestCase(TestCase):
def assert_difference(
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
- ):
+ ) -> None:
self.assertEqual(
minuend.approx_difference(subtrahend),
expected,
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
)
- def test_state_filter_difference_no_include_other_minus_no_include_other(self):
+ def test_state_filter_difference_no_include_other_minus_no_include_other(
+ self,
+ ) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b do not have the
@@ -610,7 +624,7 @@ class StateFilterDifferenceTestCase(TestCase):
),
)
- def test_state_filter_difference_include_other_minus_no_include_other(self):
+ def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only a has the include_others flag set.
@@ -739,7 +753,7 @@ class StateFilterDifferenceTestCase(TestCase):
),
)
- def test_state_filter_difference_include_other_minus_include_other(self):
+ def test_state_filter_difference_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b have the include_others
@@ -864,7 +878,7 @@ class StateFilterDifferenceTestCase(TestCase):
),
)
- def test_state_filter_difference_no_include_other_minus_include_other(self):
+ def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only b has the include_others flag set.
@@ -979,7 +993,7 @@ class StateFilterDifferenceTestCase(TestCase):
),
)
- def test_state_filter_difference_simple_cases(self):
+ def test_state_filter_difference_simple_cases(self) -> None:
"""
Tests some very simple cases of the StateFilter approx_difference,
that are not explicitly tested by the more in-depth tests.
@@ -995,7 +1009,7 @@ class StateFilterDifferenceTestCase(TestCase):
class StateFilterTestCase(TestCase):
- def test_return_expanded(self):
+ def test_return_expanded(self) -> None:
"""
Tests the behaviour of the return_expanded() function that expands
StateFilters to include more state types (for the sake of cache hit rate).
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index 34fa810cf6..bc090ebce0 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -14,11 +14,15 @@
from typing import List
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.filtering import Filter
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.types import JsonDict
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -37,12 +41,14 @@ class PaginationTestCase(HomeserverTestCase):
login.register_servlets,
]
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = {"msc3874_enabled": True}
return config
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.user_id = self.register_user("test", "test")
self.tok = self.login("test", "test")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
@@ -130,7 +136,7 @@ class PaginationTestCase(HomeserverTestCase):
return [ev.event_id for ev in events]
- def test_filter_relation_senders(self):
+ def test_filter_relation_senders(self) -> None:
# Messages which second user reacted to.
filter = {"related_by_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
@@ -146,7 +152,7 @@ class PaginationTestCase(HomeserverTestCase):
chunk = self._filter_messages(filter)
self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2])
- def test_filter_relation_type(self):
+ def test_filter_relation_type(self) -> None:
# Messages which have annotations.
filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
chunk = self._filter_messages(filter)
@@ -167,7 +173,7 @@ class PaginationTestCase(HomeserverTestCase):
chunk = self._filter_messages(filter)
self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2])
- def test_filter_relation_senders_and_type(self):
+ def test_filter_relation_senders_and_type(self) -> None:
# Messages which second user reacted to.
filter = {
"related_by_senders": [self.second_user_id],
@@ -176,7 +182,7 @@ class PaginationTestCase(HomeserverTestCase):
chunk = self._filter_messages(filter)
self.assertEqual(chunk, [self.event_id_1])
- def test_duplicate_relation(self):
+ def test_duplicate_relation(self) -> None:
"""An event should only be returned once if there are multiple relations to it."""
self.helper.send_event(
room_id=self.room_id,
diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py
index e05daa285e..db9ee9955e 100644
--- a/tests/storage/test_transactions.py
+++ b/tests/storage/test_transactions.py
@@ -12,17 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
from synapse.storage.databases.main.transactions import DestinationRetryTimings
+from synapse.util import Clock
from synapse.util.retryutils import MAX_RETRY_INTERVAL
from tests.unittest import HomeserverTestCase
class TransactionStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.store = homeserver.get_datastores().main
- def test_get_set_transactions(self):
+ def test_get_set_transactions(self) -> None:
"""Tests that we can successfully get a non-existent entry for
destination retries, as well as testing tht we can set and get
correctly.
@@ -44,18 +50,18 @@ class TransactionStoreTestCase(HomeserverTestCase):
r,
)
- def test_initial_set_transactions(self):
+ def test_initial_set_transactions(self) -> None:
"""Tests that we can successfully set the destination retries (there
was a bug around invalidating the cache that broke this)
"""
d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
self.get_success(d)
- def test_large_destination_retry(self):
+ def test_large_destination_retry(self) -> None:
d = self.store.set_destination_retry_timings(
"example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL
)
self.get_success(d)
- d = self.store.get_destination_retry_timings("example.com")
- self.get_success(d)
+ d2 = self.store.get_destination_retry_timings("example.com")
+ self.get_success(d2)
diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py
index ace82cbf42..15ea4770bd 100644
--- a/tests/storage/test_txn_limit.py
+++ b/tests/storage/test_txn_limit.py
@@ -12,21 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.storage.types import Cursor
+from synapse.util import Clock
+
from tests import unittest
class SQLTransactionLimitTestCase(unittest.HomeserverTestCase):
"""Test SQL transaction limit doesn't break transactions."""
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(db_txn_limit=1000)
- def test_config(self):
+ def test_config(self) -> None:
db_config = self.hs.config.database.get_single_database()
self.assertEqual(db_config.config["txn_limit"], 1000)
- def test_select(self):
- def do_select(txn):
+ def test_select(self) -> None:
+ def do_select(txn: Cursor) -> None:
txn.execute("SELECT 1")
db_pool = self.hs.get_datastores().databases[0]
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 88c7d5fec0..3ba896ecf3 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import re
from typing import Any, Dict, Set, Tuple
from unittest import mock
from unittest.mock import Mock, patch
@@ -30,6 +31,12 @@ from synapse.util import Clock
from tests.test_utils.event_injection import inject_member_event
from tests.unittest import HomeserverTestCase, override_config
+try:
+ import icu
+except ImportError:
+ icu = None # type: ignore
+
+
ALICE = "@alice:a"
BOB = "@bob:b"
BOBBY = "@bobby:a"
@@ -467,3 +474,39 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
r["results"][0],
{"user_id": BELA, "display_name": "Bela", "avatar_url": None},
)
+
+
+class UserDirectoryICUTestCase(HomeserverTestCase):
+ if not icu:
+ skip = "Requires PyICU"
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.user_dir_helper = GetUserDirectoryTables(self.store)
+
+ def test_icu_word_boundary(self) -> None:
+ """Tests that we correctly detect word boundaries when ICU (International
+ Components for Unicode) support is available.
+ """
+
+ display_name = "Gáo"
+
+ # This word is not broken down correctly by Python's regular expressions,
+ # likely because á is actually a lowercase a followed by a U+0301 combining
+ # acute accent. This is specifically something that ICU support fixes.
+ matches = re.findall(r"([\w\-]+)", display_name, re.UNICODE)
+ self.assertEqual(len(matches), 2)
+
+ self.get_success(
+ self.store.update_profile_in_user_dir(ALICE, display_name, None)
+ )
+ self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE,)))
+
+ # Check that searching for this user yields the correct result.
+ r = self.get_success(self.store.search_user_dir(BOB, display_name, 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(len(r["results"]), 1)
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": ALICE, "display_name": display_name, "avatar_url": None},
+ )
diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py
index cae14151c0..0e3fc2a77f 100644
--- a/tests/storage/util/test_partial_state_events_tracker.py
+++ b/tests/storage/util/test_partial_state_events_tracker.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict
+from typing import Collection, Dict
from unittest import mock
from twisted.internet.defer import CancelledError, ensureDeferred
@@ -31,7 +31,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
# the results to be returned by the mocked get_partial_state_events
self._events_dict: Dict[str, bool] = {}
- async def get_partial_state_events(events):
+ async def get_partial_state_events(events: Collection[str]) -> Dict[str, bool]:
return {e: self._events_dict[e] for e in events}
self.mock_store = mock.Mock(spec_set=["get_partial_state_events"])
@@ -39,7 +39,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
self.tracker = PartialStateEventsTracker(self.mock_store)
- def test_does_not_block_for_full_state_events(self):
+ def test_does_not_block_for_full_state_events(self) -> None:
self._events_dict = {"event1": False, "event2": False}
self.successResultOf(
@@ -50,7 +50,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
["event1", "event2"]
)
- def test_blocks_for_partial_state_events(self):
+ def test_blocks_for_partial_state_events(self) -> None:
self._events_dict = {"event1": True, "event2": False}
d = ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
@@ -62,12 +62,12 @@ class PartialStateEventsTrackerTestCase(TestCase):
self.tracker.notify_un_partial_stated("event1")
self.successResultOf(d)
- def test_un_partial_state_race(self):
+ def test_un_partial_state_race(self) -> None:
# if the event is un-partial-stated between the initial check and the
# registration of the listener, it should not block.
self._events_dict = {"event1": True, "event2": False}
- async def get_partial_state_events(events):
+ async def get_partial_state_events(events: Collection[str]) -> Dict[str, bool]:
res = {e: self._events_dict[e] for e in events}
# change the result for next time
self._events_dict = {"event1": False, "event2": False}
@@ -79,19 +79,19 @@ class PartialStateEventsTrackerTestCase(TestCase):
ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
)
- def test_un_partial_state_during_get_partial_state_events(self):
+ def test_un_partial_state_during_get_partial_state_events(self) -> None:
# we should correctly handle a call to notify_un_partial_stated during the
# second call to get_partial_state_events.
self._events_dict = {"event1": True, "event2": False}
- async def get_partial_state_events1(events):
+ async def get_partial_state_events1(events: Collection[str]) -> Dict[str, bool]:
self.mock_store.get_partial_state_events.side_effect = (
get_partial_state_events2
)
return {e: self._events_dict[e] for e in events}
- async def get_partial_state_events2(events):
+ async def get_partial_state_events2(events: Collection[str]) -> Dict[str, bool]:
self.tracker.notify_un_partial_stated("event1")
self._events_dict["event1"] = False
return {e: self._events_dict[e] for e in events}
@@ -102,7 +102,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
)
- def test_cancellation(self):
+ def test_cancellation(self) -> None:
self._events_dict = {"event1": True, "event2": False}
d1 = ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
@@ -127,12 +127,12 @@ class PartialCurrentStateTrackerTestCase(TestCase):
self.tracker = PartialCurrentStateTracker(self.mock_store)
- def test_does_not_block_for_full_state_rooms(self):
+ def test_does_not_block_for_full_state_rooms(self) -> None:
self.mock_store.is_partial_state_room.return_value = make_awaitable(False)
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
- def test_blocks_for_partial_room_state(self):
+ def test_blocks_for_partial_room_state(self) -> None:
self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
d = ensureDeferred(self.tracker.await_full_state("room_id"))
@@ -144,10 +144,10 @@ class PartialCurrentStateTrackerTestCase(TestCase):
self.tracker.notify_un_partial_stated("room_id")
self.successResultOf(d)
- def test_un_partial_state_race(self):
+ def test_un_partial_state_race(self) -> None:
# We should correctly handle race between awaiting the state and us
# un-partialling the state
- async def is_partial_state_room(events):
+ async def is_partial_state_room(room_id: str) -> bool:
self.tracker.notify_un_partial_stated("room_id")
return True
@@ -155,7 +155,7 @@ class PartialCurrentStateTrackerTestCase(TestCase):
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
- def test_cancellation(self):
+ def test_cancellation(self) -> None:
self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
d1 = ensureDeferred(self.tracker.await_full_state("room_id"))
|