diff --git a/changelog.d/8553.docker b/changelog.d/8553.docker
new file mode 100644
index 0000000000..f99c4207b8
--- /dev/null
+++ b/changelog.d/8553.docker
@@ -0,0 +1 @@
+Use jemalloc if available in docker.
diff --git a/changelog.d/9542.bugfix b/changelog.d/9542.bugfix
new file mode 100644
index 0000000000..51b1876f3b
--- /dev/null
+++ b/changelog.d/9542.bugfix
@@ -0,0 +1 @@
+Purge chain cover indexes for events that were purged prior to Synapse v1.29.0.
diff --git a/changelog.d/9561.misc b/changelog.d/9561.misc
new file mode 100644
index 0000000000..6c529a82ee
--- /dev/null
+++ b/changelog.d/9561.misc
@@ -0,0 +1 @@
+Increase the threshold before which outbound federation to a server goes into "catch up" mode, which is expensive for the remote server to handle.
diff --git a/changelog.d/9563.misc b/changelog.d/9563.misc
new file mode 100644
index 0000000000..7a3493e4a1
--- /dev/null
+++ b/changelog.d/9563.misc
@@ -0,0 +1 @@
+Fix type hints and tests for BlacklistingAgentWrapper and BlacklistingReactorWrapper.
diff --git a/changelog.d/9568.misc b/changelog.d/9568.misc
new file mode 100644
index 0000000000..561963de93
--- /dev/null
+++ b/changelog.d/9568.misc
@@ -0,0 +1 @@
+Do not have mypy ignore type hints from unpaddedbase64.
diff --git a/changelog.d/9573.feature b/changelog.d/9573.feature
new file mode 100644
index 0000000000..5214b50d41
--- /dev/null
+++ b/changelog.d/9573.feature
@@ -0,0 +1 @@
+Add prometheus metrics for number of users successfully registering and logging in.
diff --git a/changelog.d/9576.misc b/changelog.d/9576.misc
new file mode 100644
index 0000000000..bc257d05b7
--- /dev/null
+++ b/changelog.d/9576.misc
@@ -0,0 +1 @@
+Improve efficiency of calculating the auth chain in large rooms.
diff --git a/changelog.d/9580.doc b/changelog.d/9580.doc
new file mode 100644
index 0000000000..f9c8b328b3
--- /dev/null
+++ b/changelog.d/9580.doc
@@ -0,0 +1 @@
+Clarify the spam checker modules documentation example to mention that `parse_config` is a required method.
diff --git a/changelog.d/9583.bugfix b/changelog.d/9583.bugfix
new file mode 100644
index 0000000000..51b1876f3b
--- /dev/null
+++ b/changelog.d/9583.bugfix
@@ -0,0 +1 @@
+Purge chain cover indexes for events that were purged prior to Synapse v1.29.0.
diff --git a/changelog.d/9586.misc b/changelog.d/9586.misc
new file mode 100644
index 0000000000..2def9d5f55
--- /dev/null
+++ b/changelog.d/9586.misc
@@ -0,0 +1 @@
+Convert `synapse.types.Requester` to an `attrs` class.
diff --git a/changelog.d/9587.bugfix b/changelog.d/9587.bugfix
new file mode 100644
index 0000000000..d8f04c4f21
--- /dev/null
+++ b/changelog.d/9587.bugfix
@@ -0,0 +1 @@
+Re-Activating account with admin API when local passwords are disabled.
\ No newline at end of file
diff --git a/changelog.d/9590.misc b/changelog.d/9590.misc
new file mode 100644
index 0000000000..186396c45b
--- /dev/null
+++ b/changelog.d/9590.misc
@@ -0,0 +1 @@
+Add logging for redis connection setup.
diff --git a/changelog.d/9591.misc b/changelog.d/9591.misc
new file mode 100644
index 0000000000..14c7b78dd9
--- /dev/null
+++ b/changelog.d/9591.misc
@@ -0,0 +1 @@
+Fix incorrect type hints.
diff --git a/changelog.d/9596.misc b/changelog.d/9596.misc
new file mode 100644
index 0000000000..fc19a95f75
--- /dev/null
+++ b/changelog.d/9596.misc
@@ -0,0 +1 @@
+Improve logging when processing incoming transactions.
diff --git a/changelog.d/9597.bugfix b/changelog.d/9597.bugfix
new file mode 100644
index 0000000000..349dc9d664
--- /dev/null
+++ b/changelog.d/9597.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in Synapse 1.20 which caused incoming federation transactions to stack up, causing slow recovery from outages.
diff --git a/changelog.d/9601.feature b/changelog.d/9601.feature
new file mode 100644
index 0000000000..5078d63ffa
--- /dev/null
+++ b/changelog.d/9601.feature
@@ -0,0 +1 @@
+Optimise handling of incomplete room history for incoming federation.
diff --git a/changelog.d/9604.doc b/changelog.d/9604.doc
new file mode 100644
index 0000000000..d413e38b72
--- /dev/null
+++ b/changelog.d/9604.doc
@@ -0,0 +1 @@
+Clarify the sample configuration for `stats` settings.
diff --git a/changelog.d/9604.misc b/changelog.d/9604.misc
new file mode 100644
index 0000000000..0583988588
--- /dev/null
+++ b/changelog.d/9604.misc
@@ -0,0 +1 @@
+Remove unused `stats.retention` setting, and emit a warning if stats are disabled.
diff --git a/changelog.d/9608.misc b/changelog.d/9608.misc
new file mode 100644
index 0000000000..14c7b78dd9
--- /dev/null
+++ b/changelog.d/9608.misc
@@ -0,0 +1 @@
+Fix incorrect type hints.
diff --git a/changelog.d/9617.feature b/changelog.d/9617.feature
new file mode 100644
index 0000000000..b462a32b92
--- /dev/null
+++ b/changelog.d/9617.feature
@@ -0,0 +1 @@
+Finalise support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)).
diff --git a/changelog.d/9618.misc b/changelog.d/9618.misc
new file mode 100644
index 0000000000..14c7b78dd9
--- /dev/null
+++ b/changelog.d/9618.misc
@@ -0,0 +1 @@
+Fix incorrect type hints.
diff --git a/changelog.d/9619.misc b/changelog.d/9619.misc
new file mode 100644
index 0000000000..50267bfbc4
--- /dev/null
+++ b/changelog.d/9619.misc
@@ -0,0 +1 @@
+Prevent attempting to bundle aggregations for state events in /context APIs.
\ No newline at end of file
diff --git a/changelog.d/9620.bugfix b/changelog.d/9620.bugfix
new file mode 100644
index 0000000000..427580f4ad
--- /dev/null
+++ b/changelog.d/9620.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in v1.28.0 where the OpenID Connect callback endpoint could error with a `MacaroonInitException`.
diff --git a/changelog.d/9623.bugfix b/changelog.d/9623.bugfix
new file mode 100644
index 0000000000..ecccb46105
--- /dev/null
+++ b/changelog.d/9623.bugfix
@@ -0,0 +1 @@
+Fix Internal Server Error on `GET /_synapse/client/saml2/authn_response` request.
diff --git a/changelog.d/9626.feature b/changelog.d/9626.feature
new file mode 100644
index 0000000000..eacba6201b
--- /dev/null
+++ b/changelog.d/9626.feature
@@ -0,0 +1 @@
+Tell spam checker modules about the SSO IdP a user registered through if one was used.
diff --git a/docker/Dockerfile b/docker/Dockerfile
index d619ee08ed..def4501541 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -69,6 +69,7 @@ RUN apt-get update && apt-get install -y \
libpq5 \
libwebp6 \
xmlsec1 \
+ libjemalloc2 \
&& rm -rf /var/lib/apt/lists/*
COPY --from=builder /install /usr/local
diff --git a/docker/README.md b/docker/README.md
index 7b138df4d3..3a7dc585e7 100644
--- a/docker/README.md
+++ b/docker/README.md
@@ -204,3 +204,8 @@ healthcheck:
timeout: 10s
retries: 3
```
+
+## Using jemalloc
+
+Jemalloc is embedded in the image and will be used instead of the default allocator.
+You can read about jemalloc by reading the Synapse [README](../README.md)
\ No newline at end of file
diff --git a/docker/start.py b/docker/start.py
index 0d2c590b88..16d6a8208a 100755
--- a/docker/start.py
+++ b/docker/start.py
@@ -3,6 +3,7 @@
import codecs
import glob
import os
+import platform
import subprocess
import sys
@@ -213,6 +214,13 @@ def main(args, environ):
if "-m" not in args:
args = ["-m", synapse_worker] + args
+ jemallocpath = "/usr/lib/%s-linux-gnu/libjemalloc.so.2" % (platform.machine(),)
+
+ if os.path.isfile(jemallocpath):
+ environ["LD_PRELOAD"] = jemallocpath
+ else:
+ log("Could not find %s, will not use" % (jemallocpath,))
+
# if there are no config files passed to synapse, try adding the default file
if not any(p.startswith("--config-path") or p.startswith("-c") for p in args):
config_dir = environ.get("SYNAPSE_CONFIG_DIR", "/data")
@@ -248,9 +256,9 @@ running with 'migrate_config'. See the README for more details.
args = ["python"] + args
if ownership is not None:
args = ["gosu", ownership] + args
- os.execv("/usr/sbin/gosu", args)
+ os.execve("/usr/sbin/gosu", args, environ)
else:
- os.execv("/usr/local/bin/python", args)
+ os.execve("/usr/local/bin/python", args, environ)
if __name__ == "__main__":
diff --git a/docs/openid.md b/docs/openid.md
index 01205d1220..cfaafc5015 100644
--- a/docs/openid.md
+++ b/docs/openid.md
@@ -226,7 +226,7 @@ Synapse config:
oidc_providers:
- idp_id: github
idp_name: Github
- idp_brand: "org.matrix.github" # optional: styling hint for clients
+ idp_brand: "github" # optional: styling hint for clients
discover: false
issuer: "https://github.com/"
client_id: "your-client-id" # TO BE FILLED
@@ -252,7 +252,7 @@ oidc_providers:
oidc_providers:
- idp_id: google
idp_name: Google
- idp_brand: "org.matrix.google" # optional: styling hint for clients
+ idp_brand: "google" # optional: styling hint for clients
issuer: "https://accounts.google.com/"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
@@ -299,7 +299,7 @@ Synapse config:
oidc_providers:
- idp_id: gitlab
idp_name: Gitlab
- idp_brand: "org.matrix.gitlab" # optional: styling hint for clients
+ idp_brand: "gitlab" # optional: styling hint for clients
issuer: "https://gitlab.com/"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
@@ -334,7 +334,7 @@ Synapse config:
```yaml
- idp_id: facebook
idp_name: Facebook
- idp_brand: "org.matrix.facebook" # optional: styling hint for clients
+ idp_brand: "facebook" # optional: styling hint for clients
discover: false
issuer: "https://facebook.com"
client_id: "your-client-id" # TO BE FILLED
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index c32ee4a897..7de000f4a4 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1919,7 +1919,7 @@ oidc_providers:
#
#- idp_id: github
# idp_name: Github
- # idp_brand: org.matrix.github
+ # idp_brand: github
# discover: false
# issuer: "https://github.com/"
# client_id: "your-client-id" # TO BE FILLED
@@ -2645,19 +2645,20 @@ user_directory:
-# Local statistics collection. Used in populating the room directory.
+# Settings for local room and user statistics collection. See
+# docs/room_and_user_statistics.md.
#
-# 'bucket_size' controls how large each statistics timeslice is. It can
-# be defined in a human readable short form -- e.g. "1d", "1y".
-#
-# 'retention' controls how long historical statistics will be kept for.
-# It can be defined in a human readable short form -- e.g. "1d", "1y".
-#
-#
-#stats:
-# enabled: true
-# bucket_size: 1d
-# retention: 1y
+stats:
+ # Uncomment the following to disable room and user statistics. Note that doing
+ # so may cause certain features (such as the room directory) not to work
+ # correctly.
+ #
+ #enabled: false
+
+ # The size of each timeslice in the room_stats_historical and
+ # user_stats_historical tables, as a time period. Defaults to "1d".
+ #
+ #bucket_size: 1h
# Server Notices room configuration
diff --git a/docs/spam_checker.md b/docs/spam_checker.md
index e615ac9910..52947f605e 100644
--- a/docs/spam_checker.md
+++ b/docs/spam_checker.md
@@ -14,6 +14,7 @@ The Python class is instantiated with two objects:
* An instance of `synapse.module_api.ModuleApi`.
It then implements methods which return a boolean to alter behavior in Synapse.
+All the methods must be defined.
There's a generic method for checking every event (`check_event_for_spam`), as
well as some specific methods:
@@ -24,6 +25,7 @@ well as some specific methods:
* `user_may_publish_room`
* `check_username_for_spam`
* `check_registration_for_spam`
+* `check_media_file_for_spam`
The details of each of these methods (as well as their inputs and outputs)
are documented in the `synapse.events.spamcheck.SpamChecker` class.
@@ -31,6 +33,10 @@ are documented in the `synapse.events.spamcheck.SpamChecker` class.
The `ModuleApi` class provides a way for the custom spam checker class to
call back into the homeserver internals.
+Additionally, a `parse_config` method is mandatory and receives the plugin config
+dictionary. After parsing, It must return an object which will be
+passed to `__init__` later.
+
### Example
```python
@@ -41,6 +47,10 @@ class ExampleSpamChecker:
self.config = config
self.api = api
+ @staticmethod
+ def parse_config(config):
+ return config
+
async def check_event_for_spam(self, foo):
return False # allow all events
@@ -59,7 +69,13 @@ class ExampleSpamChecker:
async def check_username_for_spam(self, user_profile):
return False # allow all usernames
- async def check_registration_for_spam(self, email_threepid, username, request_info):
+ async def check_registration_for_spam(
+ self,
+ email_threepid,
+ username,
+ request_info,
+ auth_provider_id,
+ ):
return RegistrationBehaviour.ALLOW # allow all registrations
async def check_media_file_for_spam(self, file_wrapper, file_info):
diff --git a/mypy.ini b/mypy.ini
index f31cd432e6..e0685e097c 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -117,9 +117,6 @@ ignore_missing_imports = True
[mypy-saml2.*]
ignore_missing_imports = True
-[mypy-unpaddedbase64]
-ignore_missing_imports = True
-
[mypy-canonicaljson]
ignore_missing_imports = True
diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index 618548a305..080ca40287 100644
--- a/stubs/txredisapi.pyi
+++ b/stubs/txredisapi.pyi
@@ -17,7 +17,9 @@
"""
from typing import Any, List, Optional, Type, Union
-class RedisProtocol:
+from twisted.internet import protocol
+
+class RedisProtocol(protocol.Protocol):
def publish(self, channel: str, message: bytes): ...
async def ping(self) -> None: ...
async def set(
@@ -52,7 +54,7 @@ def lazyConnection(
class ConnectionHandler: ...
-class RedisFactory:
+class RedisFactory(protocol.ReconnectingClientFactory):
continueTrying: bool
handler: RedisProtocol
pool: List[RedisProtocol]
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 968cf6f174..e10e33fd23 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -164,7 +164,7 @@ class Auth:
async def get_user_by_req(
self,
- request: Request,
+ request: SynapseRequest,
allow_guest: bool = False,
rights: str = "access",
allow_expired: bool = False,
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index 7f5e449eb2..2bfb537c15 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -237,7 +237,7 @@ class OIDCConfig(Config):
#
#- idp_id: github
# idp_name: Github
- # idp_brand: org.matrix.github
+ # idp_brand: github
# discover: false
# issuer: "https://github.com/"
# client_id: "your-client-id" # TO BE FILLED
@@ -272,7 +272,12 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
"idp_icon": {"type": "string"},
"idp_brand": {
"type": "string",
- # MSC2758-style namespaced identifier
+ "minLength": 1,
+ "maxLength": 255,
+ "pattern": "^[a-z][a-z0-9_.-]*$",
+ },
+ "idp_unstable_brand": {
+ "type": "string",
"minLength": 1,
"maxLength": 255,
"pattern": "^[a-z][a-z0-9_.-]*$",
@@ -466,6 +471,7 @@ def _parse_oidc_config_dict(
idp_name=oidc_config.get("idp_name", "OIDC"),
idp_icon=idp_icon,
idp_brand=oidc_config.get("idp_brand"),
+ unstable_idp_brand=oidc_config.get("unstable_idp_brand"),
discover=oidc_config.get("discover", True),
issuer=oidc_config["issuer"],
client_id=oidc_config["client_id"],
@@ -512,6 +518,9 @@ class OidcProviderConfig:
# Optional brand identifier for this IdP.
idp_brand = attr.ib(type=Optional[str])
+ # Optional brand identifier for the unstable API (see MSC2858).
+ unstable_idp_brand = attr.ib(type=Optional[str])
+
# whether the OIDC discovery mechanism is used to discover endpoints
discover = attr.ib(type=bool)
diff --git a/synapse/config/stats.py b/synapse/config/stats.py
index b559bfa411..2258329a52 100644
--- a/synapse/config/stats.py
+++ b/synapse/config/stats.py
@@ -13,10 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import sys
+import logging
from ._base import Config
+ROOM_STATS_DISABLED_WARN = """\
+WARNING: room/user statistics have been disabled via the stats.enabled
+configuration setting. This means that certain features (such as the room
+directory) will not operate correctly. Future versions of Synapse may ignore
+this setting.
+
+To fix this warning, remove the stats.enabled setting from your configuration
+file.
+--------------------------------------------------------------------------------"""
+
+logger = logging.getLogger(__name__)
+
class StatsConfig(Config):
"""Stats Configuration
@@ -28,30 +40,29 @@ class StatsConfig(Config):
def read_config(self, config, **kwargs):
self.stats_enabled = True
self.stats_bucket_size = 86400 * 1000
- self.stats_retention = sys.maxsize
stats_config = config.get("stats", None)
if stats_config:
self.stats_enabled = stats_config.get("enabled", self.stats_enabled)
self.stats_bucket_size = self.parse_duration(
stats_config.get("bucket_size", "1d")
)
- self.stats_retention = self.parse_duration(
- stats_config.get("retention", "%ds" % (sys.maxsize,))
- )
+ if not self.stats_enabled:
+ logger.warning(ROOM_STATS_DISABLED_WARN)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
- # Local statistics collection. Used in populating the room directory.
+ # Settings for local room and user statistics collection. See
+ # docs/room_and_user_statistics.md.
#
- # 'bucket_size' controls how large each statistics timeslice is. It can
- # be defined in a human readable short form -- e.g. "1d", "1y".
- #
- # 'retention' controls how long historical statistics will be kept for.
- # It can be defined in a human readable short form -- e.g. "1d", "1y".
- #
- #
- #stats:
- # enabled: true
- # bucket_size: 1d
- # retention: 1y
+ stats:
+ # Uncomment the following to disable room and user statistics. Note that doing
+ # so may cause certain features (such as the room directory) not to work
+ # correctly.
+ #
+ #enabled: false
+
+ # The size of each timeslice in the room_stats_historical and
+ # user_stats_historical tables, as a time period. Defaults to "1d".
+ #
+ #bucket_size: 1h
"""
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 8cfc0bb3cb..a9185987a2 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,6 +15,7 @@
# limitations under the License.
import inspect
+import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from synapse.rest.media.v1._base import FileInfo
@@ -27,6 +28,8 @@ if TYPE_CHECKING:
import synapse.events
import synapse.server
+logger = logging.getLogger(__name__)
+
class SpamChecker:
def __init__(self, hs: "synapse.server.HomeServer"):
@@ -190,6 +193,7 @@ class SpamChecker:
email_threepid: Optional[dict],
username: Optional[str],
request_info: Collection[Tuple[str, str]],
+ auth_provider_id: Optional[str] = None,
) -> RegistrationBehaviour:
"""Checks if we should allow the given registration request.
@@ -198,6 +202,9 @@ class SpamChecker:
username: The request user name, if any
request_info: List of tuples of user agent and IP that
were used during the registration process.
+ auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml",
+ "cas". If any. Note this does not include users registered
+ via a password provider.
Returns:
Enum for how the request should be handled
@@ -208,9 +215,25 @@ class SpamChecker:
# spam checker
checker = getattr(spam_checker, "check_registration_for_spam", None)
if checker:
- behaviour = await maybe_awaitable(
- checker(email_threepid, username, request_info)
- )
+ # Provide auth_provider_id if the function supports it
+ checker_args = inspect.signature(checker)
+ if len(checker_args.parameters) == 4:
+ d = checker(
+ email_threepid,
+ username,
+ request_info,
+ auth_provider_id,
+ )
+ elif len(checker_args.parameters) == 3:
+ d = checker(email_threepid, username, request_info)
+ else:
+ logger.error(
+ "Invalid signature for %s.check_registration_for_spam. Denying registration",
+ spam_checker.__module__,
+ )
+ return RegistrationBehaviour.DENY
+
+ behaviour = await maybe_awaitable(d)
assert isinstance(behaviour, RegistrationBehaviour)
if behaviour != RegistrationBehaviour.ALLOW:
return behaviour
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index f8e368f81b..98caf2a1a4 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -112,10 +112,11 @@ class FederationServer(FederationBase):
# with FederationHandlerRegistry.
hs.get_directory_handler()
- self._federation_ratelimiter = hs.get_federation_ratelimiter()
-
self._server_linearizer = Linearizer("fed_server")
- self._transaction_linearizer = Linearizer("fed_txn_handler")
+
+ # origins that we are currently processing a transaction from.
+ # a dict from origin to txn id.
+ self._active_transactions = {} # type: Dict[str, str]
# We cache results for transaction with the same ID
self._transaction_resp_cache = ResponseCache(
@@ -169,6 +170,33 @@ class FederationServer(FederationBase):
logger.debug("[%s] Got transaction", transaction_id)
+ # Reject malformed transactions early: reject if too many PDUs/EDUs
+ if len(transaction.pdus) > 50 or ( # type: ignore
+ hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
+ ):
+ logger.info("Transaction PDU or EDU count too large. Returning 400")
+ return 400, {}
+
+ # we only process one transaction from each origin at a time. We need to do
+ # this check here, rather than in _on_incoming_transaction_inner so that we
+ # don't cache the rejection in _transaction_resp_cache (so that if the txn
+ # arrives again later, we can process it).
+ current_transaction = self._active_transactions.get(origin)
+ if current_transaction and current_transaction != transaction_id:
+ logger.warning(
+ "Received another txn %s from %s while still processing %s",
+ transaction_id,
+ origin,
+ current_transaction,
+ )
+ return 429, {
+ "errcode": Codes.UNKNOWN,
+ "error": "Too many concurrent transactions",
+ }
+
+ # CRITICAL SECTION: we must now not await until we populate _active_transactions
+ # in _on_incoming_transaction_inner.
+
# We wrap in a ResponseCache so that we de-duplicate retried
# transactions.
return await self._transaction_resp_cache.wrap(
@@ -182,26 +210,18 @@ class FederationServer(FederationBase):
async def _on_incoming_transaction_inner(
self, origin: str, transaction: Transaction, request_time: int
) -> Tuple[int, Dict[str, Any]]:
- # Use a linearizer to ensure that transactions from a remote are
- # processed in order.
- with await self._transaction_linearizer.queue(origin):
- # We rate limit here *after* we've queued up the incoming requests,
- # so that we don't fill up the ratelimiter with blocked requests.
- #
- # This is important as the ratelimiter allows N concurrent requests
- # at a time, and only starts ratelimiting if there are more requests
- # than that being processed at a time. If we queued up requests in
- # the linearizer/response cache *after* the ratelimiting then those
- # queued up requests would count as part of the allowed limit of N
- # concurrent requests.
- with self._federation_ratelimiter.ratelimit(origin) as d:
- await d
-
- result = await self._handle_incoming_transaction(
- origin, transaction, request_time
- )
+ # CRITICAL SECTION: the first thing we must do (before awaiting) is
+ # add an entry to _active_transactions.
+ assert origin not in self._active_transactions
+ self._active_transactions[origin] = transaction.transaction_id # type: ignore
- return result
+ try:
+ result = await self._handle_incoming_transaction(
+ origin, transaction, request_time
+ )
+ return result
+ finally:
+ del self._active_transactions[origin]
async def _handle_incoming_transaction(
self, origin: str, transaction: Transaction, request_time: int
@@ -227,19 +247,6 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore
- # Reject if PDU count > 50 or EDU count > 100
- if len(transaction.pdus) > 50 or ( # type: ignore
- hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
- ):
-
- logger.info("Transaction PDU or EDU count too large. Returning 400")
-
- response = {}
- await self.transaction_actions.set_response(
- origin, transaction, 400, response
- )
- return 400, response
-
# We process PDUs and EDUs in parallel. This is important as we don't
# want to block things like to device messages from reaching clients
# behind the potentially expensive handling of PDUs.
@@ -335,34 +342,41 @@ class FederationServer(FederationBase):
# impose a limit to avoid going too crazy with ram/cpu.
async def process_pdus_for_room(room_id: str):
- logger.debug("Processing PDUs for %s", room_id)
- try:
- await self.check_server_matches_acl(origin_host, room_id)
- except AuthError as e:
- logger.warning("Ignoring PDUs for room %s from banned server", room_id)
- for pdu in pdus_by_room[room_id]:
- event_id = pdu.event_id
- pdu_results[event_id] = e.error_dict()
- return
+ with nested_logging_context(room_id):
+ logger.debug("Processing PDUs for %s", room_id)
- for pdu in pdus_by_room[room_id]:
- event_id = pdu.event_id
- with pdu_process_time.time():
- with nested_logging_context(event_id):
- try:
- await self._handle_received_pdu(origin, pdu)
- pdu_results[event_id] = {}
- except FederationError as e:
- logger.warning("Error handling PDU %s: %s", event_id, e)
- pdu_results[event_id] = {"error": str(e)}
- except Exception as e:
- f = failure.Failure()
- pdu_results[event_id] = {"error": str(e)}
- logger.error(
- "Failed to handle PDU %s",
- event_id,
- exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
- )
+ try:
+ await self.check_server_matches_acl(origin_host, room_id)
+ except AuthError as e:
+ logger.warning(
+ "Ignoring PDUs for room %s from banned server", room_id
+ )
+ for pdu in pdus_by_room[room_id]:
+ event_id = pdu.event_id
+ pdu_results[event_id] = e.error_dict()
+ return
+
+ for pdu in pdus_by_room[room_id]:
+ pdu_results[pdu.event_id] = await process_pdu(pdu)
+
+ async def process_pdu(pdu: EventBase) -> JsonDict:
+ event_id = pdu.event_id
+ with pdu_process_time.time():
+ with nested_logging_context(event_id):
+ try:
+ await self._handle_received_pdu(origin, pdu)
+ return {}
+ except FederationError as e:
+ logger.warning("Error handling PDU %s: %s", event_id, e)
+ return {"error": str(e)}
+ except Exception as e:
+ f = failure.Failure()
+ logger.error(
+ "Failed to handle PDU %s",
+ event_id,
+ exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
+ )
+ return {"error": str(e)}
await concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
@@ -447,7 +461,7 @@ class FederationServer(FederationBase):
async def _on_state_ids_request_compute(self, room_id, event_id):
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
- auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
+ auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
async def _on_context_state_request_compute(
@@ -460,7 +474,9 @@ class FederationServer(FederationBase):
else:
pdus = (await self.state.get_current_state(room_id)).values()
- auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus])
+ auth_chain = await self.store.get_auth_chain(
+ room_id, [pdu.event_id for pdu in pdus]
+ )
return {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
@@ -864,7 +880,9 @@ class FederationHandlerRegistry:
self.edu_handlers = (
{}
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
- self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
+ self.query_handlers = (
+ {}
+ ) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]]
# Map from type to instance names that we should route EDU handling to.
# We randomly choose one instance from the list to route to for each new
@@ -898,7 +916,7 @@ class FederationHandlerRegistry:
self.edu_handlers[edu_type] = handler
def register_query_handler(
- self, query_type: str, handler: Callable[[dict], defer.Deferred]
+ self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
):
"""Sets the handler callable that will be used to handle an incoming
federation query of the given type.
@@ -975,7 +993,7 @@ class FederationHandlerRegistry:
# Oh well, let's just log and move on.
logger.warning("No handler registered for EDU type %s", edu_type)
- async def on_query(self, query_type: str, args: dict):
+ async def on_query(self, query_type: str, args: dict) -> JsonDict:
handler = self.query_handlers.get(query_type)
if handler:
return await handler(args)
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index deb519f3ef..cc0d765e5f 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -17,6 +17,7 @@ import datetime
import logging
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast
+import attr
from prometheus_client import Counter
from synapse.api.errors import (
@@ -93,6 +94,10 @@ class PerDestinationQueue:
self._destination = destination
self.transmission_loop_running = False
+ # Flag to signal to any running transmission loop that there is new data
+ # queued up to be sent.
+ self._new_data_to_send = False
+
# True whilst we are sending events that the remote homeserver missed
# because it was unreachable. We start in this state so we can perform
# catch-up at startup.
@@ -108,7 +113,7 @@ class PerDestinationQueue:
# destination (we are the only updater so this is safe)
self._last_successful_stream_ordering = None # type: Optional[int]
- # a list of pending PDUs
+ # a queue of pending PDUs
self._pending_pdus = [] # type: List[EventBase]
# XXX this is never actually used: see
@@ -208,6 +213,10 @@ class PerDestinationQueue:
transaction in the background.
"""
+ # Mark that we (may) have new things to send, so that any running
+ # transmission loop will recheck whether there is stuff to send.
+ self._new_data_to_send = True
+
if self.transmission_loop_running:
# XXX: this can get stuck on by a never-ending
# request at which point pending_pdus just keeps growing.
@@ -250,125 +259,41 @@ class PerDestinationQueue:
pending_pdus = []
while True:
- # We have to keep 2 free slots for presence and rr_edus
- limit = MAX_EDUS_PER_TRANSACTION - 2
-
- device_update_edus, dev_list_id = await self._get_device_update_edus(
- limit
- )
-
- limit -= len(device_update_edus)
-
- (
- to_device_edus,
- device_stream_id,
- ) = await self._get_to_device_message_edus(limit)
-
- pending_edus = device_update_edus + to_device_edus
-
- # BEGIN CRITICAL SECTION
- #
- # In order to avoid a race condition, we need to make sure that
- # the following code (from popping the queues up to the point
- # where we decide if we actually have any pending messages) is
- # atomic - otherwise new PDUs or EDUs might arrive in the
- # meantime, but not get sent because we hold the
- # transmission_loop_running flag.
-
- pending_pdus = self._pending_pdus
+ self._new_data_to_send = False
- # We can only include at most 50 PDUs per transactions
- pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:]
+ async with _TransactionQueueManager(self) as (
+ pending_pdus,
+ pending_edus,
+ ):
+ if not pending_pdus and not pending_edus:
+ logger.debug("TX [%s] Nothing to send", self._destination)
+
+ # If we've gotten told about new things to send during
+ # checking for things to send, we try looking again.
+ # Otherwise new PDUs or EDUs might arrive in the meantime,
+ # but not get sent because we hold the
+ # `transmission_loop_running` flag.
+ if self._new_data_to_send:
+ continue
+ else:
+ return
- pending_edus.extend(self._get_rr_edus(force_flush=False))
- pending_presence = self._pending_presence
- self._pending_presence = {}
- if pending_presence:
- pending_edus.append(
- Edu(
- origin=self._server_name,
- destination=self._destination,
- edu_type="m.presence",
- content={
- "push": [
- format_user_presence_state(
- presence, self._clock.time_msec()
- )
- for presence in pending_presence.values()
- ]
- },
+ if pending_pdus:
+ logger.debug(
+ "TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+ self._destination,
+ len(pending_pdus),
)
- )
- pending_edus.extend(
- self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
- )
- while (
- len(pending_edus) < MAX_EDUS_PER_TRANSACTION
- and self._pending_edus_keyed
- ):
- _, val = self._pending_edus_keyed.popitem()
- pending_edus.append(val)
-
- if pending_pdus:
- logger.debug(
- "TX [%s] len(pending_pdus_by_dest[dest]) = %d",
- self._destination,
- len(pending_pdus),
+ await self._transaction_manager.send_new_transaction(
+ self._destination, pending_pdus, pending_edus
)
- if not pending_pdus and not pending_edus:
- logger.debug("TX [%s] Nothing to send", self._destination)
- self._last_device_stream_id = device_stream_id
- return
-
- # if we've decided to send a transaction anyway, and we have room, we
- # may as well send any pending RRs
- if len(pending_edus) < MAX_EDUS_PER_TRANSACTION:
- pending_edus.extend(self._get_rr_edus(force_flush=True))
-
- # END CRITICAL SECTION
-
- success = await self._transaction_manager.send_new_transaction(
- self._destination, pending_pdus, pending_edus
- )
- if success:
sent_transactions_counter.inc()
sent_edus_counter.inc(len(pending_edus))
for edu in pending_edus:
sent_edus_by_type.labels(edu.edu_type).inc()
- # Remove the acknowledged device messages from the database
- # Only bother if we actually sent some device messages
- if to_device_edus:
- await self._store.delete_device_msgs_for_remote(
- self._destination, device_stream_id
- )
- # also mark the device updates as sent
- if device_update_edus:
- logger.info(
- "Marking as sent %r %r", self._destination, dev_list_id
- )
- await self._store.mark_as_sent_devices_by_remote(
- self._destination, dev_list_id
- )
-
- self._last_device_stream_id = device_stream_id
- self._last_device_list_stream_id = dev_list_id
-
- if pending_pdus:
- # we sent some PDUs and it was successful, so update our
- # last_successful_stream_ordering in the destinations table.
- final_pdu = pending_pdus[-1]
- last_successful_stream_ordering = (
- final_pdu.internal_metadata.stream_ordering
- )
- assert last_successful_stream_ordering
- await self._store.set_destination_last_successful_stream_ordering(
- self._destination, last_successful_stream_ordering
- )
- else:
- break
except NotRetryingDestination as e:
logger.debug(
"TX [%s] not ready for retry yet (next retry at %s) - "
@@ -401,7 +326,7 @@ class PerDestinationQueue:
self._pending_presence = {}
self._pending_rrs = {}
- self._start_catching_up()
+ self._start_catching_up()
except FederationDeniedError as e:
logger.info(e)
except HttpResponseException as e:
@@ -412,7 +337,6 @@ class PerDestinationQueue:
e,
)
- self._start_catching_up()
except RequestSendFailed as e:
logger.warning(
"TX [%s] Failed to send transaction: %s", self._destination, e
@@ -422,16 +346,12 @@ class PerDestinationQueue:
logger.info(
"Failed to send event %s to %s", p.event_id, self._destination
)
-
- self._start_catching_up()
except Exception:
logger.exception("TX [%s] Failed to send transaction", self._destination)
for p in pending_pdus:
logger.info(
"Failed to send event %s to %s", p.event_id, self._destination
)
-
- self._start_catching_up()
finally:
# We want to be *very* sure we clear this after we stop processing
self.transmission_loop_running = False
@@ -499,13 +419,10 @@ class PerDestinationQueue:
rooms = [p.room_id for p in catchup_pdus]
logger.info("Catching up rooms to %s: %r", self._destination, rooms)
- success = await self._transaction_manager.send_new_transaction(
+ await self._transaction_manager.send_new_transaction(
self._destination, catchup_pdus, []
)
- if not success:
- return
-
sent_transactions_counter.inc()
final_pdu = catchup_pdus[-1]
self._last_successful_stream_ordering = cast(
@@ -584,3 +501,135 @@ class PerDestinationQueue:
"""
self._catching_up = True
self._pending_pdus = []
+
+
+@attr.s(slots=True)
+class _TransactionQueueManager:
+ """A helper async context manager for pulling stuff off the queues and
+ tracking what was last successfully sent, etc.
+ """
+
+ queue = attr.ib(type=PerDestinationQueue)
+
+ _device_stream_id = attr.ib(type=Optional[int], default=None)
+ _device_list_id = attr.ib(type=Optional[int], default=None)
+ _last_stream_ordering = attr.ib(type=Optional[int], default=None)
+ _pdus = attr.ib(type=List[EventBase], factory=list)
+
+ async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]:
+ # First we calculate the EDUs we want to send, if any.
+
+ # We start by fetching device related EDUs, i.e device updates and to
+ # device messages. We have to keep 2 free slots for presence and rr_edus.
+ limit = MAX_EDUS_PER_TRANSACTION - 2
+
+ device_update_edus, dev_list_id = await self.queue._get_device_update_edus(
+ limit
+ )
+
+ if device_update_edus:
+ self._device_list_id = dev_list_id
+ else:
+ self.queue._last_device_list_stream_id = dev_list_id
+
+ limit -= len(device_update_edus)
+
+ (
+ to_device_edus,
+ device_stream_id,
+ ) = await self.queue._get_to_device_message_edus(limit)
+
+ if to_device_edus:
+ self._device_stream_id = device_stream_id
+ else:
+ self.queue._last_device_stream_id = device_stream_id
+
+ pending_edus = device_update_edus + to_device_edus
+
+ # Now add the read receipt EDU.
+ pending_edus.extend(self.queue._get_rr_edus(force_flush=False))
+
+ # And presence EDU.
+ if self.queue._pending_presence:
+ pending_edus.append(
+ Edu(
+ origin=self.queue._server_name,
+ destination=self.queue._destination,
+ edu_type="m.presence",
+ content={
+ "push": [
+ format_user_presence_state(
+ presence, self.queue._clock.time_msec()
+ )
+ for presence in self.queue._pending_presence.values()
+ ]
+ },
+ )
+ )
+ self.queue._pending_presence = {}
+
+ # Finally add any other types of EDUs if there is room.
+ pending_edus.extend(
+ self.queue._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
+ )
+ while (
+ len(pending_edus) < MAX_EDUS_PER_TRANSACTION
+ and self.queue._pending_edus_keyed
+ ):
+ _, val = self.queue._pending_edus_keyed.popitem()
+ pending_edus.append(val)
+
+ # Now we look for any PDUs to send, by getting up to 50 PDUs from the
+ # queue
+ self._pdus = self.queue._pending_pdus[:50]
+
+ if not self._pdus and not pending_edus:
+ return [], []
+
+ # if we've decided to send a transaction anyway, and we have room, we
+ # may as well send any pending RRs
+ if len(pending_edus) < MAX_EDUS_PER_TRANSACTION:
+ pending_edus.extend(self.queue._get_rr_edus(force_flush=True))
+
+ if self._pdus:
+ self._last_stream_ordering = self._pdus[
+ -1
+ ].internal_metadata.stream_ordering
+ assert self._last_stream_ordering
+
+ return self._pdus, pending_edus
+
+ async def __aexit__(self, exc_type, exc, tb):
+ if exc_type is not None:
+ # Failed to send transaction, so we bail out.
+ return
+
+ # Successfully sent transactions, so we remove pending PDUs from the queue
+ if self._pdus:
+ self.queue._pending_pdus = self.queue._pending_pdus[len(self._pdus) :]
+
+ # Succeeded to send the transaction so we record where we have sent up
+ # to in the various streams
+
+ if self._device_stream_id:
+ await self.queue._store.delete_device_msgs_for_remote(
+ self.queue._destination, self._device_stream_id
+ )
+ self.queue._last_device_stream_id = self._device_stream_id
+
+ # also mark the device updates as sent
+ if self._device_list_id:
+ logger.info(
+ "Marking as sent %r %r", self.queue._destination, self._device_list_id
+ )
+ await self.queue._store.mark_as_sent_devices_by_remote(
+ self.queue._destination, self._device_list_id
+ )
+ self.queue._last_device_list_stream_id = self._device_list_id
+
+ if self._last_stream_ordering:
+ # we sent some PDUs and it was successful, so update our
+ # last_successful_stream_ordering in the destinations table.
+ await self.queue._store.set_destination_last_successful_stream_ordering(
+ self.queue._destination, self._last_stream_ordering
+ )
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 2a9cd063c4..07b740c2f2 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -69,15 +69,12 @@ class TransactionManager:
destination: str,
pdus: List[EventBase],
edus: List[Edu],
- ) -> bool:
+ ) -> None:
"""
Args:
destination: The destination to send to (e.g. 'example.org')
pdus: In-order list of PDUs to send
edus: List of EDUs to send
-
- Returns:
- True iff the transaction was successful
"""
# Make a transaction-sending opentracing span. This span follows on from
@@ -96,8 +93,6 @@ class TransactionManager:
edu.strip_context()
with start_active_span_follows_from("send_transaction", span_contexts):
- success = True
-
logger.debug("TX [%s] _attempt_new_transaction", destination)
txn_id = str(self._next_txn_id)
@@ -152,44 +147,29 @@ class TransactionManager:
response = await self._transport_layer.send_transaction(
transaction, json_data_cb
)
- code = 200
except HttpResponseException as e:
code = e.code
response = e.response
- if e.code in (401, 404, 429) or 500 <= e.code:
- logger.info(
- "TX [%s] {%s} got %d response", destination, txn_id, code
- )
- raise e
-
- logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
-
- if code == 200:
- for e_id, r in response.get("pdus", {}).items():
- if "error" in r:
- logger.warning(
- "TX [%s] {%s} Remote returned error for %s: %s",
- destination,
- txn_id,
- e_id,
- r,
- )
- else:
- for p in pdus:
+ set_tag(tags.ERROR, True)
+
+ logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
+ raise
+
+ logger.info("TX [%s] {%s} got 200 response", destination, txn_id)
+
+ for e_id, r in response.get("pdus", {}).items():
+ if "error" in r:
logger.warning(
- "TX [%s] {%s} Failed to send event %s",
+ "TX [%s] {%s} Remote returned error for %s: %s",
destination,
txn_id,
- p.event_id,
+ e_id,
+ r,
)
- success = False
- if success and pdus and destination in self._federation_metrics_domains:
+ if pdus and destination in self._federation_metrics_domains:
last_pdu = pdus[-1]
last_pdu_ts_metric.labels(server_name=destination).set(
last_pdu.origin_server_ts / 1000
)
-
- set_tag(tags.ERROR, not success)
- return success
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index bec0c615d4..fb5f8118f0 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -337,7 +337,8 @@ class AuthHandler(BaseHandler):
user is too high to proceed
"""
-
+ if not requester.access_token_id:
+ raise ValueError("Cannot validate a user without an access token")
if self._ui_auth_session_timeout:
last_validated = await self.store.get_access_token_last_validated(
requester.access_token_id
@@ -1213,7 +1214,7 @@ class AuthHandler(BaseHandler):
async def delete_access_tokens_for_user(
self,
user_id: str,
- except_token_id: Optional[str] = None,
+ except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
):
"""Invalidate access tokens belonging to a user
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 04972f9cf0..cb67589f7d 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -83,6 +83,7 @@ class CasHandler:
# the SsoIdentityProvider protocol type.
self.idp_icon = None
self.idp_brand = None
+ self.unstable_idp_brand = None
self._sso_handler = hs.get_sso_handler()
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 2ead626a4d..598a66f74c 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -201,7 +201,7 @@ class FederationHandler(BaseHandler):
or pdu.internal_metadata.is_outlier()
)
if already_seen:
- logger.debug("[%s %s]: Already seen pdu", room_id, event_id)
+ logger.debug("Already seen pdu")
return
# do some initial sanity-checking of the event. In particular, make
@@ -210,18 +210,14 @@ class FederationHandler(BaseHandler):
try:
self._sanity_check_event(pdu)
except SynapseError as err:
- logger.warning(
- "[%s %s] Received event failed sanity checks", room_id, event_id
- )
+ logger.warning("Received event failed sanity checks")
raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
# If we are currently in the process of joining this room, then we
# queue up events for later processing.
if room_id in self.room_queues:
logger.info(
- "[%s %s] Queuing PDU from %s for now: join in progress",
- room_id,
- event_id,
+ "Queuing PDU from %s for now: join in progress",
origin,
)
self.room_queues[room_id].append((pdu, origin))
@@ -236,9 +232,7 @@ class FederationHandler(BaseHandler):
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
- "[%s %s] Ignoring PDU from %s as we're not in the room",
- room_id,
- event_id,
+ "Ignoring PDU from %s as we're not in the room",
origin,
)
return None
@@ -250,7 +244,7 @@ class FederationHandler(BaseHandler):
# We only backfill backwards to the min depth.
min_depth = await self.get_min_depth_for_context(pdu.room_id)
- logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
+ logger.debug("min_depth: %d", min_depth)
prevs = set(pdu.prev_event_ids())
seen = await self.store.have_events_in_timeline(prevs)
@@ -267,17 +261,13 @@ class FederationHandler(BaseHandler):
# If we're missing stuff, ensure we only fetch stuff one
# at a time.
logger.info(
- "[%s %s] Acquiring room lock to fetch %d missing prev_events: %s",
- room_id,
- event_id,
+ "Acquiring room lock to fetch %d missing prev_events: %s",
len(missing_prevs),
shortstr(missing_prevs),
)
with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
- "[%s %s] Acquired room lock to fetch %d missing prev_events",
- room_id,
- event_id,
+ "Acquired room lock to fetch %d missing prev_events",
len(missing_prevs),
)
@@ -297,9 +287,7 @@ class FederationHandler(BaseHandler):
if not prevs - seen:
logger.info(
- "[%s %s] Found all missing prev_events",
- room_id,
- event_id,
+ "Found all missing prev_events",
)
if prevs - seen:
@@ -329,9 +317,7 @@ class FederationHandler(BaseHandler):
if sent_to_us_directly:
logger.warning(
- "[%s %s] Rejecting: failed to fetch %d prev events: %s",
- room_id,
- event_id,
+ "Rejecting: failed to fetch %d prev events: %s",
len(prevs - seen),
shortstr(prevs - seen),
)
@@ -367,17 +353,16 @@ class FederationHandler(BaseHandler):
# Ask the remote server for the states we don't
# know about
for p in prevs - seen:
- logger.info(
- "Requesting state at missing prev_event %s",
- event_id,
- )
+ logger.info("Requesting state after missing prev_event %s", p)
with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
- (remote_state, _,) = await self._get_state_for_room(
- origin, room_id, p, include_event_in_state=True
+ remote_state = (
+ await self._get_state_after_missing_prev_event(
+ origin, room_id, p
+ )
)
remote_state_map = {
@@ -414,10 +399,7 @@ class FederationHandler(BaseHandler):
state = [event_map[e] for e in state_map.values()]
except Exception:
logger.warning(
- "[%s %s] Error attempting to resolve state at missing "
- "prev_events",
- room_id,
- event_id,
+ "Error attempting to resolve state at missing " "prev_events",
exc_info=True,
)
raise FederationError(
@@ -454,9 +436,7 @@ class FederationHandler(BaseHandler):
latest |= seen
logger.info(
- "[%s %s]: Requesting missing events between %s and %s",
- room_id,
- event_id,
+ "Requesting missing events between %s and %s",
shortstr(latest),
event_id,
)
@@ -523,15 +503,11 @@ class FederationHandler(BaseHandler):
# We failed to get the missing events, but since we need to handle
# the case of `get_missing_events` not returning the necessary
# events anyway, it is safe to simply log the error and continue.
- logger.warning(
- "[%s %s]: Failed to get prev_events: %s", room_id, event_id, e
- )
+ logger.warning("Failed to get prev_events: %s", e)
return
logger.info(
- "[%s %s]: Got %d prev_events: %s",
- room_id,
- event_id,
+ "Got %d prev_events: %s",
len(missing_events),
shortstr(missing_events),
)
@@ -542,9 +518,7 @@ class FederationHandler(BaseHandler):
for ev in missing_events:
logger.info(
- "[%s %s] Handling received prev_event %s",
- room_id,
- event_id,
+ "Handling received prev_event %s",
ev.event_id,
)
with nested_logging_context(ev.event_id):
@@ -553,9 +527,7 @@ class FederationHandler(BaseHandler):
except FederationError as e:
if e.code == 403:
logger.warning(
- "[%s %s] Received prev_event %s failed history check.",
- room_id,
- event_id,
+ "Received prev_event %s failed history check.",
ev.event_id,
)
else:
@@ -566,7 +538,6 @@ class FederationHandler(BaseHandler):
destination: str,
room_id: str,
event_id: str,
- include_event_in_state: bool = False,
) -> Tuple[List[EventBase], List[EventBase]]:
"""Requests all of the room state at a given event from a remote homeserver.
@@ -574,11 +545,9 @@ class FederationHandler(BaseHandler):
destination: The remote homeserver to query for the state.
room_id: The id of the room we're interested in.
event_id: The id of the event we want the state at.
- include_event_in_state: if true, the event itself will be included in the
- returned state event list.
Returns:
- A list of events in the state, possibly including the event itself, and
+ A list of events in the state, not including the event itself, and
a list of events in the auth chain for the given event.
"""
(
@@ -590,9 +559,6 @@ class FederationHandler(BaseHandler):
desired_events = set(state_event_ids + auth_event_ids)
- if include_event_in_state:
- desired_events.add(event_id)
-
event_map = await self._get_events_from_store_or_dest(
destination, room_id, desired_events
)
@@ -609,13 +575,6 @@ class FederationHandler(BaseHandler):
event_map[e_id] for e_id in state_event_ids if e_id in event_map
]
- if include_event_in_state:
- remote_event = event_map.get(event_id)
- if not remote_event:
- raise Exception("Unable to get missing prev_event %s" % (event_id,))
- if remote_event.is_state() and remote_event.rejected_reason is None:
- remote_state.append(remote_event)
-
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
auth_chain.sort(key=lambda e: e.depth)
@@ -689,6 +648,131 @@ class FederationHandler(BaseHandler):
return fetched_events
+ async def _get_state_after_missing_prev_event(
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ ) -> List[EventBase]:
+ """Requests all of the room state at a given event from a remote homeserver.
+
+ Args:
+ destination: The remote homeserver to query for the state.
+ room_id: The id of the room we're interested in.
+ event_id: The id of the event we want the state at.
+
+ Returns:
+ A list of events in the state, including the event itself
+ """
+ # TODO: This function is basically the same as _get_state_for_room. Can
+ # we make backfill() use it, rather than having two code paths? I think the
+ # only difference is that backfill() persists the prev events separately.
+
+ (
+ state_event_ids,
+ auth_event_ids,
+ ) = await self.federation_client.get_room_state_ids(
+ destination, room_id, event_id=event_id
+ )
+
+ logger.debug(
+ "state_ids returned %i state events, %i auth events",
+ len(state_event_ids),
+ len(auth_event_ids),
+ )
+
+ # start by just trying to fetch the events from the store
+ desired_events = set(state_event_ids)
+ desired_events.add(event_id)
+ logger.debug("Fetching %i events from cache/store", len(desired_events))
+ fetched_events = await self.store.get_events(
+ desired_events, allow_rejected=True
+ )
+
+ missing_desired_events = desired_events - fetched_events.keys()
+ logger.debug(
+ "We are missing %i events (got %i)",
+ len(missing_desired_events),
+ len(fetched_events),
+ )
+
+ # We probably won't need most of the auth events, so let's just check which
+ # we have for now, rather than thrashing the event cache with them all
+ # unnecessarily.
+
+ # TODO: we probably won't actually need all of the auth events, since we
+ # already have a bunch of the state events. It would be nice if the
+ # federation api gave us a way of finding out which we actually need.
+
+ missing_auth_events = set(auth_event_ids) - fetched_events.keys()
+ missing_auth_events.difference_update(
+ await self.store.have_seen_events(missing_auth_events)
+ )
+ logger.debug("We are also missing %i auth events", len(missing_auth_events))
+
+ missing_events = missing_desired_events | missing_auth_events
+ logger.debug("Fetching %i events from remote", len(missing_events))
+ await self._get_events_and_persist(
+ destination=destination, room_id=room_id, events=missing_events
+ )
+
+ # we need to make sure we re-load from the database to get the rejected
+ # state correct.
+ fetched_events.update(
+ (await self.store.get_events(missing_desired_events, allow_rejected=True))
+ )
+
+ # check for events which were in the wrong room.
+ #
+ # this can happen if a remote server claims that the state or
+ # auth_events at an event in room A are actually events in room B
+
+ bad_events = [
+ (event_id, event.room_id)
+ for event_id, event in fetched_events.items()
+ if event.room_id != room_id
+ ]
+
+ for bad_event_id, bad_room_id in bad_events:
+ # This is a bogus situation, but since we may only discover it a long time
+ # after it happened, we try our best to carry on, by just omitting the
+ # bad events from the returned state set.
+ logger.warning(
+ "Remote server %s claims event %s in room %s is an auth/state "
+ "event in room %s",
+ destination,
+ bad_event_id,
+ bad_room_id,
+ room_id,
+ )
+
+ del fetched_events[bad_event_id]
+
+ # if we couldn't get the prev event in question, that's a problem.
+ remote_event = fetched_events.get(event_id)
+ if not remote_event:
+ raise Exception("Unable to get missing prev_event %s" % (event_id,))
+
+ # missing state at that event is a warning, not a blocker
+ # XXX: this doesn't sound right? it means that we'll end up with incomplete
+ # state.
+ failed_to_fetch = desired_events - fetched_events.keys()
+ if failed_to_fetch:
+ logger.warning(
+ "Failed to fetch missing state events for %s %s",
+ event_id,
+ failed_to_fetch,
+ )
+
+ remote_state = [
+ fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
+ ]
+
+ if remote_event.is_state() and remote_event.rejected_reason is None:
+ remote_state.append(remote_event)
+
+ return remote_state
+
async def _process_received_pdu(
self,
origin: str,
@@ -707,10 +791,7 @@ class FederationHandler(BaseHandler):
(ie, we are missing one or more prev_events), the resolved state at the
event
"""
- room_id = event.room_id
- event_id = event.event_id
-
- logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
+ logger.debug("Processing event: %s", event)
try:
await self._handle_new_event(origin, event, state=state)
@@ -871,7 +952,6 @@ class FederationHandler(BaseHandler):
destination=dest,
room_id=room_id,
event_id=e_id,
- include_event_in_state=False,
)
auth_events.update({a.event_id: a for a in auth})
auth_events.update({s.event_id: s for s in state})
@@ -1317,7 +1397,7 @@ class FederationHandler(BaseHandler):
async def on_event_auth(self, event_id: str) -> List[EventBase]:
event = await self.store.get_event(event_id)
auth = await self.store.get_auth_chain(
- list(event.auth_event_ids()), include_given=True
+ event.room_id, list(event.auth_event_ids()), include_given=True
)
return list(auth)
@@ -1580,7 +1660,7 @@ class FederationHandler(BaseHandler):
prev_state_ids = await context.get_prev_state_ids()
state_ids = list(prev_state_ids.values())
- auth_chain = await self.store.get_auth_chain(state_ids)
+ auth_chain = await self.store.get_auth_chain(event.room_id, state_ids)
state = await self.store.get_events(list(prev_state_ids.values()))
@@ -2219,7 +2299,7 @@ class FederationHandler(BaseHandler):
# Now get the current auth_chain for the event.
local_auth_chain = await self.store.get_auth_chain(
- list(event.auth_event_ids()), include_given=True
+ room_id, list(event.auth_event_ids()), include_given=True
)
# TODO: Check if we would now reject event_id. If so we need to tell
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 825fadb76f..6d8551a6d6 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -29,11 +29,13 @@ from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
from jinja2 import Environment, Template
from pymacaroons.exceptions import (
MacaroonDeserializationException,
+ MacaroonInitException,
MacaroonInvalidSignatureException,
)
from typing_extensions import TypedDict
from twisted.web.client import readBody
+from twisted.web.http_headers import Headers
from synapse.config import ConfigError
from synapse.config.oidc_config import (
@@ -216,7 +218,7 @@ class OidcHandler:
session_data = self._token_generator.verify_oidc_session_token(
session, state
)
- except (MacaroonDeserializationException, KeyError) as e:
+ except (MacaroonInitException, MacaroonDeserializationException, KeyError) as e:
logger.exception("Invalid session for OIDC callback")
self._sso_handler.render_error(request, "invalid_session", str(e))
return
@@ -329,6 +331,9 @@ class OidcProvider:
# optional brand identifier for this auth provider
self.idp_brand = provider.idp_brand
+ # Optional brand identifier for the unstable API (see MSC2858).
+ self.unstable_idp_brand = provider.unstable_idp_brand
+
self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self)
@@ -538,7 +543,7 @@ class OidcProvider:
"""
metadata = await self.load_metadata()
token_endpoint = metadata.get("token_endpoint")
- headers = {
+ raw_headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": self._http_client.user_agent,
"Accept": "application/json",
@@ -552,10 +557,10 @@ class OidcProvider:
body = urlencode(args, True)
# Fill the body/headers with credentials
- uri, headers, body = self._client_auth.prepare(
- method="POST", uri=token_endpoint, headers=headers, body=body
+ uri, raw_headers, body = self._client_auth.prepare(
+ method="POST", uri=token_endpoint, headers=raw_headers, body=body
)
- headers = {k: [v] for (k, v) in headers.items()}
+ headers = Headers({k: [v] for (k, v) in raw_headers.items()})
# Do the actual request
# We're not using the SimpleHttpClient util methods as we don't want to
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index b66f8756b8..1abc8875cb 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -16,7 +16,7 @@
"""Contains functions for registering clients."""
import logging
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
from prometheus_client import Counter
@@ -82,6 +82,7 @@ class RegistrationHandler(BaseHandler):
)
else:
self.device_handler = hs.get_device_handler()
+ self._register_device_client = self.register_device_inner
self.pusher_pool = hs.get_pusherpool()
self.session_lifetime = hs.config.session_lifetime
@@ -197,8 +198,7 @@ class RegistrationHandler(BaseHandler):
admin api, otherwise False.
user_agent_ips: Tuples of IP addresses and user-agents used
during the registration process.
- auth_provider_id: The SSO IdP the user used, if any (just used for the
- prometheus metrics).
+ auth_provider_id: The SSO IdP the user used, if any.
Returns:
The registered user_id.
Raises:
@@ -210,6 +210,7 @@ class RegistrationHandler(BaseHandler):
threepid,
localpart,
user_agent_ips or [],
+ auth_provider_id=auth_provider_id,
)
if result == RegistrationBehaviour.DENY:
@@ -678,17 +679,35 @@ class RegistrationHandler(BaseHandler):
Returns:
Tuple of device ID and access token
"""
+ res = await self._register_device_client(
+ user_id=user_id,
+ device_id=device_id,
+ initial_display_name=initial_display_name,
+ is_guest=is_guest,
+ is_appservice_ghost=is_appservice_ghost,
+ )
- if self.hs.config.worker_app:
- r = await self._register_device_client(
- user_id=user_id,
- device_id=device_id,
- initial_display_name=initial_display_name,
- is_guest=is_guest,
- is_appservice_ghost=is_appservice_ghost,
- )
- return r["device_id"], r["access_token"]
+ login_counter.labels(
+ guest=is_guest,
+ auth_provider=(auth_provider_id or ""),
+ ).inc()
+
+ return res["device_id"], res["access_token"]
+ async def register_device_inner(
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ initial_display_name: Optional[str],
+ is_guest: bool = False,
+ is_appservice_ghost: bool = False,
+ ) -> Dict[str, str]:
+ """Helper for register_device
+
+ Does the bits that need doing on the main process. Not for use outside this
+ class and RegisterDeviceReplicationServlet.
+ """
+ assert not self.hs.config.worker_app
valid_until_ms = None
if self.session_lifetime is not None:
if is_guest:
@@ -713,12 +732,7 @@ class RegistrationHandler(BaseHandler):
is_appservice_ghost=is_appservice_ghost,
)
- login_counter.labels(
- guest=is_guest,
- auth_provider=(auth_provider_id or ""),
- ).inc()
-
- return (registered_device_id, access_token)
+ return {"device_id": registered_device_id, "access_token": access_token}
async def post_registration_actions(
self, user_id: str, auth_result: dict, access_token: Optional[str]
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index a9645b77d8..ec2ba11c75 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -81,6 +81,7 @@ class SamlHandler(BaseHandler):
# the SsoIdentityProvider protocol type.
self.idp_icon = None
self.idp_brand = None
+ self.unstable_idp_brand = None
# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 6ef459acff..415b1c2d17 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -98,6 +98,11 @@ class SsoIdentityProvider(Protocol):
"""Optional branding identifier"""
return None
+ @property
+ def unstable_idp_brand(self) -> Optional[str]:
+ """Optional brand identifier for the unstable API (see MSC2858)."""
+ return None
+
@abc.abstractmethod
async def handle_redirect_request(
self,
diff --git a/synapse/http/client.py b/synapse/http/client.py
index af34d583ad..1e01e0a9f2 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -39,12 +39,15 @@ from zope.interface import implementer, provider
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, error as twisted_error, protocol, ssl
+from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import (
IAddress,
IHostResolution,
IReactorPluggableNameResolver,
IResolutionReceiver,
+ ITCPTransport,
)
+from twisted.internet.protocol import connectionDone
from twisted.internet.task import Cooperator
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
@@ -56,7 +59,13 @@ from twisted.web.client import (
)
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
-from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
+from twisted.web.iweb import (
+ UNKNOWN_LENGTH,
+ IAgent,
+ IBodyProducer,
+ IPolicyForHTTPS,
+ IResponse,
+)
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@@ -151,16 +160,17 @@ class _IPBlacklistingResolver:
def resolveHostName(
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
) -> IResolutionReceiver:
-
- r = recv()
addresses = [] # type: List[IAddress]
def _callback() -> None:
- r.resolutionBegan(None)
-
has_bad_ip = False
- for i in addresses:
- ip_address = IPAddress(i.host)
+ for address in addresses:
+ # We only expect IPv4 and IPv6 addresses since only A/AAAA lookups
+ # should go through this path.
+ if not isinstance(address, (IPv4Address, IPv6Address)):
+ continue
+
+ ip_address = IPAddress(address.host)
if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
@@ -175,15 +185,15 @@ class _IPBlacklistingResolver:
# request, but all we can really do from here is claim that there were no
# valid results.
if not has_bad_ip:
- for i in addresses:
- r.addressResolved(i)
- r.resolutionComplete()
+ for address in addresses:
+ recv.addressResolved(address)
+ recv.resolutionComplete()
@provider(IResolutionReceiver)
class EndpointReceiver:
@staticmethod
def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
- pass
+ recv.resolutionBegan(resolutionInProgress)
@staticmethod
def addressResolved(address: IAddress) -> None:
@@ -197,7 +207,7 @@ class _IPBlacklistingResolver:
EndpointReceiver, hostname, portNumber=portNumber
)
- return r
+ return recv
@implementer(ISynapseReactor)
@@ -346,7 +356,7 @@ class SimpleHttpClient:
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
use_proxy=use_proxy,
- )
+ ) # type: IAgent
if self._ip_blacklist:
# If we have an IP blacklist, we then install the blacklisting Agent
@@ -752,6 +762,8 @@ class BodyExceededMaxSize(Exception):
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which immediately errors upon receiving data."""
+ transport = None # type: Optional[ITCPTransport]
+
def __init__(self, deferred: defer.Deferred):
self.deferred = deferred
@@ -763,18 +775,21 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
+ assert self.transport is not None
self.transport.abortConnection()
def dataReceived(self, data: bytes) -> None:
self._maybe_fail()
- def connectionLost(self, reason: Failure) -> None:
+ def connectionLost(self, reason: Failure = connectionDone) -> None:
self._maybe_fail()
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
+ transport = None # type: Optional[ITCPTransport]
+
def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
):
@@ -797,9 +812,10 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
+ assert self.transport is not None
self.transport.abortConnection()
- def connectionLost(self, reason: Failure) -> None:
+ def connectionLost(self, reason: Failure = connectionDone) -> None:
# If the maximum size was already exceeded, there's nothing to do.
if self.deferred.called:
return
@@ -868,6 +884,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by
return query_str.encode("utf8")
+@implementer(IPolicyForHTTPS)
class InsecureInterceptableContextFactory(ssl.ContextFactory):
"""
Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index 174ca7be5a..643492ceaf 100644
--- a/synapse/logging/_remote.py
+++ b/synapse/logging/_remote.py
@@ -32,8 +32,9 @@ from twisted.internet.endpoints import (
TCP4ClientEndpoint,
TCP6ClientEndpoint,
)
-from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
+from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
from twisted.internet.protocol import Factory, Protocol
+from twisted.internet.tcp import Connection
from twisted.python.failure import Failure
logger = logging.getLogger(__name__)
@@ -52,7 +53,9 @@ class LogProducer:
format: A callable to format the log record to a string.
"""
- transport = attr.ib(type=ITransport)
+ # This is essentially ITCPTransport, but that is missing certain fields
+ # (connected and registerProducer) which are part of the implementation.
+ transport = attr.ib(type=Connection)
_format = attr.ib(type=Callable[[logging.LogRecord], str])
_buffer = attr.ib(type=deque)
_paused = attr.ib(default=False, type=bool, init=False)
@@ -149,8 +152,6 @@ class RemoteHandler(logging.Handler):
if self._connection_waiter:
return
- self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
-
def fail(failure: Failure) -> None:
# If the Deferred was cancelled (e.g. during shutdown) do not try to
# reconnect (this will cause an infinite loop of errors).
@@ -163,9 +164,13 @@ class RemoteHandler(logging.Handler):
self._connect()
def writer(result: Protocol) -> None:
+ # Force recognising transport as a Connection and not the more
+ # generic ITransport.
+ transport = result.transport # type: Connection # type: ignore
+
# We have a connection. If we already have a producer, and its
# transport is the same, just trigger a resumeProducing.
- if self._producer and result.transport is self._producer.transport:
+ if self._producer and transport is self._producer.transport:
self._producer.resumeProducing()
self._connection_waiter = None
return
@@ -177,14 +182,16 @@ class RemoteHandler(logging.Handler):
# Make a new producer and start it.
self._producer = LogProducer(
buffer=self._buffer,
- transport=result.transport,
+ transport=transport,
format=self.format,
)
- result.transport.registerProducer(self._producer, True)
+ transport.registerProducer(self._producer, True)
self._producer.resumeProducing()
self._connection_waiter = None
- self._connection_waiter.addCallbacks(writer, fail)
+ deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred
+ deferred.addCallbacks(writer, fail)
+ self._connection_waiter = deferred
def _handle_pressure(self) -> None:
"""
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 5fec2aaf5d..3dc06a79e8 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -16,8 +16,8 @@
import logging
from typing import TYPE_CHECKING, Dict, List, Optional
-from twisted.internet.base import DelayedCall
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
+from twisted.internet.interfaces import IDelayedCall
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher, PusherConfig, ThrottleParams
@@ -66,7 +66,7 @@ class EmailPusher(Pusher):
self.store = self.hs.get_datastore()
self.email = pusher_config.pushkey
- self.timed_call = None # type: Optional[DelayedCall]
+ self.timed_call = None # type: Optional[IDelayedCall]
self.throttle_params = {} # type: Dict[str, ThrottleParams]
self._inited = False
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 36071feb36..4ec1bfa6ea 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -61,7 +61,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_guest = content["is_guest"]
is_appservice_ghost = content["is_appservice_ghost"]
- device_id, access_token = await self.registration_handler.register_device(
+ res = await self.registration_handler.register_device_inner(
user_id,
device_id,
initial_display_name,
@@ -69,7 +69,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_appservice_ghost=is_appservice_ghost,
)
- return 200, {"device_id": device_id, "access_token": access_token}
+ return 200, res
def register_servlets(hs, http_server):
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index a7245da152..a8894beadf 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -48,7 +48,7 @@ from synapse.replication.tcp.commands import (
UserIpCommand,
UserSyncCommand,
)
-from synapse.replication.tcp.protocol import AbstractConnection
+from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams import (
STREAMS_MAP,
AccountDataStream,
@@ -82,7 +82,7 @@ user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache"
# the type of the entries in _command_queues_by_stream
_StreamCommandQueue = Deque[
- Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
+ Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection]
]
@@ -174,7 +174,7 @@ class ReplicationCommandHandler:
# The currently connected connections. (The list of places we need to send
# outgoing replication commands to.)
- self._connections = [] # type: List[AbstractConnection]
+ self._connections = [] # type: List[IReplicationConnection]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
@@ -197,7 +197,7 @@ class ReplicationCommandHandler:
# For each connection, the incoming stream names that have received a POSITION
# from that connection.
- self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
+ self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]]
LaterGauge(
"synapse_replication_tcp_command_queue",
@@ -220,7 +220,7 @@ class ReplicationCommandHandler:
self._server_notices_sender = hs.get_server_notices_sender()
def _add_command_to_stream_queue(
- self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
+ self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None:
"""Queue the given received command for processing
@@ -267,7 +267,7 @@ class ReplicationCommandHandler:
async def _process_command(
self,
cmd: Union[PositionCommand, RdataCommand],
- conn: AbstractConnection,
+ conn: IReplicationConnection,
stream_name: str,
) -> None:
if isinstance(cmd, PositionCommand):
@@ -302,7 +302,7 @@ class ReplicationCommandHandler:
hs, outbound_redis_connection
)
hs.get_reactor().connectTCP(
- hs.config.redis.redis_host,
+ hs.config.redis.redis_host.encode(),
hs.config.redis.redis_port,
self._factory,
)
@@ -311,7 +311,7 @@ class ReplicationCommandHandler:
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
- hs.get_reactor().connectTCP(host, port, self._factory)
+ hs.get_reactor().connectTCP(host.encode(), port, self._factory)
def get_streams(self) -> Dict[str, Stream]:
"""Get a map from stream name to all streams."""
@@ -321,10 +321,10 @@ class ReplicationCommandHandler:
"""Get a list of streams that this instances replicates."""
return self._streams_to_replicate
- def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
+ def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
self.send_positions_to_connection(conn)
- def send_positions_to_connection(self, conn: AbstractConnection):
+ def send_positions_to_connection(self, conn: IReplicationConnection):
"""Send current position of all streams this process is source of to
the connection.
"""
@@ -347,7 +347,7 @@ class ReplicationCommandHandler:
)
def on_USER_SYNC(
- self, conn: AbstractConnection, cmd: UserSyncCommand
+ self, conn: IReplicationConnection, cmd: UserSyncCommand
) -> Optional[Awaitable[None]]:
user_sync_counter.inc()
@@ -359,21 +359,23 @@ class ReplicationCommandHandler:
return None
def on_CLEAR_USER_SYNC(
- self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
+ self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand
) -> Optional[Awaitable[None]]:
if self._is_master:
return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
else:
return None
- def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
+ def on_FEDERATION_ACK(
+ self, conn: IReplicationConnection, cmd: FederationAckCommand
+ ):
federation_ack_counter.inc()
if self._federation_sender:
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
def on_USER_IP(
- self, conn: AbstractConnection, cmd: UserIpCommand
+ self, conn: IReplicationConnection, cmd: UserIpCommand
) -> Optional[Awaitable[None]]:
user_ip_cache_counter.inc()
@@ -395,7 +397,7 @@ class ReplicationCommandHandler:
assert self._server_notices_sender is not None
await self._server_notices_sender.on_user_ip(cmd.user_id)
- def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
+ def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes
return
@@ -412,7 +414,7 @@ class ReplicationCommandHandler:
self._add_command_to_stream_queue(conn, cmd)
async def _process_rdata(
- self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
+ self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand
) -> None:
"""Process an RDATA command
@@ -486,7 +488,7 @@ class ReplicationCommandHandler:
stream_name, instance_name, token, rows
)
- def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
+ def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
# Ignore POSITION that are just our own echoes
return
@@ -496,7 +498,7 @@ class ReplicationCommandHandler:
self._add_command_to_stream_queue(conn, cmd)
async def _process_position(
- self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
+ self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand
) -> None:
"""Process a POSITION command
@@ -553,7 +555,9 @@ class ReplicationCommandHandler:
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
- def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
+ def on_REMOTE_SERVER_UP(
+ self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
+ ):
""""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)
@@ -576,7 +580,7 @@ class ReplicationCommandHandler:
# between two instances, but that is not currently supported).
self.send_command(cmd, ignore_conn=conn)
- def new_connection(self, connection: AbstractConnection):
+ def new_connection(self, connection: IReplicationConnection):
"""Called when we have a new connection."""
self._connections.append(connection)
@@ -603,7 +607,7 @@ class ReplicationCommandHandler:
UserSyncCommand(self._instance_id, user_id, True, now)
)
- def lost_connection(self, connection: AbstractConnection):
+ def lost_connection(self, connection: IReplicationConnection):
"""Called when a connection is closed/lost."""
# we no longer need _streams_by_connection for this connection.
streams = self._streams_by_connection.pop(connection, None)
@@ -624,7 +628,7 @@ class ReplicationCommandHandler:
return bool(self._connections)
def send_command(
- self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
+ self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
):
"""Send a command to all connected connections.
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index e0b4ad314d..825900f64c 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
-import abc
import fcntl
import logging
import struct
@@ -54,8 +53,10 @@ from inspect import isawaitable
from typing import TYPE_CHECKING, List, Optional
from prometheus_client import Counter
+from zope.interface import Interface, implementer
from twisted.internet import task
+from twisted.internet.tcp import Connection
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
@@ -121,6 +122,14 @@ class ConnectionStates:
CLOSED = "closed"
+class IReplicationConnection(Interface):
+ """An interface for replication connections."""
+
+ def send_command(cmd: Command):
+ """Send the command down the connection"""
+
+
+@implementer(IReplicationConnection)
class BaseReplicationStreamProtocol(LineOnlyReceiver):
"""Base replication protocol shared between client and server.
@@ -137,6 +146,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
(if they send a `PING` command)
"""
+ # The transport is going to be an ITCPTransport, but that doesn't have the
+ # (un)registerProducer methods, those are only on the implementation.
+ transport = None # type: Connection
+
delimiter = b"\n"
# Valid commands we expect to receive
@@ -181,6 +194,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
connected_connections.append(self) # Register connection for metrics
+ assert self.transport is not None
self.transport.registerProducer(self, True) # For the *Producing callbacks
self._send_pending_commands()
@@ -205,6 +219,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
logger.info(
"[%s] Failed to close connection gracefully, aborting", self.id()
)
+ assert self.transport is not None
self.transport.abortConnection()
else:
if now - self.last_sent_command >= PING_TIME:
@@ -294,6 +309,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def close(self):
logger.warning("[%s] Closing connection", self.id())
self.time_we_closed = self.clock.time_msec()
+ assert self.transport is not None
self.transport.loseConnection()
self.on_connection_closed()
@@ -391,6 +407,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def connectionLost(self, reason):
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
if isinstance(reason, Failure):
+ assert reason.type is not None
connection_close_counter.labels(reason.type.__name__).inc()
else:
connection_close_counter.labels(reason.__class__.__name__).inc()
@@ -495,20 +512,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_command(ReplicateCommand())
-class AbstractConnection(abc.ABC):
- """An interface for replication connections."""
-
- @abc.abstractmethod
- def send_command(self, cmd: Command):
- """Send the command down the connection"""
- pass
-
-
-# This tells python that `BaseReplicationStreamProtocol` implements the
-# interface.
-AbstractConnection.register(BaseReplicationStreamProtocol)
-
-
# The following simply registers metrics for the replication connections
pending_commands = LaterGauge(
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 7560706b4b..2f4d407f94 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -19,6 +19,11 @@ from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast
import attr
import txredisapi
+from zope.interface import implementer
+
+from twisted.internet.address import IPv4Address, IPv6Address
+from twisted.internet.interfaces import IAddress, IConnector
+from twisted.python.failure import Failure
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import (
@@ -32,7 +37,7 @@ from synapse.replication.tcp.commands import (
parse_command_from_line,
)
from synapse.replication.tcp.protocol import (
- AbstractConnection,
+ IReplicationConnection,
tcp_inbound_commands_counter,
tcp_outbound_commands_counter,
)
@@ -62,7 +67,8 @@ class ConstantProperty(Generic[T, V]):
pass
-class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
+@implementer(IReplicationConnection)
+class RedisSubscriber(txredisapi.SubscriberProtocol):
"""Connection to redis subscribed to replication stream.
This class fulfils two functions:
@@ -71,7 +77,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
connection, parsing *incoming* messages into replication commands, and passing them
to `ReplicationCommandHandler`
- (b) it implements the AbstractConnection API, where it sends *outgoing* commands
+ (b) it implements the IReplicationConnection API, where it sends *outgoing* commands
onto outbound_redis_connection.
Due to the vagaries of `txredisapi` we don't want to have a custom
@@ -253,6 +259,37 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
except Exception:
logger.warning("Failed to send ping to a redis connection")
+ # ReconnectingClientFactory has some logging (if you enable `self.noisy`), but
+ # it's rubbish. We add our own here.
+
+ def startedConnecting(self, connector: IConnector):
+ logger.info(
+ "Connecting to redis server %s", format_address(connector.getDestination())
+ )
+ super().startedConnecting(connector)
+
+ def clientConnectionFailed(self, connector: IConnector, reason: Failure):
+ logger.info(
+ "Connection to redis server %s failed: %s",
+ format_address(connector.getDestination()),
+ reason.value,
+ )
+ super().clientConnectionFailed(connector, reason)
+
+ def clientConnectionLost(self, connector: IConnector, reason: Failure):
+ logger.info(
+ "Connection to redis server %s lost: %s",
+ format_address(connector.getDestination()),
+ reason.value,
+ )
+ super().clientConnectionLost(connector, reason)
+
+
+def format_address(address: IAddress) -> str:
+ if isinstance(address, (IPv4Address, IPv6Address)):
+ return "%s:%i" % (address.host, address.port)
+ return str(address)
+
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately
@@ -328,6 +365,6 @@ def lazyConnection(
factory.continueTrying = reconnect
reactor = hs.get_reactor()
- reactor.connectTCP(host, port, factory, timeout=30, bindAddress=None)
+ reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None)
return factory.handler
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index e09234c644..7681e55b58 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -15,10 +15,9 @@
import re
-import twisted.web.server
-
-import synapse.api.auth
+from synapse.api.auth import Auth
from synapse.api.errors import AuthError
+from synapse.http.site import SynapseRequest
from synapse.types import UserID
@@ -37,13 +36,11 @@ def admin_patterns(path_regex: str, version: str = "v1"):
return patterns
-async def assert_requester_is_admin(
- auth: synapse.api.auth.Auth, request: twisted.web.server.Request
-) -> None:
+async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None:
"""Verify that the requester is an admin user
Args:
- auth: api.auth.Auth singleton
+ auth: Auth singleton
request: incoming request
Raises:
@@ -53,11 +50,11 @@ async def assert_requester_is_admin(
await assert_user_is_admin(auth, requester.user)
-async def assert_user_is_admin(auth: synapse.api.auth.Auth, user_id: UserID) -> None:
+async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None:
"""Verify that the given user is an admin user
Args:
- auth: api.auth.Auth singleton
+ auth: Auth singleton
user_id: user to check
Raises:
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 511c859f64..7fcc48a9d7 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -17,10 +17,9 @@
import logging
from typing import TYPE_CHECKING, Tuple
-from twisted.web.server import Request
-
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
+from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
@@ -50,7 +49,9 @@ class QuarantineMediaInRoom(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -75,7 +76,9 @@ class QuarantineMediaByUser(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
+ async def on_POST(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -103,7 +106,7 @@ class QuarantineMediaByID(RestServlet):
self.auth = hs.get_auth()
async def on_POST(
- self, request: Request, server_name: str, media_id: str
+ self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -127,7 +130,9 @@ class ProtectMediaByID(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
+ async def on_POST(
+ self, request: SynapseRequest, media_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -148,7 +153,9 @@ class ListMediaInRoom(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
is_admin = await self.auth.is_server_admin(requester.user)
if not is_admin:
@@ -166,7 +173,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
self.media_repository = hs.get_media_repository()
self.auth = hs.get_auth()
- async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
@@ -189,7 +196,7 @@ class DeleteMediaByID(RestServlet):
self.media_repository = hs.get_media_repository()
async def on_DELETE(
- self, request: Request, server_name: str, media_id: str
+ self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
@@ -218,7 +225,9 @@ class DeleteMediaByDateSize(RestServlet):
self.server_name = hs.hostname
self.media_repository = hs.get_media_repository()
- async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
+ async def on_POST(
+ self, request: SynapseRequest, server_name: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index f2c42a0f30..263d8ec076 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -685,7 +685,10 @@ class RoomEventContextServlet(RestServlet):
results["events_after"], time_now
)
results["state"] = await self._event_serializer.serialize_events(
- results["state"], time_now
+ results["state"],
+ time_now,
+ # No need to bundle aggregations for state events
+ bundle_aggregations=False,
)
return 200, results
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 267a993430..2c89b62e25 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -269,7 +269,10 @@ class UserRestServletV2(RestServlet):
target_user.to_string(), False, requester, by_admin=True
)
elif not deactivate and user["deactivated"]:
- if "password" not in body:
+ if (
+ "password" not in body
+ and self.hs.config.password_localdb_enabled
+ ):
raise SynapseError(
400, "Must provide a password to re-activate an account."
)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 34bc1bd49b..e4c352f572 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,10 +14,12 @@
# limitations under the License.
import logging
+import re
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
+from synapse.api.urls import CLIENT_API_PREFIX
from synapse.appservice import ApplicationService
from synapse.handlers.sso import SsoIdentityProvider
from synapse.http import get_request_uri
@@ -94,11 +96,21 @@ class LoginRestServlet(RestServlet):
flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
- sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict
+ sso_flow = {
+ "type": LoginRestServlet.SSO_TYPE,
+ "identity_providers": [
+ _get_auth_flow_dict_for_idp(
+ idp,
+ )
+ for idp in self._sso_handler.get_identity_providers().values()
+ ],
+ } # type: JsonDict
if self._msc2858_enabled:
+ # backwards-compatibility support for clients which don't
+ # support the stable API yet
sso_flow["org.matrix.msc2858.identity_providers"] = [
- _get_auth_flow_dict_for_idp(idp)
+ _get_auth_flow_dict_for_idp(idp, use_unstable_brands=True)
for idp in self._sso_handler.get_identity_providers().values()
]
@@ -331,22 +343,38 @@ class LoginRestServlet(RestServlet):
return result
-def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
+def _get_auth_flow_dict_for_idp(
+ idp: SsoIdentityProvider, use_unstable_brands: bool = False
+) -> JsonDict:
"""Return an entry for the login flow dict
Returns an entry suitable for inclusion in "identity_providers" in the
response to GET /_matrix/client/r0/login
+
+ Args:
+ idp: the identity provider to describe
+ use_unstable_brands: whether we should use brand identifiers suitable
+ for the unstable API
"""
e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
if idp.idp_icon:
e["icon"] = idp.idp_icon
if idp.idp_brand:
e["brand"] = idp.idp_brand
+ # use the stable brand identifier if the unstable identifier isn't defined.
+ if use_unstable_brands and idp.unstable_idp_brand:
+ e["brand"] = idp.unstable_idp_brand
return e
class SsoRedirectServlet(RestServlet):
- PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
+ PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
+ re.compile(
+ "^"
+ + CLIENT_API_PREFIX
+ + "/r0/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
+ )
+ ]
def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they
@@ -364,7 +392,8 @@ class SsoRedirectServlet(RestServlet):
def register(self, http_server: HttpServer) -> None:
super().register(http_server)
if self._msc2858_enabled:
- # expose additional endpoint for MSC2858 support
+ # expose additional endpoint for MSC2858 support: backwards-compat support
+ # for clients which don't yet support the stable endpoints.
http_server.register_paths(
"GET",
client_patterns(
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 9a1df30c29..5884daea6d 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -671,7 +671,10 @@ class RoomEventContextServlet(RestServlet):
results["events_after"], time_now
)
results["state"] = await self._event_serializer.serialize_events(
- results["state"], time_now
+ results["state"],
+ time_now,
+ # No need to bundle aggregations for state events
+ bundle_aggregations=False,
)
return 200, results
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 7aea4cebf5..5901432fad 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -32,6 +32,7 @@ from synapse.http.servlet import (
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
from synapse.types import GroupID, JsonDict
from ._base import client_patterns
@@ -70,7 +71,9 @@ class GroupServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -81,7 +84,9 @@ class GroupServlet(RestServlet):
return 200, group_description
@_validate_group_id
- async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_POST(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -111,7 +116,9 @@ class GroupSummaryServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -144,7 +151,11 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: Request, group_id: str, category_id: Optional[str], room_id: str
+ self,
+ request: SynapseRequest,
+ group_id: str,
+ category_id: Optional[str],
+ room_id: str,
):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -176,7 +187,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
- self, request: Request, group_id: str, category_id: str, room_id: str
+ self, request: SynapseRequest, group_id: str, category_id: str, room_id: str
):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -206,7 +217,7 @@ class GroupCategoryServlet(RestServlet):
@_validate_group_id
async def on_GET(
- self, request: Request, group_id: str, category_id: str
+ self, request: SynapseRequest, group_id: str, category_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -219,7 +230,7 @@ class GroupCategoryServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: Request, group_id: str, category_id: str
+ self, request: SynapseRequest, group_id: str, category_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -247,7 +258,7 @@ class GroupCategoryServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
- self, request: Request, group_id: str, category_id: str
+ self, request: SynapseRequest, group_id: str, category_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -274,7 +285,9 @@ class GroupCategoriesServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -298,7 +311,7 @@ class GroupRoleServlet(RestServlet):
@_validate_group_id
async def on_GET(
- self, request: Request, group_id: str, role_id: str
+ self, request: SynapseRequest, group_id: str, role_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -311,7 +324,7 @@ class GroupRoleServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: Request, group_id: str, role_id: str
+ self, request: SynapseRequest, group_id: str, role_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -339,7 +352,7 @@ class GroupRoleServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
- self, request: Request, group_id: str, role_id: str
+ self, request: SynapseRequest, group_id: str, role_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -366,7 +379,9 @@ class GroupRolesServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -399,7 +414,11 @@ class GroupSummaryUsersRoleServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: Request, group_id: str, role_id: Optional[str], user_id: str
+ self,
+ request: SynapseRequest,
+ group_id: str,
+ role_id: Optional[str],
+ user_id: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -431,7 +450,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
- self, request: Request, group_id: str, role_id: str, user_id: str
+ self, request: SynapseRequest, group_id: str, role_id: str, user_id: str
):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -458,7 +477,9 @@ class GroupRoomServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -481,7 +502,9 @@ class GroupUsersServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -504,7 +527,9 @@ class GroupInvitedUsersServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -526,7 +551,9 @@ class GroupSettingJoinPolicyServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -554,7 +581,7 @@ class GroupCreateServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
self.server_name = hs.hostname
- async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -598,7 +625,7 @@ class GroupAdminRoomsServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: Request, group_id: str, room_id: str
+ self, request: SynapseRequest, group_id: str, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -615,7 +642,7 @@ class GroupAdminRoomsServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
- self, request: Request, group_id: str, room_id: str
+ self, request: SynapseRequest, group_id: str, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -646,7 +673,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: Request, group_id: str, room_id: str, config_key: str
+ self, request: SynapseRequest, group_id: str, room_id: str, config_key: str
):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -678,7 +705,9 @@ class GroupAdminUsersInviteServlet(RestServlet):
self.is_mine_id = hs.is_mine_id
@_validate_group_id
- async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
+ async def on_PUT(
+ self, request: SynapseRequest, group_id, user_id
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -708,7 +737,9 @@ class GroupAdminUsersKickServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
+ async def on_PUT(
+ self, request: SynapseRequest, group_id, user_id
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -735,7 +766,9 @@ class GroupSelfLeaveServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -762,7 +795,9 @@ class GroupSelfJoinServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -789,7 +824,9 @@ class GroupSelfAcceptInviteServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -816,7 +853,9 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
self.store = hs.get_datastore()
@_validate_group_id
- async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -839,7 +878,9 @@ class PublicisedGroupsForUserServlet(RestServlet):
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
result = await self.groups_handler.get_publicised_groups_for_user(user_id)
@@ -859,7 +900,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
@@ -881,7 +922,7 @@ class GroupsForUserServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index 9039662f7e..1eff98ef14 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, respond_with_json
+from synapse.http.site import SynapseRequest
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
@@ -35,7 +36,7 @@ class MediaConfigResource(DirectServeJsonResource):
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size}
- async def _async_render_GET(self, request: Request) -> None:
+ async def _async_render_GET(self, request: SynapseRequest) -> None:
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 0641924f18..8b4841ed5d 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -35,6 +35,7 @@ from synapse.api.errors import (
from synapse.config._base import ConfigError
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import random_string
@@ -145,7 +146,7 @@ class MediaRepository:
upload_name: Optional[str],
content: IO,
content_length: int,
- auth_user: str,
+ auth_user: UserID,
) -> str:
"""Store uploaded content for a local user and return the mxc URL
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index a074e807dc..b8895aeaa9 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -39,6 +39,7 @@ from synapse.http.server import (
respond_with_json_bytes,
)
from synapse.http.servlet import parse_integer, parse_string
+from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
@@ -185,7 +186,7 @@ class PreviewUrlResource(DirectServeJsonResource):
request.setHeader(b"Allow", b"OPTIONS, GET")
respond_with_json(request, 200, {}, send_cors=True)
- async def _async_render_GET(self, request: Request) -> None:
+ async def _async_render_GET(self, request: SynapseRequest) -> None:
# XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request)
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 5e104fac40..ae5aef2f7f 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -22,6 +22,7 @@ from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
from synapse.rest.media.v1.media_storage import SpamMediaException
if TYPE_CHECKING:
@@ -49,7 +50,7 @@ class UploadResource(DirectServeJsonResource):
async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True)
- async def _async_render_POST(self, request: Request) -> None:
+ async def _async_render_POST(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
diff --git a/synapse/rest/synapse/client/saml2/response_resource.py b/synapse/rest/synapse/client/saml2/response_resource.py
index f6668fb5e3..4dfadf1bfb 100644
--- a/synapse/rest/synapse/client/saml2/response_resource.py
+++ b/synapse/rest/synapse/client/saml2/response_resource.py
@@ -14,24 +14,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
+
from synapse.http.server import DirectServeHtmlResource
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class SAML2ResponseResource(DirectServeHtmlResource):
"""A Twisted web resource which handles the SAML response"""
isLeaf = 1
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self._saml_handler = hs.get_saml_handler()
+ self._sso_handler = hs.get_sso_handler()
async def _async_render_GET(self, request):
# We're not expecting any GET request on that resource if everything goes right,
# but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
# In this case, just tell the user that something went wrong and they should
# try to authenticate again.
- self._saml_handler._render_error(
+ self._sso_handler.render_error(
request, "unexpected_get", "Unexpected GET request on /saml2/authn_response"
)
diff --git a/synapse/server.py b/synapse/server.py
index 369cc88026..48ac87a124 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -351,11 +351,9 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_http_client_context_factory(self) -> IPolicyForHTTPS:
- return (
- InsecureInterceptableContextFactory()
- if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
- else RegularPolicyForHTTPS()
- )
+ if self.config.use_insecure_ssl_client_just_for_testing_do_not_use:
+ return InsecureInterceptableContextFactory()
+ return RegularPolicyForHTTPS()
@cache_in_self
def get_simple_http_client(self) -> SimpleHttpClient:
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 18ddb92fcc..332193ad1c 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -54,11 +54,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
) # type: LruCache[str, List[Tuple[str, int]]]
async def get_auth_chain(
- self, event_ids: Collection[str], include_given: bool = False
+ self, room_id: str, event_ids: Collection[str], include_given: bool = False
) -> List[EventBase]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
+ room_id: The room the event is in.
event_ids: state events
include_given: include the given events in result
@@ -66,24 +67,44 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
list of events
"""
event_ids = await self.get_auth_chain_ids(
- event_ids, include_given=include_given
+ room_id, event_ids, include_given=include_given
)
return await self.get_events_as_list(event_ids)
async def get_auth_chain_ids(
self,
+ room_id: str,
event_ids: Collection[str],
include_given: bool = False,
) -> List[str]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
+ room_id: The room the event is in.
event_ids: state events
include_given: include the given events in result
Returns:
- An awaitable which resolve to a list of event_ids
+ list of event_ids
"""
+
+ # Check if we have indexed the room so we can use the chain cover
+ # algorithm.
+ room = await self.get_room(room_id)
+ if room["has_auth_chain_index"]:
+ try:
+ return await self.db_pool.runInteraction(
+ "get_auth_chain_ids_chains",
+ self._get_auth_chain_ids_using_cover_index_txn,
+ room_id,
+ event_ids,
+ include_given,
+ )
+ except _NoChainCoverIndex:
+ # For whatever reason we don't actually have a chain cover index
+ # for the events in question, so we fall back to the old method.
+ pass
+
return await self.db_pool.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
@@ -91,9 +112,130 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
include_given,
)
+ def _get_auth_chain_ids_using_cover_index_txn(
+ self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
+ ) -> List[str]:
+ """Calculates the auth chain IDs using the chain index."""
+
+ # First we look up the chain ID/sequence numbers for the given events.
+
+ initial_events = set(event_ids)
+
+ # All the events that we've found that are reachable from the events.
+ seen_events = set() # type: Set[str]
+
+ # A map from chain ID to max sequence number of the given events.
+ event_chains = {} # type: Dict[int, int]
+
+ sql = """
+ SELECT event_id, chain_id, sequence_number
+ FROM event_auth_chains
+ WHERE %s
+ """
+ for batch in batch_iter(initial_events, 1000):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", batch
+ )
+ txn.execute(sql % (clause,), args)
+
+ for event_id, chain_id, sequence_number in txn:
+ seen_events.add(event_id)
+ event_chains[chain_id] = max(
+ sequence_number, event_chains.get(chain_id, 0)
+ )
+
+ # Check that we actually have a chain ID for all the events.
+ events_missing_chain_info = initial_events.difference(seen_events)
+ if events_missing_chain_info:
+ # This can happen due to e.g. downgrade/upgrade of the server. We
+ # raise an exception and fall back to the previous algorithm.
+ logger.info(
+ "Unexpectedly found that events don't have chain IDs in room %s: %s",
+ room_id,
+ events_missing_chain_info,
+ )
+ raise _NoChainCoverIndex(room_id)
+
+ # Now we look up all links for the chains we have, adding chains that
+ # are reachable from any event.
+ sql = """
+ SELECT
+ origin_chain_id, origin_sequence_number,
+ target_chain_id, target_sequence_number
+ FROM event_auth_chain_links
+ WHERE %s
+ """
+
+ # A map from chain ID to max sequence number *reachable* from any event ID.
+ chains = {} # type: Dict[int, int]
+
+ # Add all linked chains reachable from initial set of chains.
+ for batch in batch_iter(event_chains, 1000):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "origin_chain_id", batch
+ )
+ txn.execute(sql % (clause,), args)
+
+ for (
+ origin_chain_id,
+ origin_sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ ) in txn:
+ # chains are only reachable if the origin sequence number of
+ # the link is less than the max sequence number in the
+ # origin chain.
+ if origin_sequence_number <= event_chains.get(origin_chain_id, 0):
+ chains[target_chain_id] = max(
+ target_sequence_number,
+ chains.get(target_chain_id, 0),
+ )
+
+ # Add the initial set of chains, excluding the sequence corresponding to
+ # initial event.
+ for chain_id, seq_no in event_chains.items():
+ chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0))
+
+ # Now for each chain we figure out the maximum sequence number reachable
+ # from *any* event ID. Events with a sequence less than that are in the
+ # auth chain.
+ if include_given:
+ results = initial_events
+ else:
+ results = set()
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # We can use `execute_values` to efficiently fetch the gaps when
+ # using postgres.
+ sql = """
+ SELECT event_id
+ FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
+ WHERE
+ c.chain_id = l.chain_id
+ AND sequence_number <= max_seq
+ """
+
+ rows = txn.execute_values(sql, chains.items())
+ results.update(r for r, in rows)
+ else:
+ # For SQLite we just fall back to doing a noddy for loop.
+ sql = """
+ SELECT event_id FROM event_auth_chains
+ WHERE chain_id = ? AND sequence_number <= ?
+ """
+ for chain_id, max_no in chains.items():
+ txn.execute(sql, (chain_id, max_no))
+ results.update(r for r, in txn)
+
+ return list(results)
+
def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
) -> List[str]:
+ """Calculates the auth chain IDs.
+
+ This is used when we don't have a cover index for the room.
+ """
if include_given:
results = set(event_ids)
else:
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index cb6b1f8a0c..78367ea58d 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -135,6 +135,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self._chain_cover_index,
)
+ self.db_pool.updates.register_background_update_handler(
+ "purged_chain_cover",
+ self._purged_chain_cover_index,
+ )
+
async def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
@@ -932,3 +937,77 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
processed_count=count,
finished_room_map=finished_rooms,
)
+
+ async def _purged_chain_cover_index(self, progress: dict, batch_size: int) -> int:
+ """
+ A background updates that iterates over the chain cover and deletes the
+ chain cover for events that have been purged.
+
+ This may be due to fully purging a room or via setting a retention policy.
+ """
+ current_event_id = progress.get("current_event_id", "")
+
+ def purged_chain_cover_txn(txn) -> int:
+ # The event ID from events will be null if the chain ID / sequence
+ # number points to a purged event.
+ sql = """
+ SELECT event_id, chain_id, sequence_number, e.event_id IS NOT NULL
+ FROM event_auth_chains
+ LEFT JOIN events AS e USING (event_id)
+ WHERE event_id > ? ORDER BY event_auth_chains.event_id ASC LIMIT ?
+ """
+ txn.execute(sql, (current_event_id, batch_size))
+
+ rows = txn.fetchall()
+ if not rows:
+ return 0
+
+ # The event IDs and chain IDs / sequence numbers where the event has
+ # been purged.
+ unreferenced_event_ids = []
+ unreferenced_chain_id_tuples = []
+ event_id = ""
+ for event_id, chain_id, sequence_number, has_event in rows:
+ if not has_event:
+ unreferenced_event_ids.append((event_id,))
+ unreferenced_chain_id_tuples.append((chain_id, sequence_number))
+
+ # Delete the unreferenced auth chains from event_auth_chain_links and
+ # event_auth_chains.
+ txn.executemany(
+ """
+ DELETE FROM event_auth_chains WHERE event_id = ?
+ """,
+ unreferenced_event_ids,
+ )
+ # We should also delete matching target_*, but there is no index on
+ # target_chain_id. Hopefully any purged events are due to a room
+ # being fully purged and they will be removed from the origin_*
+ # searches.
+ txn.executemany(
+ """
+ DELETE FROM event_auth_chain_links WHERE
+ origin_chain_id = ? AND origin_sequence_number = ?
+ """,
+ unreferenced_chain_id_tuples,
+ )
+
+ progress = {
+ "current_event_id": event_id,
+ }
+
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "purged_chain_cover", progress
+ )
+
+ return len(rows)
+
+ result = await self.db_pool.runInteraction(
+ "_purged_chain_cover_index",
+ purged_chain_cover_txn,
+ )
+
+ if not result:
+ await self.db_pool.updates._end_background_update("purged_chain_cover")
+
+ return result
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index edbe42f2bf..c04e162ccc 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
+
import logging
import threading
from collections import namedtuple
@@ -1044,7 +1044,8 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
set[str]: The events we have already seen.
"""
- results = set()
+ # if the event cache contains the event, obviously we've seen it.
+ results = {x for x in event_ids if self._get_event_cache.contains(x)}
def have_seen_events_txn(txn, chunk):
sql = "SELECT event_id FROM events as e WHERE "
@@ -1052,12 +1053,9 @@ class EventsWorkerStore(SQLBaseStore):
txn.database_engine, "e.event_id", chunk
)
txn.execute(sql + clause, args)
- for (event_id,) in txn:
- results.add(event_id)
+ results.update(row[0] for row in txn)
- # break the input up into chunks of 100
- input_iterator = iter(event_ids)
- for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
+ for chunk in batch_iter((x for x in event_ids if x not in results), 100):
await self.db_pool.runInteraction(
"have_seen_events", have_seen_events_txn, chunk
)
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 0836e4af49..41f4fe7f95 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -331,13 +331,9 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
txn.executemany(
"""
DELETE FROM event_auth_chain_links WHERE
- (origin_chain_id = ? AND origin_sequence_number = ?) OR
- (target_chain_id = ? AND target_sequence_number = ?)
+ origin_chain_id = ? AND origin_sequence_number = ?
""",
- (
- (chain_id, seq_num, chain_id, seq_num)
- for (chain_id, seq_num) in referenced_chain_id_tuples
- ),
+ referenced_chain_id_tuples,
)
# Now we delete tables which lack an index on room_id but have one on event_id
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 61a7556e56..eba66ff352 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@
# limitations under the License.
import logging
import re
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import attr
@@ -1510,7 +1510,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
async def user_delete_access_tokens(
self,
user_id: str,
- except_token_id: Optional[str] = None,
+ except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
) -> List[Tuple[str, int, Optional[str]]]:
"""
@@ -1533,7 +1533,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items)
- values = [v for _, v in items]
+ values = [v for _, v in items] # type: List[Union[str, int]]
if except_token_id:
where_clause += " AND id != ?"
values.append(except_token_id)
diff --git a/synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql b/synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql
new file mode 100644
index 0000000000..87cb1f3cfd
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql
@@ -0,0 +1,17 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5910, 'purged_chain_cover', '{}');
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index b921d63d30..0309661841 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -350,11 +350,11 @@ class TransactionStore(TransactionWorkerStore):
self.db_pool.simple_upsert_many_txn(
txn,
- "destination_rooms",
- ["destination", "room_id"],
- rows,
- ["stream_ordering"],
- [(stream_ordering,)] * len(rows),
+ table="destination_rooms",
+ key_names=("destination", "room_id"),
+ key_values=rows,
+ value_names=["stream_ordering"],
+ value_values=[(stream_ordering,)] * len(rows),
)
async def get_destination_last_successful_stream_ordering(
diff --git a/synapse/types.py b/synapse/types.py
index 0216d213c7..b08ce90140 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -83,33 +83,32 @@ class ISynapseReactor(
"""The interfaces necessary for Synapse to function."""
-class Requester(
- namedtuple(
- "Requester",
- [
- "user",
- "access_token_id",
- "is_guest",
- "shadow_banned",
- "device_id",
- "app_service",
- "authenticated_entity",
- ],
- )
-):
+@attr.s(frozen=True, slots=True)
+class Requester:
"""
Represents the user making a request
Attributes:
- user (UserID): id of the user making the request
- access_token_id (int|None): *ID* of the access token used for this
+ user: id of the user making the request
+ access_token_id: *ID* of the access token used for this
request, or None if it came via the appservice API or similar
- is_guest (bool): True if the user making this request is a guest user
- shadow_banned (bool): True if the user making this request has been shadow-banned.
- device_id (str|None): device_id which was set at authentication time
- app_service (ApplicationService|None): the AS requesting on behalf of the user
+ is_guest: True if the user making this request is a guest user
+ shadow_banned: True if the user making this request has been shadow-banned.
+ device_id: device_id which was set at authentication time
+ app_service: the AS requesting on behalf of the user
+ authenticated_entity: The entity that authenticated when making the request.
+ This is different to the user_id when an admin user or the server is
+ "puppeting" the user.
"""
+ user = attr.ib(type="UserID")
+ access_token_id = attr.ib(type=Optional[int])
+ is_guest = attr.ib(type=bool)
+ shadow_banned = attr.ib(type=bool)
+ device_id = attr.ib(type=Optional[str])
+ app_service = attr.ib(type=Optional["ApplicationService"])
+ authenticated_entity = attr.ib(type=str)
+
def serialize(self):
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`
@@ -157,23 +156,23 @@ class Requester(
def create_requester(
user_id: Union[str, "UserID"],
access_token_id: Optional[int] = None,
- is_guest: Optional[bool] = False,
- shadow_banned: Optional[bool] = False,
+ is_guest: bool = False,
+ shadow_banned: bool = False,
device_id: Optional[str] = None,
app_service: Optional["ApplicationService"] = None,
authenticated_entity: Optional[str] = None,
-):
+) -> Requester:
"""
Create a new ``Requester`` object
Args:
- user_id (str|UserID): id of the user making the request
- access_token_id (int|None): *ID* of the access token used for this
+ user_id: id of the user making the request
+ access_token_id: *ID* of the access token used for this
request, or None if it came via the appservice API or similar
- is_guest (bool): True if the user making this request is a guest user
- shadow_banned (bool): True if the user making this request is shadow-banned.
- device_id (str|None): device_id which was set at authentication time
- app_service (ApplicationService|None): the AS requesting on behalf of the user
+ is_guest: True if the user making this request is a guest user
+ shadow_banned: True if the user making this request is shadow-banned.
+ device_id: device_id which was set at authentication time
+ app_service: the AS requesting on behalf of the user
authenticated_entity: The entity that authenticated when making the request.
This is different to the user_id when an admin user or the server is
"puppeting" the user.
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 1a3ccb263d..6f96cd7940 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -7,6 +7,7 @@ from synapse.federation.sender import PerDestinationQueue, TransactionManager
from synapse.federation.units import Edu
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
+from synapse.util.retryutils import NotRetryingDestination
from tests.test_utils import event_injection, make_awaitable
from tests.unittest import FederatingHomeserverTestCase, override_config
@@ -49,7 +50,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
else:
data = json_cb()
self.failed_pdus.extend(data["pdus"])
- raise IOError("Failed to connect because this is a test!")
+ raise NotRetryingDestination(0, 24 * 60 * 60 * 1000, txn.destination)
def get_destination_room(self, room: str, destination: str = "host2") -> dict:
"""
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index bdf3d0a8a2..94b6903594 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -517,6 +517,37 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertTrue(requester.shadow_banned)
+ def test_spam_checker_receives_sso_type(self):
+ """Test rejecting registration based on SSO type"""
+
+ class BanBadIdPUser:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info, auth_provider_id=None
+ ):
+ # Reject any user coming from CAS and whose username contains profanity
+ if auth_provider_id == "cas" and "flimflob" in username:
+ return RegistrationBehaviour.DENY
+ return RegistrationBehaviour.ALLOW
+
+ # Configure a spam checker that denies a certain user on a specific IdP
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [BanBadIdPUser()]
+
+ f = self.get_failure(
+ self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
+ SynapseError,
+ )
+ exception = f.value
+
+ # We return 429 from the spam checker for denied registrations
+ self.assertIsInstance(exception, SynapseError)
+ self.assertEqual(exception.code, 429)
+
+ # Check the same username can register using SAML
+ self.get_success(
+ self.handler.register_user(localpart="bobflimflob", auth_provider_id="saml")
+ )
+
async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None
):
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index 21ecb81c99..0ce181a51e 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -16,12 +16,23 @@ from io import BytesIO
from mock import Mock
+from netaddr import IPSet
+
+from twisted.internet.error import DNSLookupError
from twisted.python.failure import Failure
-from twisted.web.client import ResponseDone
+from twisted.test.proto_helpers import AccumulatingProtocol
+from twisted.web.client import Agent, ResponseDone
from twisted.web.iweb import UNKNOWN_LENGTH
-from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
+from synapse.api.errors import SynapseError
+from synapse.http.client import (
+ BlacklistingAgentWrapper,
+ BlacklistingReactorWrapper,
+ BodyExceededMaxSize,
+ read_body_with_max_size,
+)
+from tests.server import FakeTransport, get_clock
from tests.unittest import TestCase
@@ -119,3 +130,114 @@ class ReadBodyWithMaxSizeTests(TestCase):
# The data is never consumed.
self.assertEqual(result.getvalue(), b"")
+
+
+class BlacklistingAgentTest(TestCase):
+ def setUp(self):
+ self.reactor, self.clock = get_clock()
+
+ self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
+ self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
+ self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
+
+ # Configure the reactor's DNS resolver.
+ for (domain, ip) in (
+ (self.safe_domain, self.safe_ip),
+ (self.unsafe_domain, self.unsafe_ip),
+ (self.allowed_domain, self.allowed_ip),
+ ):
+ self.reactor.lookups[domain.decode()] = ip.decode()
+ self.reactor.lookups[ip.decode()] = ip.decode()
+
+ self.ip_whitelist = IPSet([self.allowed_ip.decode()])
+ self.ip_blacklist = IPSet(["5.0.0.0/8"])
+
+ def test_reactor(self):
+ """Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
+ agent = Agent(
+ BlacklistingReactorWrapper(
+ self.reactor,
+ ip_whitelist=self.ip_whitelist,
+ ip_blacklist=self.ip_blacklist,
+ ),
+ )
+
+ # The unsafe domains and IPs should be rejected.
+ for domain in (self.unsafe_domain, self.unsafe_ip):
+ self.failureResultOf(
+ agent.request(b"GET", b"http://" + domain), DNSLookupError
+ )
+
+ # The safe domains IPs should be accepted.
+ for domain in (
+ self.safe_domain,
+ self.allowed_domain,
+ self.safe_ip,
+ self.allowed_ip,
+ ):
+ d = agent.request(b"GET", b"http://" + domain)
+
+ # Grab the latest TCP connection.
+ (
+ host,
+ port,
+ client_factory,
+ _timeout,
+ _bindAddress,
+ ) = self.reactor.tcpClients[-1]
+
+ # Make the connection and pump data through it.
+ client = client_factory.buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
+ )
+
+ response = self.successResultOf(d)
+ self.assertEqual(response.code, 200)
+
+ def test_agent(self):
+ """Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
+ agent = BlacklistingAgentWrapper(
+ Agent(self.reactor),
+ ip_whitelist=self.ip_whitelist,
+ ip_blacklist=self.ip_blacklist,
+ )
+
+ # The unsafe IPs should be rejected.
+ self.failureResultOf(
+ agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError
+ )
+
+ # The safe and unsafe domains and safe IPs should be accepted.
+ for domain in (
+ self.safe_domain,
+ self.unsafe_domain,
+ self.allowed_domain,
+ self.safe_ip,
+ self.allowed_ip,
+ ):
+ d = agent.request(b"GET", b"http://" + domain)
+
+ # Grab the latest TCP connection.
+ (
+ host,
+ port,
+ client_factory,
+ _timeout,
+ _bindAddress,
+ ) = self.reactor.tcpClients[-1]
+
+ # Make the connection and pump data through it.
+ client = client_factory.buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
+ )
+
+ response = self.successResultOf(d)
+ self.assertEqual(response.code, 200)
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 20940c8107..67b7913666 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, Dict, List, Optional, Tuple
-
-import attr
+from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
@@ -158,10 +156,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
- request_factory = OneShotRequestFactory()
-
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor, request_factory, self.site)
+ channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -183,7 +179,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
server_to_client_transport.loseConnection()
client_to_server_transport.loseConnection()
- return request_factory.request
+ return channel.request
def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
@@ -237,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
if self.hs.config.redis.redis_enabled:
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
- "localhost",
+ b"localhost",
6379,
self.connect_any_redis_attempts,
)
@@ -392,10 +388,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
- request_factory = OneShotRequestFactory()
-
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs])
+ channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -421,7 +415,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
clients = self.reactor.tcpClients
while clients:
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
- self.assertEqual(host, "localhost")
+ self.assertEqual(host, b"localhost")
self.assertEqual(port, 6379)
client_protocol = client_factory.buildProtocol(None)
@@ -453,21 +447,6 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler):
self.received_rdata_rows.append((stream_name, token, r))
-@attr.s()
-class OneShotRequestFactory:
- """A simple request factory that generates a single `SynapseRequest` and
- stores it for future use. Can only be used once.
- """
-
- request = attr.ib(default=None)
-
- def __call__(self, *args, **kwargs):
- assert self.request is None
-
- self.request = SynapseRequest(*args, **kwargs)
- return self.request
-
-
class _PushHTTPChannel(HTTPChannel):
"""A HTTPChannel that wraps pull producers to push producers.
@@ -479,7 +458,7 @@ class _PushHTTPChannel(HTTPChannel):
"""
def __init__(
- self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site
+ self, reactor: IReactorTime, request_factory: Type[Request], site: Site
):
super().__init__()
self.reactor = reactor
@@ -510,6 +489,11 @@ class _PushHTTPChannel(HTTPChannel):
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
return False
+ def requestDone(self, request):
+ # Store the request for inspection.
+ self.request = request
+ super().requestDone(request)
+
class _PullToPushProducer:
"""A push producer that wraps a pull producer."""
@@ -597,6 +581,8 @@ class FakeRedisPubSubServer:
class FakeRedisPubSubProtocol(Protocol):
"""A connection from a client talking to the fake Redis server."""
+ transport = None # type: Optional[FakeTransport]
+
def __init__(self, server: FakeRedisPubSubServer):
self._server = server
self._reader = hiredis.Reader()
@@ -641,6 +627,8 @@ class FakeRedisPubSubProtocol(Protocol):
def send(self, msg):
"""Send a message back to the client."""
+ assert self.transport is not None
+
raw = self.encode(msg).encode("utf-8")
self.transport.write(raw)
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index f235f1bd83..0d9e3bb11d 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -17,7 +17,7 @@ import mock
from synapse.app.generic_worker import GenericWorkerServer
from synapse.replication.tcp.commands import FederationAckCommand
-from synapse.replication.tcp.protocol import AbstractConnection
+from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams.federation import FederationStream
from tests.unittest import HomeserverTestCase
@@ -51,8 +51,10 @@ class FederationAckTestCase(HomeserverTestCase):
"""
rch = self.hs.get_tcp_replication()
- # wire up the ReplicationCommandHandler to a mock connection
- mock_connection = mock.Mock(spec=AbstractConnection)
+ # wire up the ReplicationCommandHandler to a mock connection, which needs
+ # to implement IReplicationConnection. (Note that Mock doesn't understand
+ # interfaces, but casing an interface to a list gives the attributes.)
+ mock_connection = mock.Mock(spec=list(IReplicationConnection))
rch.new_connection(mock_connection)
# tell it it received an RDATA row
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 20af3285bd..988821b16f 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -437,14 +437,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
- expected_flows = [
- {"type": "m.login.cas"},
- {"type": "m.login.sso"},
- {"type": "m.login.token"},
- {"type": "m.login.password"},
- ] + ADDITIONAL_LOGIN_FLOWS
+ expected_flow_types = [
+ "m.login.cas",
+ "m.login.sso",
+ "m.login.token",
+ "m.login.password",
+ ] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS]
- self.assertCountEqual(channel.json_body["flows"], expected_flows)
+ self.assertCountEqual(
+ [f["type"] for f in channel.json_body["flows"]], expected_flow_types
+ )
@override_config({"experimental_features": {"msc2858_enabled": True}})
def test_get_msc2858_login_flows(self):
@@ -636,22 +638,25 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 400, channel.result)
- def test_client_idp_redirect_msc2858_disabled(self):
- """If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
- channel = self._make_sso_redirect_request(True, "oidc")
- self.assertEqual(channel.code, 400, channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
-
- @override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_idp_redirect_to_unknown(self):
"""If the client tries to pick an unknown IdP, return a 404"""
- channel = self._make_sso_redirect_request(True, "xxx")
+ channel = self._make_sso_redirect_request(False, "xxx")
self.assertEqual(channel.code, 404, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
- @override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_idp_redirect_to_oidc(self):
"""If the client pick a known IdP, redirect to it"""
+ channel = self._make_sso_redirect_request(False, "oidc")
+ self.assertEqual(channel.code, 302, channel.result)
+ oidc_uri = channel.headers.getRawHeaders("Location")[0]
+ oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
+
+ # it should redirect us to the auth page of the OIDC server
+ self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+
+ @override_config({"experimental_features": {"msc2858_enabled": True}})
+ def test_client_msc2858_redirect_to_oidc(self):
+ """Test the unstable API"""
channel = self._make_sso_redirect_request(True, "oidc")
self.assertEqual(channel.code, 302, channel.result)
oidc_uri = channel.headers.getRawHeaders("Location")[0]
@@ -660,6 +665,12 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+ def test_client_idp_redirect_msc2858_disabled(self):
+ """If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400"""
+ channel = self._make_sso_redirect_request(True, "oidc")
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+
def _make_sso_redirect_request(
self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None
):
diff --git a/tests/server.py b/tests/server.py
index 863f6da738..2287d20076 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -16,6 +16,7 @@ from twisted.internet.interfaces import (
IReactorPluggableNameResolver,
IReactorTCP,
IResolverSimple,
+ ITransport,
)
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
@@ -467,6 +468,7 @@ def get_clock():
return clock, hs_clock
+@implementer(ITransport)
@attr.s(cmp=False)
class FakeTransport:
"""
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 06000f81a6..d597d712d6 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -118,8 +118,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
self.assertTrue(r == [room2] or r == [room3])
- @parameterized.expand([(True,), (False,)])
- def test_auth_difference(self, use_chain_cover_index: bool):
+ def _setup_auth_chain(self, use_chain_cover_index: bool) -> str:
room_id = "@ROOM:local"
# The silly auth graph we use to test the auth difference algorithm,
@@ -165,7 +164,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
"j": 1,
}
- # Mark the room as not having a cover index
+ # Mark the room as maybe having a cover index.
def store_room(txn):
self.store.db_pool.simple_insert_txn(
@@ -222,6 +221,77 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
)
+ return room_id
+
+ @parameterized.expand([(True,), (False,)])
+ def test_auth_chain_ids(self, use_chain_cover_index: bool):
+ room_id = self._setup_auth_chain(use_chain_cover_index)
+
+ # a and b have the same auth chain.
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["a"]))
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["b"]))
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["a", "b"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["c"]))
+ self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
+
+ # d and e have the same auth chain.
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["d"]))
+ self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["e"]))
+ self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["f"]))
+ self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["g"]))
+ self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"]))
+ self.assertEqual(auth_chain_ids, ["k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"]))
+ self.assertEqual(auth_chain_ids, ["j"])
+
+ # j and k have no parents.
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"]))
+ self.assertEqual(auth_chain_ids, [])
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"]))
+ self.assertEqual(auth_chain_ids, [])
+
+ # More complex input sequences.
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["b", "c", "d"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["h", "i"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["k", "j"])
+
+ # e gets returned even though include_given is false, but it is in the
+ # auth chain of b.
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["b", "e"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+ # Test include_given.
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["i"], include_given=True)
+ )
+ self.assertCountEqual(auth_chain_ids, ["i", "j"])
+
+ @parameterized.expand([(True,), (False,)])
+ def test_auth_difference(self, use_chain_cover_index: bool):
+ room_id = self._setup_auth_chain(use_chain_cover_index)
+
# Now actually test that various combinations give the right result:
difference = self.get_success(
diff --git a/tox.ini b/tox.ini
index a6d10537ae..9ff70fe312 100644
--- a/tox.ini
+++ b/tox.ini
@@ -189,7 +189,5 @@ commands=
[testenv:mypy]
deps =
{[base]deps}
- # Type hints are broken with Twisted > 20.3.0, see https://github.com/matrix-org/synapse/issues/9513
- twisted==20.3.0
extras = all,mypy
commands = mypy
|