diff --git a/CHANGES.md b/CHANGES.md
index 89fee07db9..9264614f39 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,10 +1,24 @@
+For the next release
+====================
+
+Removal warning
+---------------
+
+Some older clients used a
+[disallowed character](https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-register-email-requesttoken)
+(`:`) in the `client_secret` parameter of various endpoints. The incorrect
+behaviour was allowed for backwards compatibility, but is now being removed
+from Synapse as most users have updated their client. Further context can be
+found at [\#6766](https://github.com/matrix-org/synapse/issues/6766).
+
+
Synapse 1.19.1rc1 (2020-08-25)
==============================
Bugfixes
--------
-- Fixes a bug where appservices with ratelimiting disabled would still be ratelimited when joining rooms. This bug was introduced in v1.19.0. ([\#8139](https://github.com/matrix-org/synapse/issues/8139))
+- Fix a bug introduced in v1.19.0 where appservices with ratelimiting disabled would still be ratelimited when joining rooms. ([\#8139](https://github.com/matrix-org/synapse/issues/8139))
- Fix a bug introduced in v1.19.0 that would cause e.g. profile updates to fail due to incorrect application of rate limits on join requests. ([\#8153](https://github.com/matrix-org/synapse/issues/8153))
diff --git a/changelog.d/7377.misc b/changelog.d/7377.misc
new file mode 100644
index 0000000000..b3ec08855b
--- /dev/null
+++ b/changelog.d/7377.misc
@@ -0,0 +1 @@
+Add filter `name` to the `/users` admin API, which filters by user ID or displayname. Contributed by Awesome Technologies Innovationslabor GmbH.
diff --git a/changelog.d/7991.misc b/changelog.d/7991.misc
new file mode 100644
index 0000000000..1562e3af9e
--- /dev/null
+++ b/changelog.d/7991.misc
@@ -0,0 +1 @@
+Don't fail `/submit_token` requests on incorrect session ID if `request_token_inhibit_3pid_errors` is turned on.
diff --git a/changelog.d/8034.feature b/changelog.d/8034.feature
new file mode 100644
index 0000000000..813e6d0903
--- /dev/null
+++ b/changelog.d/8034.feature
@@ -0,0 +1 @@
+Add support for shadow-banning users (ignoring any message send requests).
diff --git a/changelog.d/8095.feature b/changelog.d/8095.feature
new file mode 100644
index 0000000000..813e6d0903
--- /dev/null
+++ b/changelog.d/8095.feature
@@ -0,0 +1 @@
+Add support for shadow-banning users (ignoring any message send requests).
diff --git a/changelog.d/8104.bugfix b/changelog.d/8104.bugfix
new file mode 100644
index 0000000000..e32e2996c4
--- /dev/null
+++ b/changelog.d/8104.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in v1.7.2 impacting message retention policies that would allow federated homeservers to dictate a retention period that's lower than the configured minimum allowed duration in the configuration file.
diff --git a/changelog.d/8110.bugfix b/changelog.d/8110.bugfix
new file mode 100644
index 0000000000..5269a232e1
--- /dev/null
+++ b/changelog.d/8110.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in Synapse 1.12.0 which could cause `/sync` requests to fail with a 404 if you had a very old outstanding room invite.
diff --git a/changelog.d/8124.misc b/changelog.d/8124.misc
new file mode 100644
index 0000000000..9fac710205
--- /dev/null
+++ b/changelog.d/8124.misc
@@ -0,0 +1 @@
+Reduce the amount of whitespace in JSON stored and sent in responses.
diff --git a/changelog.d/8127.misc b/changelog.d/8127.misc
new file mode 100644
index 0000000000..cb557122aa
--- /dev/null
+++ b/changelog.d/8127.misc
@@ -0,0 +1 @@
+Add type hints to `synapse.storage.database`.
diff --git a/changelog.d/8132.misc b/changelog.d/8132.misc
new file mode 100644
index 0000000000..7afa267c69
--- /dev/null
+++ b/changelog.d/8132.misc
@@ -0,0 +1 @@
+Micro-optimisations to get_auth_chain_ids.
diff --git a/changelog.d/8135.bugfix b/changelog.d/8135.bugfix
new file mode 100644
index 0000000000..9d5c60ea00
--- /dev/null
+++ b/changelog.d/8135.bugfix
@@ -0,0 +1 @@
+Clarify the error code if a user tries to register with a numeric ID. This bug was introduced in v1.15.0.
diff --git a/changelog.d/8139.bugfix b/changelog.d/8139.bugfix
new file mode 100644
index 0000000000..21f65d87b7
--- /dev/null
+++ b/changelog.d/8139.bugfix
@@ -0,0 +1 @@
+Fixes a bug where appservices with ratelimiting disabled would still be ratelimited when joining rooms. This bug was introduced in v1.19.0.
diff --git a/changelog.d/8140.misc b/changelog.d/8140.misc
new file mode 100644
index 0000000000..78d8834328
--- /dev/null
+++ b/changelog.d/8140.misc
@@ -0,0 +1 @@
+Add type hints to `synapse.state`.
diff --git a/changelog.d/8142.feature b/changelog.d/8142.feature
new file mode 100644
index 0000000000..813e6d0903
--- /dev/null
+++ b/changelog.d/8142.feature
@@ -0,0 +1 @@
+Add support for shadow-banning users (ignoring any message send requests).
diff --git a/changelog.d/8147.docker b/changelog.d/8147.docker
new file mode 100644
index 0000000000..dcc951d8f5
--- /dev/null
+++ b/changelog.d/8147.docker
@@ -0,0 +1 @@
+Added curl for healthcheck support and readme updates for the change. Contributed by @maquis196.
diff --git a/changelog.d/8152.feature b/changelog.d/8152.feature
new file mode 100644
index 0000000000..813e6d0903
--- /dev/null
+++ b/changelog.d/8152.feature
@@ -0,0 +1 @@
+Add support for shadow-banning users (ignoring any message send requests).
diff --git a/changelog.d/8158.feature b/changelog.d/8158.feature
new file mode 100644
index 0000000000..47c4c39167
--- /dev/null
+++ b/changelog.d/8158.feature
@@ -0,0 +1 @@
+ Add support for shadow-banning users (ignoring any message send requests).
diff --git a/changelog.d/8161.misc b/changelog.d/8161.misc
new file mode 100644
index 0000000000..89ff274de3
--- /dev/null
+++ b/changelog.d/8161.misc
@@ -0,0 +1 @@
+Refactor `StreamIdGenerator` and `MultiWriterIdGenerator` to have the same interface.
diff --git a/changelog.d/8163.misc b/changelog.d/8163.misc
new file mode 100644
index 0000000000..b3ec08855b
--- /dev/null
+++ b/changelog.d/8163.misc
@@ -0,0 +1 @@
+Add filter `name` to the `/users` admin API, which filters by user ID or displayname. Contributed by Awesome Technologies Innovationslabor GmbH.
diff --git a/changelog.d/8164.misc b/changelog.d/8164.misc
new file mode 100644
index 0000000000..55bc079cdb
--- /dev/null
+++ b/changelog.d/8164.misc
@@ -0,0 +1 @@
+Add functions to `MultiWriterIdGen` used by events stream.
diff --git a/changelog.d/8167.misc b/changelog.d/8167.misc
new file mode 100644
index 0000000000..e2ed9be7a4
--- /dev/null
+++ b/changelog.d/8167.misc
@@ -0,0 +1 @@
+Fix tests that were broken due to the merge of 1.19.1.
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 8b3a4246a5..432d56a8ee 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -55,6 +55,7 @@ RUN pip install --prefix="/install" --no-warn-script-location \
FROM docker.io/python:${PYTHON_VERSION}-slim
RUN apt-get update && apt-get install -y \
+ curl \
libpq5 \
xmlsec1 \
gosu \
@@ -69,3 +70,6 @@ VOLUME ["/data"]
EXPOSE 8008/tcp 8009/tcp 8448/tcp
ENTRYPOINT ["/start.py"]
+
+HEALTHCHECK --interval=1m --timeout=5s \
+ CMD curl -fSs http://localhost:8008/health || exit 1
diff --git a/docker/README.md b/docker/README.md
index 008a9ff708..d0da34778e 100644
--- a/docker/README.md
+++ b/docker/README.md
@@ -162,3 +162,32 @@ docker build -t matrixdotorg/synapse -f docker/Dockerfile .
You can choose to build a different docker image by changing the value of the `-f` flag to
point to another Dockerfile.
+
+## Disabling the healthcheck
+
+If you are using a non-standard port or tls inside docker you can disable the healthcheck
+whilst running the above `docker run` commands.
+
+```
+ --no-healthcheck
+```
+## Setting custom healthcheck on docker run
+
+If you wish to point the healthcheck at a different port with docker command, add the following
+
+```
+ --health-cmd 'curl -fSs http://localhost:1234/health'
+```
+
+## Setting the healthcheck in docker-compose file
+
+You can add the following to set a custom healthcheck in a docker compose file.
+You will need version >2.1 for this to work.
+
+```
+healthcheck:
+ test: ["CMD", "curl", "-fSs", "http://localhost:8008/health"]
+ interval: 1m
+ timeout: 10s
+ retries: 3
+```
diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index be05128b3e..d6e3194cda 100644
--- a/docs/admin_api/user_admin_api.rst
+++ b/docs/admin_api/user_admin_api.rst
@@ -108,7 +108,7 @@ The api is::
GET /_synapse/admin/v2/users?from=0&limit=10&guests=false
-To use it, you will need to authenticate by providing an `access_token` for a
+To use it, you will need to authenticate by providing an ``access_token`` for a
server admin: see `README.rst <README.rst>`_.
The parameter ``from`` is optional but used for pagination, denoting the
@@ -119,8 +119,11 @@ from a previous call.
The parameter ``limit`` is optional but is used for pagination, denoting the
maximum number of items to return in this call. Defaults to ``100``.
-The parameter ``user_id`` is optional and filters to only users with user IDs
-that contain this value.
+The parameter ``user_id`` is optional and filters to only return users with user IDs
+that contain this value. This parameter is ignored when using the ``name`` parameter.
+
+The parameter ``name`` is optional and filters to only return users with user ID localparts
+**or** displaynames that contain this value.
The parameter ``guests`` is optional and if ``false`` will **exclude** guest users.
Defaults to ``true`` to include guest users.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 13376c8a42..91bc2265df 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -446,11 +446,10 @@ retention:
# min_lifetime: 1d
# max_lifetime: 1y
- # Retention policy limits. If set, a user won't be able to send a
- # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
- # that's not within this range. This is especially useful in closed federations,
- # in which server admins can make sure every federating server applies the same
- # rules.
+ # Retention policy limits. If set, and the state of a room contains a
+ # 'm.room.retention' event in its state which contains a 'min_lifetime' or a
+ # 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy
+ # to these limits when running purge jobs.
#
#allowed_lifetime_min: 1d
#allowed_lifetime_max: 1y
@@ -476,12 +475,19 @@ retention:
# (e.g. every 12h), but not want that purge to be performed by a job that's
# iterating over every room it knows, which could be heavy on the server.
#
+ # If any purge job is configured, it is strongly recommended to have at least
+ # a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime'
+ # set, or one job without 'shortest_max_lifetime' and one job without
+ # 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if
+ # 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a
+ # room's policy to these values is done after the policies are retrieved from
+ # Synapse's database (which is done using the range specified in a purge job's
+ # configuration).
+ #
#purge_jobs:
- # - shortest_max_lifetime: 1d
- # longest_max_lifetime: 3d
+ # - longest_max_lifetime: 3d
# interval: 12h
# - shortest_max_lifetime: 3d
- # longest_max_lifetime: 1y
# interval: 1d
# Inhibits the /requestToken endpoints from returning an error that might leak
diff --git a/stubs/frozendict.pyi b/stubs/frozendict.pyi
new file mode 100644
index 0000000000..3f3af59f26
--- /dev/null
+++ b/stubs/frozendict.pyi
@@ -0,0 +1,47 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Stub for frozendict.
+
+from typing import (
+ Any,
+ Hashable,
+ Iterable,
+ Iterator,
+ Mapping,
+ overload,
+ Tuple,
+ TypeVar,
+)
+
+_KT = TypeVar("_KT", bound=Hashable) # Key type.
+_VT = TypeVar("_VT") # Value type.
+
+class frozendict(Mapping[_KT, _VT]):
+ @overload
+ def __init__(self, **kwargs: _VT) -> None: ...
+ @overload
+ def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ...
+ @overload
+ def __init__(
+ self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT
+ ) -> None: ...
+ def __getitem__(self, key: _KT) -> _VT: ...
+ def __contains__(self, key: Any) -> bool: ...
+ def copy(self, **add_or_replace: Any) -> frozendict: ...
+ def __iter__(self) -> Iterator[_KT]: ...
+ def __len__(self) -> int: ...
+ def __repr__(self) -> str: ...
+ def __hash__(self) -> int: ...
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 28a078a7b4..67009fff39 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -605,3 +605,11 @@ class HttpResponseException(CodeMessageException):
errmsg = j.pop("error", self.msg)
return ProxiedRequestError(self.code, errmsg, errcode, j)
+
+
+class ShadowBanError(Exception):
+ """
+ Raised when a shadow-banned user attempts to perform an action.
+
+ This should be caught and a proper "fake" success response sent to the user.
+ """
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 5a6a55cc4d..9d903ba996 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -1048,11 +1048,10 @@ class ServerConfig(Config):
# min_lifetime: 1d
# max_lifetime: 1y
- # Retention policy limits. If set, a user won't be able to send a
- # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
- # that's not within this range. This is especially useful in closed federations,
- # in which server admins can make sure every federating server applies the same
- # rules.
+ # Retention policy limits. If set, and the state of a room contains a
+ # 'm.room.retention' event in its state which contains a 'min_lifetime' or a
+ # 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy
+ # to these limits when running purge jobs.
#
#allowed_lifetime_min: 1d
#allowed_lifetime_max: 1y
@@ -1078,12 +1077,19 @@ class ServerConfig(Config):
# (e.g. every 12h), but not want that purge to be performed by a job that's
# iterating over every room it knows, which could be heavy on the server.
#
+ # If any purge job is configured, it is strongly recommended to have at least
+ # a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime'
+ # set, or one job without 'shortest_max_lifetime' and one job without
+ # 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if
+ # 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a
+ # room's policy to these values is done after the policies are retrieved from
+ # Synapse's database (which is done using the range specified in a purge job's
+ # configuration).
+ #
#purge_jobs:
- # - shortest_max_lifetime: 1d
- # longest_max_lifetime: 3d
+ # - longest_max_lifetime: 3d
# interval: 12h
# - shortest_max_lifetime: 3d
- # longest_max_lifetime: 1y
# interval: 1d
# Inhibits the /requestToken endpoints from returning an error that might leak
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index cc5deca75b..67db763dbf 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -133,6 +133,8 @@ class _EventInternalMetadata(object):
rejection. This is needed as those events are marked as outliers, but
they still need to be processed as if they're new events (e.g. updating
invite state in the database, relaying to clients, etc).
+
+ (Added in synapse 0.99.0, so may be unreliable for events received before that)
"""
return self._dict.get("out_of_band_membership", False)
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 3c11e317fd..1cc9eda6a9 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,9 +15,10 @@
# limitations under the License.
import inspect
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
-from synapse.spam_checker_api import SpamCheckerApi
+from synapse.spam_checker_api import RegistrationBehaviour, SpamCheckerApi
+from synapse.types import Collection
MYPY = False
if MYPY:
@@ -219,3 +220,33 @@ class SpamChecker(object):
return True
return False
+
+ def check_registration_for_spam(
+ self,
+ email_threepid: Optional[dict],
+ username: Optional[str],
+ request_info: Collection[Tuple[str, str]],
+ ) -> RegistrationBehaviour:
+ """Checks if we should allow the given registration request.
+
+ Args:
+ email_threepid: The email threepid used for registering, if any
+ username: The request user name, if any
+ request_info: List of tuples of user agent and IP that
+ were used during the registration process.
+
+ Returns:
+ Enum for how the request should be handled
+ """
+
+ for spam_checker in self.spam_checkers:
+ # For backwards compatibility, only run if the method exists on the
+ # spam checker
+ checker = getattr(spam_checker, "check_registration_for_spam", None)
+ if checker:
+ behaviour = checker(email_threepid, username, request_info)
+ assert isinstance(behaviour, RegistrationBehaviour)
+ if behaviour != RegistrationBehaviour.ALLOW:
+ return behaviour
+
+ return RegistrationBehaviour.ALLOW
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 588d222f36..5ce3874fba 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -74,15 +74,14 @@ class EventValidator(object):
)
if event.type == EventTypes.Retention:
- self._validate_retention(event, config)
+ self._validate_retention(event)
- def _validate_retention(self, event, config):
+ def _validate_retention(self, event):
"""Checks that an event that defines the retention policy for a room respects the
- boundaries imposed by the server's administrator.
+ format enforced by the spec.
Args:
event (FrozenEvent): The event to validate.
- config (Config): The homeserver's configuration.
"""
min_lifetime = event.content.get("min_lifetime")
max_lifetime = event.content.get("max_lifetime")
@@ -95,32 +94,6 @@ class EventValidator(object):
errcode=Codes.BAD_JSON,
)
- if (
- config.retention_allowed_lifetime_min is not None
- and min_lifetime < config.retention_allowed_lifetime_min
- ):
- raise SynapseError(
- code=400,
- msg=(
- "'min_lifetime' can't be lower than the minimum allowed"
- " value enforced by the server's administrator"
- ),
- errcode=Codes.BAD_JSON,
- )
-
- if (
- config.retention_allowed_lifetime_max is not None
- and min_lifetime > config.retention_allowed_lifetime_max
- ):
- raise SynapseError(
- code=400,
- msg=(
- "'min_lifetime' can't be greater than the maximum allowed"
- " value enforced by the server's administrator"
- ),
- errcode=Codes.BAD_JSON,
- )
-
if max_lifetime is not None:
if not isinstance(max_lifetime, int):
raise SynapseError(
@@ -129,32 +102,6 @@ class EventValidator(object):
errcode=Codes.BAD_JSON,
)
- if (
- config.retention_allowed_lifetime_min is not None
- and max_lifetime < config.retention_allowed_lifetime_min
- ):
- raise SynapseError(
- code=400,
- msg=(
- "'max_lifetime' can't be lower than the minimum allowed value"
- " enforced by the server's administrator"
- ),
- errcode=Codes.BAD_JSON,
- )
-
- if (
- config.retention_allowed_lifetime_max is not None
- and max_lifetime > config.retention_allowed_lifetime_max
- ):
- raise SynapseError(
- code=400,
- msg=(
- "'max_lifetime' can't be greater than the maximum allowed"
- " value enforced by the server's administrator"
- ),
- errcode=Codes.BAD_JSON,
- )
-
if (
min_lifetime is not None
and max_lifetime is not None
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index e53b6ac456..4662008bfd 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -329,10 +329,10 @@ class FederationSender(object):
room_id = receipt.room_id
# Work out which remote servers should be poked and poke them.
- domains = await self.state.get_current_hosts_in_room(room_id)
+ domains_set = await self.state.get_current_hosts_in_room(room_id)
domains = [
d
- for d in domains
+ for d in domains_set
if d != self.server_name
and self._federation_shard_config.should_handle(self._instance_name, d)
]
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 68d6870e40..654f58ddae 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -364,6 +364,14 @@ class AuthHandler(BaseHandler):
# authentication flow.
await self.store.set_ui_auth_clientdict(sid, clientdict)
+ user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+ 0
+ ].decode("ascii", "surrogateescape")
+
+ await self.store.add_user_agent_ip_to_ui_auth_session(
+ session.session_id, user_agent, clientip
+ )
+
if not authdict:
raise InteractiveAuthIncompleteError(
session.session_id, self._auth_dict_for_flows(flows, session.session_id)
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 786e608fa2..a4cc4b9a5a 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -35,6 +35,7 @@ class CasHandler:
"""
def __init__(self, hs):
+ self.hs = hs
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
@@ -210,8 +211,16 @@ class CasHandler:
else:
if not registered_user_id:
+ # Pull out the user-agent and IP from the request.
+ user_agent = request.requestHeaders.getRawHeaders(
+ b"User-Agent", default=[b""]
+ )[0].decode("ascii", "surrogateescape")
+ ip_address = self.hs.get_ip_from_request(request)
+
registered_user_id = await self._registration_handler.register_user(
- localpart=localpart, default_display_name=user_display_name
+ localpart=localpart,
+ default_display_name=user_display_name,
+ user_agent_ips=(user_agent, ip_address),
)
await self._auth_handler.complete_sso_login(
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 610b08d00b..dcb4c82244 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -16,8 +16,6 @@
import logging
from typing import Any, Dict
-from canonicaljson import json
-
from synapse.api.errors import SynapseError
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
@@ -27,6 +25,7 @@ from synapse.logging.opentracing import (
start_active_span,
)
from synapse.types import UserID, get_domain_from_id
+from synapse.util import json_encoder
from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
@@ -174,7 +173,7 @@ class DeviceMessageHandler(object):
"sender": sender_user_id,
"type": message_type,
"message_id": message_id,
- "org.matrix.opentracing_context": json.dumps(context),
+ "org.matrix.opentracing_context": json_encoder.encode(context),
}
log_kv({"local_messages": local_messages})
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index af9936f7e2..bc61bb6acb 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -23,6 +23,7 @@ from synapse.api.errors import (
CodeMessageException,
Codes,
NotFoundError,
+ ShadowBanError,
StoreError,
SynapseError,
)
@@ -200,6 +201,8 @@ class DirectoryHandler(BaseHandler):
try:
await self._update_canonical_alias(requester, user_id, room_id, room_alias)
+ except ShadowBanError as e:
+ logger.info("Failed to update alias events due to shadow-ban: %s", e)
except AuthError as e:
logger.info("Failed to update alias events: %s", e)
@@ -293,6 +296,9 @@ class DirectoryHandler(BaseHandler):
"""
Send an updated canonical alias event if the removed alias was set as
the canonical alias or listed in the alt_aliases field.
+
+ Raises:
+ ShadowBanError if the requester has been shadow-banned.
"""
alias_event = await self.state.get_current_state(
room_id, EventTypes.CanonicalAlias, ""
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 29863c029b..1865e14677 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -2144,10 +2144,10 @@ class FederationHandler(BaseHandler):
)
state_sets = list(state_sets.values())
state_sets.append(state)
- current_state_ids = await self.state_handler.resolve_events(
+ current_states = await self.state_handler.resolve_events(
room_version, state_sets, event
)
- current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
+ current_state_ids = {k: e.event_id for k, e in current_states.items()}
else:
current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
@@ -2159,9 +2159,11 @@ class FederationHandler(BaseHandler):
# Now check if event pass auth against said current state
auth_types = auth_types_for_event(event)
- current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
+ current_state_ids_list = [
+ e for k, e in current_state_ids.items() if k in auth_types
+ ]
- auth_events_map = await self.store.get_events(current_state_ids)
+ auth_events_map = await self.store.get_events(current_state_ids_list)
current_auth_events = {
(e.type, e.state_key): e for e in auth_events_map.values()
}
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index d5b12403f9..755a52a50d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import random
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from canonicaljson import encode_canonical_json
@@ -34,6 +35,7 @@ from synapse.api.errors import (
Codes,
ConsentNotGivenError,
NotFoundError,
+ ShadowBanError,
SynapseError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
@@ -648,24 +650,35 @@ class EventCreationHandler(object):
event: EventBase,
context: EventContext,
ratelimit: bool = True,
+ ignore_shadow_ban: bool = False,
) -> int:
"""
Persists and notifies local clients and federation of an event.
Args:
- requester
- event the event to send.
- context: the context of the event.
+ requester: The requester sending the event.
+ event: The event to send.
+ context: The context of the event.
ratelimit: Whether to rate limit this send.
+ ignore_shadow_ban: True if shadow-banned users should be allowed to
+ send this event.
Return:
The stream_id of the persisted event.
+
+ Raises:
+ ShadowBanError if the requester has been shadow-banned.
"""
if event.type == EventTypes.Member:
raise SynapseError(
500, "Tried to send member event through non-member codepath"
)
+ if not ignore_shadow_ban and requester.shadow_banned:
+ # We randomly sleep a bit just to annoy the requester.
+ await self.clock.sleep(random.randint(1, 10))
+ raise ShadowBanError()
+
user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
@@ -719,12 +732,28 @@ class EventCreationHandler(object):
event_dict: dict,
ratelimit: bool = True,
txn_id: Optional[str] = None,
+ ignore_shadow_ban: bool = False,
) -> Tuple[EventBase, int]:
"""
Creates an event, then sends it.
See self.create_event and self.send_nonmember_event.
+
+ Args:
+ requester: The requester sending the event.
+ event_dict: An entire event.
+ ratelimit: Whether to rate limit this send.
+ txn_id: The transaction ID.
+ ignore_shadow_ban: True if shadow-banned users should be allowed to
+ send this event.
+
+ Raises:
+ ShadowBanError if the requester has been shadow-banned.
"""
+ if not ignore_shadow_ban and requester.shadow_banned:
+ # We randomly sleep a bit just to annoy the requester.
+ await self.clock.sleep(random.randint(1, 10))
+ raise ShadowBanError()
# We limit the number of concurrent event sends in a room so that we
# don't fork the DAG too much. If we don't limit then we can end up in
@@ -743,7 +772,11 @@ class EventCreationHandler(object):
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
stream_id = await self.send_nonmember_event(
- requester, event, context, ratelimit=ratelimit
+ requester,
+ event,
+ context,
+ ratelimit=ratelimit,
+ ignore_shadow_ban=ignore_shadow_ban,
)
return event, stream_id
@@ -1183,8 +1216,14 @@ class EventCreationHandler(object):
event.internal_metadata.proactively_send = False
+ # Since this is a dummy-event it is OK if it is sent by a
+ # shadow-banned user.
await self.send_nonmember_event(
- requester, event, context, ratelimit=False
+ requester,
+ event,
+ context,
+ ratelimit=False,
+ ignore_shadow_ban=True,
)
dummy_event_sent = True
break
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index dd3703cbd2..c5bd2fea68 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -93,6 +93,7 @@ class OidcHandler:
"""
def __init__(self, hs: "HomeServer"):
+ self.hs = hs
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._client_auth = ClientAuth(
@@ -689,9 +690,17 @@ class OidcHandler:
self._render_error(request, "invalid_token", str(e))
return
+ # Pull out the user-agent and IP from the request.
+ user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+ 0
+ ].decode("ascii", "surrogateescape")
+ ip_address = self.hs.get_ip_from_request(request)
+
# Call the mapper to register/login the user
try:
- user_id = await self._map_userinfo_to_user(userinfo, token)
+ user_id = await self._map_userinfo_to_user(
+ userinfo, token, user_agent, ip_address
+ )
except MappingException as e:
logger.exception("Could not map user")
self._render_error(request, "mapping_error", str(e))
@@ -828,7 +837,9 @@ class OidcHandler:
now = self._clock.time_msec()
return now < expiry
- async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
+ async def _map_userinfo_to_user(
+ self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
+ ) -> str:
"""Maps a UserInfo object to a mxid.
UserInfo should have a claim that uniquely identifies users. This claim
@@ -843,6 +854,8 @@ class OidcHandler:
Args:
userinfo: an object representing the user
token: a dict with the tokens obtained from the provider
+ user_agent: The user agent of the client making the request.
+ ip_address: The IP address of the client making the request.
Raises:
MappingException: if there was an error while mapping some properties
@@ -899,7 +912,9 @@ class OidcHandler:
# It's the first time this user is logging in and the mapped mxid was
# not taken, register the user
registered_user_id = await self._registration_handler.register_user(
- localpart=localpart, default_display_name=attributes["display_name"],
+ localpart=localpart,
+ default_display_name=attributes["display_name"],
+ user_agent_ips=(user_agent, ip_address),
)
await self._datastore.record_user_external_id(
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 487420bb5d..ac3418d69d 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -82,6 +82,9 @@ class PaginationHandler(object):
self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
+ self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min
+ self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max
+
if hs.config.retention_enabled:
# Run the purge jobs described in the configuration file.
for job in hs.config.retention_purge_jobs:
@@ -111,7 +114,7 @@ class PaginationHandler(object):
the range to handle (inclusive). If None, it means that the range has no
upper limit.
"""
- # We want the storage layer to to include rooms with no retention policy in its
+ # We want the storage layer to include rooms with no retention policy in its
# return value only if a default retention policy is defined in the server's
# configuration and that policy's 'max_lifetime' is either lower (or equal) than
# max_ms or higher than min_ms (or both).
@@ -152,13 +155,32 @@ class PaginationHandler(object):
)
continue
- max_lifetime = retention_policy["max_lifetime"]
+ # If max_lifetime is None, it means that the room has no retention policy.
+ # Given we only retrieve such rooms when there's a default retention policy
+ # defined in the server's configuration, we can safely assume that's the
+ # case and use it for this room.
+ max_lifetime = (
+ retention_policy["max_lifetime"] or self._retention_default_max_lifetime
+ )
- if max_lifetime is None:
- # If max_lifetime is None, it means that include_null equals True,
- # therefore we can safely assume that there is a default policy defined
- # in the server's configuration.
- max_lifetime = self._retention_default_max_lifetime
+ # Cap the effective max_lifetime to be within the range allowed in the
+ # config.
+ # We do this in two steps:
+ # 1. Make sure it's higher or equal to the minimum allowed value, and if
+ # it's not replace it with that value. This is because the server
+ # operator can be required to not delete information before a given
+ # time, e.g. to comply with freedom of information laws.
+ # 2. Make sure the resulting value is lower or equal to the maximum allowed
+ # value, and if it's not replace it with that value. This is because the
+ # server operator can be required to delete any data after a specific
+ # amount of time.
+ if self._retention_allowed_lifetime_min is not None:
+ max_lifetime = max(self._retention_allowed_lifetime_min, max_lifetime)
+
+ if self._retention_allowed_lifetime_max is not None:
+ max_lifetime = min(max_lifetime, self._retention_allowed_lifetime_max)
+
+ logger.debug("[purge] max_lifetime for room %s: %s", room_id, max_lifetime)
# Figure out what token we should start purging at.
ts = self.clock.time_msec() - max_lifetime
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 24e1940ee5..1846068150 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -40,7 +40,7 @@ from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cached
from synapse.util.metrics import Measure
@@ -1318,7 +1318,7 @@ async def get_interested_parties(
async def get_interested_remotes(
store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
-) -> List[Tuple[List[str], List[UserPresenceState]]]:
+) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
"""Given a list of presence states figure out which remote servers
should be sent which.
@@ -1334,7 +1334,7 @@ async def get_interested_remotes(
each tuple the list of UserPresenceState should be sent to each
destination
"""
- hosts_and_states = []
+ hosts_and_states = [] # type: List[Tuple[Collection[str], List[UserPresenceState]]]
# First we look up the rooms each user is in (as well as any explicit
# subscriptions), then for each distinct room we look up the remote
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index e17b402e68..83baddb4fc 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -26,6 +26,7 @@ from synapse.replication.http.register import (
ReplicationPostRegisterActionsServlet,
ReplicationRegisterServlet,
)
+from synapse.spam_checker_api import RegistrationBehaviour
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester
@@ -53,6 +54,8 @@ class RegistrationHandler(BaseHandler):
self.macaroon_gen = hs.get_macaroon_generator()
self._server_notices_mxid = hs.config.server_notices_mxid
+ self.spam_checker = hs.get_spam_checker()
+
self._show_in_user_directory = self.hs.config.show_users_in_user_directory
if hs.config.worker_app:
@@ -139,7 +142,9 @@ class RegistrationHandler(BaseHandler):
try:
int(localpart)
raise SynapseError(
- 400, "Numeric user IDs are reserved for guest users."
+ 400,
+ "Numeric user IDs are reserved for guest users.",
+ errcode=Codes.INVALID_USERNAME,
)
except ValueError:
pass
@@ -157,7 +162,7 @@ class RegistrationHandler(BaseHandler):
address=None,
bind_emails=[],
by_admin=False,
- shadow_banned=False,
+ user_agent_ips=None,
):
"""Registers a new client on the server.
@@ -175,7 +180,8 @@ class RegistrationHandler(BaseHandler):
bind_emails (List[str]): list of emails to bind to this account.
by_admin (bool): True if this registration is being made via the
admin api, otherwise False.
- shadow_banned (bool): Shadow-ban the created user.
+ user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used
+ during the registration process.
Returns:
str: user_id
Raises:
@@ -183,6 +189,24 @@ class RegistrationHandler(BaseHandler):
"""
self.check_registration_ratelimit(address)
+ result = self.spam_checker.check_registration_for_spam(
+ threepid, localpart, user_agent_ips or [],
+ )
+
+ if result == RegistrationBehaviour.DENY:
+ logger.info(
+ "Blocked registration of %r", localpart,
+ )
+ # We return a 429 to make it not obvious that they've been
+ # denied.
+ raise SynapseError(429, "Rate limited")
+
+ shadow_banned = result == RegistrationBehaviour.SHADOW_BAN
+ if shadow_banned:
+ logger.info(
+ "Shadow banning registration of %r", localpart,
+ )
+
# do not check_auth_blocking if the call is coming through the Admin API
if not by_admin:
await self.auth.check_auth_blocking(threepid=threepid)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index b1dd3af7b1..3fbd99473f 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -20,6 +20,7 @@
import itertools
import logging
import math
+import random
import string
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
@@ -135,6 +136,9 @@ class RoomCreationHandler(BaseHandler):
Returns:
the new room id
+
+ Raises:
+ ShadowBanError if the requester is shadow-banned.
"""
await self.ratelimit(requester)
@@ -170,6 +174,15 @@ class RoomCreationHandler(BaseHandler):
async def _upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
):
+ """
+ Args:
+ requester: the user requesting the upgrade
+ old_room_id: the id of the room to be replaced
+ new_versions: the version to upgrade the room to
+
+ Raises:
+ ShadowBanError if the requester is shadow-banned.
+ """
user_id = requester.user.to_string()
# start by allocating a new room id
@@ -256,6 +269,9 @@ class RoomCreationHandler(BaseHandler):
old_room_id: the id of the room to be replaced
new_room_id: the id of the replacement room
old_room_state: the state map for the old room
+
+ Raises:
+ ShadowBanError if the requester is shadow-banned.
"""
old_room_pl_event_id = old_room_state.get((EventTypes.PowerLevels, ""))
@@ -644,6 +660,8 @@ class RoomCreationHandler(BaseHandler):
if mapping:
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
+ invite_3pid_list = config.get("invite_3pid", [])
+ invite_list = config.get("invite", [])
for i in invite_list:
try:
uid = UserID.from_string(i)
@@ -651,6 +669,14 @@ class RoomCreationHandler(BaseHandler):
except Exception:
raise SynapseError(400, "Invalid user_id: %s" % (i,))
+ if (invite_list or invite_3pid_list) and requester.shadow_banned:
+ # We randomly sleep a bit just to annoy the requester.
+ await self.clock.sleep(random.randint(1, 10))
+
+ # Allow the request to go through, but remove any associated invites.
+ invite_3pid_list = []
+ invite_list = []
+
await self.event_creation_handler.assert_accepted_privacy_policy(requester)
power_level_content_override = config.get("power_level_content_override")
@@ -768,6 +794,8 @@ class RoomCreationHandler(BaseHandler):
if is_direct:
content["is_direct"] = is_direct
+ # Note that update_membership with an action of "invite" can raise a
+ # ShadowBanError, but this was handled above by emptying invite_list.
_, last_stream_id = await self.room_member_handler.update_membership(
requester,
UserID.from_string(invitee),
@@ -783,6 +811,8 @@ class RoomCreationHandler(BaseHandler):
id_access_token = invite_3pid.get("id_access_token") # optional
address = invite_3pid["address"]
medium = invite_3pid["medium"]
+ # Note that do_3pid_invite can raise a ShadowBanError, but this was
+ # handled above by emptying invite_3pid_list.
last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite(
room_id,
requester.user,
@@ -843,11 +873,13 @@ class RoomCreationHandler(BaseHandler):
async def send(etype: str, content: JsonDict, **kwargs) -> int:
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
+ # Allow these events to be sent even if the user is shadow-banned to
+ # allow the room creation to complete.
(
_,
last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
- creator, event, ratelimit=False
+ creator, event, ratelimit=False, ignore_shadow_ban=True,
)
return last_stream_id
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index cae2c36d28..8ee9b2063d 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -15,14 +15,21 @@
import abc
import logging
+import random
from http import HTTPStatus
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
from unpaddedbase64 import encode_base64
from synapse import types
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ LimitExceededError,
+ ShadowBanError,
+ SynapseError,
+)
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import EventFormatVersions
from synapse.crypto.event_signing import compute_event_reference_hash
@@ -31,7 +38,15 @@ from synapse.events.builder import create_local_event_from_event_dict
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.storage.roommember import RoomsForUser
-from synapse.types import Collection, JsonDict, Requester, RoomAlias, RoomID, UserID
+from synapse.types import (
+ Collection,
+ JsonDict,
+ Requester,
+ RoomAlias,
+ RoomID,
+ StateMap,
+ UserID,
+)
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
@@ -303,7 +318,31 @@ class RoomMemberHandler(object):
new_room: bool = False,
require_consent: bool = True,
) -> Tuple[str, int]:
- """Update a user's membership in a room"""
+ """Update a user's membership in a room.
+
+ Params:
+ requester: The user who is performing the update.
+ target: The user whose membership is being updated.
+ room_id: The room ID whose membership is being updated.
+ action: The membership change, see synapse.api.constants.Membership.
+ txn_id: The transaction ID, if given.
+ remote_room_hosts: Remote servers to send the update to.
+ third_party_signed: Information from a 3PID invite.
+ ratelimit: Whether to rate limit the request.
+ content: The content of the created event.
+ require_consent: Whether consent is required.
+
+ Returns:
+ A tuple of the new event ID and stream ID.
+
+ Raises:
+ ShadowBanError if a shadow-banned requester attempts to send an invite.
+ """
+ if action == Membership.INVITE and requester.shadow_banned:
+ # We randomly sleep a bit just to annoy the requester.
+ await self.clock.sleep(random.randint(1, 10))
+ raise ShadowBanError()
+
key = (room_id,)
with (await self.member_linearizer.queue(key)):
@@ -741,9 +780,7 @@ class RoomMemberHandler(object):
if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target_user, room_id)
- async def _can_guest_join(
- self, current_state_ids: Dict[Tuple[str, str], str]
- ) -> bool:
+ async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool:
"""
Returns whether a guest can join a room based on its current state.
"""
@@ -811,6 +848,25 @@ class RoomMemberHandler(object):
new_room: bool = False,
id_access_token: Optional[str] = None,
) -> int:
+ """Invite a 3PID to a room.
+
+ Args:
+ room_id: The room to invite the 3PID to.
+ inviter: The user sending the invite.
+ medium: The 3PID's medium.
+ address: The 3PID's address.
+ id_server: The identity server to use.
+ requester: The user making the request.
+ txn_id: The transaction ID this is part of, or None if this is not
+ part of a transaction.
+ id_access_token: The optional identity server access token.
+
+ Returns:
+ The new stream ID.
+
+ Raises:
+ ShadowBanError if the requester has been shadow-banned.
+ """
if self.config.block_non_admin_invites:
is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
@@ -818,6 +874,11 @@ class RoomMemberHandler(object):
403, "Invites have been disabled on this server", Codes.FORBIDDEN
)
+ if requester.shadow_banned:
+ # We randomly sleep a bit just to annoy the requester.
+ await self.clock.sleep(random.randint(1, 10))
+ raise ShadowBanError()
+
# We need to rate limit *before* we send out any 3PID invites, so we
# can't just rely on the standard ratelimiting of events.
await self.base_handler.ratelimit(requester)
@@ -865,6 +926,8 @@ class RoomMemberHandler(object):
raise SynapseError(403, "Invites have been disabled on this server")
if invitee:
+ # Note that update_membership with an action of "invite" can raise
+ # a ShadowBanError, but this was done above already.
_, stream_id = await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
)
@@ -970,9 +1033,7 @@ class RoomMemberHandler(object):
)
return stream_id
- async def _is_host_in_room(
- self, current_state_ids: Dict[Tuple[str, str], str]
- ) -> bool:
+ async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
# Have we just created the room, and is this about to be the very
# first member event?
create_event_id = current_state_ids.get(("m.room.create", ""))
@@ -1103,7 +1164,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return event_id, stream_id
# The room is too large. Leave.
- requester = types.create_requester(user, None, False, None)
+ requester = types.create_requester(user, None, False, False, None)
await self.update_membership(
requester=requester, target=user, room_id=room_id, action="leave"
)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index c1fcb98454..b426199aa6 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -54,6 +54,7 @@ class Saml2SessionData:
class SamlHandler:
def __init__(self, hs: "synapse.server.HomeServer"):
+ self.hs = hs
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
@@ -133,8 +134,14 @@ class SamlHandler:
# the dict.
self.expire_sessions()
+ # Pull out the user-agent and IP from the request.
+ user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+ 0
+ ].decode("ascii", "surrogateescape")
+ ip_address = self.hs.get_ip_from_request(request)
+
user_id, current_session = await self._map_saml_response_to_user(
- resp_bytes, relay_state
+ resp_bytes, relay_state, user_agent, ip_address
)
# Complete the interactive auth session or the login.
@@ -147,7 +154,11 @@ class SamlHandler:
await self._auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(
- self, resp_bytes: str, client_redirect_url: str
+ self,
+ resp_bytes: str,
+ client_redirect_url: str,
+ user_agent: str,
+ ip_address: str,
) -> Tuple[str, Optional[Saml2SessionData]]:
"""
Given a sample response, retrieve the cached session and user for it.
@@ -155,6 +166,8 @@ class SamlHandler:
Args:
resp_bytes: The SAML response.
client_redirect_url: The redirect URL passed in by the client.
+ user_agent: The user agent of the client making the request.
+ ip_address: The IP address of the client making the request.
Returns:
Tuple of the user ID and SAML session associated with this response.
@@ -291,6 +304,7 @@ class SamlHandler:
localpart=localpart,
default_display_name=displayname,
bind_emails=emails,
+ user_agent_ips=(user_agent, ip_address),
)
await self._datastore.record_user_external_id(
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index abe532d350..d39ac62168 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -172,12 +172,11 @@ from functools import wraps
from typing import TYPE_CHECKING, Dict, Optional, Type
import attr
-from canonicaljson import json
from twisted.internet import defer
from synapse.config import ConfigError
-from synapse.util import json_decoder
+from synapse.util import json_decoder, json_encoder
if TYPE_CHECKING:
from synapse.http.site import SynapseRequest
@@ -693,7 +692,7 @@ def active_span_context_as_string():
opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
)
- return json.dumps(carrier)
+ return json_encoder.encode(carrier)
@only_if_tracing
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 7c292ef3f9..09726d52d6 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -316,6 +316,9 @@ class JoinRoomAliasServlet(RestServlet):
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
+ # update_membership with an action of "invite" can raise a
+ # ShadowBanError. This is not handled since it is assumed that
+ # an admin isn't going to call this API with a shadow-banned user.
await self.room_member_handler.update_membership(
requester=requester,
target=fake_requester.user,
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index cc0bdfa5c9..f3e77da850 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -73,6 +73,7 @@ class UsersRestServletV2(RestServlet):
The parameters `from` and `limit` are required only for pagination.
By default, a `limit` of 100 is used.
The parameter `user_id` can be used to filter by user id.
+ The parameter `name` can be used to filter by user id or display name.
The parameter `guests` can be used to exclude guest users.
The parameter `deactivated` can be used to include deactivated users.
"""
@@ -89,11 +90,12 @@ class UsersRestServletV2(RestServlet):
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
user_id = parse_string(request, "user_id", default=None)
+ name = parse_string(request, "name", default=None)
guests = parse_boolean(request, "guests", default=True)
deactivated = parse_boolean(request, "deactivated", default=False)
users, total = await self.store.get_users_paginate(
- start, limit, user_id, guests, deactivated
+ start, limit, user_id, name, guests, deactivated
)
ret = {"users": users, "total": total}
if len(users) >= limit:
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index bc914d920e..7ed1ccb5a0 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -27,6 +27,7 @@ from synapse.api.errors import (
Codes,
HttpResponseException,
InvalidClientCredentialsError,
+ ShadowBanError,
SynapseError,
)
from synapse.api.filtering import Filter
@@ -45,6 +46,7 @@ from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
from synapse.util import json_decoder
+from synapse.util.stringutils import random_string
MYPY = False
if MYPY:
@@ -199,23 +201,26 @@ class RoomStateEventRestServlet(TransactionRestServlet):
if state_key is not None:
event_dict["state_key"] = state_key
- if event_type == EventTypes.Member:
- membership = content.get("membership", None)
- event_id, _ = await self.room_member_handler.update_membership(
- requester,
- target=UserID.from_string(state_key),
- room_id=room_id,
- action=membership,
- content=content,
- )
- else:
- (
- event,
- _,
- ) = await self.event_creation_handler.create_and_send_nonmember_event(
- requester, event_dict, txn_id=txn_id
- )
- event_id = event.event_id
+ try:
+ if event_type == EventTypes.Member:
+ membership = content.get("membership", None)
+ event_id, _ = await self.room_member_handler.update_membership(
+ requester,
+ target=UserID.from_string(state_key),
+ room_id=room_id,
+ action=membership,
+ content=content,
+ )
+ else:
+ (
+ event,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
+ requester, event_dict, txn_id=txn_id
+ )
+ event_id = event.event_id
+ except ShadowBanError:
+ event_id = "$" + random_string(43)
set_tag("event_id", event_id)
ret = {"event_id": event_id}
@@ -248,12 +253,19 @@ class RoomSendEventRestServlet(TransactionRestServlet):
if b"ts" in request.args and requester.app_service:
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
- event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
- requester, event_dict, txn_id=txn_id
- )
+ try:
+ (
+ event,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
+ requester, event_dict, txn_id=txn_id
+ )
+ event_id = event.event_id
+ except ShadowBanError:
+ event_id = "$" + random_string(43)
- set_tag("event_id", event.event_id)
- return 200, {"event_id": event.event_id}
+ set_tag("event_id", event_id)
+ return 200, {"event_id": event_id}
def on_GET(self, request, room_id, event_type, txn_id):
return 200, "Not implemented"
@@ -719,17 +731,21 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content = {}
if membership_action == "invite" and self._has_3pid_invite_keys(content):
- await self.room_member_handler.do_3pid_invite(
- room_id,
- requester.user,
- content["medium"],
- content["address"],
- content["id_server"],
- requester,
- txn_id,
- new_room=False,
- id_access_token=content.get("id_access_token"),
- )
+ try:
+ await self.room_member_handler.do_3pid_invite(
+ room_id,
+ requester.user,
+ content["medium"],
+ content["address"],
+ content["id_server"],
+ requester,
+ txn_id,
+ new_room=False,
+ id_access_token=content.get("id_access_token"),
+ )
+ except ShadowBanError:
+ # Pretend the request succeeded.
+ pass
return 200, {}
target = requester.user
@@ -741,15 +757,19 @@ class RoomMembershipRestServlet(TransactionRestServlet):
if "reason" in content:
event_content = {"reason": content["reason"]}
- await self.room_member_handler.update_membership(
- requester=requester,
- target=target,
- room_id=room_id,
- action=membership_action,
- txn_id=txn_id,
- third_party_signed=content.get("third_party_signed", None),
- content=event_content,
- )
+ try:
+ await self.room_member_handler.update_membership(
+ requester=requester,
+ target=target,
+ room_id=room_id,
+ action=membership_action,
+ txn_id=txn_id,
+ third_party_signed=content.get("third_party_signed", None),
+ content=event_content,
+ )
+ except ShadowBanError:
+ # Pretend the request succeeded.
+ pass
return_value = {}
@@ -787,20 +807,27 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
- event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
- requester,
- {
- "type": EventTypes.Redaction,
- "content": content,
- "room_id": room_id,
- "sender": requester.user.to_string(),
- "redacts": event_id,
- },
- txn_id=txn_id,
- )
+ try:
+ (
+ event,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
+ requester,
+ {
+ "type": EventTypes.Redaction,
+ "content": content,
+ "room_id": room_id,
+ "sender": requester.user.to_string(),
+ "redacts": event_id,
+ },
+ txn_id=txn_id,
+ )
+ event_id = event.event_id
+ except ShadowBanError:
+ event_id = "$" + random_string(43)
- set_tag("event_id", event.event_id)
- return 200, {"event_id": event.event_id}
+ set_tag("event_id", event_id)
+ return 200, {"event_id": event_id}
def on_PUT(self, request, room_id, event_id, txn_id):
set_tag("txn_id", txn_id)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 6b945e1849..570fa0a2eb 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import random
import re
from http import HTTPStatus
from typing import TYPE_CHECKING
@@ -122,6 +123,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
if self.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
+ # Also wait for some random amount of time between 100ms and 1s to make it
+ # look like we did something.
+ await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
@@ -491,6 +495,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
if self.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
+ # Also wait for some random amount of time between 100ms and 1s to make it
+ # look like we did something.
+ await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -563,6 +570,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
if self.hs.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
+ # Also wait for some random amount of time between 100ms and 1s to make it
+ # look like we did something.
+ await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index e0d83a962d..a4c08c8ec5 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -17,6 +17,7 @@
import hmac
import logging
+import random
import re
from typing import List, Union
@@ -133,6 +134,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
if self.hs.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
+ # Also wait for some random amount of time between 100ms and 1s to make it
+ # look like we did something.
+ await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -207,6 +211,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
if self.hs.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
+ # Also wait for some random amount of time between 100ms and 1s to make it
+ # look like we did something.
+ await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(
@@ -658,6 +665,10 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_IN_USE,
)
+ entries = await self.store.get_user_agents_ips_to_ui_auth_session(
+ session_id
+ )
+
registered_user_id = await self.registration_handler.register_user(
localpart=desired_username,
password_hash=password_hash,
@@ -665,6 +676,7 @@ class RegisterRestServlet(RestServlet):
default_display_name=desired_display_name,
threepid=threepid,
address=client_addr,
+ user_agent_ips=entries,
)
# Necessary due to auth checks prior to the threepid being
# written to the db
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index 89002ffbff..e29f49f7f5 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -22,7 +22,7 @@ any time to reflect changes in the MSC.
import logging
from synapse.api.constants import EventTypes, RelationTypes
-from synapse.api.errors import SynapseError
+from synapse.api.errors import ShadowBanError, SynapseError
from synapse.http.servlet import (
RestServlet,
parse_integer,
@@ -35,6 +35,7 @@ from synapse.storage.relations import (
PaginationChunk,
RelationPaginationToken,
)
+from synapse.util.stringutils import random_string
from ._base import client_patterns
@@ -111,11 +112,18 @@ class RelationSendServlet(RestServlet):
"sender": requester.user.to_string(),
}
- event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
- requester, event_dict=event_dict, txn_id=txn_id
- )
+ try:
+ (
+ event,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
+ requester, event_dict=event_dict, txn_id=txn_id
+ )
+ event_id = event.event_id
+ except ShadowBanError:
+ event_id = "$" + random_string(43)
- return 200, {"event_id": event.event_id}
+ return 200, {"event_id": event_id}
class RelationPaginationServlet(RestServlet):
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index f357015a70..39a5518614 100644
--- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -15,13 +15,14 @@
import logging
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import Codes, ShadowBanError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.util import stringutils
from ._base import client_patterns
@@ -62,7 +63,6 @@ class RoomUpgradeRestServlet(RestServlet):
content = parse_json_object_from_request(request)
assert_params_in_dict(content, ("new_version",))
- new_version = content["new_version"]
new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"])
if new_version is None:
@@ -72,9 +72,13 @@ class RoomUpgradeRestServlet(RestServlet):
Codes.UNSUPPORTED_ROOM_VERSION,
)
- new_room_id = await self._room_creation_handler.upgrade_room(
- requester, room_id, new_version
- )
+ try:
+ new_room_id = await self._room_creation_handler.upgrade_room(
+ requester, room_id, new_version
+ )
+ except ShadowBanError:
+ # Generate a random room ID.
+ new_room_id = stringutils.random_string(18)
ret = {"replacement_room": new_room_id}
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index 20177b44e7..e15e13b756 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import logging
from twisted.web.resource import Resource
from synapse.http.server import set_cors_headers
+from synapse.util import json_encoder
logger = logging.getLogger(__name__)
@@ -67,4 +67,4 @@ class WellKnownResource(Resource):
logger.debug("returning: %s", r)
request.setHeader(b"Content-Type", b"application/json")
- return json.dumps(r).encode("utf-8")
+ return json_encoder.encode(r).encode("utf-8")
diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py
index 7f63f1bfa0..9be92e2565 100644
--- a/synapse/spam_checker_api/__init__.py
+++ b/synapse/spam_checker_api/__init__.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from enum import Enum
from twisted.internet import defer
@@ -25,6 +26,16 @@ if MYPY:
logger = logging.getLogger(__name__)
+class RegistrationBehaviour(Enum):
+ """
+ Enum to define whether a registration request should allowed, denied, or shadow-banned.
+ """
+
+ ALLOW = "allow"
+ SHADOW_BAN = "shadow_ban"
+ DENY = "deny"
+
+
class SpamCheckerApi(object):
"""A proxy object that gets passed to spam checkers so they can get
access to rooms and other relevant information.
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index dba8d91eef..a601303fa3 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,11 +16,22 @@
import logging
from collections import namedtuple
-from typing import Awaitable, Dict, Iterable, List, Optional, Set
+from typing import (
+ Awaitable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Union,
+ overload,
+)
import attr
from frozendict import frozendict
from prometheus_client import Histogram
+from typing_extensions import Literal
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
@@ -30,7 +41,7 @@ from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
-from synapse.types import StateMap
+from synapse.types import Collection, StateMap
from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
@@ -68,8 +79,14 @@ def _gen_state_id():
class _StateCacheEntry(object):
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
- def __init__(self, state, state_group, prev_group=None, delta_ids=None):
- # dict[(str, str), str] map from (type, state_key) to event_id
+ def __init__(
+ self,
+ state: StateMap[str],
+ state_group: Optional[int],
+ prev_group: Optional[int] = None,
+ delta_ids: Optional[StateMap[str]] = None,
+ ):
+ # A map from (type, state_key) to event_id.
self.state = frozendict(state)
# the ID of a state group if one and only one is involved.
@@ -107,24 +124,49 @@ class StateHandler(object):
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
+ @overload
async def get_current_state(
- self, room_id, event_type=None, state_key="", latest_event_ids=None
- ):
- """ Retrieves the current state for the room. This is done by
+ self,
+ room_id: str,
+ event_type: Literal[None] = None,
+ state_key: str = "",
+ latest_event_ids: Optional[List[str]] = None,
+ ) -> StateMap[EventBase]:
+ ...
+
+ @overload
+ async def get_current_state(
+ self,
+ room_id: str,
+ event_type: str,
+ state_key: str = "",
+ latest_event_ids: Optional[List[str]] = None,
+ ) -> Optional[EventBase]:
+ ...
+
+ async def get_current_state(
+ self,
+ room_id: str,
+ event_type: Optional[str] = None,
+ state_key: str = "",
+ latest_event_ids: Optional[List[str]] = None,
+ ) -> Union[Optional[EventBase], StateMap[EventBase]]:
+ """Retrieves the current state for the room. This is done by
calling `get_latest_events_in_room` to get the leading edges of the
event graph and then resolving any of the state conflicts.
This is equivalent to getting the state of an event that were to send
next before receiving any new events.
- If `event_type` is specified, then the method returns only the one
- event (or None) with that `event_type` and `state_key`.
-
Returns:
- map from (type, state_key) to event
+ If `event_type` is specified, then the method returns only the one
+ event (or None) with that `event_type` and `state_key`.
+
+ Otherwise, a map from (type, state_key) to event.
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
+ assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
@@ -140,34 +182,30 @@ class StateHandler(object):
state_map = await self.store.get_events(
list(state.values()), get_prev_content=False
)
- state = {
+ return {
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
}
- return state
-
- async def get_current_state_ids(self, room_id, latest_event_ids=None):
+ async def get_current_state_ids(
+ self, room_id: str, latest_event_ids: Optional[Iterable[str]] = None
+ ) -> StateMap[str]:
"""Get the current state, or the state at a set of events, for a room
Args:
- room_id (str):
-
- latest_event_ids (iterable[str]|None): if given, the forward
- extremities to resolve. If None, we look them up from the
- database (via a cache)
+ room_id:
+ latest_event_ids: if given, the forward extremities to resolve. If
+ None, we look them up from the database (via a cache).
Returns:
- Deferred[dict[(str, str), str)]]: the state dict, mapping from
- (event_type, state_key) -> event_id
+ the state dict, mapping from (event_type, state_key) -> event_id
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
+ assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
- state = ret.state
-
- return state
+ return dict(ret.state)
async def get_current_users_in_room(
self, room_id: str, latest_event_ids: Optional[List[str]] = None
@@ -183,32 +221,34 @@ class StateHandler(object):
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
+ assert latest_event_ids is not None
+
logger.debug("calling resolve_state_groups from get_current_users_in_room")
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
- joined_users = await self.store.get_joined_users_from_state(room_id, entry)
- return joined_users
+ return await self.store.get_joined_users_from_state(room_id, entry)
- async def get_current_hosts_in_room(self, room_id):
+ async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
return await self.get_hosts_in_room_at_events(room_id, event_ids)
- async def get_hosts_in_room_at_events(self, room_id, event_ids):
+ async def get_hosts_in_room_at_events(
+ self, room_id: str, event_ids: List[str]
+ ) -> Set[str]:
"""Get the hosts that were in a room at the given event ids
Args:
- room_id (str):
- event_ids (list[str]):
+ room_id:
+ event_ids:
Returns:
- Deferred[list[str]]: the hosts in the room at the given events
+ The hosts in the room at the given events
"""
entry = await self.resolve_state_groups_for_events(room_id, event_ids)
- joined_hosts = await self.store.get_joined_hosts(room_id, entry)
- return joined_hosts
+ return await self.store.get_joined_hosts(room_id, entry)
async def compute_event_context(
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
- ):
+ ) -> EventContext:
"""Build an EventContext structure for the event.
This works out what the current state should be for the event, and
@@ -221,7 +261,7 @@ class StateHandler(object):
when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling.
Returns:
- synapse.events.snapshot.EventContext:
+ The event context.
"""
if event.internal_metadata.is_outlier():
@@ -275,7 +315,7 @@ class StateHandler(object):
event.room_id, event.prev_event_ids()
)
- state_ids_before_event = entry.state
+ state_ids_before_event = dict(entry.state)
state_group_before_event = entry.state_group
state_group_before_event_prev_group = entry.prev_group
deltas_to_state_group_before_event = entry.delta_ids
@@ -346,19 +386,18 @@ class StateHandler(object):
)
@measure_func()
- async def resolve_state_groups_for_events(self, room_id, event_ids):
+ async def resolve_state_groups_for_events(
+ self, room_id: str, event_ids: Iterable[str]
+ ) -> _StateCacheEntry:
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
Args:
- room_id (str)
- event_ids (list[str])
- explicit_room_version (str|None): If set uses the the given room
- version to choose the resolution algorithm. If None, then
- checks the database for room version.
+ room_id
+ event_ids
Returns:
- Deferred[_StateCacheEntry]: resolved state
+ The resolved state
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
@@ -394,7 +433,12 @@ class StateHandler(object):
)
return result
- async def resolve_events(self, room_version, state_sets, event):
+ async def resolve_events(
+ self,
+ room_version: str,
+ state_sets: Collection[Iterable[EventBase]],
+ event: EventBase,
+ ) -> StateMap[EventBase]:
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
@@ -414,9 +458,7 @@ class StateHandler(object):
state_res_store=StateResolutionStore(self.store),
)
- new_state = {key: state_map[ev_id] for key, ev_id in new_state.items()}
-
- return new_state
+ return {key: state_map[ev_id] for key, ev_id in new_state.items()}
class StateResolutionHandler(object):
@@ -444,7 +486,12 @@ class StateResolutionHandler(object):
@log_function
async def resolve_state_groups(
- self, room_id, room_version, state_groups_ids, event_map, state_res_store
+ self,
+ room_id: str,
+ room_version: str,
+ state_groups_ids: Dict[int, StateMap[str]],
+ event_map: Optional[Dict[str, EventBase]],
+ state_res_store: "StateResolutionStore",
):
"""Resolves conflicts between a set of state groups
@@ -452,13 +499,13 @@ class StateResolutionHandler(object):
not be called for a single state group
Args:
- room_id (str): room we are resolving for (used for logging and sanity checks)
- room_version (str): version of the room
- state_groups_ids (dict[int, dict[(str, str), str]]):
- map from state group id to the state in that state group
+ room_id: room we are resolving for (used for logging and sanity checks)
+ room_version: version of the room
+ state_groups_ids:
+ A map from state group id to the state in that state group
(where 'state' is a map from state key to event id)
- event_map(dict[str,FrozenEvent]|None):
+ event_map:
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
@@ -466,10 +513,10 @@ class StateResolutionHandler(object):
If None, all events will be fetched via state_res_store.
- state_res_store (StateResolutionStore)
+ state_res_store
Returns:
- _StateCacheEntry: resolved state
+ The resolved state
"""
logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
@@ -530,21 +577,22 @@ class StateResolutionHandler(object):
return cache
-def _make_state_cache_entry(new_state, state_groups_ids):
+def _make_state_cache_entry(
+ new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]]
+) -> _StateCacheEntry:
"""Given a resolved state, and a set of input state groups, pick one to base
a new state group on (if any), and return an appropriately-constructed
_StateCacheEntry.
Args:
- new_state (dict[(str, str), str]): resolved state map (mapping from
- (type, state_key) to event_id)
+ new_state: resolved state map (mapping from (type, state_key) to event_id)
- state_groups_ids (dict[int, dict[(str, str), str]]):
- map from state group id to the state in that state group
- (where 'state' is a map from state key to event id)
+ state_groups_ids:
+ map from state group id to the state in that state group (where
+ 'state' is a map from state key to event id)
Returns:
- _StateCacheEntry
+ The cache entry.
"""
# if the new state matches any of the input state groups, we can
# use that state group again. Otherwise we will generate a state_id
@@ -585,7 +633,7 @@ def resolve_events_with_store(
clock: Clock,
room_id: str,
room_version: str,
- state_sets: List[StateMap[str]],
+ state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
) -> Awaitable[StateMap[str]]:
@@ -633,15 +681,17 @@ class StateResolutionStore(object):
store = attr.ib()
- def get_events(self, event_ids, allow_rejected=False):
+ def get_events(
+ self, event_ids: Iterable[str], allow_rejected: bool = False
+ ) -> Awaitable[Dict[str, EventBase]]:
"""Get events from the database
Args:
- event_ids (list): The event_ids of the events to fetch
- allow_rejected (bool): If True return rejected events.
+ event_ids: The event_ids of the events to fetch
+ allow_rejected: If True return rejected events.
Returns:
- Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event.
+ An awaitable which resolves to a dict from event_id to event.
"""
return self.store.get_events(
@@ -651,7 +701,9 @@ class StateResolutionStore(object):
allow_rejected=allow_rejected,
)
- def get_auth_chain_difference(self, state_sets: List[Set[str]]):
+ def get_auth_chain_difference(
+ self, state_sets: List[Set[str]]
+ ) -> Awaitable[Set[str]]:
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
@@ -660,7 +712,7 @@ class StateResolutionStore(object):
chain.
Returns:
- Deferred[Set[str]]: Set of event IDs.
+ An awaitable that resolves to a set of event IDs.
"""
return self.store.get_auth_chain_difference(state_sets)
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index ab5e24841d..0eb7fdd9e5 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -15,7 +15,17 @@
import hashlib
import logging
-from typing import Awaitable, Callable, Dict, List, Optional
+from typing import (
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+)
from synapse import event_auth
from synapse.api.constants import EventTypes
@@ -32,10 +42,10 @@ POWER_KEY = (EventTypes.PowerLevels, "")
async def resolve_events_with_store(
room_id: str,
- state_sets: List[StateMap[str]],
+ state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
- state_map_factory: Callable[[List[str]], Awaitable],
-):
+ state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
+) -> StateMap[str]:
"""
Args:
room_id: the room we are working in
@@ -56,8 +66,7 @@ async def resolve_events_with_store(
an Awaitable that resolves to a dict of event_id to event.
Returns:
- Deferred[dict[(str, str), str]]:
- a map from (type, state_key) to event_id.
+ A map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
return state_sets[0]
@@ -75,8 +84,8 @@ async def resolve_events_with_store(
"Asking for %d/%d conflicted events", len(needed_events), needed_event_count
)
- # dict[str, FrozenEvent]: a map from state event id to event. Only includes
- # the state events which are in conflict (and those in event_map)
+ # A map from state event id to event. Only includes the state events which
+ # are in conflict (and those in event_map).
state_map = await state_map_factory(needed_events)
if event_map is not None:
state_map.update(event_map)
@@ -91,8 +100,6 @@ async def resolve_events_with_store(
# get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state.
- #
- # dict[(str, str), str]: a map from state key to event id
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
@@ -122,29 +129,30 @@ async def resolve_events_with_store(
)
-def _seperate(state_sets):
+def _seperate(
+ state_sets: Iterable[StateMap[str]],
+) -> Tuple[StateMap[str], StateMap[Set[str]]]:
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
Args:
- state_sets(iterable[dict[(str, str), str]]):
+ state_sets:
List of dicts of (type, state_key) -> event_id, which are the
different state groups to resolve.
Returns:
- (dict[(str, str), str], dict[(str, str), set[str]]):
- A tuple of (unconflicted_state, conflicted_state), where:
+ A tuple of (unconflicted_state, conflicted_state), where:
- unconflicted_state is a dict mapping (type, state_key)->event_id
- for unconflicted state keys.
+ unconflicted_state is a dict mapping (type, state_key)->event_id
+ for unconflicted state keys.
- conflicted_state is a dict mapping (type, state_key) to a set of
- event ids for conflicted state keys.
+ conflicted_state is a dict mapping (type, state_key) to a set of
+ event ids for conflicted state keys.
"""
state_set_iterator = iter(state_sets)
unconflicted_state = dict(next(state_set_iterator))
- conflicted_state = {}
+ conflicted_state = {} # type: StateMap[Set[str]]
for state_set in state_set_iterator:
for key, value in state_set.items():
@@ -171,7 +179,21 @@ def _seperate(state_sets):
return unconflicted_state, conflicted_state
-def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
+def _create_auth_events_from_maps(
+ unconflicted_state: StateMap[str],
+ conflicted_state: StateMap[Set[str]],
+ state_map: Dict[str, EventBase],
+) -> StateMap[str]:
+ """
+
+ Args:
+ unconflicted_state: The unconflicted state map.
+ conflicted_state: The conflicted state map.
+ state_map:
+
+ Returns:
+ A map from state key to event id.
+ """
auth_events = {}
for event_ids in conflicted_state.values():
for event_id in event_ids:
@@ -179,14 +201,17 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
keys = event_auth.auth_types_for_event(state_map[event_id])
for key in keys:
if key not in auth_events:
- event_id = unconflicted_state.get(key, None)
- if event_id:
- auth_events[key] = event_id
+ auth_event_id = unconflicted_state.get(key, None)
+ if auth_event_id:
+ auth_events[key] = auth_event_id
return auth_events
def _resolve_with_state(
- unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map
+ unconflicted_state_ids: StateMap[str],
+ conflicted_state_ids: StateMap[Set[str]],
+ auth_event_ids: StateMap[str],
+ state_map: Dict[str, EventBase],
):
conflicted_state = {}
for key, event_ids in conflicted_state_ids.items():
@@ -215,7 +240,9 @@ def _resolve_with_state(
return new_state
-def _resolve_state_events(conflicted_state, auth_events):
+def _resolve_state_events(
+ conflicted_state: StateMap[List[EventBase]], auth_events: StateMap[EventBase]
+) -> StateMap[EventBase]:
""" This is where we actually decide which of the conflicted state to
use.
@@ -255,7 +282,9 @@ def _resolve_state_events(conflicted_state, auth_events):
return resolved_state
-def _resolve_auth_events(events, auth_events):
+def _resolve_auth_events(
+ events: List[EventBase], auth_events: StateMap[EventBase]
+) -> EventBase:
reverse = list(reversed(_ordered_events(events)))
auth_keys = {
@@ -289,7 +318,9 @@ def _resolve_auth_events(events, auth_events):
return event
-def _resolve_normal_events(events, auth_events):
+def _resolve_normal_events(
+ events: List[EventBase], auth_events: StateMap[EventBase]
+) -> EventBase:
for event in _ordered_events(events):
try:
# The signatures have already been checked at this point
@@ -309,7 +340,7 @@ def _resolve_normal_events(events, auth_events):
return event
-def _ordered_events(events):
+def _ordered_events(events: Iterable[EventBase]) -> List[EventBase]:
def key_func(e):
# we have to use utf-8 rather than ascii here because it turns out we allow
# people to send us events with non-ascii event IDs :/
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 6634955cdc..0e9ffbd6e6 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -16,7 +16,21 @@
import heapq
import itertools
import logging
-from typing import Dict, List, Optional
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Generator,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ overload,
+)
+
+from typing_extensions import Literal
import synapse.state
from synapse import event_auth
@@ -40,10 +54,10 @@ async def resolve_events_with_store(
clock: Clock,
room_id: str,
room_version: str,
- state_sets: List[StateMap[str]],
+ state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "synapse.state.StateResolutionStore",
-):
+) -> StateMap[str]:
"""Resolves the state using the v2 state resolution algorithm
Args:
@@ -63,8 +77,7 @@ async def resolve_events_with_store(
state_res_store:
Returns:
- Deferred[dict[(str, str), str]]:
- a map from (type, state_key) to event_id.
+ A map from (type, state_key) to event_id.
"""
logger.debug("Computing conflicted state")
@@ -171,18 +184,23 @@ async def resolve_events_with_store(
return resolved_state
-async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
+async def _get_power_level_for_sender(
+ room_id: str,
+ event_id: str,
+ event_map: Dict[str, EventBase],
+ state_res_store: "synapse.state.StateResolutionStore",
+) -> int:
"""Return the power level of the sender of the given event according to
their auth events.
Args:
- room_id (str)
- event_id (str)
- event_map (dict[str,FrozenEvent])
- state_res_store (StateResolutionStore)
+ room_id
+ event_id
+ event_map
+ state_res_store
Returns:
- Deferred[int]
+ The power level.
"""
event = await _get_event(room_id, event_id, event_map, state_res_store)
@@ -217,17 +235,21 @@ async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_st
return int(level)
-async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
+async def _get_auth_chain_difference(
+ state_sets: Sequence[StateMap[str]],
+ event_map: Dict[str, EventBase],
+ state_res_store: "synapse.state.StateResolutionStore",
+) -> Set[str]:
"""Compare the auth chains of each state set and return the set of events
that only appear in some but not all of the auth chains.
Args:
- state_sets (list)
- event_map (dict[str,FrozenEvent])
- state_res_store (StateResolutionStore)
+ state_sets
+ event_map
+ state_res_store
Returns:
- Deferred[set[str]]: Set of event IDs
+ Set of event IDs
"""
difference = await state_res_store.get_auth_chain_difference(
@@ -237,17 +259,19 @@ async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
return difference
-def _seperate(state_sets):
+def _seperate(
+ state_sets: Iterable[StateMap[str]],
+) -> Tuple[StateMap[str], StateMap[Set[str]]]:
"""Return the unconflicted and conflicted state. This is different than in
the original algorithm, as this defines a key to be conflicted if one of
the state sets doesn't have that key.
Args:
- state_sets (list)
+ state_sets
Returns:
- tuple[dict, dict]: A tuple of unconflicted and conflicted state. The
- conflicted state dict is a map from type/state_key to set of event IDs
+ A tuple of unconflicted and conflicted state. The conflicted state dict
+ is a map from type/state_key to set of event IDs
"""
unconflicted_state = {}
conflicted_state = {}
@@ -260,18 +284,20 @@ def _seperate(state_sets):
event_ids.discard(None)
conflicted_state[key] = event_ids
- return unconflicted_state, conflicted_state
+ # mypy doesn't understand that discarding None above means that conflicted
+ # state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]].
+ return unconflicted_state, conflicted_state # type: ignore
-def _is_power_event(event):
+def _is_power_event(event: EventBase) -> bool:
"""Return whether or not the event is a "power event", as defined by the
v2 state resolution algorithm
Args:
- event (FrozenEvent)
+ event
Returns:
- boolean
+ True if the event is a power event.
"""
if (event.type, event.state_key) in (
(EventTypes.PowerLevels, ""),
@@ -288,19 +314,23 @@ def _is_power_event(event):
async def _add_event_and_auth_chain_to_graph(
- graph, room_id, event_id, event_map, state_res_store, auth_diff
-):
+ graph: Dict[str, Set[str]],
+ room_id: str,
+ event_id: str,
+ event_map: Dict[str, EventBase],
+ state_res_store: "synapse.state.StateResolutionStore",
+ auth_diff: Set[str],
+) -> None:
"""Helper function for _reverse_topological_power_sort that add the event
and its auth chain (that is in the auth diff) to the graph
Args:
- graph (dict[str, set[str]]): A map from event ID to the events auth
- event IDs
- room_id (str): the room we are working in
- event_id (str): Event to add to the graph
- event_map (dict[str,FrozenEvent])
- state_res_store (StateResolutionStore)
- auth_diff (set[str]): Set of event IDs that are in the auth difference.
+ graph: A map from event ID to the events auth event IDs
+ room_id: the room we are working in
+ event_id: Event to add to the graph
+ event_map
+ state_res_store
+ auth_diff: Set of event IDs that are in the auth difference.
"""
state = [event_id]
@@ -318,24 +348,29 @@ async def _add_event_and_auth_chain_to_graph(
async def _reverse_topological_power_sort(
- clock, room_id, event_ids, event_map, state_res_store, auth_diff
-):
+ clock: Clock,
+ room_id: str,
+ event_ids: Iterable[str],
+ event_map: Dict[str, EventBase],
+ state_res_store: "synapse.state.StateResolutionStore",
+ auth_diff: Set[str],
+) -> List[str]:
"""Returns a list of the event_ids sorted by reverse topological ordering,
and then by power level and origin_server_ts
Args:
- clock (Clock)
- room_id (str): the room we are working in
- event_ids (list[str]): The events to sort
- event_map (dict[str,FrozenEvent])
- state_res_store (StateResolutionStore)
- auth_diff (set[str]): Set of event IDs that are in the auth difference.
+ clock
+ room_id: the room we are working in
+ event_ids: The events to sort
+ event_map
+ state_res_store
+ auth_diff: Set of event IDs that are in the auth difference.
Returns:
- Deferred[list[str]]: The sorted list
+ The sorted list
"""
- graph = {}
+ graph = {} # type: Dict[str, Set[str]]
for idx, event_id in enumerate(event_ids, start=1):
await _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff
@@ -372,22 +407,28 @@ async def _reverse_topological_power_sort(
async def _iterative_auth_checks(
- clock, room_id, room_version, event_ids, base_state, event_map, state_res_store
-):
+ clock: Clock,
+ room_id: str,
+ room_version: str,
+ event_ids: List[str],
+ base_state: StateMap[str],
+ event_map: Dict[str, EventBase],
+ state_res_store: "synapse.state.StateResolutionStore",
+) -> StateMap[str]:
"""Sequentially apply auth checks to each event in given list, updating the
state as it goes along.
Args:
- clock (Clock)
- room_id (str)
- room_version (str)
- event_ids (list[str]): Ordered list of events to apply auth checks to
- base_state (StateMap[str]): The set of state to start with
- event_map (dict[str,FrozenEvent])
- state_res_store (StateResolutionStore)
+ clock
+ room_id
+ room_version
+ event_ids: Ordered list of events to apply auth checks to
+ base_state: The set of state to start with
+ event_map
+ state_res_store
Returns:
- Deferred[StateMap[str]]: Returns the final updated state
+ Returns the final updated state
"""
resolved_state = base_state.copy()
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
@@ -439,21 +480,26 @@ async def _iterative_auth_checks(
async def _mainline_sort(
- clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store
-):
+ clock: Clock,
+ room_id: str,
+ event_ids: List[str],
+ resolved_power_event_id: Optional[str],
+ event_map: Dict[str, EventBase],
+ state_res_store: "synapse.state.StateResolutionStore",
+) -> List[str]:
"""Returns a sorted list of event_ids sorted by mainline ordering based on
the given event resolved_power_event_id
Args:
- clock (Clock)
- room_id (str): room we're working in
- event_ids (list[str]): Events to sort
- resolved_power_event_id (str): The final resolved power level event ID
- event_map (dict[str,FrozenEvent])
- state_res_store (StateResolutionStore)
+ clock
+ room_id: room we're working in
+ event_ids: Events to sort
+ resolved_power_event_id: The final resolved power level event ID
+ event_map
+ state_res_store
Returns:
- Deferred[list[str]]: The sorted list
+ The sorted list
"""
if not event_ids:
# It's possible for there to be no event IDs here to sort, so we can
@@ -505,59 +551,90 @@ async def _mainline_sort(
async def _get_mainline_depth_for_event(
- event, mainline_map, event_map, state_res_store
-):
+ event: EventBase,
+ mainline_map: Dict[str, int],
+ event_map: Dict[str, EventBase],
+ state_res_store: "synapse.state.StateResolutionStore",
+) -> int:
"""Get the mainline depths for the given event based on the mainline map
Args:
- event (FrozenEvent)
- mainline_map (dict[str, int]): Map from event_id to mainline depth for
- events in the mainline.
- event_map (dict[str,FrozenEvent])
- state_res_store (StateResolutionStore)
+ event
+ mainline_map: Map from event_id to mainline depth for events in the mainline.
+ event_map
+ state_res_store
Returns:
- Deferred[int]
+ The mainline depth
"""
room_id = event.room_id
+ tmp_event = event # type: Optional[EventBase]
# We do an iterative search, replacing `event with the power level in its
# auth events (if any)
- while event:
+ while tmp_event:
depth = mainline_map.get(event.event_id)
if depth is not None:
return depth
- auth_events = event.auth_event_ids()
- event = None
+ auth_events = tmp_event.auth_event_ids()
+ tmp_event = None
for aid in auth_events:
aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
- event = aev
+ tmp_event = aev
break
# Didn't find a power level auth event, so we just return 0
return 0
-async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
+@overload
+async def _get_event(
+ room_id: str,
+ event_id: str,
+ event_map: Dict[str, EventBase],
+ state_res_store: "synapse.state.StateResolutionStore",
+ allow_none: Literal[False] = False,
+) -> EventBase:
+ ...
+
+
+@overload
+async def _get_event(
+ room_id: str,
+ event_id: str,
+ event_map: Dict[str, EventBase],
+ state_res_store: "synapse.state.StateResolutionStore",
+ allow_none: Literal[True],
+) -> Optional[EventBase]:
+ ...
+
+
+async def _get_event(
+ room_id: str,
+ event_id: str,
+ event_map: Dict[str, EventBase],
+ state_res_store: "synapse.state.StateResolutionStore",
+ allow_none: bool = False,
+) -> Optional[EventBase]:
"""Helper function to look up event in event_map, falling back to looking
it up in the store
Args:
- room_id (str)
- event_id (str)
- event_map (dict[str,FrozenEvent])
- state_res_store (StateResolutionStore)
- allow_none (bool): if the event is not found, return None rather than raising
+ room_id
+ event_id
+ event_map
+ state_res_store
+ allow_none: if the event is not found, return None rather than raising
an exception
Returns:
- Deferred[Optional[FrozenEvent]]
+ The event, or none if the event does not exist (and allow_none is True).
"""
if event_id not in event_map:
events = await state_res_store.get_events([event_id], allow_rejected=True)
@@ -577,7 +654,9 @@ async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=F
return event
-def lexicographical_topological_sort(graph, key):
+def lexicographical_topological_sort(
+ graph: Dict[str, Set[str]], key: Callable[[str], Any]
+) -> Generator[str, None, None]:
"""Performs a lexicographic reverse topological sort on the graph.
This returns a reverse topological sort (i.e. if node A references B then B
@@ -587,20 +666,20 @@ def lexicographical_topological_sort(graph, key):
NOTE: `graph` is modified during the sort.
Args:
- graph (dict[str, set[str]]): A representation of the graph where each
- node is a key in the dict and its value are the nodes edges.
- key (func): A function that takes a node and returns a value that is
- comparable and used to order nodes
+ graph: A representation of the graph where each node is a key in the
+ dict and its value are the nodes edges.
+ key: A function that takes a node and returns a value that is comparable
+ and used to order nodes
Yields:
- str: The next node in the topological sort
+ The next node in the topological sort
"""
# Note, this is basically Kahn's algorithm except we look at nodes with no
# outgoing edges, c.f.
# https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
outdegree_map = graph
- reverse_graph = {}
+ reverse_graph = {} # type: Dict[str, Set[str]]
# Lists of nodes with zero out degree. Is actually a tuple of
# `(key(node), node)` so that sorting does the right thing
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 90a1f9e8b1..56818f4df8 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -16,9 +16,8 @@
import logging
from typing import Optional
-from canonicaljson import json
-
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import json_encoder
from . import engines
@@ -457,7 +456,7 @@ class BackgroundUpdater(object):
progress(dict): The progress of the update.
"""
- progress_json = json.dumps(progress)
+ progress_json = json_encoder.encode(progress)
self.db_pool.simple_update_one_txn(
txn,
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index b9aef96b08..bc327e344e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -28,6 +28,7 @@ from typing import (
Optional,
Tuple,
TypeVar,
+ Union,
)
from prometheus_client import Histogram
@@ -125,7 +126,7 @@ class LoggingTransaction:
method.
Args:
- txn: The database transcation object to wrap.
+ txn: The database transaction object to wrap.
name: The name of this transactions for logging.
database_engine
after_callbacks: A list that callbacks will be appended to
@@ -160,7 +161,7 @@ class LoggingTransaction:
self.after_callbacks = after_callbacks
self.exception_callbacks = exception_callbacks
- def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
+ def call_after(self, callback: "Callable[..., None]", *args: Any, **kwargs: Any):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
@@ -171,7 +172,9 @@ class LoggingTransaction:
assert self.after_callbacks is not None
self.after_callbacks.append((callback, args, kwargs))
- def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
+ def call_on_exception(
+ self, callback: "Callable[..., None]", *args: Any, **kwargs: Any
+ ):
# if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
@@ -195,7 +198,7 @@ class LoggingTransaction:
def description(self) -> Any:
return self.txn.description
- def execute_batch(self, sql, args):
+ def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch # type: ignore
@@ -204,17 +207,17 @@ class LoggingTransaction:
for val in args:
self.execute(sql, val)
- def execute(self, sql: str, *args: Any):
+ def execute(self, sql: str, *args: Any) -> None:
self._do_execute(self.txn.execute, sql, *args)
- def executemany(self, sql: str, *args: Any):
+ def executemany(self, sql: str, *args: Any) -> None:
self._do_execute(self.txn.executemany, sql, *args)
def _make_sql_one_line(self, sql: str) -> str:
"Strip newlines out of SQL so that the loggers in the DB are on one line"
return " ".join(line.strip() for line in sql.splitlines() if line.strip())
- def _do_execute(self, func, sql, *args):
+ def _do_execute(self, func, sql: str, *args: Any) -> None:
sql = self._make_sql_one_line(sql)
# TODO(paul): Maybe use 'info' and 'debug' for values?
@@ -240,7 +243,7 @@ class LoggingTransaction:
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
sql_query_timer.labels(sql.split()[0]).observe(secs)
- def close(self):
+ def close(self) -> None:
self.txn.close()
@@ -249,13 +252,13 @@ class PerformanceCounters(object):
self.current_counters = {}
self.previous_counters = {}
- def update(self, key, duration_secs):
+ def update(self, key: str, duration_secs: float) -> None:
count, cum_time = self.current_counters.get(key, (0, 0))
count += 1
cum_time += duration_secs
self.current_counters[key] = (count, cum_time)
- def interval(self, interval_duration_secs, limit=3):
+ def interval(self, interval_duration_secs: float, limit: int = 3) -> str:
counters = []
for name, (count, cum_time) in self.current_counters.items():
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
@@ -279,6 +282,9 @@ class PerformanceCounters(object):
return top_n_counters
+R = TypeVar("R")
+
+
class DatabasePool(object):
"""Wraps a single physical database and connection pool.
@@ -327,12 +333,12 @@ class DatabasePool(object):
self._check_safe_to_upsert,
)
- def is_running(self):
+ def is_running(self) -> bool:
"""Is the database pool currently running
"""
return self._db_pool.running
- async def _check_safe_to_upsert(self):
+ async def _check_safe_to_upsert(self) -> None:
"""
Is it safe to use native UPSERT?
@@ -363,7 +369,7 @@ class DatabasePool(object):
self._check_safe_to_upsert,
)
- def start_profiling(self):
+ def start_profiling(self) -> None:
self._previous_loop_ts = monotonic_time()
def loop():
@@ -387,8 +393,15 @@ class DatabasePool(object):
self._clock.looping_call(loop, 10000)
def new_transaction(
- self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
- ):
+ self,
+ conn: Connection,
+ desc: str,
+ after_callbacks: List[_CallbackListEntry],
+ exception_callbacks: List[_CallbackListEntry],
+ func: "Callable[..., R]",
+ *args: Any,
+ **kwargs: Any
+ ) -> R:
start = monotonic_time()
txn_id = self._TXN_ID
@@ -537,7 +550,9 @@ class DatabasePool(object):
return result
- async def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
+ async def runWithConnection(
+ self, func: "Callable[..., R]", *args: Any, **kwargs: Any
+ ) -> R:
"""Wraps the .runWithConnection() method on the underlying db_pool.
Arguments:
@@ -576,11 +591,11 @@ class DatabasePool(object):
)
@staticmethod
- def cursor_to_dict(cursor):
+ def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
"""Converts a SQL cursor into an list of dicts.
Args:
- cursor : The DBAPI cursor which has executed a query.
+ cursor: The DBAPI cursor which has executed a query.
Returns:
A list of dicts where the key is the column header.
"""
@@ -588,7 +603,7 @@ class DatabasePool(object):
results = [dict(zip(col_headers, row)) for row in cursor]
return results
- def execute(self, desc, decoder, query, *args):
+ def execute(self, desc: str, decoder: Callable, query: str, *args: Any):
"""Runs a single query for a result set.
Args:
@@ -597,7 +612,7 @@ class DatabasePool(object):
query - The query string to execute
*args - Query args.
Returns:
- The result of decoder(results)
+ Deferred which results to the result of decoder(results)
"""
def interaction(txn):
@@ -612,20 +627,25 @@ class DatabasePool(object):
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
- async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
+ async def simple_insert(
+ self,
+ table: str,
+ values: Dict[str, Any],
+ or_ignore: bool = False,
+ desc: str = "simple_insert",
+ ) -> bool:
"""Executes an INSERT query on the named table.
Args:
- table : string giving the table name
- values : dict of new column names and values for them
- or_ignore : bool stating whether an exception should be raised
+ table: string giving the table name
+ values: dict of new column names and values for them
+ or_ignore: bool stating whether an exception should be raised
when a conflicting row already exists. If True, False will be
returned by the function instead
- desc : string giving a description of the transaction
+ desc: string giving a description of the transaction
Returns:
- bool: Whether the row was inserted or not. Only useful when
- `or_ignore` is True
+ Whether the row was inserted or not. Only useful when `or_ignore` is True
"""
try:
await self.runInteraction(desc, self.simple_insert_txn, table, values)
@@ -638,7 +658,9 @@ class DatabasePool(object):
return True
@staticmethod
- def simple_insert_txn(txn, table, values):
+ def simple_insert_txn(
+ txn: LoggingTransaction, table: str, values: Dict[str, Any]
+ ) -> None:
keys, vals = zip(*values.items())
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@@ -649,11 +671,15 @@ class DatabasePool(object):
txn.execute(sql, vals)
- def simple_insert_many(self, table, values, desc):
+ def simple_insert_many(
+ self, table: str, values: List[Dict[str, Any]], desc: str
+ ) -> defer.Deferred:
return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
@staticmethod
- def simple_insert_many_txn(txn, table, values):
+ def simple_insert_many_txn(
+ txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
+ ) -> None:
if not values:
return
@@ -683,13 +709,13 @@ class DatabasePool(object):
async def simple_upsert(
self,
- table,
- keyvalues,
- values,
- insertion_values={},
- desc="simple_upsert",
- lock=True,
- ):
+ table: str,
+ keyvalues: Dict[str, Any],
+ values: Dict[str, Any],
+ insertion_values: Dict[str, Any] = {},
+ desc: str = "simple_upsert",
+ lock: bool = True,
+ ) -> Optional[bool]:
"""
`lock` should generally be set to True (the default), but can be set
@@ -703,16 +729,14 @@ class DatabasePool(object):
this table.
Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key columns and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
+ table: The table to upsert into
+ keyvalues: The unique key columns and their new values
+ values: The nonunique columns and their new values
+ insertion_values: additional key/values to use only when inserting
+ lock: True to lock the table when doing the upsert.
Returns:
- None or bool: Native upserts always return None. Emulated
- upserts return True if a new entry was created, False if an existing
- one was updated.
+ Native upserts always return None. Emulated upserts return True if a
+ new entry was created, False if an existing one was updated.
"""
attempts = 0
while True:
@@ -739,29 +763,34 @@ class DatabasePool(object):
)
def simple_upsert_txn(
- self, txn, table, keyvalues, values, insertion_values={}, lock=True
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ values: Dict[str, Any],
+ insertion_values: Dict[str, Any] = {},
+ lock: bool = True,
+ ) -> Optional[bool]:
"""
Pick the UPSERT method which works best on the platform. Either the
native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
Args:
txn: The transaction to use.
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
+ table: The table to upsert into
+ keyvalues: The unique key tables and their new values
+ values: The nonunique columns and their new values
+ insertion_values: additional key/values to use only when inserting
+ lock: True to lock the table when doing the upsert.
Returns:
- None or bool: Native upserts always return None. Emulated
- upserts return True if a new entry was created, False if an existing
- one was updated.
+ Native upserts always return None. Emulated upserts return True if a
+ new entry was created, False if an existing one was updated.
"""
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
- return self.simple_upsert_txn_native_upsert(
+ self.simple_upsert_txn_native_upsert(
txn, table, keyvalues, values, insertion_values=insertion_values
)
+ return None
else:
return self.simple_upsert_txn_emulated(
txn,
@@ -773,18 +802,23 @@ class DatabasePool(object):
)
def simple_upsert_txn_emulated(
- self, txn, table, keyvalues, values, insertion_values={}, lock=True
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ values: Dict[str, Any],
+ insertion_values: Dict[str, Any] = {},
+ lock: bool = True,
+ ) -> bool:
"""
Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
+ table: The table to upsert into
+ keyvalues: The unique key tables and their new values
+ values: The nonunique columns and their new values
+ insertion_values: additional key/values to use only when inserting
+ lock: True to lock the table when doing the upsert.
Returns:
- bool: Return True if a new entry was created, False if an existing
+ Returns True if a new entry was created, False if an existing
one was updated.
"""
# We need to lock the table :(, unless we're *really* careful
@@ -842,19 +876,21 @@ class DatabasePool(object):
return True
def simple_upsert_txn_native_upsert(
- self, txn, table, keyvalues, values, insertion_values={}
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ values: Dict[str, Any],
+ insertion_values: Dict[str, Any] = {},
+ ) -> None:
"""
Use the native UPSERT functionality in recent PostgreSQL versions.
Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- Returns:
- None
+ table: The table to upsert into
+ keyvalues: The unique key tables and their new values
+ values: The nonunique columns and their new values
+ insertion_values: additional key/values to use only when inserting
"""
allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues)
@@ -985,18 +1021,22 @@ class DatabasePool(object):
return txn.execute_batch(sql, args)
def simple_select_one(
- self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
- ):
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ allow_none: bool = False,
+ desc: str = "simple_select_one",
+ ) -> defer.Deferred:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- retcols : list of strings giving the names of the columns to return
-
- allow_none : If true, return None instead of failing if the SELECT
- statement returns no rows
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
+ retcols: list of strings giving the names of the columns to return
+ allow_none: If true, return None instead of failing if the SELECT
+ statement returns no rows
"""
return self.runInteraction(
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
@@ -1004,19 +1044,22 @@ class DatabasePool(object):
def simple_select_one_onecol(
self,
- table,
- keyvalues,
- retcol,
- allow_none=False,
- desc="simple_select_one_onecol",
- ):
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: Iterable[str],
+ allow_none: bool = False,
+ desc: str = "simple_select_one_onecol",
+ ) -> defer.Deferred:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- retcol : string giving the name of the column to return
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
+ retcol: string giving the name of the column to return
+ allow_none: If true, return None instead of failing if the SELECT
+ statement returns no rows
+ desc: description of the transaction, for logging and metrics
"""
return self.runInteraction(
desc,
@@ -1029,8 +1072,13 @@ class DatabasePool(object):
@classmethod
def simple_select_one_onecol_txn(
- cls, txn, table, keyvalues, retcol, allow_none=False
- ):
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: Iterable[str],
+ allow_none: bool = False,
+ ) -> Optional[Any]:
ret = cls.simple_select_onecol_txn(
txn, table=table, keyvalues=keyvalues, retcol=retcol
)
@@ -1044,7 +1092,12 @@ class DatabasePool(object):
raise StoreError(404, "No row found")
@staticmethod
- def simple_select_onecol_txn(txn, table, keyvalues, retcol):
+ def simple_select_onecol_txn(
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: Iterable[str],
+ ) -> List[Any]:
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
if keyvalues:
@@ -1056,15 +1109,19 @@ class DatabasePool(object):
return [r[0] for r in txn]
def simple_select_onecol(
- self, table, keyvalues, retcol, desc="simple_select_onecol"
- ):
+ self,
+ table: str,
+ keyvalues: Optional[Dict[str, Any]],
+ retcol: str,
+ desc: str = "simple_select_onecol",
+ ) -> defer.Deferred:
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
Args:
- table (str): table name
- keyvalues (dict|None): column names and values to select the rows with
- retcol (str): column whos value we wish to retrieve.
+ table: table name
+ keyvalues: column names and values to select the rows with
+ retcol: column whos value we wish to retrieve.
Returns:
Deferred: Results in a list
@@ -1073,16 +1130,22 @@ class DatabasePool(object):
desc, self.simple_select_onecol_txn, table, keyvalues, retcol
)
- def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
+ def simple_select_list(
+ self,
+ table: str,
+ keyvalues: Optional[Dict[str, Any]],
+ retcols: Iterable[str],
+ desc: str = "simple_select_list",
+ ) -> defer.Deferred:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
- table (str): the table name
- keyvalues (dict[str, Any] | None):
+ table: the table name
+ keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
- retcols (iterable[str]): the names of the columns to return
+ retcols: the names of the columns to return
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
@@ -1091,17 +1154,23 @@ class DatabasePool(object):
)
@classmethod
- def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
+ def simple_select_list_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Optional[Dict[str, Any]],
+ retcols: Iterable[str],
+ ) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
- txn : Transaction object
- table (str): the table name
- keyvalues (dict[str, T] | None):
+ txn: Transaction object
+ table: the table name
+ keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
- retcols (iterable[str]): the names of the columns to return
+ retcols: the names of the columns to return
"""
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % (
@@ -1118,25 +1187,25 @@ class DatabasePool(object):
async def simple_select_many_batch(
self,
- table,
- column,
- iterable,
- retcols,
- keyvalues={},
- desc="simple_select_many_batch",
- batch_size=100,
- ):
+ table: str,
+ column: str,
+ iterable: Iterable[Any],
+ retcols: Iterable[str],
+ keyvalues: Dict[str, Any] = {},
+ desc: str = "simple_select_many_batch",
+ batch_size: int = 100,
+ ) -> List[Any]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Filters rows by if value of `column` is in `iterable`.
Args:
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
- retcols : list of strings giving the names of the columns to return
+ table: string giving the table name
+ column: column name to test for inclusion against `iterable`
+ iterable: list
+ keyvalues: dict of column names and values to select the rows with
+ retcols: list of strings giving the names of the columns to return
"""
results = [] # type: List[Dict[str, Any]]
@@ -1165,19 +1234,27 @@ class DatabasePool(object):
return results
@classmethod
- def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
+ def simple_select_many_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ column: str,
+ iterable: Iterable[Any],
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ ) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Filters rows by if value of `column` is in `iterable`.
Args:
- txn : Transaction object
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
- retcols : list of strings giving the names of the columns to return
+ txn: Transaction object
+ table: string giving the table name
+ column: column name to test for inclusion against `iterable`
+ iterable: list
+ keyvalues: dict of column names and values to select the rows with
+ retcols: list of strings giving the names of the columns to return
"""
if not iterable:
return []
@@ -1198,13 +1275,24 @@ class DatabasePool(object):
txn.execute(sql, values)
return cls.cursor_to_dict(txn)
- def simple_update(self, table, keyvalues, updatevalues, desc):
+ def simple_update(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ updatevalues: Dict[str, Any],
+ desc: str,
+ ) -> defer.Deferred:
return self.runInteraction(
desc, self.simple_update_txn, table, keyvalues, updatevalues
)
@staticmethod
- def simple_update_txn(txn, table, keyvalues, updatevalues):
+ def simple_update_txn(
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ updatevalues: Dict[str, Any],
+ ) -> int:
if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
else:
@@ -1221,31 +1309,32 @@ class DatabasePool(object):
return txn.rowcount
def simple_update_one(
- self, table, keyvalues, updatevalues, desc="simple_update_one"
- ):
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ updatevalues: Dict[str, Any],
+ desc: str = "simple_update_one",
+ ) -> defer.Deferred:
"""Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- updatevalues : dict giving column names and values to update
- retcols : optional list of column names to return
-
- If present, retcols gives a list of column names on which to perform
- a SELECT statement *before* performing the UPDATE statement. The values
- of these will be returned in a dict.
-
- These are performed within the same transaction, allowing an atomic
- get-and-set. This can be used to implement compare-and-set by putting
- the update column in the 'keyvalues' dict as well.
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
+ updatevalues: dict giving column names and values to update
"""
return self.runInteraction(
desc, self.simple_update_one_txn, table, keyvalues, updatevalues
)
@classmethod
- def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+ def simple_update_one_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ updatevalues: Dict[str, Any],
+ ) -> None:
rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
if rowcount == 0:
@@ -1253,8 +1342,18 @@ class DatabasePool(object):
if rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))
+ # Ideally we could use the overload decorator here to specify that the
+ # return type is only optional if allow_none is True, but this does not work
+ # when you call a static method from an instance.
+ # See https://github.com/python/mypy/issues/7781
@staticmethod
- def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
+ def simple_select_one_txn(
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ allow_none: bool = False,
+ ) -> Optional[Dict[str, Any]]:
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
@@ -1273,24 +1372,28 @@ class DatabasePool(object):
return dict(zip(retcols, row))
- def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
+ def simple_delete_one(
+ self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
+ ) -> defer.Deferred:
"""Executes a DELETE query on the named table, expecting to delete a
single row.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
"""
return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
@staticmethod
- def simple_delete_one_txn(txn, table, keyvalues):
+ def simple_delete_one_txn(
+ txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
+ ) -> None:
"""Executes a DELETE query on the named table, expecting to delete a
single row.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
"""
sql = "DELETE FROM %s WHERE %s" % (
table,
@@ -1303,11 +1406,13 @@ class DatabasePool(object):
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))
- def simple_delete(self, table, keyvalues, desc):
+ def simple_delete(self, table: str, keyvalues: Dict[str, Any], desc: str):
return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
@staticmethod
- def simple_delete_txn(txn, table, keyvalues):
+ def simple_delete_txn(
+ txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
+ ) -> int:
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k,) for k in keyvalues),
@@ -1316,26 +1421,39 @@ class DatabasePool(object):
txn.execute(sql, list(keyvalues.values()))
return txn.rowcount
- def simple_delete_many(self, table, column, iterable, keyvalues, desc):
+ def simple_delete_many(
+ self,
+ table: str,
+ column: str,
+ iterable: Iterable[Any],
+ keyvalues: Dict[str, Any],
+ desc: str,
+ ) -> defer.Deferred:
return self.runInteraction(
desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
)
@staticmethod
- def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+ def simple_delete_many_txn(
+ txn: LoggingTransaction,
+ table: str,
+ column: str,
+ iterable: Iterable[Any],
+ keyvalues: Dict[str, Any],
+ ) -> int:
"""Executes a DELETE query on the named table.
Filters rows by if value of `column` is in `iterable`.
Args:
- txn : Transaction object
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
+ txn: Transaction object
+ table: string giving the table name
+ column: column name to test for inclusion against `iterable`
+ iterable: list
+ keyvalues: dict of column names and values to select the rows with
Returns:
- int: Number rows deleted
+ Number rows deleted
"""
if not iterable:
return 0
@@ -1356,8 +1474,14 @@ class DatabasePool(object):
return txn.rowcount
def get_cache_dict(
- self, db_conn, table, entity_column, stream_column, max_value, limit=100000
- ):
+ self,
+ db_conn: Connection,
+ table: str,
+ entity_column: str,
+ stream_column: str,
+ max_value: int,
+ limit: int = 100000,
+ ) -> Tuple[Dict[Any, int], int]:
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache.
@@ -1390,34 +1514,34 @@ class DatabasePool(object):
def simple_select_list_paginate(
self,
- table,
- orderby,
- start,
- limit,
- retcols,
- filters=None,
- keyvalues=None,
- order_direction="ASC",
- desc="simple_select_list_paginate",
- ):
+ table: str,
+ orderby: str,
+ start: int,
+ limit: int,
+ retcols: Iterable[str],
+ filters: Optional[Dict[str, Any]] = None,
+ keyvalues: Optional[Dict[str, Any]] = None,
+ order_direction: str = "ASC",
+ desc: str = "simple_select_list_paginate",
+ ) -> defer.Deferred:
"""
Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts.
Args:
- table (str): the table name
- filters (dict[str, T] | None):
+ table: the table name
+ orderby: Column to order the results by.
+ start: Index to begin the query at.
+ limit: Number of results to return.
+ retcols: the names of the columns to return
+ filters:
column names and values to filter the rows with, or None to not
apply a WHERE ? LIKE ? clause.
- keyvalues (dict[str, T] | None):
+ keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
- orderby (str): Column to order the results by.
- start (int): Index to begin the query at.
- limit (int): Number of results to return.
- retcols (iterable[str]): the names of the columns to return
- order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+ order_direction: Whether the results should be ordered "ASC" or "DESC".
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
@@ -1437,16 +1561,16 @@ class DatabasePool(object):
@classmethod
def simple_select_list_paginate_txn(
cls,
- txn,
- table,
- orderby,
- start,
- limit,
- retcols,
- filters=None,
- keyvalues=None,
- order_direction="ASC",
- ):
+ txn: LoggingTransaction,
+ table: str,
+ orderby: str,
+ start: int,
+ limit: int,
+ retcols: Iterable[str],
+ filters: Optional[Dict[str, Any]] = None,
+ keyvalues: Optional[Dict[str, Any]] = None,
+ order_direction: str = "ASC",
+ ) -> List[Dict[str, Any]]:
"""
Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
@@ -1457,21 +1581,22 @@ class DatabasePool(object):
using 'AND'.
Args:
- txn : Transaction object
- table (str): the table name
- orderby (str): Column to order the results by.
- start (int): Index to begin the query at.
- limit (int): Number of results to return.
- retcols (iterable[str]): the names of the columns to return
- filters (dict[str, T] | None):
+ txn: Transaction object
+ table: the table name
+ orderby: Column to order the results by.
+ start: Index to begin the query at.
+ limit: Number of results to return.
+ retcols: the names of the columns to return
+ filters:
column names and values to filter the rows with, or None to not
apply a WHERE ? LIKE ? clause.
- keyvalues (dict[str, T] | None):
+ keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
- order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+ order_direction: Whether the results should be ordered "ASC" or "DESC".
+
Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
+ The result as a list of dictionaries.
"""
if order_direction not in ["ASC", "DESC"]:
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
@@ -1497,16 +1622,23 @@ class DatabasePool(object):
return cls.cursor_to_dict(txn)
- def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
+ def simple_search_list(
+ self,
+ table: str,
+ term: Optional[str],
+ col: str,
+ retcols: Iterable[str],
+ desc="simple_search_list",
+ ):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
- table (str): the table name
- term (str | None):
- term for searching the table matched to a column.
- col (str): column to query term should be matched to
- retcols (iterable[str]): the names of the columns to return
+ table: the table name
+ term: term for searching the table matched to a column.
+ col: column to query term should be matched to
+ retcols: the names of the columns to return
+
Returns:
defer.Deferred: resolves to list[dict[str, Any]] or None
"""
@@ -1516,19 +1648,26 @@ class DatabasePool(object):
)
@classmethod
- def simple_search_list_txn(cls, txn, table, term, col, retcols):
+ def simple_search_list_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ term: Optional[str],
+ col: str,
+ retcols: Iterable[str],
+ ) -> Union[List[Dict[str, Any]], int]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
- txn : Transaction object
- table (str): the table name
- term (str | None):
- term for searching the table matched to a column.
- col (str): column to query term should be matched to
- retcols (iterable[str]): the names of the columns to return
+ txn: Transaction object
+ table: the table name
+ term: term for searching the table matched to a column.
+ col: column to query term should be matched to
+ retcols: the names of the columns to return
+
Returns:
- defer.Deferred: resolves to list[dict[str, Any]] or None
+ 0 if no term is given, otherwise a list of dictionaries.
"""
if term:
sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
@@ -1541,7 +1680,7 @@ class DatabasePool(object):
def make_in_list_sql_clause(
- database_engine, column: str, iterable: Iterable
+ database_engine: BaseDatabaseEngine, column: str, iterable: Iterable
) -> Tuple[str, list]:
"""Returns an SQL clause that checks the given column is in the iterable.
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 17fa470919..0934ae276c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -498,7 +498,7 @@ class DataStore(
)
def get_users_paginate(
- self, start, limit, name=None, guests=True, deactivated=False
+ self, start, limit, user_id=None, name=None, guests=True, deactivated=False
):
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
@@ -507,7 +507,8 @@ class DataStore(
Args:
start (int): start number to begin the query from
limit (int): number of rows to retrieve
- name (string): filter for user names
+ user_id (string): search for user_id. ignored if name is not None
+ name (string): search for local part of user_id or display name
guests (bool): whether to in include guest users
deactivated (bool): whether to include deactivated users
Returns:
@@ -516,11 +517,14 @@ class DataStore(
def get_users_paginate_txn(txn):
filters = []
- args = []
+ args = [self.hs.config.server_name]
if name:
+ filters.append("(name LIKE ? OR displayname LIKE ?)")
+ args.extend(["@%" + name + "%:%", "%" + name + "%"])
+ elif user_id:
filters.append("name LIKE ?")
- args.append("%" + name + "%")
+ args.extend(["%" + user_id + "%"])
if not guests:
filters.append("is_guest = 0")
@@ -530,20 +534,23 @@ class DataStore(
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
- sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause)
- txn.execute(sql, args)
- count = txn.fetchone()[0]
-
- args = [self.hs.config.server_name] + args + [limit, start]
- sql = """
- SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url
+ sql_base = """
FROM users as u
LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
{}
- ORDER BY u.name LIMIT ? OFFSET ?
""".format(
where_clause
)
+ sql = "SELECT COUNT(*) as total_users " + sql_base
+ txn.execute(sql, args)
+ count = txn.fetchone()[0]
+
+ sql = (
+ "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url "
+ + sql_base
+ + " ORDER BY u.name LIMIT ? OFFSET ?"
+ )
+ args += [limit, start]
txn.execute(sql, args)
users = self.db_pool.cursor_to_dict(txn)
return users, count
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 82aac2bbf3..04042a2c98 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -336,7 +336,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
@@ -384,7 +384,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 02568a2391..77723f7d4d 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -16,13 +16,12 @@
import logging
import re
-from canonicaljson import json
-
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.util import json_encoder
logger = logging.getLogger(__name__)
@@ -204,7 +203,7 @@ class ApplicationServiceTransactionWorkerStore(
new_txn_id = max(highest_txn_id, last_txn_id) + 1
# Insert new txn into txn table
- event_ids = json.dumps([e.event_id for e in events])
+ event_ids = json_encoder.encode([e.event_id for e in events])
txn.execute(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)",
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 1f6e995c4f..bb85637a95 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
- with self._device_inbox_id_gen.get_next() as stream_id:
+ with await self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
txn, stream_id, local_messages_by_user_then_device
)
- with self._device_inbox_id_gen.get_next() as stream_id:
+ with await self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9a786e2929..03b45dbc4d 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -380,7 +380,7 @@ class DeviceWorkerStore(SQLBaseStore):
THe new stream ID.
"""
- with self._device_list_id_gen.get_next() as stream_id:
+ with await self._device_list_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
@@ -1146,7 +1146,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return
- with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
+ with await self._device_list_id_gen.get_next_mult(
+ len(device_ids)
+ ) as stream_ids:
await self.db_pool.runInteraction(
"add_device_change_to_stream",
self._add_device_change_to_stream_txn,
@@ -1159,7 +1161,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return stream_ids[-1]
context = get_active_span_text_map()
- with self._device_list_id_gen.get_next_mult(
+ with await self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index f93e0d320d..385868bdab 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -648,7 +648,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
- def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
+ def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
"""Set a user's cross-signing key.
Args:
@@ -658,6 +658,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
key (dict): the key data
+ stream_id (int)
"""
# the 'key' dict will look something like:
# {
@@ -695,23 +696,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
)
# and finally, store the key itself
- with self._cross_signing_id_gen.get_next() as stream_id:
- self.db_pool.simple_insert_txn(
- txn,
- "e2e_cross_signing_keys",
- values={
- "user_id": user_id,
- "keytype": key_type,
- "keydata": json_encoder.encode(key),
- "stream_id": stream_id,
- },
- )
+ self.db_pool.simple_insert_txn(
+ txn,
+ "e2e_cross_signing_keys",
+ values={
+ "user_id": user_id,
+ "keytype": key_type,
+ "keydata": json_encoder.encode(key),
+ "stream_id": stream_id,
+ },
+ )
self._invalidate_cache_and_stream(
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
)
- def set_e2e_cross_signing_key(self, user_id, key_type, key):
+ async def set_e2e_cross_signing_key(self, user_id, key_type, key):
"""Set a user's cross-signing key.
Args:
@@ -719,13 +719,16 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
key_type (str): the type of cross-signing key to set
key (dict): the key data
"""
- return self.db_pool.runInteraction(
- "add_e2e_cross_signing_key",
- self._set_e2e_cross_signing_key_txn,
- user_id,
- key_type,
- key,
- )
+
+ with await self._cross_signing_id_gen.get_next() as stream_id:
+ return await self.db_pool.runInteraction(
+ "add_e2e_cross_signing_key",
+ self._set_e2e_cross_signing_key_txn,
+ user_id,
+ key_type,
+ key,
+ stream_id,
+ )
def store_e2e_cross_signing_signatures(self, user_id, signatures):
"""Stores cross-signing signatures.
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 4826be630c..e6a97b018c 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -15,14 +15,16 @@
import itertools
import logging
from queue import Empty, PriorityQueue
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import Dict, Iterable, List, Set, Tuple
from synapse.api.errors import StoreError
+from synapse.events import EventBase
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
+from synapse.types import Collection
from synapse.util.caches.descriptors import cached
from synapse.util.iterutils import batch_iter
@@ -30,12 +32,14 @@ logger = logging.getLogger(__name__)
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
- async def get_auth_chain(self, event_ids, include_given=False):
+ async def get_auth_chain(
+ self, event_ids: Collection[str], include_given: bool = False
+ ) -> List[EventBase]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
- event_ids (list): state events
- include_given (bool): include the given events in result
+ event_ids: state events
+ include_given: include the given events in result
Returns:
list of events
@@ -45,43 +49,34 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
)
return await self.get_events_as_list(event_ids)
- def get_auth_chain_ids(
- self,
- event_ids: List[str],
- include_given: bool = False,
- ignore_events: Optional[Set[str]] = None,
- ):
+ async def get_auth_chain_ids(
+ self, event_ids: Collection[str], include_given: bool = False,
+ ) -> List[str]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
event_ids: state events
include_given: include the given events in result
- ignore_events: Set of events to exclude from the returned auth
- chain. This is useful if the caller will just discard the
- given events anyway, and saves us from figuring out their auth
- chains if not required.
Returns:
list of event_ids
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
event_ids,
include_given,
- ignore_events,
)
- def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
- if ignore_events is None:
- ignore_events = set()
-
+ def _get_auth_chain_ids_txn(
+ self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
+ ) -> List[str]:
if include_given:
results = set(event_ids)
else:
results = set()
- base_sql = "SELECT auth_id FROM event_auth WHERE "
+ base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE "
front = set(event_ids)
while front:
@@ -93,7 +88,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(base_sql + clause, args)
new_front.update(r[0] for r in txn)
- new_front -= ignore_events
new_front -= results
front = new_front
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index b90e6de2d5..6313b41eef 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -153,11 +153,11 @@ class PersistEventsStore:
# Note: Multiple instances of this function cannot be in flight at
# the same time for the same room.
if backfilled:
- stream_ordering_manager = self._backfill_id_gen.get_next_mult(
+ stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
else:
- stream_ordering_manager = self._stream_id_gen.get_next_mult(
+ stream_ordering_manager = await self._stream_id_gen.get_next_mult(
len(events_and_contexts)
)
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 4a3333c0db..e1241a724b 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -620,19 +620,38 @@ class EventsWorkerStore(SQLBaseStore):
room_version_id = row["room_version_id"]
if not room_version_id:
- # this should only happen for out-of-band membership events
- if not internal_metadata.get("out_of_band_membership"):
- logger.warning(
- "Room %s for event %s is unknown", d["room_id"], event_id
+ # this should only happen for out-of-band membership events which
+ # arrived before #6983 landed. For all other events, we should have
+ # an entry in the 'rooms' table.
+ #
+ # However, the 'out_of_band_membership' flag is unreliable for older
+ # invites, so just accept it for all membership events.
+ #
+ if d["type"] != EventTypes.Member:
+ raise Exception(
+ "Room %s for event %s is unknown" % (d["room_id"], event_id)
)
- continue
- # take a wild stab at the room version based on the event format
+ # so, assuming this is an out-of-band-invite that arrived before #6983
+ # landed, we know that the room version must be v5 or earlier (because
+ # v6 hadn't been invented at that point, so invites from such rooms
+ # would have been rejected.)
+ #
+ # The main reason we need to know the room version here (other than
+ # choosing the right python Event class) is in case the event later has
+ # to be redacted - and all the room versions up to v5 used the same
+ # redaction algorithm.
+ #
+ # So, the following approximations should be adequate.
+
if format_version == EventFormatVersions.V1:
+ # if it's event format v1 then it must be room v1 or v2
room_version = RoomVersions.V1
elif format_version == EventFormatVersions.V2:
+ # if it's event format v2 then it must be room v3
room_version = RoomVersions.V3
else:
+ # if it's event format v3 then it must be room v4 or v5
room_version = RoomVersions.V5
else:
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 0e3b8739c6..a488e0924b 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -1182,7 +1182,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id
- with self._group_updates_id_gen.get_next() as next_id:
+ with await self._group_updates_id_gen.get_next() as next_id:
res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 4e3ec02d14..c9f655dfb7 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -23,7 +23,7 @@ from synapse.util.iterutils import batch_iter
class PresenceStore(SQLBaseStore):
async def update_presence(self, presence_states):
- stream_ordering_manager = self._presence_id_gen.get_next_mult(
+ stream_ordering_manager = await self._presence_id_gen.get_next_mult(
len(presence_states)
)
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index a585e54812..2fb5b02d7d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
) -> None:
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
- with self._push_rules_stream_id_gen.get_next() as stream_id:
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
if before or after:
@@ -560,7 +560,7 @@ class PushRuleStore(PushRulesWorkerStore):
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
)
- with self._push_rules_stream_id_gen.get_next() as stream_id:
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
@@ -571,7 +571,7 @@ class PushRuleStore(PushRulesWorkerStore):
)
async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
- with self._push_rules_stream_id_gen.get_next() as stream_id:
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
@@ -646,7 +646,7 @@ class PushRuleStore(PushRulesWorkerStore):
data={"actions": actions_json},
)
- with self._push_rules_stream_id_gen.get_next() as stream_id:
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 1126fd0751..c388468273 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
last_stream_ordering,
profile_tag="",
) -> None:
- with self._pushers_id_gen.get_next() as stream_id:
+ with await self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
await self.db_pool.simple_upsert(
@@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
},
)
- with self._pushers_id_gen.get_next() as stream_id:
+ with await self._pushers_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 19ad1c056f..6821476ee0 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -520,8 +520,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"insert_receipt_conv", graph_to_linear
)
- stream_id_manager = self._receipts_id_gen.get_next()
- with stream_id_manager as stream_id:
+ with await self._receipts_id_gen.get_next() as stream_id:
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 5986d32b18..336b578e23 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -968,6 +968,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
super(RegistrationStore, self).__init__(database, db_conn, hs)
self._account_validity = hs.config.account_validity
+ self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
if self._account_validity.enabled:
self._clock.call_later(
@@ -1381,15 +1382,22 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
if not row:
- raise ThreepidValidationError(400, "Unknown session_id")
+ if self._ignore_unknown_session_error:
+ # If we need to inhibit the error caused by an incorrect session ID,
+ # use None as placeholder values for the client secret and the
+ # validation timestamp.
+ # It shouldn't be an issue because they're both only checked after
+ # the token check, which should fail. And if it doesn't for some
+ # reason, the next check is on the client secret, which is NOT NULL,
+ # so we don't have to worry about the client secret matching by
+ # accident.
+ row = {"client_secret": None, "validated_at": None}
+ else:
+ raise ThreepidValidationError(400, "Unknown session_id")
+
retrieved_client_secret = row["client_secret"]
validated_at = row["validated_at"]
- if retrieved_client_secret != client_secret:
- raise ThreepidValidationError(
- 400, "This client_secret does not match the provided session_id"
- )
-
row = self.db_pool.simple_select_one_txn(
txn,
table="threepid_validation_token",
@@ -1405,6 +1413,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
expires = row["expires"]
next_link = row["next_link"]
+ if retrieved_client_secret != client_secret:
+ raise ThreepidValidationError(
+ 400, "This client_secret does not match the provided session_id"
+ )
+
# If the session is already validated, no need to revalidate
if validated_at:
return next_link
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 0142a856d5..99a8a9fab0 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -21,10 +21,6 @@ from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
-from canonicaljson import json
-
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
@@ -32,6 +28,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchStore
from synapse.types import ThirdPartyInstanceID
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -342,23 +339,22 @@ class RoomWorkerStore(SQLBaseStore):
desc="is_room_blocked",
)
- @defer.inlineCallbacks
- def is_room_published(self, room_id):
+ async def is_room_published(self, room_id: str) -> bool:
"""Check whether a room has been published in the local public room
directory.
Args:
- room_id (str)
+ room_id
Returns:
- bool: Whether the room is currently published in the room directory
+ Whether the room is currently published in the room directory
"""
# Get room information
- room_info = yield self.get_room(room_id)
+ room_info = await self.get_room(room_id)
if not room_info:
- defer.returnValue(False)
+ return False
# Check the is_public value
- defer.returnValue(room_info.get("is_public", False))
+ return room_info.get("is_public", False)
async def get_rooms_paginate(
self,
@@ -572,7 +568,7 @@ class RoomWorkerStore(SQLBaseStore):
# maximum, in order not to filter out events we should filter out when sending to
# the client.
if not self.config.retention_enabled:
- defer.returnValue({"min_lifetime": None, "max_lifetime": None})
+ return {"min_lifetime": None, "max_lifetime": None}
def get_retention_policy_for_room_txn(txn):
txn.execute(
@@ -1155,7 +1151,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with self._public_room_id_gen.get_next() as next_id:
+ with await self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"store_room_txn", store_room_txn, next_id
)
@@ -1222,7 +1218,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with self._public_room_id_gen.get_next() as next_id:
+ with await self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id
)
@@ -1302,7 +1298,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with self._public_room_id_gen.get_next() as next_id:
+ with await self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn,
@@ -1335,7 +1331,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
"event_id": event_id,
"user_id": user_id,
"reason": reason,
- "content": json.dumps(content),
+ "content": json_encoder.encode(content),
},
desc="add_event_report",
)
diff --git a/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
new file mode 100644
index 0000000000..4cc96a5341
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- A table of the IP address and user-agent used to complete each step of a
+-- user-interactive authentication session.
+CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips(
+ session_id TEXT NOT NULL,
+ ip TEXT NOT NULL,
+ user_agent TEXT NOT NULL,
+ UNIQUE (session_id, ip, user_agent),
+ FOREIGN KEY (session_id)
+ REFERENCES ui_auth_sessions (session_id)
+);
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index e4e0a0c433..0c34bbf21a 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -17,11 +17,10 @@
import logging
from typing import Dict, List, Tuple
-from canonicaljson import json
-
from synapse.storage._base import db_to_json
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.types import JsonDict
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -98,7 +97,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
txn.execute(sql, (user_id, room_id))
tags = []
for tag, content in txn:
- tags.append(json.dumps(tag) + ":" + content)
+ tags.append(json_encoder.encode(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}"
results.append((stream_id, (user_id, room_id, tag_json)))
@@ -200,7 +199,7 @@ class TagsStore(TagsWorkerStore):
Returns:
The next account data ID.
"""
- content_json = json.dumps(content)
+ content_json = json_encoder.encode(content)
def add_tag_txn(txn, next_id):
self.db_pool.simple_upsert_txn(
@@ -211,7 +210,7 @@ class TagsStore(TagsWorkerStore):
)
self._update_revision_txn(txn, user_id, room_id, next_id)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
@@ -233,7 +232,7 @@ class TagsStore(TagsWorkerStore):
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 37276f73f8..9eef8e57c5 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -12,15 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
import attr
-from canonicaljson import json
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict
-from synapse.util import stringutils as stringutils
+from synapse.util import json_encoder, stringutils
@attr.s
@@ -72,7 +72,7 @@ class UIAuthWorkerStore(SQLBaseStore):
StoreError if a unique session ID cannot be generated.
"""
# The clientdict gets stored as JSON.
- clientdict_json = json.dumps(clientdict)
+ clientdict_json = json_encoder.encode(clientdict)
# autogen a session ID and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
@@ -143,7 +143,7 @@ class UIAuthWorkerStore(SQLBaseStore):
await self.db_pool.simple_upsert(
table="ui_auth_sessions_credentials",
keyvalues={"session_id": session_id, "stage_type": stage_type},
- values={"result": json.dumps(result)},
+ values={"result": json_encoder.encode(result)},
desc="mark_ui_auth_stage_complete",
)
except self.db_pool.engine.module.IntegrityError:
@@ -184,7 +184,7 @@ class UIAuthWorkerStore(SQLBaseStore):
The dictionary from the client root level, not the 'auth' key.
"""
# The clientdict gets stored as JSON.
- clientdict_json = json.dumps(clientdict)
+ clientdict_json = json_encoder.encode(clientdict)
await self.db_pool.simple_update_one(
table="ui_auth_sessions",
@@ -214,14 +214,16 @@ class UIAuthWorkerStore(SQLBaseStore):
value,
)
- def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
+ def _set_ui_auth_session_data_txn(
+ self, txn: LoggingTransaction, session_id: str, key: str, value: Any
+ ):
# Get the current value.
result = self.db_pool.simple_select_one_txn(
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
- )
+ ) # type: Dict[str, Any] # type: ignore
# Update it and add it back to the database.
serverdict = db_to_json(result["serverdict"])
@@ -231,7 +233,7 @@ class UIAuthWorkerStore(SQLBaseStore):
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
- updatevalues={"serverdict": json.dumps(serverdict)},
+ updatevalues={"serverdict": json_encoder.encode(serverdict)},
)
async def get_ui_auth_session_data(
@@ -258,6 +260,34 @@ class UIAuthWorkerStore(SQLBaseStore):
return serverdict.get(key, default)
+ async def add_user_agent_ip_to_ui_auth_session(
+ self, session_id: str, user_agent: str, ip: str,
+ ):
+ """Add the given user agent / IP to the tracking table
+ """
+ await self.db_pool.simple_upsert(
+ table="ui_auth_sessions_ips",
+ keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
+ values={},
+ desc="add_user_agent_ip_to_ui_auth_session",
+ )
+
+ async def get_user_agents_ips_to_ui_auth_session(
+ self, session_id: str,
+ ) -> List[Tuple[str, str]]:
+ """Get the given user agents / IPs used during the ui auth process
+
+ Returns:
+ List of user_agent/ip pairs
+ """
+ rows = await self.db_pool.simple_select_list(
+ table="ui_auth_sessions_ips",
+ keyvalues={"session_id": session_id},
+ retcols=("user_agent", "ip"),
+ desc="get_user_agents_ips_to_ui_auth_session",
+ )
+ return [(row["user_agent"], row["ip"]) for row in rows]
+
class UIAuthStore(UIAuthWorkerStore):
def delete_old_ui_auth_sessions(self, expiration_time: int):
@@ -275,12 +305,23 @@ class UIAuthStore(UIAuthWorkerStore):
expiration_time,
)
- def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
+ def _delete_old_ui_auth_sessions_txn(
+ self, txn: LoggingTransaction, expiration_time: int
+ ):
# Get the expired sessions.
sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
txn.execute(sql, [expiration_time])
session_ids = [r[0] for r in txn.fetchall()]
+ # Delete the corresponding IP/user agents.
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="ui_auth_sessions_ips",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={},
+ )
+
# Delete the corresponding completed credentials.
self.db_pool.simple_delete_many_txn(
txn,
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 0bf772d4d1..5b07847773 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -14,9 +14,10 @@
# limitations under the License.
import contextlib
+import heapq
import threading
from collections import deque
-from typing import Dict, Set
+from typing import Dict, List, Set
from typing_extensions import Deque
@@ -80,7 +81,7 @@ class StreamIdGenerator(object):
upwards, -1 to grow downwards.
Usage:
- with stream_id_gen.get_next() as stream_id:
+ with await stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
@@ -95,10 +96,10 @@ class StreamIdGenerator(object):
)
self._unfinished_ids = deque() # type: Deque[int]
- def get_next(self):
+ async def get_next(self):
"""
Usage:
- with stream_id_gen.get_next() as stream_id:
+ with await stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
@@ -117,10 +118,10 @@ class StreamIdGenerator(object):
return manager()
- def get_next_mult(self, n):
+ async def get_next_mult(self, n):
"""
Usage:
- with stream_id_gen.get_next(n) as stream_ids:
+ with await stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
@@ -210,6 +211,23 @@ class MultiWriterIdGenerator:
# should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int]
+ # We track the max position where we know everything before has been
+ # persisted. This is done by a) looking at the min across all instances
+ # and b) noting that if we have seen a run of persisted positions
+ # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
+ #
+ # Note: There is no guarentee that the IDs generated by the sequence
+ # will be gapless; gaps can form when e.g. a transaction was rolled
+ # back. This means that sometimes we won't be able to skip forward the
+ # position even though everything has been persisted. However, since
+ # gaps should be relatively rare it's still worth doing the book keeping
+ # that allows us to skip forwards when there are gapless runs of
+ # positions.
+ self._persisted_upto_position = (
+ min(self._current_positions.values()) if self._current_positions else 0
+ )
+ self._known_persisted_positions = [] # type: List[int]
+
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
def _load_current_ids(
@@ -234,9 +252,12 @@ class MultiWriterIdGenerator:
return current_positions
- def _load_next_id_txn(self, txn):
+ def _load_next_id_txn(self, txn) -> int:
return self._sequence_gen.get_next_id_txn(txn)
+ def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
+ return self._sequence_gen.get_next_mult_txn(txn, n)
+
async def get_next(self):
"""
Usage:
@@ -262,6 +283,34 @@ class MultiWriterIdGenerator:
return manager()
+ async def get_next_mult(self, n: int):
+ """
+ Usage:
+ with await stream_id_gen.get_next_mult(5) as stream_ids:
+ # ... persist events ...
+ """
+ next_ids = await self._db.runInteraction(
+ "_load_next_mult_id", self._load_next_mult_id_txn, n
+ )
+
+ # Assert the fetched ID is actually greater than any ID we've already
+ # seen. If not, then the sequence and table have got out of sync
+ # somehow.
+ assert max(self.get_positions().values(), default=0) < min(next_ids)
+
+ with self._lock:
+ self._unfinished_ids.update(next_ids)
+
+ @contextlib.contextmanager
+ def manager():
+ try:
+ yield next_ids
+ finally:
+ for i in next_ids:
+ self._mark_id_as_finished(i)
+
+ return manager()
+
def get_next_txn(self, txn: LoggingTransaction):
"""
Usage:
@@ -326,3 +375,53 @@ class MultiWriterIdGenerator:
self._current_positions[instance_name] = max(
new_id, self._current_positions.get(instance_name, 0)
)
+
+ self._add_persisted_position(new_id)
+
+ def get_persisted_upto_position(self) -> int:
+ """Get the max position where all previous positions have been
+ persisted.
+
+ Note: In the worst case scenario this will be equal to the minimum
+ position across writers. This means that the returned position here can
+ lag if one writer doesn't write very often.
+ """
+
+ with self._lock:
+ return self._persisted_upto_position
+
+ def _add_persisted_position(self, new_id: int):
+ """Record that we have persisted a position.
+
+ This is used to keep the `_current_positions` up to date.
+ """
+
+ # We require that the lock is locked by caller
+ assert self._lock.locked()
+
+ heapq.heappush(self._known_persisted_positions, new_id)
+
+ # We move the current min position up if the minimum current positions
+ # of all instances is higher (since by definition all positions less
+ # that that have been persisted).
+ min_curr = min(self._current_positions.values())
+ self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
+
+ # We now iterate through the seen positions, discarding those that are
+ # less than the current min positions, and incrementing the min position
+ # if its exactly one greater.
+ #
+ # This is also where we discard items from `_known_persisted_positions`
+ # (to ensure the list doesn't infinitely grow).
+ while self._known_persisted_positions:
+ if self._known_persisted_positions[0] <= self._persisted_upto_position:
+ heapq.heappop(self._known_persisted_positions)
+ elif (
+ self._known_persisted_positions[0] == self._persisted_upto_position + 1
+ ):
+ heapq.heappop(self._known_persisted_positions)
+ self._persisted_upto_position += 1
+ else:
+ # There was a gap in seen positions, so there is nothing more to
+ # do.
+ break
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 63dfea4220..ffc1894748 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -14,7 +14,7 @@
# limitations under the License.
import abc
import threading
-from typing import Callable, Optional
+from typing import Callable, List, Optional
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.types import Cursor
@@ -39,6 +39,12 @@ class PostgresSequenceGenerator(SequenceGenerator):
txn.execute("SELECT nextval(?)", (self._sequence_name,))
return txn.fetchone()[0]
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ txn.execute(
+ "SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
+ )
+ return [i for (i,) in txn]
+
GetFirstCallbackType = Callable[[Cursor], int]
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 1bb25ab684..f92f3b8c15 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -374,12 +374,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
self.handler._auth_handler.complete_sso_login = simple_async_mock()
- request = Mock(spec=["args", "getCookie", "addCookie"])
+ request = Mock(
+ spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+ )
code = "code"
state = "state"
nonce = "nonce"
client_redirect_url = "http://client/redirect"
+ user_agent = "Browser"
+ ip_address = "10.0.0.1"
session = self.handler._generate_oidc_session_token(
state=state,
nonce=nonce,
@@ -392,6 +396,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]
+ request.requestHeaders = Mock(spec=["getRawHeaders"])
+ request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
+ request.getClientIP.return_value = ip_address
+
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
@@ -399,7 +407,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._map_userinfo_to_user.assert_called_once_with(
+ userinfo, token, user_agent, ip_address
+ )
self.handler._fetch_userinfo.assert_not_called()
self.handler._render_error.assert_not_called()
@@ -431,7 +441,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called()
- self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._map_userinfo_to_user.assert_called_once_with(
+ userinfo, token, user_agent, ip_address
+ )
self.handler._fetch_userinfo.assert_called_once_with(token)
self.handler._render_error.assert_not_called()
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 6dfea58cff..ea66196bb6 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -17,6 +17,7 @@ from mock import Mock
from twisted.internet import defer
+from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler
@@ -25,16 +26,18 @@ from synapse.rest.client.v2_alpha.register import (
_map_email_to_displayname,
register_servlets,
)
+from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, UserID, create_requester
from tests.server import FakeChannel
from tests.test_utils import make_awaitable
from tests.unittest import override_config
+from tests.utils import mock_getRawHeaders
from .. import unittest
-class RegistrationHandlers(object):
+class RegistrationHandlers:
def __init__(self, hs):
self.registration_handler = RegistrationHandler(hs)
@@ -485,6 +488,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
+ def test_spam_checker_deny(self):
+ """A spam checker can deny registration, which results in an error."""
+
+ class DenyAll:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info
+ ):
+ return RegistrationBehaviour.DENY
+
+ # Configure a spam checker that denies all users.
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [DenyAll()]
+
+ self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
+
+ def test_spam_checker_shadow_ban(self):
+ """A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
+
+ class BanAll:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info
+ ):
+ return RegistrationBehaviour.SHADOW_BAN
+
+ # Configure a spam checker that denies all users.
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [BanAll()]
+
+ user_id = self.get_success(self.handler.register_user(localpart="user"))
+
+ # Get an access token.
+ token = self.macaroon_generator.generate_access_token(user_id)
+ self.get_success(
+ self.store.add_access_token_to_user(
+ user_id=user_id, token=token, device_id=None, valid_until_ms=None
+ )
+ )
+
+ # Ensure the user was marked as shadow-banned.
+ request = Mock(args={})
+ request.args[b"access_token"] = [token.encode("ascii")]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ auth = Auth(self.hs)
+ requester = self.get_success(auth.get_user_by_req(request))
+
+ self.assertTrue(requester.shadow_banned)
+
def test_email_to_displayname_mapping(self):
"""Test that custom emails are mapped to new user displaynames correctly"""
self._check_mapping(
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 46c3810e70..48f750d357 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -238,7 +238,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def test_spam_checker(self):
"""
- A user which fails to the spam checks will not appear in search results.
+ A user which fails the spam checks will not appear in search results.
"""
u1 = self.register_user("user1", "pass")
u1_token = self.login(u1, "pass")
@@ -269,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Configure a spam checker that does not filter any users.
spam_checker = self.hs.get_spam_checker()
- class AllowAll(object):
+ class AllowAll:
def check_username_for_spam(self, user_profile):
# Allow all users.
return False
@@ -282,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users.
- class BlockAll(object):
+ class BlockAll:
def check_username_for_spam(self, user_profile):
# All users are spammy.
return True
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index ab91baeacc..c7e287c61e 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -46,50 +46,63 @@ class RetentionTestCase(unittest.HomeserverTestCase):
}
self.hs = self.setup_test_homeserver(config=config)
+
return self.hs
def prepare(self, reactor, clock, homeserver):
self.user_id = self.register_user("user", "password")
self.token = self.login("user", "password")
- def test_retention_state_event(self):
- """Tests that the server configuration can limit the values a user can set to the
- room's retention policy.
+ self.store = self.hs.get_datastore()
+ self.serializer = self.hs.get_event_client_serializer()
+ self.clock = self.hs.get_clock()
+
+ def test_retention_event_purged_with_state_event(self):
+ """Tests that expired events are correctly purged when the room's retention policy
+ is defined by a state event.
"""
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ # Set the room's retention period to 2 days.
+ lifetime = one_day_ms * 2
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": one_day_ms * 4},
+ body={"max_lifetime": lifetime},
tok=self.token,
- expect_code=400,
)
+ self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+
+ def test_retention_event_purged_with_state_event_outside_allowed(self):
+ """Tests that the server configuration can override the policy for a room when
+ running the purge jobs.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ # Set a max_lifetime higher than the maximum allowed value.
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": one_hour_ms},
+ body={"max_lifetime": one_day_ms * 4},
tok=self.token,
- expect_code=400,
)
- def test_retention_event_purged_with_state_event(self):
- """Tests that expired events are correctly purged when the room's retention policy
- is defined by a state event.
- """
- room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ # Check that the event is purged after waiting for the maximum allowed duration
+ # instead of the one specified in the room's policy.
+ self._test_retention_event_purged(room_id, one_day_ms * 1.5)
- # Set the room's retention period to 2 days.
- lifetime = one_day_ms * 2
+ # Set a max_lifetime lower than the minimum allowed value.
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": lifetime},
+ body={"max_lifetime": one_hour_ms},
tok=self.token,
)
- self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+ # Check that the event is purged after waiting for the minimum allowed duration
+ # instead of the one specified in the room's policy.
+ self._test_retention_event_purged(room_id, one_day_ms * 0.5)
def test_retention_event_purged_without_state_event(self):
"""Tests that expired events are correctly purged when the room's retention policy
@@ -141,7 +154,27 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# That event should be the second, not outdated event.
self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
- def _test_retention_event_purged(self, room_id, increment):
+ def _test_retention_event_purged(self, room_id: str, increment: float):
+ """Run the following test scenario to test the message retention policy support:
+
+ 1. Send event 1
+ 2. Increment time by `increment`
+ 3. Send event 2
+ 4. Increment time by `increment`
+ 5. Check that event 1 has been purged
+ 6. Check that event 2 has not been purged
+ 7. Check that state events that were sent before event 1 aren't purged.
+ The main reason for sending a second event is because currently Synapse won't
+ purge the latest message in a room because it would otherwise result in a lack of
+ forward extremities for this room. It's also a good thing to ensure the purge jobs
+ aren't too greedy and purge messages they shouldn't.
+
+ Args:
+ room_id: The ID of the room to test retention in.
+ increment: The number of milliseconds to advance the clock each time. Must be
+ defined so that events in the room aren't purged if they are `increment`
+ old but are purged if they are `increment * 2` old.
+ """
# Get the create event to, later, check that we can still access it.
message_handler = self.hs.get_message_handler()
create_event = self.get_success(
@@ -157,7 +190,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
expired_event_id = resp.get("event_id")
# Check that we can retrieve the event.
- expired_event = self.get_event(room_id, expired_event_id)
+ expired_event = self.get_event(expired_event_id)
self.assertEqual(
expired_event.get("content", {}).get("body"), "1", expired_event
)
@@ -175,26 +208,31 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# one should still be kept.
self.reactor.advance(increment / 1000)
- # Check that the event has been purged from the database.
- self.get_event(room_id, expired_event_id, expected_code=404)
+ # Check that the first event has been purged from the database, i.e. that we
+ # can't retrieve it anymore, because it has expired.
+ self.get_event(expired_event_id, expect_none=True)
- # Check that the event that hasn't been purged can still be retrieved.
- valid_event = self.get_event(room_id, valid_event_id)
+ # Check that the event that hasn't expired can still be retrieved.
+ valid_event = self.get_event(valid_event_id)
self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event)
# Check that we can still access state events that were sent before the event that
# has been purged.
self.get_event(room_id, create_event.event_id)
- def get_event(self, room_id, event_id, expected_code=200):
- url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+ def get_event(self, event_id, expect_none=False):
+ event = self.get_success(self.store.get_event(event_id, allow_none=True))
- request, channel = self.make_request("GET", url, access_token=self.token)
- self.render(request)
+ if expect_none:
+ self.assertIsNone(event)
+ return {}
- self.assertEqual(channel.code, expected_code, channel.result)
+ self.assertIsNotNone(event)
- return channel.json_body
+ time_now = self.clock.time_msec()
+ serialized = self.get_success(self.serializer.serialize_event(event, time_now))
+
+ return serialized
class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index e74bddc1e5..68c4a6a8f7 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -21,13 +21,13 @@
import json
from urllib import parse as urlparse
-from mock import Mock
+from mock import Mock, patch
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
from synapse.rest.client.v1 import directory, login, profile, room
-from synapse.rest.client.v2_alpha import account
+from synapse.rest.client.v2_alpha import account, room_upgrade_rest_servlet
from synapse.types import JsonDict, RoomAlias, UserID
from synapse.util.stringutils import random_string
@@ -684,38 +684,39 @@ class RoomJoinRatelimitTestCase(RoomBase):
]
@unittest.override_config(
- {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
def test_join_local_ratelimit(self):
"""Tests that local joins are actually rate-limited."""
- for i in range(5):
+ for i in range(3):
self.helper.create_room_as(self.user_id)
self.helper.create_room_as(self.user_id, expect_code=429)
@unittest.override_config(
- {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
def test_join_local_ratelimit_profile_change(self):
"""Tests that sending a profile update into all of the user's joined rooms isn't
rate-limited by the rate-limiter on joins."""
- # Create and join more rooms than the rate-limiting config allows in a second.
+ # Create and join as many rooms as the rate-limiting config allows in a second.
room_ids = [
self.helper.create_room_as(self.user_id),
self.helper.create_room_as(self.user_id),
self.helper.create_room_as(self.user_id),
]
- self.reactor.advance(1)
- room_ids = room_ids + [
- self.helper.create_room_as(self.user_id),
- self.helper.create_room_as(self.user_id),
- self.helper.create_room_as(self.user_id),
- ]
+ # Let some time for the rate-limiter to forget about our multi-join.
+ self.reactor.advance(2)
+ # Add one to make sure we're joined to more rooms than the config allows us to
+ # join in a second.
+ room_ids.append(self.helper.create_room_as(self.user_id))
# Create a profile for the user, since it hasn't been done on registration.
store = self.hs.get_datastore()
- store.create_profile(UserID.from_string(self.user_id).localpart)
+ self.get_success(
+ store.create_profile(UserID.from_string(self.user_id).localpart)
+ )
# Update the display name for the user.
path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
@@ -738,7 +739,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
self.assertEquals(channel.json_body["displayname"], "John Doe")
@unittest.override_config(
- {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
def test_join_local_ratelimit_idempotent(self):
"""Tests that the room join endpoints remain idempotent despite rate-limiting
@@ -754,7 +755,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
for path in paths_to_test:
# Make sure we send more requests than the rate-limiting config would allow
# if all of these requests ended up joining the user to a room.
- for i in range(6):
+ for i in range(4):
request, channel = self.make_request("POST", path % room_id, {})
self.render(request)
self.assertEquals(channel.code, 200)
@@ -2059,3 +2060,158 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
"""An alias which does not point to the room raises a SynapseError."""
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
+
+
+# To avoid the tests timing out don't add a delay to "annoy the requester".
+@patch("random.randint", new=lambda a, b: 0)
+class ShadowBannedTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ directory.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ room_upgrade_rest_servlet.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.banned_user_id = self.register_user("banned", "test")
+ self.banned_access_token = self.login("banned", "test")
+
+ self.store = self.hs.get_datastore()
+
+ self.get_success(
+ self.store.db_pool.simple_update(
+ table="users",
+ keyvalues={"name": self.banned_user_id},
+ updatevalues={"shadow_banned": True},
+ desc="shadow_ban",
+ )
+ )
+
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+ def test_invite(self):
+ """Invites from shadow-banned users don't actually get sent."""
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # Inviting the user completes successfully.
+ self.helper.invite(
+ room=room_id,
+ src=self.banned_user_id,
+ tok=self.banned_access_token,
+ targ=self.other_user_id,
+ )
+
+ # But the user wasn't actually invited.
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(self.other_user_id)
+ )
+ self.assertEqual(invited_rooms, [])
+
+ def test_invite_3pid(self):
+ """Ensure that a 3PID invite does not attempt to contact the identity server."""
+ identity_handler = self.hs.get_handlers().identity_handler
+ identity_handler.lookup_3pid = Mock(
+ side_effect=AssertionError("This should not get called")
+ )
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # Inviting the user completes successfully.
+ request, channel = self.make_request(
+ "POST",
+ "/rooms/%s/invite" % (room_id,),
+ {"id_server": "test", "medium": "email", "address": "test@test.test"},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ # This should have raised an error earlier, but double check this wasn't called.
+ identity_handler.lookup_3pid.assert_not_called()
+
+ def test_create_room(self):
+ """Invitations during a room creation should be discarded, but the room still gets created."""
+ # The room creation is successful.
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/createRoom",
+ {"visibility": "public", "invite": [self.other_user_id]},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ room_id = channel.json_body["room_id"]
+
+ # But the user wasn't actually invited.
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(self.other_user_id)
+ )
+ self.assertEqual(invited_rooms, [])
+
+ # Since a real room was created, the other user should be able to join it.
+ self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
+
+ # Both users should be in the room.
+ users = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
+
+ def test_message(self):
+ """Messages from shadow-banned users don't actually get sent."""
+
+ room_id = self.helper.create_room_as(
+ self.other_user_id, tok=self.other_access_token
+ )
+
+ # The user should be in the room.
+ self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token)
+
+ # Sending a message should complete successfully.
+ result = self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "with right label"},
+ tok=self.banned_access_token,
+ )
+ self.assertIn("event_id", result)
+ event_id = result["event_id"]
+
+ latest_events = self.get_success(
+ self.store.get_latest_event_ids_in_room(room_id)
+ )
+ self.assertNotIn(event_id, latest_events)
+
+ def test_upgrade(self):
+ """A room upgrade should fail, but look like it succeeded."""
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
+ {"new_version": "6"},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ # A new room_id should be returned.
+ self.assertIn("replacement_room", channel.json_body)
+
+ new_room_id = channel.json_body["replacement_room"]
+
+ # It doesn't really matter what API we use here, we just want to assert
+ # that the room doesn't exist.
+ summary = self.get_success(self.store.get_room_summary(new_room_id))
+ # The summary should be empty since the room doesn't exist.
+ self.assertEqual(summary, {})
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 7a05194653..9b9a183e7f 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -182,3 +182,39 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
+
+ def test_get_persisted_upto_position(self):
+ """Test that `get_persisted_upto_position` correctly tracks updates to
+ positions.
+ """
+
+ self._insert_rows("first", 3)
+ self._insert_rows("second", 5)
+
+ id_gen = self._create_id_generator("first")
+
+ # Min is 3 and there is a gap between 5, so we expect it to be 3.
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ # We advance "first" straight to 6. Min is now 5 but there is no gap so
+ # we expect it to be 6
+ id_gen.advance("first", 6)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 6)
+
+ # No gap, so we expect 7.
+ id_gen.advance("second", 7)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 7)
+
+ # We haven't seen 8 yet, so we expect 7 still.
+ id_gen.advance("second", 9)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 7)
+
+ # Now that we've seen 7, 8 and 9 we can got straight to 9.
+ id_gen.advance("first", 8)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 9)
+
+ # Jump forward with gaps. The minimum is 11, even though we haven't seen
+ # 10 we know that everything before 11 must be persisted.
+ id_gen.advance("first", 11)
+ id_gen.advance("second", 15)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 11)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 840db66072..58f827d8d3 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -17,6 +17,7 @@
from twisted.internet import defer
from synapse.api.constants import UserTypes
+from synapse.api.errors import ThreepidValidationError
from tests import unittest
from tests.utils import setup_test_homeserver
@@ -122,3 +123,33 @@ class RegistrationStoreTestCase(unittest.TestCase):
)
res = yield self.store.is_support_user(SUPPORT_USER)
self.assertTrue(res)
+
+ @defer.inlineCallbacks
+ def test_3pid_inhibit_invalid_validation_session_error(self):
+ """Tests that enabling the configuration option to inhibit 3PID errors on
+ /requestToken also inhibits validation errors caused by an unknown session ID.
+ """
+
+ # Check that, with the config setting set to false (the default value), a
+ # validation error is caused by the unknown session ID.
+ try:
+ yield defer.ensureDeferred(
+ self.store.validate_threepid_session(
+ "fake_sid", "fake_client_secret", "fake_token", 0,
+ )
+ )
+ except ThreepidValidationError as e:
+ self.assertEquals(e.msg, "Unknown session_id", e)
+
+ # Set the config setting to true.
+ self.store._ignore_unknown_session_error = True
+
+ # Check that now the validation error is caused by the token not matching.
+ try:
+ yield defer.ensureDeferred(
+ self.store.validate_threepid_session(
+ "fake_sid", "fake_client_secret", "fake_token", 0,
+ )
+ )
+ except ThreepidValidationError as e:
+ self.assertEquals(e.msg, "Validation token not found or has expired", e)
diff --git a/tox.ini b/tox.ini
index 050e36bc82..f8ecd1aa69 100644
--- a/tox.ini
+++ b/tox.ini
@@ -210,6 +210,7 @@ commands = mypy \
synapse/server.py \
synapse/server_notices \
synapse/spam_checker_api \
+ synapse/state \
synapse/storage/databases/main/ui_auth.py \
synapse/storage/database.py \
synapse/storage/engines \
|