diff --git a/CHANGES.md b/CHANGES.md
index 4a356442e3..9a30a2e901 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,13 @@
+Synapse 1.15.1 (2020-06-16)
+===========================
+
+Bugfixes
+--------
+
+- Fix a bug introduced in v1.15.0 that would crash Synapse on start when using certain password auth providers. ([\#7684](https://github.com/matrix-org/synapse/issues/7684))
+- Fix a bug introduced in v1.15.0 which meant that some 3PID management endpoints were not accessible on the correct URL. ([\#7685](https://github.com/matrix-org/synapse/issues/7685))
+
+
Synapse 1.15.0 (2020-06-11)
===========================
diff --git a/changelog.d/7606.bugfix b/changelog.d/7606.bugfix
new file mode 100644
index 0000000000..6c488c3996
--- /dev/null
+++ b/changelog.d/7606.bugfix
@@ -0,0 +1 @@
+Remove `user_id` from the response to `GET /_matrix/client/r0/presence/{userId}/status` to match the specification.
diff --git a/changelog.d/7639.feature b/changelog.d/7639.feature
new file mode 100644
index 0000000000..ca80fc0ca2
--- /dev/null
+++ b/changelog.d/7639.feature
@@ -0,0 +1 @@
+Add an option to enable encryption by default for new rooms.
diff --git a/changelog.d/7648.bugfix b/changelog.d/7648.bugfix
new file mode 100644
index 0000000000..ff2417bfb6
--- /dev/null
+++ b/changelog.d/7648.bugfix
@@ -0,0 +1 @@
+In working mode, ensure that replicated data has not already been received.
diff --git a/changelog.d/7652.doc b/changelog.d/7652.doc
new file mode 100644
index 0000000000..c3ebbeb262
--- /dev/null
+++ b/changelog.d/7652.doc
@@ -0,0 +1 @@
+Spelling correction in sample_config.yaml.
diff --git a/changelog.d/7657.misc b/changelog.d/7657.misc
new file mode 100644
index 0000000000..384c05d2cd
--- /dev/null
+++ b/changelog.d/7657.misc
@@ -0,0 +1 @@
+Clean-up the login fallback code.
diff --git a/changelog.d/7663.bugfix b/changelog.d/7663.bugfix
new file mode 100644
index 0000000000..b58316b34f
--- /dev/null
+++ b/changelog.d/7663.bugfix
@@ -0,0 +1 @@
+Fix intermittent exception during startup, introduced in Synapse 1.14.0.
diff --git a/changelog.d/7664.misc b/changelog.d/7664.misc
new file mode 100644
index 0000000000..bbf0908109
--- /dev/null
+++ b/changelog.d/7664.misc
@@ -0,0 +1 @@
+Increase the default SAML session expirary time to 15 minutes.
diff --git a/changelog.d/7681.misc b/changelog.d/7681.misc
new file mode 100644
index 0000000000..e474fc39cd
--- /dev/null
+++ b/changelog.d/7681.misc
@@ -0,0 +1 @@
+Refactor handling of `listeners` configuration settings.
diff --git a/changelog.d/7687.bugfix b/changelog.d/7687.bugfix
new file mode 100644
index 0000000000..0413aff0b3
--- /dev/null
+++ b/changelog.d/7687.bugfix
@@ -0,0 +1 @@
+Accept the proper field (`phone`) for the `m.id.phone` identifier type. The legacy field of `number` is still accepted as a fallback. Bug introduced in v0.20.0-rc1.
diff --git a/changelog.d/7688.bugfix b/changelog.d/7688.bugfix
new file mode 100644
index 0000000000..2ed718a087
--- /dev/null
+++ b/changelog.d/7688.bugfix
@@ -0,0 +1 @@
+Fix "Starting db txn 'get_completed_ui_auth_stages' from sentinel context" warning. The bug was introduced in 1.13.0rc1.
diff --git a/changelog.d/7689.bugfix b/changelog.d/7689.bugfix
new file mode 100644
index 0000000000..10ad4c8c13
--- /dev/null
+++ b/changelog.d/7689.bugfix
@@ -0,0 +1 @@
+Compare the URI and method during user interactive authentication (instead of the URI twice). Bug introduced in 1.13.0rc1.
diff --git a/changelog.d/7691.bugfix b/changelog.d/7691.bugfix
new file mode 100644
index 0000000000..2a8a480c53
--- /dev/null
+++ b/changelog.d/7691.bugfix
@@ -0,0 +1 @@
+Fix a long standing bug where the response to the `GET room_keys/version` endpoint had the incorrect type for the `etag` field.
diff --git a/changelog.d/7692.misc b/changelog.d/7692.misc
new file mode 100644
index 0000000000..ef6cbe0005
--- /dev/null
+++ b/changelog.d/7692.misc
@@ -0,0 +1 @@
+Replace uses of `six.iterkeys`/`iteritems`/`itervalues` with `keys()`/`items()`/`values()`.
diff --git a/changelog.d/7698.bugfix b/changelog.d/7698.bugfix
new file mode 100644
index 0000000000..32de7459eb
--- /dev/null
+++ b/changelog.d/7698.bugfix
@@ -0,0 +1 @@
+Fix logged error during device resync in opentracing. Broke in v1.14.0.
diff --git a/changelog.d/7701.bugfix b/changelog.d/7701.bugfix
new file mode 100644
index 0000000000..e5b10f75fd
--- /dev/null
+++ b/changelog.d/7701.bugfix
@@ -0,0 +1 @@
+Do not break push rule evaluation when receiving an event with a non-string body. This is a long-standing bug.
diff --git a/debian/changelog b/debian/changelog
index 182a50ee37..3e83e9be9a 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,9 @@
+matrix-synapse-py3 (1.15.1) stable; urgency=medium
+
+ * New synapse release 1.15.1.
+
+ -- Synapse Packaging team <packages@matrix.org> Tue, 16 Jun 2020 10:27:50 +0100
+
matrix-synapse-py3 (1.15.0) stable; urgency=medium
* New synapse release 1.15.0.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 847926c146..e698166283 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -283,7 +283,7 @@ listeners:
# number of monthly active users.
#
# 'limit_usage_by_mau' disables/enables monthly active user blocking. When
-# anabled and a limit is reached the server returns a 'ResourceLimitError'
+# enabled and a limit is reached the server returns a 'ResourceLimitError'
# with error type Codes.RESOURCE_LIMIT_EXCEEDED
#
# 'max_mau_value' is the hard limit of monthly active users above which
@@ -1629,7 +1629,7 @@ saml2_config:
# The lifetime of a SAML session. This defines how long a user has to
# complete the authentication process, if allow_unsolicited is unset.
- # The default is 5 minutes.
+ # The default is 15 minutes.
#
#saml_session_lifetime: 5m
@@ -2148,6 +2148,26 @@ spam_checker:
# example_stop_events_from: ['@bad:example.com']
+## Rooms ##
+
+# Controls whether locally-created rooms should be end-to-end encrypted by
+# default.
+#
+# Possible options are "all", "invite", and "off". They are defined as:
+#
+# * "all": any locally-created room
+# * "invite": any room created with the "private_chat" or "trusted_private_chat"
+# room creation presets
+# * "off": this option will take no effect
+#
+# The default value is "off".
+#
+# Note that this option will only affect rooms created after it is set. It
+# will also not affect rooms created by other servers.
+#
+#encryption_enabled_by_default_for_room_type: invite
+
+
# Uncomment to allow non-server-admin users to create groups on this server
#
#enable_group_creation: true
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 1d9d85a727..4d39996a2e 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -36,7 +36,7 @@ try:
except ImportError:
pass
-__version__ = "1.15.0"
+__version__ = "1.15.1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index d9e7736b8d..e615533ea3 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -16,8 +16,6 @@
import logging
from typing import Optional
-from six import itervalues
-
import pymacaroons
from netaddr import IPAddress
@@ -90,7 +88,7 @@ class Auth(object):
event, prev_state_ids, for_verification=True
)
auth_events = yield self.store.get_events(auth_events_ids)
- auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)}
+ auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
event_auth.check(
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 5ec4a77ccd..6a6d32c302 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -150,3 +150,8 @@ class EventContentFields(object):
# Timestamp to delete the event after
# cf https://github.com/matrix-org/matrix-doc/pull/2228
SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after"
+
+
+class RoomEncryptionAlgorithms(object):
+ MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2"
+ DEFAULT = MEGOLM_V1_AES_SHA2
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 2c34815f55..367330fd91 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -20,7 +20,6 @@
import logging
from typing import Dict, List
-from six import iteritems
from six.moves import http_client
from canonicaljson import json
@@ -505,7 +504,7 @@ def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
A dict representing the error response JSON.
"""
err = {"error": msg, "errcode": code}
- for key, value in iteritems(kwargs):
+ for key, value in kwargs.items():
err[key] = value
return err
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index dedff81af3..373a80a4a7 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -20,6 +20,7 @@ import signal
import socket
import sys
import traceback
+from typing import Iterable
from daemonize import Daemonize
from typing_extensions import NoReturn
@@ -29,6 +30,7 @@ from twisted.protocols.tls import TLSMemoryBIOFactory
import synapse
from synapse.app import check_bind_error
+from synapse.config.server import ListenerConfig
from synapse.crypto import context_factory
from synapse.logging.context import PreserveLoggingContext
from synapse.util.async_helpers import Linearizer
@@ -234,7 +236,7 @@ def refresh_certificate(hs):
logger.info("Context factories updated.")
-def start(hs, listeners=None):
+def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
"""
Start a Synapse server or worker.
@@ -245,8 +247,8 @@ def start(hs, listeners=None):
notify systemd.
Args:
- hs (synapse.server.HomeServer)
- listeners (list[dict]): Listener configuration ('listeners' in homeserver.yaml)
+ hs: homeserver instance
+ listeners: Listener configuration ('listeners' in homeserver.yaml)
"""
try:
# Set up the SIGHUP machinery.
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index f3ec2a34ec..27a3fc9ed6 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -37,6 +37,7 @@ from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
+from synapse.config.server import ListenerConfig
from synapse.federation import send_queue
from synapse.federation.transport.server import TransportLayerServer
from synapse.handlers.presence import (
@@ -514,13 +515,18 @@ class GenericWorkerSlavedStore(
class GenericWorkerServer(HomeServer):
DATASTORE_CLASS = GenericWorkerSlavedStore
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
+ def _listen_http(self, listener_config: ListenerConfig):
+ port = listener_config.port
+ bind_addresses = listener_config.bind_addresses
+
+ assert listener_config.http_options is not None
+
+ site_tag = listener_config.http_options.tag
+ if site_tag is None:
+ site_tag = port
resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
+ for res in listener_config.http_options.resources:
+ for name in res.names:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
elif name == "client":
@@ -590,7 +596,7 @@ class GenericWorkerServer(HomeServer):
" repository is disabled. Ignoring."
)
- if name == "openid" and "federation" not in res["names"]:
+ if name == "openid" and "federation" not in res.names:
# Only load the openid resource separately if federation resource
# is not specified since federation resource includes openid
# resource.
@@ -625,19 +631,19 @@ class GenericWorkerServer(HomeServer):
logger.info("Synapse worker now listening on port %d", port)
- def start_listening(self, listeners):
+ def start_listening(self, listeners: Iterable[ListenerConfig]):
for listener in listeners:
- if listener["type"] == "http":
+ if listener.type == "http":
self._listen_http(listener)
- elif listener["type"] == "manhole":
+ elif listener.type == "manhole":
_base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
+ listener.bind_addresses,
+ listener.port,
manhole(
username="matrix", password="rabbithole", globals={"hs": self}
),
)
- elif listener["type"] == "metrics":
+ elif listener.type == "metrics":
if not self.get_config().enable_metrics:
logger.warning(
(
@@ -646,9 +652,9 @@ class GenericWorkerServer(HomeServer):
)
)
else:
- _base.listen_metrics(listener["bind_addresses"], listener["port"])
+ _base.listen_metrics(listener.bind_addresses, listener.port)
else:
- logger.warning("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unsupported listener type: %s", listener.type)
self.get_tcp_replication().start_replication(self)
@@ -738,6 +744,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
except Exception:
logger.exception("Error processing replication")
+ async def on_position(self, stream_name: str, instance_name: str, token: int):
+ await super().on_position(stream_name, instance_name, token)
+ # Also call on_rdata to ensure that stream positions are properly reset.
+ await self.on_rdata(stream_name, instance_name, token, [])
+
def stop_pusher(self, user_id, app_id, pushkey):
if not self.notify_pushers:
return
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 8454d74858..299134d00f 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -23,8 +23,7 @@ import math
import os
import resource
import sys
-
-from six import iteritems
+from typing import Iterable
from prometheus_client import Gauge
@@ -50,6 +49,7 @@ from synapse.app import _base
from synapse.app._base import listen_ssl, listen_tcp, quit_with_error
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.server import ListenerConfig
from synapse.federation.transport.server import TransportLayerServer
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import (
@@ -89,24 +89,24 @@ def gz_wrap(r):
class SynapseHomeServer(HomeServer):
DATASTORE_CLASS = DataStore
- def _listener_http(self, config, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- tls = listener_config.get("tls", False)
- site_tag = listener_config.get("tag", port)
+ def _listener_http(self, config: HomeServerConfig, listener_config: ListenerConfig):
+ port = listener_config.port
+ bind_addresses = listener_config.bind_addresses
+ tls = listener_config.tls
+ site_tag = listener_config.http_options.tag
+ if site_tag is None:
+ site_tag = port
resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "openid" and "federation" in res["names"]:
+ for res in listener_config.http_options.resources:
+ for name in res.names:
+ if name == "openid" and "federation" in res.names:
# Skip loading openid resource if federation is defined
# since federation resource will include openid
continue
- resources.update(
- self._configure_named_resource(name, res.get("compress", False))
- )
+ resources.update(self._configure_named_resource(name, res.compress))
- additional_resources = listener_config.get("additional_resources", {})
+ additional_resources = listener_config.http_options.additional_resources
logger.debug("Configuring additional resources: %r", additional_resources)
module_api = ModuleApi(self, self.get_auth_handler())
for path, resmodule in additional_resources.items():
@@ -278,7 +278,7 @@ class SynapseHomeServer(HomeServer):
return resources
- def start_listening(self, listeners):
+ def start_listening(self, listeners: Iterable[ListenerConfig]):
config = self.get_config()
if config.redis_enabled:
@@ -288,25 +288,25 @@ class SynapseHomeServer(HomeServer):
self.get_tcp_replication().start_replication(self)
for listener in listeners:
- if listener["type"] == "http":
+ if listener.type == "http":
self._listening_services.extend(self._listener_http(config, listener))
- elif listener["type"] == "manhole":
+ elif listener.type == "manhole":
listen_tcp(
- listener["bind_addresses"],
- listener["port"],
+ listener.bind_addresses,
+ listener.port,
manhole(
username="matrix", password="rabbithole", globals={"hs": self}
),
)
- elif listener["type"] == "replication":
+ elif listener.type == "replication":
services = listen_tcp(
- listener["bind_addresses"],
- listener["port"],
+ listener.bind_addresses,
+ listener.port,
ReplicationStreamProtocolFactory(self),
)
for s in services:
reactor.addSystemEventTrigger("before", "shutdown", s.stopListening)
- elif listener["type"] == "metrics":
+ elif listener.type == "metrics":
if not self.get_config().enable_metrics:
logger.warning(
(
@@ -315,9 +315,11 @@ class SynapseHomeServer(HomeServer):
)
)
else:
- _base.listen_metrics(listener["bind_addresses"], listener["port"])
+ _base.listen_metrics(listener.bind_addresses, listener.port)
else:
- logger.warning("Unrecognized listener type: %s", listener["type"])
+ # this shouldn't happen, as the listener type should have been checked
+ # during parsing
+ logger.warning("Unrecognized listener type: %s", listener.type)
# Gauges to expose monthly active user control metrics
@@ -525,7 +527,7 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
stats["total_nonbridged_users"] = total_nonbridged_users
daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
- for name, count in iteritems(daily_user_type_results):
+ for name, count in daily_user_type_results.items():
stats["daily_user_type_" + name] = count
room_count = yield hs.get_datastore().get_room_count()
@@ -537,7 +539,7 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
r30_results = yield hs.get_datastore().count_r30_users()
- for name, count in iteritems(r30_results):
+ for name, count in r30_results.items():
stats["r30_users_" + name] = count
daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index 0672538796..aff5b21ab2 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -15,6 +15,7 @@
import os
import re
+import threading
from typing import Callable, Dict
from ._base import Config, ConfigError
@@ -25,6 +26,9 @@ _CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR"
# Map from canonicalised cache name to cache.
_CACHES = {}
+# a lock on the contents of _CACHES
+_CACHES_LOCK = threading.Lock()
+
_DEFAULT_FACTOR_SIZE = 0.5
_DEFAULT_EVENT_CACHE_SIZE = "10K"
@@ -66,7 +70,10 @@ def add_resizable_cache(cache_name: str, cache_resize_callback: Callable):
# Some caches have '*' in them which we strip out.
cache_name = _canonicalise_cache_name(cache_name)
- _CACHES[cache_name] = cache_resize_callback
+ # sometimes caches are initialised from background threads, so we need to make
+ # sure we don't conflict with another thread running a resize operation
+ with _CACHES_LOCK:
+ _CACHES[cache_name] = cache_resize_callback
# Ensure all loaded caches are sized appropriately
#
@@ -87,7 +94,8 @@ class CacheConfig(Config):
os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
)
properties.resize_all_caches_func = None
- _CACHES.clear()
+ with _CACHES_LOCK:
+ _CACHES.clear()
def generate_config_section(self, **kwargs):
return """\
@@ -193,6 +201,8 @@ class CacheConfig(Config):
For each cache, run the mapped callback function with either
a specific cache factor or the default, global one.
"""
- for cache_name, callback in _CACHES.items():
- new_factor = self.cache_factors.get(cache_name, self.global_factor)
- callback(new_factor)
+ # block other threads from modifying _CACHES while we iterate it.
+ with _CACHES_LOCK:
+ for cache_name, callback in _CACHES.items():
+ new_factor = self.cache_factors.get(cache_name, self.global_factor)
+ callback(new_factor)
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 2c7b3a699f..264c274c52 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -36,6 +36,7 @@ from .ratelimiting import RatelimitConfig
from .redis import RedisConfig
from .registration import RegistrationConfig
from .repository import ContentRepositoryConfig
+from .room import RoomConfig
from .room_directory import RoomDirectoryConfig
from .saml2_config import SAML2Config
from .server import ServerConfig
@@ -79,6 +80,7 @@ class HomeServerConfig(RootConfig):
PasswordAuthProviderConfig,
PushConfig,
SpamCheckerConfig,
+ RoomConfig,
GroupsConfig,
UserDirectoryConfig,
ConsentConfig,
diff --git a/synapse/config/room.py b/synapse/config/room.py
new file mode 100644
index 0000000000..6aa4de0672
--- /dev/null
+++ b/synapse/config/room.py
@@ -0,0 +1,80 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.api.constants import RoomCreationPreset
+
+from ._base import Config, ConfigError
+
+logger = logging.Logger(__name__)
+
+
+class RoomDefaultEncryptionTypes(object):
+ """Possible values for the encryption_enabled_by_default_for_room_type config option"""
+
+ ALL = "all"
+ INVITE = "invite"
+ OFF = "off"
+
+
+class RoomConfig(Config):
+ section = "room"
+
+ def read_config(self, config, **kwargs):
+ # Whether new, locally-created rooms should have encryption enabled
+ encryption_for_room_type = config.get(
+ "encryption_enabled_by_default_for_room_type",
+ RoomDefaultEncryptionTypes.OFF,
+ )
+ if encryption_for_room_type == RoomDefaultEncryptionTypes.ALL:
+ self.encryption_enabled_by_default_for_room_presets = [
+ RoomCreationPreset.PRIVATE_CHAT,
+ RoomCreationPreset.TRUSTED_PRIVATE_CHAT,
+ RoomCreationPreset.PUBLIC_CHAT,
+ ]
+ elif encryption_for_room_type == RoomDefaultEncryptionTypes.INVITE:
+ self.encryption_enabled_by_default_for_room_presets = [
+ RoomCreationPreset.PRIVATE_CHAT,
+ RoomCreationPreset.TRUSTED_PRIVATE_CHAT,
+ ]
+ elif encryption_for_room_type == RoomDefaultEncryptionTypes.OFF:
+ self.encryption_enabled_by_default_for_room_presets = []
+ else:
+ raise ConfigError(
+ "Invalid value for encryption_enabled_by_default_for_room_type"
+ )
+
+ def generate_config_section(self, **kwargs):
+ return """\
+ ## Rooms ##
+
+ # Controls whether locally-created rooms should be end-to-end encrypted by
+ # default.
+ #
+ # Possible options are "all", "invite", and "off". They are defined as:
+ #
+ # * "all": any locally-created room
+ # * "invite": any room created with the "private_chat" or "trusted_private_chat"
+ # room creation presets
+ # * "off": this option will take no effect
+ #
+ # The default value is "off".
+ #
+ # Note that this option will only affect rooms created after it is set. It
+ # will also not affect rooms created by other servers.
+ #
+ #encryption_enabled_by_default_for_room_type: invite
+ """
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index d0a19751e8..293643b2de 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -160,7 +160,7 @@ class SAML2Config(Config):
# session lifetime: in milliseconds
self.saml2_session_lifetime = self.parse_duration(
- saml2_config.get("saml_session_lifetime", "5m")
+ saml2_config.get("saml_session_lifetime", "15m")
)
template_dir = saml2_config.get("template_dir")
@@ -286,7 +286,7 @@ class SAML2Config(Config):
# The lifetime of a SAML session. This defines how long a user has to
# complete the authentication process, if allow_unsolicited is unset.
- # The default is 5 minutes.
+ # The default is 15 minutes.
#
#saml_session_lifetime: 5m
diff --git a/synapse/config/server.py b/synapse/config/server.py
index b96b68f685..43ab5d62d6 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -19,7 +19,7 @@ import logging
import os.path
import re
from textwrap import indent
-from typing import Dict, List, Optional
+from typing import Any, Dict, Iterable, List, Optional
import attr
import yaml
@@ -57,6 +57,64 @@ on how to configure the new listener.
--------------------------------------------------------------------------------"""
+KNOWN_LISTENER_TYPES = {
+ "http",
+ "metrics",
+ "manhole",
+ "replication",
+}
+
+KNOWN_RESOURCES = {
+ "client",
+ "consent",
+ "federation",
+ "keys",
+ "media",
+ "metrics",
+ "openid",
+ "replication",
+ "static",
+ "webclient",
+}
+
+
+@attr.s(frozen=True)
+class HttpResourceConfig:
+ names = attr.ib(
+ type=List[str],
+ factory=list,
+ validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)), # type: ignore
+ )
+ compress = attr.ib(
+ type=bool,
+ default=False,
+ validator=attr.validators.optional(attr.validators.instance_of(bool)), # type: ignore[arg-type]
+ )
+
+
+@attr.s(frozen=True)
+class HttpListenerConfig:
+ """Object describing the http-specific parts of the config of a listener"""
+
+ x_forwarded = attr.ib(type=bool, default=False)
+ resources = attr.ib(type=List[HttpResourceConfig], factory=list)
+ additional_resources = attr.ib(type=Dict[str, dict], factory=dict)
+ tag = attr.ib(type=str, default=None)
+
+
+@attr.s(frozen=True)
+class ListenerConfig:
+ """Object describing the configuration of a single listener."""
+
+ port = attr.ib(type=int, validator=attr.validators.instance_of(int))
+ bind_addresses = attr.ib(type=List[str])
+ type = attr.ib(type=str, validator=attr.validators.in_(KNOWN_LISTENER_TYPES))
+ tls = attr.ib(type=bool, default=False)
+
+ # http_options is only populated if type=http
+ http_options = attr.ib(type=Optional[HttpListenerConfig], default=None)
+
+
class ServerConfig(Config):
section = "server"
@@ -385,38 +443,21 @@ class ServerConfig(Config):
}
]
- self.listeners = [] # type: List[dict]
- for listener in config.get("listeners", []):
- if not isinstance(listener.get("port", None), int):
- raise ConfigError(
- "Listener configuration is lacking a valid 'port' option"
- )
+ self.listeners = [parse_listener_def(x) for x in config.get("listeners", [])]
- if listener.setdefault("tls", False):
- # no_tls is not really supported any more, but let's grandfather it in
- # here.
- if config.get("no_tls", False):
+ # no_tls is not really supported any more, but let's grandfather it in
+ # here.
+ if config.get("no_tls", False):
+ l2 = []
+ for listener in self.listeners:
+ if listener.tls:
logger.info(
- "Ignoring TLS-enabled listener on port %i due to no_tls"
+ "Ignoring TLS-enabled listener on port %i due to no_tls",
+ listener.port,
)
- continue
-
- bind_address = listener.pop("bind_address", None)
- bind_addresses = listener.setdefault("bind_addresses", [])
-
- # if bind_address was specified, add it to the list of addresses
- if bind_address:
- bind_addresses.append(bind_address)
-
- # if we still have an empty list of addresses, use the default list
- if not bind_addresses:
- if listener["type"] == "metrics":
- # the metrics listener doesn't support IPv6
- bind_addresses.append("0.0.0.0")
else:
- bind_addresses.extend(DEFAULT_BIND_ADDRESSES)
-
- self.listeners.append(listener)
+ l2.append(listener)
+ self.listeners = l2
if not self.web_client_location:
_warn_if_webclient_configured(self.listeners)
@@ -452,43 +493,41 @@ class ServerConfig(Config):
bind_host = config.get("bind_host", "")
gzip_responses = config.get("gzip_responses", True)
+ http_options = HttpListenerConfig(
+ resources=[
+ HttpResourceConfig(names=["client"], compress=gzip_responses),
+ HttpResourceConfig(names=["federation"]),
+ ],
+ )
+
self.listeners.append(
- {
- "port": bind_port,
- "bind_addresses": [bind_host],
- "tls": True,
- "type": "http",
- "resources": [
- {"names": ["client"], "compress": gzip_responses},
- {"names": ["federation"], "compress": False},
- ],
- }
+ ListenerConfig(
+ port=bind_port,
+ bind_addresses=[bind_host],
+ tls=True,
+ type="http",
+ http_options=http_options,
+ )
)
unsecure_port = config.get("unsecure_port", bind_port - 400)
if unsecure_port:
self.listeners.append(
- {
- "port": unsecure_port,
- "bind_addresses": [bind_host],
- "tls": False,
- "type": "http",
- "resources": [
- {"names": ["client"], "compress": gzip_responses},
- {"names": ["federation"], "compress": False},
- ],
- }
+ ListenerConfig(
+ port=unsecure_port,
+ bind_addresses=[bind_host],
+ tls=False,
+ type="http",
+ http_options=http_options,
+ )
)
manhole = config.get("manhole")
if manhole:
self.listeners.append(
- {
- "port": manhole,
- "bind_addresses": ["127.0.0.1"],
- "type": "manhole",
- "tls": False,
- }
+ ListenerConfig(
+ port=manhole, bind_addresses=["127.0.0.1"], type="manhole",
+ )
)
metrics_port = config.get("metrics_port")
@@ -496,13 +535,14 @@ class ServerConfig(Config):
logger.warning(METRICS_PORT_WARNING)
self.listeners.append(
- {
- "port": metrics_port,
- "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")],
- "tls": False,
- "type": "http",
- "resources": [{"names": ["metrics"], "compress": False}],
- }
+ ListenerConfig(
+ port=metrics_port,
+ bind_addresses=[config.get("metrics_bind_host", "127.0.0.1")],
+ type="http",
+ http_options=HttpListenerConfig(
+ resources=[HttpResourceConfig(names=["metrics"])]
+ ),
+ )
)
_check_resource_config(self.listeners)
@@ -528,7 +568,7 @@ class ServerConfig(Config):
)
def has_tls_listener(self) -> bool:
- return any(listener["tls"] for listener in self.listeners)
+ return any(listener.tls for listener in self.listeners)
def generate_config_section(
self, server_name, data_dir_path, open_private_ports, listeners, **kwargs
@@ -862,7 +902,7 @@ class ServerConfig(Config):
# number of monthly active users.
#
# 'limit_usage_by_mau' disables/enables monthly active user blocking. When
- # anabled and a limit is reached the server returns a 'ResourceLimitError'
+ # enabled and a limit is reached the server returns a 'ResourceLimitError'
# with error type Codes.RESOURCE_LIMIT_EXCEEDED
#
# 'max_mau_value' is the hard limit of monthly active users above which
@@ -1155,6 +1195,44 @@ def read_gc_thresholds(thresholds):
)
+def parse_listener_def(listener: Any) -> ListenerConfig:
+ """parse a listener config from the config file"""
+ listener_type = listener["type"]
+
+ port = listener.get("port")
+ if not isinstance(port, int):
+ raise ConfigError("Listener configuration is lacking a valid 'port' option")
+
+ tls = listener.get("tls", False)
+
+ bind_addresses = listener.get("bind_addresses", [])
+ bind_address = listener.get("bind_address")
+ # if bind_address was specified, add it to the list of addresses
+ if bind_address:
+ bind_addresses.append(bind_address)
+
+ # if we still have an empty list of addresses, use the default list
+ if not bind_addresses:
+ if listener_type == "metrics":
+ # the metrics listener doesn't support IPv6
+ bind_addresses.append("0.0.0.0")
+ else:
+ bind_addresses.extend(DEFAULT_BIND_ADDRESSES)
+
+ http_config = None
+ if listener_type == "http":
+ http_config = HttpListenerConfig(
+ x_forwarded=listener.get("x_forwarded", False),
+ resources=[
+ HttpResourceConfig(**res) for res in listener.get("resources", [])
+ ],
+ additional_resources=listener.get("additional_resources", {}),
+ tag=listener.get("tag"),
+ )
+
+ return ListenerConfig(port, bind_addresses, listener_type, tls, http_config)
+
+
NO_MORE_WEB_CLIENT_WARNING = """
Synapse no longer includes a web client. To enable a web client, configure
web_client_location. To remove this warning, remove 'webclient' from the 'listeners'
@@ -1162,40 +1240,27 @@ configuration.
"""
-def _warn_if_webclient_configured(listeners):
+def _warn_if_webclient_configured(listeners: Iterable[ListenerConfig]) -> None:
for listener in listeners:
- for res in listener.get("resources", []):
- for name in res.get("names", []):
+ if not listener.http_options:
+ continue
+ for res in listener.http_options.resources:
+ for name in res.names:
if name == "webclient":
logger.warning(NO_MORE_WEB_CLIENT_WARNING)
return
-KNOWN_RESOURCES = (
- "client",
- "consent",
- "federation",
- "keys",
- "media",
- "metrics",
- "openid",
- "replication",
- "static",
- "webclient",
-)
-
-
-def _check_resource_config(listeners):
+def _check_resource_config(listeners: Iterable[ListenerConfig]) -> None:
resource_names = {
res_name
for listener in listeners
- for res in listener.get("resources", [])
- for res_name in res.get("names", [])
+ if listener.http_options
+ for res in listener.http_options.resources
+ for res_name in res.names
}
for resource in resource_names:
- if resource not in KNOWN_RESOURCES:
- raise ConfigError("Unknown listener resource '%s'" % (resource,))
if resource == "consent":
try:
check_requirements("resources.consent")
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index ed06b91a54..dbc661630c 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -16,6 +16,7 @@
import attr
from ._base import Config, ConfigError
+from .server import ListenerConfig, parse_listener_def
@attr.s
@@ -52,7 +53,9 @@ class WorkerConfig(Config):
if self.worker_app == "synapse.app.homeserver":
self.worker_app = None
- self.worker_listeners = config.get("worker_listeners", [])
+ self.worker_listeners = [
+ parse_listener_def(x) for x in config.get("worker_listeners", [])
+ ]
self.worker_daemonize = config.get("worker_daemonize")
self.worker_pid_file = config.get("worker_pid_file")
self.worker_log_config = config.get("worker_log_config")
@@ -75,24 +78,11 @@ class WorkerConfig(Config):
manhole = config.get("worker_manhole")
if manhole:
self.worker_listeners.append(
- {
- "port": manhole,
- "bind_addresses": ["127.0.0.1"],
- "type": "manhole",
- "tls": False,
- }
+ ListenerConfig(
+ port=manhole, bind_addresses=["127.0.0.1"], type="manhole",
+ )
)
- if self.worker_listeners:
- for listener in self.worker_listeners:
- bind_address = listener.pop("bind_address", None)
- bind_addresses = listener.setdefault("bind_addresses", [])
-
- if bind_address:
- bind_addresses.append(bind_address)
- elif not bind_addresses:
- bind_addresses.append("")
-
# A map from instance name to host/port of their HTTP replication endpoint.
instance_map = config.get("instance_map") or {}
self.instance_map = {
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 533ba327f5..cc5deca75b 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -20,8 +20,6 @@ import os
from distutils.util import strtobool
from typing import Dict, Optional, Type
-import six
-
from unpaddedbase64 import encode_base64
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
@@ -290,7 +288,7 @@ class EventBase(metaclass=abc.ABCMeta):
return list(self._dict.items())
def keys(self):
- return six.iterkeys(self._dict)
+ return self._dict.keys()
def prev_event_ids(self):
"""Returns the list of prev event IDs. The order matches the order
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 7c5f620d09..f94cdcbaba 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -14,8 +14,6 @@
# limitations under the License.
from typing import Optional, Union
-from six import iteritems
-
import attr
from frozendict import frozendict
@@ -341,7 +339,7 @@ def _encode_state_dict(state_dict):
if state_dict is None:
return None
- return [(etype, state_key, v) for (etype, state_key), v in iteritems(state_dict)]
+ return [(etype, state_key, v) for (etype, state_key), v in state_dict.items()]
def _decode_state_dict(input):
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index c0012c6872..b2ab5bd6a4 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -93,8 +93,8 @@ class FederationBase(object):
# *actual* redacted copy to be on the safe side.)
redacted_event = prune_event(pdu)
if set(redacted_event.keys()) == set(pdu.keys()) and set(
- six.iterkeys(redacted_event.content)
- ) == set(six.iterkeys(pdu.content)):
+ redacted_event.content.keys()
+ ) == set(pdu.content.keys()):
logger.info(
"Event %s seems to have been redacted; using our redacted "
"copy",
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 32a8a2ee46..6920c23723 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -18,7 +18,6 @@ import logging
from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union
import six
-from six import iteritems
from canonicaljson import json
from prometheus_client import Counter
@@ -534,9 +533,9 @@ class FederationServer(FederationBase):
",".join(
(
"%s for %s:%s" % (key_id, user_id, device_id)
- for user_id, user_keys in iteritems(json_result)
- for device_id, device_keys in iteritems(user_keys)
- for key_id, _ in iteritems(device_keys)
+ for user_id, user_keys in json_result.items()
+ for device_id, device_keys in user_keys.items()
+ for key_id, _ in device_keys.items()
)
),
)
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 52f4f54215..6bbd762681 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -33,8 +33,6 @@ import logging
from collections import namedtuple
from typing import Dict, List, Tuple, Type
-from six import iteritems
-
from sortedcontainers import SortedDict
from twisted.internet import defer
@@ -327,7 +325,7 @@ class FederationRemoteSendQueue(object):
# stream position.
keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]}
- for ((destination, edu_key), pos) in iteritems(keyed_edus):
+ for ((destination, edu_key), pos) in keyed_edus.items():
rows.append(
(
pos,
@@ -530,10 +528,10 @@ def process_rows_for_federation(transaction_queue, rows):
states=[state], destinations=destinations
)
- for destination, edu_map in iteritems(buff.keyed_edus):
+ for destination, edu_map in buff.keyed_edus.items():
for key, edu in edu_map.items():
transaction_queue.send_edu(edu, key)
- for destination, edu_list in iteritems(buff.edus):
+ for destination, edu_list in buff.edus.items():
for edu in edu_list:
transaction_queue.send_edu(edu, None)
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index d473576902..5b8faea4e7 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -16,8 +16,6 @@
import logging
from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple
-from six import itervalues
-
from prometheus_client import Counter
from twisted.internet import defer
@@ -218,7 +216,7 @@ class FederationSender(object):
defer.gatherResults(
[
run_in_background(handle_room_events, evs)
- for evs in itervalues(events_by_room)
+ for evs in events_by_room.values()
],
consumeErrors=True,
)
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index fe62f78e67..ac1b64caff 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -15,8 +15,6 @@
import logging
-from six import itervalues
-
from prometheus_client import Counter
from twisted.internet import defer
@@ -125,7 +123,7 @@ class ApplicationServicesHandler(object):
defer.gatherResults(
[
run_in_background(handle_room_events, evs)
- for evs in itervalues(events_by_room)
+ for evs in events_by_room.values()
],
consumeErrors=True,
)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 119678e67b..b01124fe42 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -297,7 +297,7 @@ class AuthHandler(BaseHandler):
# Convert the URI and method to strings.
uri = request.uri.decode("utf-8")
- method = request.uri.decode("utf-8")
+ method = request.method.decode("utf-8")
# If there's no session ID, create a new session.
if not sid:
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 230d170258..31346b56c3 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -17,8 +17,6 @@
import logging
from typing import Any, Dict, Optional
-from six import iteritems, itervalues
-
from twisted.internet import defer
from synapse.api import errors
@@ -159,7 +157,7 @@ class DeviceWorkerHandler(BaseHandler):
# The user may have left the room
# TODO: Check if they actually did or if we were just invited.
if room_id not in room_ids:
- for key, event_id in iteritems(current_state_ids):
+ for key, event_id in current_state_ids.items():
etype, state_key = key
if etype != EventTypes.Member:
continue
@@ -182,7 +180,7 @@ class DeviceWorkerHandler(BaseHandler):
log_kv(
{"event": "encountered empty previous state", "room_id": room_id}
)
- for key, event_id in iteritems(current_state_ids):
+ for key, event_id in current_state_ids.items():
etype, state_key = key
if etype != EventTypes.Member:
continue
@@ -198,10 +196,10 @@ class DeviceWorkerHandler(BaseHandler):
# Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users.
- for state_dict in itervalues(prev_state_ids):
+ for state_dict in prev_state_ids.values():
member_event = state_dict.get((EventTypes.Member, user_id), None)
if not member_event or member_event != current_member_id:
- for key, event_id in iteritems(current_state_ids):
+ for key, event_id in current_state_ids.items():
etype, state_key = key
if etype != EventTypes.Member:
continue
@@ -211,14 +209,14 @@ class DeviceWorkerHandler(BaseHandler):
# If there has been any change in membership, include them in the
# possibly changed list. We'll check if they are joined below,
# and we're not toooo worried about spuriously adding users.
- for key, event_id in iteritems(current_state_ids):
+ for key, event_id in current_state_ids.items():
etype, state_key = key
if etype != EventTypes.Member:
continue
# check if this member has changed since any of the extremities
# at the stream_ordering, and add them to the list if so.
- for state_dict in itervalues(prev_state_ids):
+ for state_dict in prev_state_ids.values():
prev_event_id = state_dict.get(key, None)
if not prev_event_id or prev_event_id != event_id:
if state_key != user_id:
@@ -693,6 +691,7 @@ class DeviceListUpdater(object):
return False
+ @trace
@defer.inlineCallbacks
def _maybe_retry_device_resync(self):
"""Retry to resync device lists that are out of sync, except if another retry is
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 774a252619..a7e60cbc26 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -17,8 +17,6 @@
import logging
-from six import iteritems
-
import attr
from canonicaljson import encode_canonical_json, json
from signedjson.key import decode_verify_key_bytes
@@ -135,7 +133,7 @@ class E2eKeysHandler(object):
remote_queries_not_in_cache = {}
if remote_queries:
query_list = []
- for user_id, device_ids in iteritems(remote_queries):
+ for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
else:
@@ -145,9 +143,9 @@ class E2eKeysHandler(object):
user_ids_not_in_cache,
remote_results,
) = yield self.store.get_user_devices_from_cache(query_list)
- for user_id, devices in iteritems(remote_results):
+ for user_id, devices in remote_results.items():
user_devices = results.setdefault(user_id, {})
- for device_id, device in iteritems(devices):
+ for device_id, device in devices.items():
keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None)
if keys:
@@ -446,9 +444,9 @@ class E2eKeysHandler(object):
",".join(
(
"%s for %s:%s" % (key_id, user_id, device_id)
- for user_id, user_keys in iteritems(json_result)
- for device_id, device_keys in iteritems(user_keys)
- for key_id, _ in iteritems(device_keys)
+ for user_id, user_keys in json_result.items()
+ for device_id, device_keys in user_keys.items()
+ for key_id, _ in device_keys.items()
)
),
)
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 9abaf13b8f..f55470a707 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -16,8 +16,6 @@
import logging
-from six import iteritems
-
from twisted.internet import defer
from synapse.api.errors import (
@@ -205,8 +203,8 @@ class E2eRoomKeysHandler(object):
)
to_insert = [] # batch the inserts together
changed = False # if anything has changed, we need to update the etag
- for room_id, room in iteritems(room_keys["rooms"]):
- for session_id, room_key in iteritems(room["sessions"]):
+ for room_id, room in room_keys["rooms"].items():
+ for session_id, room_key in room["sessions"].items():
if not isinstance(room_key["is_verified"], bool):
msg = (
"is_verified must be a boolean in keys for session %s in"
@@ -351,6 +349,7 @@ class E2eRoomKeysHandler(object):
raise
res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"])
+ res["etag"] = str(res["etag"])
return res
@trace
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bbf23345e2..55ed640283 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -21,8 +21,6 @@ import itertools
import logging
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
-import six
-from six import iteritems, itervalues
from six.moves import http_client, zip
import attr
@@ -33,7 +31,12 @@ from unpaddedbase64 import decode_base64
from twisted.internet import defer
from synapse import event_auth
-from synapse.api.constants import EventTypes, Membership, RejectedReason
+from synapse.api.constants import (
+ EventTypes,
+ Membership,
+ RejectedReason,
+ RoomEncryptionAlgorithms,
+)
from synapse.api.errors import (
AuthError,
CodeMessageException,
@@ -396,7 +399,7 @@ class FederationHandler(BaseHandler):
)
event_map.update(evs)
- state = [event_map[e] for e in six.itervalues(state_map)]
+ state = [event_map[e] for e in state_map.values()]
except Exception:
logger.warning(
"[%s %s] Error attempting to resolve state at missing "
@@ -745,7 +748,10 @@ class FederationHandler(BaseHandler):
if device:
keys = device.get("keys", {}).get("keys", {})
- if event.content.get("algorithm") == "m.megolm.v1.aes-sha2":
+ if (
+ event.content.get("algorithm")
+ == RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2
+ ):
# For this algorithm we expect a curve25519 key.
key_name = "curve25519:%s" % (device_id,)
current_keys = [keys.get(key_name)]
@@ -1004,7 +1010,7 @@ class FederationHandler(BaseHandler):
"""
joined_users = [
(state_key, int(event.depth))
- for (e_type, state_key), event in iteritems(state)
+ for (e_type, state_key), event in state.items()
if e_type == EventTypes.Member and event.membership == Membership.JOIN
]
@@ -1094,16 +1100,16 @@ class FederationHandler(BaseHandler):
states = dict(zip(event_ids, [s.state for s in states]))
state_map = await self.store.get_events(
- [e_id for ids in itervalues(states) for e_id in itervalues(ids)],
+ [e_id for ids in states.values() for e_id in ids.values()],
get_prev_content=False,
)
states = {
key: {
k: state_map[e_id]
- for k, e_id in iteritems(state_dict)
+ for k, e_id in state_dict.items()
if e_id in state_map
}
- for key, state_dict in iteritems(states)
+ for key, state_dict in states.items()
}
for e_id, _ in sorted_extremeties_tuple:
@@ -1735,7 +1741,7 @@ class FederationHandler(BaseHandler):
state_groups = await self.state_store.get_state_groups(room_id, [event_id])
if state_groups:
- _, state = list(iteritems(state_groups)).pop()
+ _, state = list(state_groups.items()).pop()
results = {(e.type, e.state_key): e for e in state}
if event.is_state():
@@ -2098,7 +2104,7 @@ class FederationHandler(BaseHandler):
room_version, state_sets, event
)
current_state_ids = {
- k: e.event_id for k, e in iteritems(current_state_ids)
+ k: e.event_id for k, e in current_state_ids.items()
}
else:
current_state_ids = await self.state_handler.get_current_state_ids(
@@ -2114,7 +2120,7 @@ class FederationHandler(BaseHandler):
# Now check if event pass auth against said current state
auth_types = auth_types_for_event(event)
current_state_ids = [
- e for k, e in iteritems(current_state_ids) if k in auth_types
+ e for k, e in current_state_ids.items() if k in auth_types
]
current_auth_events = await self.store.get_events(current_state_ids)
@@ -2430,7 +2436,7 @@ class FederationHandler(BaseHandler):
else:
event_key = None
state_updates = {
- k: a.event_id for k, a in iteritems(auth_events) if k != event_key
+ k: a.event_id for k, a in auth_events.items() if k != event_key
}
current_state_ids = await context.get_current_state_ids()
@@ -2441,7 +2447,7 @@ class FederationHandler(BaseHandler):
prev_state_ids = await context.get_prev_state_ids()
prev_state_ids = dict(prev_state_ids)
- prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)})
+ prev_state_ids.update({k: a.event_id for k, a in auth_events.items()})
# create a new state group as a delta from the existing one.
prev_group = context.state_group
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index ebe8d25bd8..7cb106e365 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -16,8 +16,6 @@
import logging
-from six import iteritems
-
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.types import get_domain_from_id
@@ -227,7 +225,7 @@ class GroupsLocalWorkerHandler(object):
results = {}
failed_results = []
- for destination, dest_user_ids in iteritems(destinations):
+ for destination, dest_user_ids in destinations.items():
try:
r = await self.transport_client.bulk_get_publicised_groups(
destination, list(dest_user_ids)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 649ca1f08a..354da9a3b5 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -17,7 +17,7 @@
import logging
from typing import Optional, Tuple
-from six import iteritems, itervalues, string_types
+from six import string_types
from canonicaljson import encode_canonical_json, json
@@ -246,7 +246,7 @@ class MessageHandler(object):
"avatar_url": profile.avatar_url,
"display_name": profile.display_name,
}
- for user_id, profile in iteritems(users_with_profile)
+ for user_id, profile in users_with_profile.items()
}
def maybe_schedule_expiry(self, event):
@@ -988,7 +988,7 @@ class EventCreationHandler(object):
state_to_include_ids = [
e_id
- for k, e_id in iteritems(current_state_ids)
+ for k, e_id in current_state_ids.items()
if k[0] in self.room_invite_state_types
or k == (EventTypes.Member, event.sender)
]
@@ -1002,7 +1002,7 @@ class EventCreationHandler(object):
"content": e.content,
"sender": e.sender,
}
- for e in itervalues(state_to_include)
+ for e in state_to_include.values()
]
invitee = UserID.from_string(event.state_key)
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index d7442c62a7..7fbc229502 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -15,8 +15,6 @@
# limitations under the License.
import logging
-from six import iteritems
-
from twisted.internet import defer
from twisted.python.failure import Failure
@@ -145,7 +143,7 @@ class PaginationHandler(object):
logger.debug("[purge] Rooms to purge: %s", rooms)
- for room_id, retention_policy in iteritems(rooms):
+ for room_id, retention_policy in rooms.items():
logger.info("[purge] Attempting to purge messages in room %s", room_id)
if room_id in self._purges_in_progress_by_room:
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 3594f3b00f..2e8914be14 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -27,8 +27,6 @@ import logging
from contextlib import contextmanager
from typing import Dict, Iterable, List, Set
-from six import iteritems, itervalues
-
from prometheus_client import Counter
from typing_extensions import ContextManager
@@ -170,14 +168,14 @@ class BasePresenceHandler(abc.ABC):
for user_id in user_ids
}
- missing = [user_id for user_id, state in iteritems(states) if not state]
+ missing = [user_id for user_id, state in states.items() if not state]
if missing:
# There are things not in our in memory cache. Lets pull them out of
# the database.
res = await self.store.get_presence_for_users(missing)
states.update(res)
- missing = [user_id for user_id, state in iteritems(states) if not state]
+ missing = [user_id for user_id, state in states.items() if not state]
if missing:
new = {
user_id: UserPresenceState.default(user_id) for user_id in missing
@@ -632,7 +630,7 @@ class PresenceHandler(BasePresenceHandler):
await self._update_states(
[
prev_state.copy_and_replace(last_user_sync_ts=time_now_ms)
- for prev_state in itervalues(prev_states)
+ for prev_state in prev_states.values()
]
)
self.external_process_last_updated_ms.pop(process_id, None)
@@ -1087,7 +1085,7 @@ class PresenceEventSource(object):
return (list(updates.values()), max_token)
else:
return (
- [s for s in itervalues(updates) if s.state != PresenceState.OFFLINE],
+ [s for s in updates.values() if s.state != PresenceState.OFFLINE],
max_token,
)
@@ -1323,11 +1321,11 @@ def get_interested_remotes(store, states, state_handler):
# hosts in those rooms.
room_ids_to_states, users_to_states = yield get_interested_parties(store, states)
- for room_id, states in iteritems(room_ids_to_states):
+ for room_id, states in room_ids_to_states.items():
hosts = yield state_handler.get_current_hosts_in_room(room_id)
hosts_and_states.append((hosts, states))
- for user_id, states in iteritems(users_to_states):
+ for user_id, states in users_to_states.items():
host = get_domain_from_id(user_id)
hosts_and_states.append(([host], states))
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 6a340f9177..11798b9df0 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -24,9 +24,14 @@ import string
from collections import OrderedDict
from typing import Tuple
-from six import iteritems, string_types
+from six import string_types
-from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
+from synapse.api.constants import (
+ EventTypes,
+ JoinRules,
+ RoomCreationPreset,
+ RoomEncryptionAlgorithms,
+)
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events.utils import copy_power_levels_contents
@@ -56,33 +61,6 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000
class RoomCreationHandler(BaseHandler):
-
- PRESETS_DICT = {
- RoomCreationPreset.PRIVATE_CHAT: {
- "join_rules": JoinRules.INVITE,
- "history_visibility": "shared",
- "original_invitees_have_ops": False,
- "guest_can_join": True,
- "encryption_alg": "m.megolm.v1.aes-sha2",
- "power_level_content_override": {"invite": 0},
- },
- RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
- "join_rules": JoinRules.INVITE,
- "history_visibility": "shared",
- "original_invitees_have_ops": True,
- "guest_can_join": True,
- "encryption_alg": "m.megolm.v1.aes-sha2",
- "power_level_content_override": {"invite": 0},
- },
- RoomCreationPreset.PUBLIC_CHAT: {
- "join_rules": JoinRules.PUBLIC,
- "history_visibility": "shared",
- "original_invitees_have_ops": False,
- "guest_can_join": False,
- "power_level_content_override": {},
- },
- }
-
def __init__(self, hs):
super(RoomCreationHandler, self).__init__(hs)
@@ -91,6 +69,39 @@ class RoomCreationHandler(BaseHandler):
self.room_member_handler = hs.get_room_member_handler()
self.config = hs.config
+ # Room state based off defined presets
+ self._presets_dict = {
+ RoomCreationPreset.PRIVATE_CHAT: {
+ "join_rules": JoinRules.INVITE,
+ "history_visibility": "shared",
+ "original_invitees_have_ops": False,
+ "guest_can_join": True,
+ "power_level_content_override": {"invite": 0},
+ },
+ RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
+ "join_rules": JoinRules.INVITE,
+ "history_visibility": "shared",
+ "original_invitees_have_ops": True,
+ "guest_can_join": True,
+ "power_level_content_override": {"invite": 0},
+ },
+ RoomCreationPreset.PUBLIC_CHAT: {
+ "join_rules": JoinRules.PUBLIC,
+ "history_visibility": "shared",
+ "original_invitees_have_ops": False,
+ "guest_can_join": False,
+ "power_level_content_override": {},
+ },
+ }
+
+ # Modify presets to selectively enable encryption by default per homeserver config
+ for preset_name, preset_config in self._presets_dict.items():
+ encrypted = (
+ preset_name
+ in self.config.encryption_enabled_by_default_for_room_presets
+ )
+ preset_config["encrypted"] = encrypted
+
self._replication = hs.get_replication_data_handler()
# linearizer to stop two upgrades happening at once
@@ -378,7 +389,7 @@ class RoomCreationHandler(BaseHandler):
# map from event_id to BaseEvent
old_room_state_events = await self.store.get_events(old_room_state_ids.values())
- for k, old_event_id in iteritems(old_room_state_ids):
+ for k, old_event_id in old_room_state_ids.items():
old_event = old_room_state_events.get(old_event_id)
if old_event:
initial_state[k] = old_event.content
@@ -431,7 +442,7 @@ class RoomCreationHandler(BaseHandler):
old_room_member_state_events = await self.store.get_events(
old_room_member_state_ids.values()
)
- for k, old_event in iteritems(old_room_member_state_events):
+ for k, old_event in old_room_member_state_events.items():
# Only transfer ban events
if (
"membership" in old_event.content
@@ -817,7 +828,7 @@ class RoomCreationHandler(BaseHandler):
)
return last_stream_id
- config = RoomCreationHandler.PRESETS_DICT[preset_config]
+ config = self._presets_dict[preset_config]
creator_id = creator.user.to_string()
@@ -908,11 +919,11 @@ class RoomCreationHandler(BaseHandler):
etype=etype, state_key=state_key, content=content
)
- if "encryption_alg" in config:
+ if config["encrypted"]:
last_sent_stream_id = await send(
etype=EventTypes.RoomEncryption,
state_key="",
- content={"algorithm": config["encryption_alg"]},
+ content={"algorithm": RoomEncryptionAlgorithms.DEFAULT},
)
return last_sent_stream_id
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 4cbc02b0d0..5e05be6181 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -17,8 +17,6 @@ import logging
from collections import namedtuple
from typing import Any, Dict, Optional
-from six import iteritems
-
import msgpack
from unpaddedbase64 import decode_base64, encode_base64
@@ -271,7 +269,7 @@ class RoomListHandler(BaseHandler):
event_map = yield self.store.get_events(
[
event_id
- for key, event_id in iteritems(current_state_ids)
+ for key, event_id in current_state_ids.items()
if key[0]
in (
EventTypes.Create,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 6bdb24baff..4c7524493e 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -18,8 +18,6 @@ import itertools
import logging
from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple
-from six import iteritems, itervalues
-
import attr
from prometheus_client import Counter
@@ -390,7 +388,7 @@ class SyncHandler(object):
# result returned by the event source is poor form (it might cache
# the object)
room_id = event["room_id"]
- event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"}
+ event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
receipt_key = since_token.receipt_key if since_token else "0"
@@ -408,7 +406,7 @@ class SyncHandler(object):
for event in receipts:
room_id = event["room_id"]
# exclude room id, as above
- event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"}
+ event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
return now_token, ephemeral_by_room
@@ -454,7 +452,7 @@ class SyncHandler(object):
current_state_ids_map = await self.state.get_current_state_ids(
room_id
)
- current_state_ids = frozenset(itervalues(current_state_ids_map))
+ current_state_ids = frozenset(current_state_ids_map.values())
recents = await filter_events_for_client(
self.storage,
@@ -509,7 +507,7 @@ class SyncHandler(object):
current_state_ids_map = await self.state.get_current_state_ids(
room_id
)
- current_state_ids = frozenset(itervalues(current_state_ids_map))
+ current_state_ids = frozenset(current_state_ids_map.values())
loaded_recents = await filter_events_for_client(
self.storage,
@@ -909,7 +907,7 @@ class SyncHandler(object):
logger.debug("filtering state from %r...", state_ids)
state_ids = {
t: event_id
- for t, event_id in iteritems(state_ids)
+ for t, event_id in state_ids.items()
if cache.get(t[1]) != event_id
}
logger.debug("...to %r", state_ids)
@@ -1430,7 +1428,7 @@ class SyncHandler(object):
if since_token:
for joined_sync in sync_result_builder.joined:
it = itertools.chain(
- joined_sync.timeline.events, itervalues(joined_sync.state)
+ joined_sync.timeline.events, joined_sync.state.values()
)
for event in it:
if event.type == EventTypes.Member:
@@ -1505,7 +1503,7 @@ class SyncHandler(object):
newly_left_rooms = []
room_entries = []
invited = []
- for room_id, events in iteritems(mem_change_events_by_room_id):
+ for room_id, events in mem_change_events_by_room_id.items():
logger.debug(
"Membership changes in %s: [%s]",
room_id,
@@ -1993,17 +1991,17 @@ def _calculate_state(
event_id_to_key = {
e: key
for key, e in itertools.chain(
- iteritems(timeline_contains),
- iteritems(previous),
- iteritems(timeline_start),
- iteritems(current),
+ timeline_contains.items(),
+ previous.items(),
+ timeline_start.items(),
+ current.items(),
)
}
- c_ids = set(itervalues(current))
- ts_ids = set(itervalues(timeline_start))
- p_ids = set(itervalues(previous))
- tc_ids = set(itervalues(timeline_contains))
+ c_ids = set(current.values())
+ ts_ids = set(timeline_start.values())
+ p_ids = set(previous.values())
+ tc_ids = set(timeline_contains.values())
# If we are lazyloading room members, we explicitly add the membership events
# for the senders in the timeline into the state block returned by /sync,
@@ -2017,7 +2015,7 @@ def _calculate_state(
if lazy_load_members:
p_ids.difference_update(
- e for t, e in iteritems(timeline_start) if t[0] == EventTypes.Member
+ e for t, e in timeline_start.items() if t[0] == EventTypes.Member
)
state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 12423b909a..521b6d620d 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -15,8 +15,6 @@
import logging
-from six import iteritems, iterkeys
-
import synapse.metrics
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.handlers.state_deltas import StateDeltasHandler
@@ -289,7 +287,7 @@ class UserDirectoryHandler(StateDeltasHandler):
users_with_profile = await self.state.get_current_users_in_room(room_id)
# Remove every user from the sharing tables for that room.
- for user_id in iterkeys(users_with_profile):
+ for user_id in users_with_profile.keys():
await self.store.remove_user_who_share_room(user_id, room_id)
# Then, re-add them to the tables.
@@ -298,7 +296,7 @@ class UserDirectoryHandler(StateDeltasHandler):
# which when ran over an entire room, will result in the same values
# being added multiple times. The batching upserts shouldn't make this
# too bad, though.
- for user_id, profile in iteritems(users_with_profile):
+ for user_id, profile in users_with_profile.items():
await self._handle_new_user(room_id, user_id, profile)
async def _handle_new_user(self, room_id, user_id, profile):
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 167293c46d..cbc37eac6e 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -19,6 +19,7 @@ from typing import Optional
from twisted.python.failure import Failure
from twisted.web.server import Request, Site
+from synapse.config.server import ListenerConfig
from synapse.http import redact_uri
from synapse.http.request_metrics import RequestMetrics, requests_counter
from synapse.logging.context import LoggingContext, PreserveLoggingContext
@@ -350,7 +351,7 @@ class SynapseSite(Site):
self,
logger_name,
site_tag,
- config,
+ config: ListenerConfig,
resource,
server_version_string,
*args,
@@ -360,7 +361,8 @@ class SynapseSite(Site):
self.site_tag = site_tag
- proxied = config.get("x_forwarded", False)
+ assert config.http_options is not None
+ proxied = config.http_options.x_forwarded
self.requestFactory = XForwardedForRequest if proxied else SynapseRequest
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii")
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 9cf31f96b3..087a49d65d 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -22,8 +22,6 @@ import threading
import time
from typing import Callable, Dict, Iterable, Optional, Tuple, Union
-import six
-
import attr
from prometheus_client import Counter, Gauge, Histogram
from prometheus_client.core import (
@@ -83,7 +81,7 @@ class LaterGauge(object):
return
if isinstance(calls, dict):
- for k, v in six.iteritems(calls):
+ for k, v in calls.items():
g.add_metric(k, v)
else:
g.add_metric([], calls)
@@ -194,7 +192,7 @@ class InFlightGauge(object):
gauge = GaugeMetricFamily(
"_".join([self.name, name]), "", labels=self.labels
)
- for key, metrics in six.iteritems(metrics_by_key):
+ for key, metrics in metrics_by_key.items():
gauge.add_metric(key, getattr(metrics, name))
yield gauge
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index ecdf1ad69f..a7849cefa5 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -126,7 +126,7 @@ class ModuleApi(object):
'errcode' property for more information on the reason for failure
Returns:
- Deferred[str]: user_id
+ defer.Deferred[str]: user_id
"""
return defer.ensureDeferred(
self._hs.get_registration_handler().register_user(
@@ -149,10 +149,12 @@ class ModuleApi(object):
Returns:
defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
"""
- return self._hs.get_registration_handler().register_device(
- user_id=user_id,
- device_id=device_id,
- initial_display_name=initial_display_name,
+ return defer.ensureDeferred(
+ self._hs.get_registration_handler().register_device(
+ user_id=user_id,
+ device_id=device_id,
+ initial_display_name=initial_display_name,
+ )
)
def record_user_external_id(
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index e75d964ac8..43ffe6faf0 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -17,8 +17,6 @@
import logging
from collections import namedtuple
-from six import iteritems, itervalues
-
from prometheus_client import Counter
from twisted.internet import defer
@@ -130,7 +128,7 @@ class BulkPushRuleEvaluator(object):
event, prev_state_ids, for_verification=False
)
auth_events = yield self.store.get_events(auth_events_ids)
- auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)}
+ auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
sender_level = get_user_power_level(event.sender, auth_events)
@@ -162,7 +160,7 @@ class BulkPushRuleEvaluator(object):
condition_cache = {}
- for uid, rules in iteritems(rules_by_user):
+ for uid, rules in rules_by_user.items():
if event.sender == uid:
continue
@@ -395,7 +393,7 @@ class RulesForRoom(object):
# If the event is a join event then it will be in current state evnts
# map but not in the DB, so we have to explicitly insert it.
if event.type == EventTypes.Member:
- for event_id in itervalues(member_event_ids):
+ for event_id in member_event_ids.values():
if event_id == event.event_id:
members[event_id] = (event.state_key, event.membership)
@@ -404,7 +402,7 @@ class RulesForRoom(object):
interested_in_user_ids = {
user_id
- for user_id, membership in itervalues(members)
+ for user_id, membership in members.values()
if membership == Membership.JOIN
}
@@ -415,7 +413,7 @@ class RulesForRoom(object):
)
user_ids = {
- uid for uid, have_pusher in iteritems(if_users_with_pushers) if have_pusher
+ uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
}
logger.debug("With pushers: %r", user_ids)
@@ -436,7 +434,7 @@ class RulesForRoom(object):
)
ret_rules_by_user.update(
- item for item in iteritems(rules_by_user) if item[0] is not None
+ item for item in rules_by_user.items() if item[0] is not None
)
self.update_cache(sequence, members, ret_rules_by_user, state_group)
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 11032491af..aeac257a6e 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -131,7 +131,7 @@ class PushRuleEvaluatorForEvent(object):
# XXX: optimisation: cache our pattern regexps
if condition["key"] == "content.body":
body = self._event.content.get("body", None)
- if not body:
+ if not body or not isinstance(body, str):
return False
return _glob_matches(pattern, body, word_boundary=True)
@@ -147,7 +147,7 @@ class PushRuleEvaluatorForEvent(object):
return False
body = self._event.content.get("body", None)
- if not body:
+ if not body or not isinstance(body, str):
return False
# Similar to _glob_matches, but do not treat display_name as a glob.
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 8b4312e5a3..8ec1a619a2 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -68,9 +68,8 @@ REQUIREMENTS = [
"phonenumbers>=8.2.0",
"six>=1.10",
"prometheus_client>=0.0.18,<0.8.0",
- # we use attr.s(slots), which arrived in 16.0.0
- # Twisted 18.7.0 requires attrs>=17.4.0
- "attrs>=17.4.0",
+ # we use attr.validators.deep_iterable, which arrived in 19.1.0
+ "attrs>=19.1.0",
"netaddr>=0.7.18",
"Jinja2>=2.9",
"bleach>=1.4.3",
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index c04f622816..ea5937a20c 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -149,7 +149,7 @@ class RdataCommand(Command):
class PositionCommand(Command):
- """Sent by the server to tell the client the stream postition without
+ """Sent by the server to tell the client the stream position without
needing to send an RDATA.
Format::
@@ -188,7 +188,7 @@ class ErrorCommand(_SimpleCommand):
class PingCommand(_SimpleCommand):
- """Sent by either side as a keep alive. The data is arbitary (often timestamp)
+ """Sent by either side as a keep alive. The data is arbitrary (often timestamp)
"""
NAME = "PING"
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index cbcf46f3ae..e6a2e2598b 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -112,8 +112,8 @@ class ReplicationCommandHandler:
"replication_position", clock=self._clock
)
- # Map of stream to batched updates. See RdataCommand for info on how
- # batching works.
+ # Map of stream name to batched updates. See RdataCommand for info on
+ # how batching works.
self._pending_batches = {} # type: Dict[str, List[Any]]
# The factory used to create connections.
@@ -123,7 +123,8 @@ class ReplicationCommandHandler:
# outgoing replication commands to.)
self._connections = [] # type: List[AbstractConnection]
- # For each connection, the incoming streams that are coming from that connection
+ # For each connection, the incoming stream names that are coming from
+ # that connection.
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
LaterGauge(
@@ -310,7 +311,28 @@ class ReplicationCommandHandler:
# Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, [])
rows.append(row)
- await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
+
+ stream = self._streams.get(stream_name)
+ if not stream:
+ logger.error("Got RDATA for unknown stream: %s", stream_name)
+ return
+
+ # Find where we previously streamed up to.
+ current_token = stream.current_token(cmd.instance_name)
+
+ # Discard this data if this token is earlier than the current
+ # position. Note that streams can be reset (in which case you
+ # expect an earlier token), but that must be preceded by a
+ # POSITION command.
+ if cmd.token <= current_token:
+ logger.debug(
+ "Discarding RDATA from stream %s at position %s before previous position %s",
+ stream_name,
+ cmd.token,
+ current_token,
+ )
+ else:
+ await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index dceb2792fa..c2c9a9c3aa 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -60,10 +60,18 @@ def login_id_thirdparty_from_phone(identifier):
Returns: Login identifier dict of type 'm.id.threepid'
"""
- if "country" not in identifier or "number" not in identifier:
+ if "country" not in identifier or (
+ # The specification requires a "phone" field, while Synapse used to require a "number"
+ # field. Accept both for backwards compatibility.
+ "phone" not in identifier
+ and "number" not in identifier
+ ):
raise SynapseError(400, "Invalid phone-type identifier")
- msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"])
+ # Accept both "phone" and "number" as valid keys in m.id.phone
+ phone_number = identifier.get("phone", identifier["number"])
+
+ msisdn = phone_number_to_msisdn(identifier["country"], phone_number)
return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index eec16f8ad8..7cf007d35e 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -51,7 +51,9 @@ class PresenceStatusRestServlet(RestServlet):
raise AuthError(403, "You are not allowed to see their presence.")
state = await self.presence_handler.get_state(target_user=user)
- state = format_user_presence_state(state, self.clock.time_msec())
+ state = format_user_presence_state(
+ state, self.clock.time_msec(), include_user_id=False
+ )
return 200, state
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 8d081718e3..d4fa959657 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -761,7 +761,7 @@ class ThreepidRestServlet(RestServlet):
class ThreepidAddRestServlet(RestServlet):
- PATTERNS = client_patterns("/account/3pid/add$", releases=(), unstable=True)
+ PATTERNS = client_patterns("/account/3pid/add$")
def __init__(self, hs):
super(ThreepidAddRestServlet, self).__init__()
@@ -835,7 +835,7 @@ class ThreepidAddRestServlet(RestServlet):
class ThreepidBindRestServlet(RestServlet):
- PATTERNS = client_patterns("/account/3pid/bind$", releases=(), unstable=True)
+ PATTERNS = client_patterns("/account/3pid/bind$")
def __init__(self, hs):
super(ThreepidBindRestServlet, self).__init__()
@@ -864,7 +864,7 @@ class ThreepidBindRestServlet(RestServlet):
class ThreepidUnbindRestServlet(RestServlet):
- PATTERNS = client_patterns("/account/3pid/unbind$", releases=(), unstable=True)
+ PATTERNS = client_patterns("/account/3pid/unbind$")
def __init__(self, hs):
super(ThreepidUnbindRestServlet, self).__init__()
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index fd10d42f2f..4ee8c60257 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -20,8 +20,6 @@ import os
import shutil
from typing import Dict, Tuple
-from six import iteritems
-
import twisted.internet.error
import twisted.web.http
from twisted.web.resource import Resource
@@ -606,7 +604,7 @@ class MediaRepository(object):
thumbnails[(t_width, t_height, r_type)] = r_method
# Now we generate the thumbnails for each dimension, store it
- for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
+ for (t_width, t_height, t_type), t_method in thumbnails.items():
# Generate the thumbnail
if t_method == "crop":
t_byte_source = await defer_to_thread(
diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
index 3bf330da49..e7e8b8e688 100644
--- a/synapse/server_notices/consent_server_notices.py
+++ b/synapse/server_notices/consent_server_notices.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from six import iteritems, string_types
+from six import string_types
from synapse.api.errors import SynapseError
from synapse.api.urls import ConsentURIBuilder
@@ -121,7 +121,7 @@ def copy_with_str_subst(x, substitutions):
if isinstance(x, string_types):
return x % substitutions
if isinstance(x, dict):
- return {k: copy_with_str_subst(v, substitutions) for (k, v) in iteritems(x)}
+ return {k: copy_with_str_subst(v, substitutions) for (k, v) in x.items()}
if isinstance(x, (list, tuple)):
return [copy_with_str_subst(y) for y in x]
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index 73f2cedb5c..4404ceff93 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -14,8 +14,6 @@
# limitations under the License.
import logging
-from six import iteritems
-
from synapse.api.constants import (
EventTypes,
LimitBlockingTypes,
@@ -214,7 +212,7 @@ class ResourceLimitsServerNotices(object):
referenced_events = list(pinned_state_event.content.get("pinned", []))
events = await self._store.get_events(referenced_events)
- for event_id, event in iteritems(events):
+ for event_id, event in events.items():
if event.type != EventTypes.Message:
continue
if event.content.get("msgtype") == ServerNoticeMsgType:
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 2fa529fcd0..50fd843f66 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -18,8 +18,6 @@ import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Set
-from six import iteritems, itervalues
-
import attr
from frozendict import frozendict
from prometheus_client import Histogram
@@ -144,7 +142,7 @@ class StateHandler(object):
list(state.values()), get_prev_content=False
)
state = {
- key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map
+ key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
}
return state
@@ -423,7 +421,7 @@ class StateHandler(object):
state_res_store=StateResolutionStore(self.store),
)
- new_state = {key: state_map[ev_id] for key, ev_id in iteritems(new_state)}
+ new_state = {key: state_map[ev_id] for key, ev_id in new_state.items()}
return new_state
@@ -505,8 +503,8 @@ class StateResolutionHandler(object):
# resolve_events_with_store do it?
new_state = {}
conflicted_state = False
- for st in itervalues(state_groups_ids):
- for key, e_id in iteritems(st):
+ for st in state_groups_ids.values():
+ for key, e_id in st.items():
if key in new_state:
conflicted_state = True
break
@@ -520,7 +518,7 @@ class StateResolutionHandler(object):
new_state = yield resolve_events_with_store(
room_id,
room_version,
- list(itervalues(state_groups_ids)),
+ list(state_groups_ids.values()),
event_map=event_map,
state_res_store=state_res_store,
)
@@ -561,12 +559,12 @@ def _make_state_cache_entry(new_state, state_groups_ids):
# not get persisted.
# first look for exact matches
- new_state_event_ids = set(itervalues(new_state))
- for sg, state in iteritems(state_groups_ids):
+ new_state_event_ids = set(new_state.values())
+ for sg, state in state_groups_ids.items():
if len(new_state_event_ids) != len(state):
continue
- old_state_event_ids = set(itervalues(state))
+ old_state_event_ids = set(state.values())
if new_state_event_ids == old_state_event_ids:
# got an exact match.
return _StateCacheEntry(state=new_state, state_group=sg)
@@ -579,8 +577,8 @@ def _make_state_cache_entry(new_state, state_groups_ids):
prev_group = None
delta_ids = None
- for old_group, old_state in iteritems(state_groups_ids):
- n_delta_ids = {k: v for k, v in iteritems(new_state) if old_state.get(k) != v}
+ for old_group, old_state in state_groups_ids.items():
+ n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v}
if not delta_ids or len(n_delta_ids) < len(delta_ids):
prev_group = old_group
delta_ids = n_delta_ids
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 9bf98d06f2..7b531a8337 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -17,8 +17,6 @@ import hashlib
import logging
from typing import Callable, Dict, List, Optional
-from six import iteritems, iterkeys, itervalues
-
from twisted.internet import defer
from synapse import event_auth
@@ -70,11 +68,11 @@ def resolve_events_with_store(
unconflicted_state, conflicted_state = _seperate(state_sets)
needed_events = {
- event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids
+ event_id for event_ids in conflicted_state.values() for event_id in event_ids
}
needed_event_count = len(needed_events)
if event_map is not None:
- needed_events -= set(iterkeys(event_map))
+ needed_events -= set(event_map.keys())
logger.info(
"Asking for %d/%d conflicted events", len(needed_events), needed_event_count
@@ -102,11 +100,11 @@ def resolve_events_with_store(
unconflicted_state, conflicted_state, state_map
)
- new_needed_events = set(itervalues(auth_events))
+ new_needed_events = set(auth_events.values())
new_needed_event_count = len(new_needed_events)
new_needed_events -= needed_events
if event_map is not None:
- new_needed_events -= set(iterkeys(event_map))
+ new_needed_events -= set(event_map.keys())
logger.info(
"Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count
@@ -152,7 +150,7 @@ def _seperate(state_sets):
conflicted_state = {}
for state_set in state_set_iterator:
- for key, value in iteritems(state_set):
+ for key, value in state_set.items():
# Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key)
if unconflicted_value is None:
@@ -178,7 +176,7 @@ def _seperate(state_sets):
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {}
- for event_ids in itervalues(conflicted_state):
+ for event_ids in conflicted_state.values():
for event_id in event_ids:
if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id])
@@ -194,7 +192,7 @@ def _resolve_with_state(
unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map
):
conflicted_state = {}
- for key, event_ids in iteritems(conflicted_state_ids):
+ for key, event_ids in conflicted_state_ids.items():
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1:
conflicted_state[key] = events
@@ -203,7 +201,7 @@ def _resolve_with_state(
auth_events = {
key: state_map[ev_id]
- for key, ev_id in iteritems(auth_event_ids)
+ for key, ev_id in auth_event_ids.items()
if ev_id in state_map
}
@@ -214,7 +212,7 @@ def _resolve_with_state(
raise
new_state = unconflicted_state_ids
- for key, event in iteritems(resolved_state):
+ for key, event in resolved_state.items():
new_state[key] = event.event_id
return new_state
@@ -238,21 +236,21 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state)
- for key, events in iteritems(conflicted_state):
+ for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events(events, auth_events)
auth_events.update(resolved_state)
- for key, events in iteritems(conflicted_state):
+ for key, events in conflicted_state.items():
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events(events, auth_events)
auth_events.update(resolved_state)
- for key, events in iteritems(conflicted_state):
+ for key, events in conflicted_state.items():
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = _resolve_normal_events(events, auth_events)
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 18484e2fa6..e25bc5d264 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -18,8 +18,6 @@ import itertools
import logging
from typing import Dict, List, Optional
-from six import iteritems, itervalues
-
from twisted.internet import defer
import synapse.state
@@ -87,7 +85,7 @@ def resolve_events_with_store(
full_conflicted_set = set(
itertools.chain(
- itertools.chain.from_iterable(itervalues(conflicted_state)), auth_diff
+ itertools.chain.from_iterable(conflicted_state.values()), auth_diff
)
)
@@ -572,7 +570,7 @@ def lexicographical_topological_sort(graph, key):
# `(key(node), node)` so that sorting does the right thing
zero_outdegree = []
- for node, edges in iteritems(graph):
+ for node, edges in graph.items():
if len(edges) == 0:
zero_outdegree.append((key(node), node))
diff --git a/synapse/static/client/login/index.html b/synapse/static/client/login/index.html
index 6fefdaaff7..9e6daf38ac 100644
--- a/synapse/static/client/login/index.html
+++ b/synapse/static/client/login/index.html
@@ -1,24 +1,24 @@
<!doctype html>
<html>
<head>
-<title> Login </title>
-<meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
-<link rel="stylesheet" href="style.css">
-<script src="js/jquery-3.4.1.min.js"></script>
-<script src="js/login.js"></script>
+ <meta http-equiv="Content-Type" content="text/html; charset=utf-8">
+ <title> Login </title>
+ <meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+ <link rel="stylesheet" href="style.css">
+ <script src="js/jquery-3.4.1.min.js"></script>
+ <script src="js/login.js"></script>
</head>
<body onload="matrixLogin.onLoad()">
- <center>
- <br/>
+ <div id="container">
<h1 id="title"></h1>
- <span id="feedback" style="color: #f00"></span>
+ <span id="feedback"></span>
<div id="loading">
<img src="spinner.gif" />
</div>
- <div id="sso_flow" class="login_flow" style="display:none">
+ <div id="sso_flow" class="login_flow" style="display: none;">
Single-sign on:
<form id="sso_form" action="/_matrix/client/r0/login/sso/redirect" method="get">
<input id="sso_redirect_url" type="hidden" name="redirectUrl" value=""/>
@@ -26,9 +26,9 @@
</form>
</div>
- <div id="password_flow" class="login_flow" style="display:none">
+ <div id="password_flow" class="login_flow" style="display: none;">
Password Authentication:
- <form onsubmit="matrixLogin.password_login(); return false;">
+ <form onsubmit="matrixLogin.passwordLogin(); return false;">
<input id="user_id" size="32" type="text" placeholder="Matrix ID (e.g. bob)" autocapitalize="off" autocorrect="off" />
<br/>
<input id="password" size="32" type="password" placeholder="Password"/>
@@ -38,9 +38,9 @@
</form>
</div>
- <div id="no_login_types" type="button" class="login_flow" style="display:none">
+ <div id="no_login_types" type="button" class="login_flow" style="display: none;">
Log in currently unavailable.
</div>
- </center>
+ </div>
</body>
</html>
diff --git a/synapse/static/client/login/js/login.js b/synapse/static/client/login/js/login.js
index ba8048b23f..3678670ec7 100644
--- a/synapse/static/client/login/js/login.js
+++ b/synapse/static/client/login/js/login.js
@@ -5,11 +5,11 @@ window.matrixLogin = {
};
// Titles get updated through the process to give users feedback.
-var TITLE_PRE_AUTH = "Log in with one of the following methods";
-var TITLE_POST_AUTH = "Logging in...";
+const TITLE_PRE_AUTH = "Log in with one of the following methods";
+const TITLE_POST_AUTH = "Logging in...";
// The cookie used to store the original query parameters when using SSO.
-var COOKIE_KEY = "synapse_login_fallback_qs";
+const COOKIE_KEY = "synapse_login_fallback_qs";
/*
* Submit a login request.
@@ -20,9 +20,9 @@ var COOKIE_KEY = "synapse_login_fallback_qs";
* login request, e.g. device_id.
* callback: (Optional) Function to call on successful login.
*/
-var submitLogin = function(type, data, extra, callback) {
+function submitLogin(type, data, extra, callback) {
console.log("Logging in with " + type);
- set_title(TITLE_POST_AUTH);
+ setTitle(TITLE_POST_AUTH);
// Add the login type.
data.type = type;
@@ -41,12 +41,15 @@ var submitLogin = function(type, data, extra, callback) {
}
matrixLogin.onLogin(response);
}).fail(errorFunc);
-};
+}
-var errorFunc = function(err) {
+/*
+ * Display an error to the user and show the login form again.
+ */
+function errorFunc(err) {
// We want to show the error to the user rather than redirecting immediately to the
// SSO portal (if SSO is the only login option), so we inhibit the redirect.
- show_login(true);
+ showLogin(true);
if (err.responseJSON && err.responseJSON.error) {
setFeedbackString(err.responseJSON.error + " (" + err.responseJSON.errcode + ")");
@@ -54,27 +57,42 @@ var errorFunc = function(err) {
else {
setFeedbackString("Request failed: " + err.status);
}
-};
+}
-var setFeedbackString = function(text) {
+/*
+ * Display an error to the user.
+ */
+function setFeedbackString(text) {
$("#feedback").text(text);
-};
+}
-var show_login = function(inhibit_redirect) {
- // Set the redirect to come back to this page, a login token will get added
- // and handled after the redirect.
- var this_page = window.location.origin + window.location.pathname;
- $("#sso_redirect_url").val(this_page);
+/*
+ * (Maybe) Show the login forms.
+ *
+ * This actually does a few unrelated functions:
+ *
+ * * Configures the SSO redirect URL to come back to this page.
+ * * Configures and shows the SSO form, if the server supports SSO.
+ * * Otherwise, shows the password form.
+ */
+function showLogin(inhibitRedirect) {
+ setTitle(TITLE_PRE_AUTH);
- // If inhibit_redirect is false, and SSO is the only supported login method,
+ // If inhibitRedirect is false, and SSO is the only supported login method,
// we can redirect straight to the SSO page.
if (matrixLogin.serverAcceptsSso) {
+ // Set the redirect to come back to this page, a login token will get
+ // added as a query parameter and handled after the redirect.
+ $("#sso_redirect_url").val(window.location.origin + window.location.pathname);
+
// Before submitting SSO, set the current query parameters into a cookie
// for retrieval later.
var qs = parseQsFromUrl();
setCookie(COOKIE_KEY, JSON.stringify(qs));
- if (!inhibit_redirect && !matrixLogin.serverAcceptsPassword) {
+ // If password is not supported and redirects are allowed, then submit
+ // the form (redirecting to the SSO provider).
+ if (!inhibitRedirect && !matrixLogin.serverAcceptsPassword) {
$("#sso_form").submit();
return;
}
@@ -87,30 +105,39 @@ var show_login = function(inhibit_redirect) {
$("#password_flow").show();
}
+ // If neither password or SSO are supported, show an error to the user.
if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsSso) {
$("#no_login_types").show();
}
- set_title(TITLE_PRE_AUTH);
-
$("#loading").hide();
-};
+}
-var show_spinner = function() {
+/*
+ * Hides the forms and shows a loading throbber.
+ */
+function showSpinner() {
$("#password_flow").hide();
$("#sso_flow").hide();
$("#no_login_types").hide();
$("#loading").show();
-};
+}
-var set_title = function(title) {
+/*
+ * Helper to show the page's main title.
+ */
+function setTitle(title) {
$("#title").text(title);
-};
+}
-var fetch_info = function(cb) {
+/*
+ * Query the login endpoint for the homeserver's supported flows.
+ *
+ * This populates matrixLogin.serverAccepts* variables.
+ */
+function fetchLoginFlows(cb) {
$.get(matrixLogin.endpoint, function(response) {
- var serverAcceptsPassword = false;
- for (var i=0; i<response.flows.length; i++) {
+ for (var i = 0; i < response.flows.length; i++) {
var flow = response.flows[i];
if ("m.login.sso" === flow.type) {
matrixLogin.serverAcceptsSso = true;
@@ -126,27 +153,41 @@ var fetch_info = function(cb) {
}).fail(errorFunc);
}
+/*
+ * Called on load to fetch login flows and attempt SSO login (if a token is available).
+ */
matrixLogin.onLoad = function() {
- fetch_info(function() {
- if (!try_token()) {
- show_login(false);
+ fetchLoginFlows(function() {
+ // (Maybe) attempt logging in via SSO if a token is available.
+ if (!tryTokenLogin()) {
+ showLogin(false);
}
});
};
-matrixLogin.password_login = function() {
+/*
+ * Submit simple user & password login.
+ */
+matrixLogin.passwordLogin = function() {
var user = $("#user_id").val();
var pwd = $("#password").val();
setFeedbackString("");
- show_spinner();
+ showSpinner();
submitLogin(
"m.login.password",
{user: user, password: pwd},
parseQsFromUrl());
};
+/*
+ * The onLogin function gets called after a succesful login.
+ *
+ * It is expected that implementations override this to be notified when the
+ * login is complete. The response to the login call is provided as the single
+ * parameter.
+ */
matrixLogin.onLogin = function(response) {
// clobber this function
console.warn("onLogin - This function should be replaced to proceed.");
@@ -155,7 +196,7 @@ matrixLogin.onLogin = function(response) {
/*
* Process the query parameters from the current URL into an object.
*/
-var parseQsFromUrl = function() {
+function parseQsFromUrl() {
var pos = window.location.href.indexOf("?");
if (pos == -1) {
return {};
@@ -174,12 +215,12 @@ var parseQsFromUrl = function() {
result[key] = val;
});
return result;
-};
+}
/*
* Process the cookies and return an object.
*/
-var parseCookies = function() {
+function parseCookies() {
var allCookies = document.cookie;
var result = {};
allCookies.split(";").forEach(function(part) {
@@ -196,32 +237,32 @@ var parseCookies = function() {
result[key] = val;
});
return result;
-};
+}
/*
* Set a cookie that is valid for 1 hour.
*/
-var setCookie = function(key, value) {
+function setCookie(key, value) {
// The maximum age is set in seconds.
var maxAge = 60 * 60;
// Set the cookie, this defaults to the current domain and path.
document.cookie = key + "=" + encodeURIComponent(value) + ";max-age=" + maxAge + ";sameSite=lax";
-};
+}
/*
* Removes a cookie by key.
*/
-var deleteCookie = function(key) {
+function deleteCookie(key) {
// Delete a cookie by setting the expiration to 0. (Note that the value
// doesn't matter.)
document.cookie = key + "=deleted;expires=0";
-};
+}
/*
* Submits the login token if one is found in the query parameters. Returns a
* boolean of whether the login token was found or not.
*/
-var try_token = function() {
+function tryTokenLogin() {
// Check if the login token is in the query parameters.
var qs = parseQsFromUrl();
@@ -233,18 +274,18 @@ var try_token = function() {
// Retrieve the original query parameters (from before the SSO redirect).
// They are stored as JSON in a cookie.
var cookies = parseCookies();
- var original_query_params = JSON.parse(cookies[COOKIE_KEY] || "{}")
+ var originalQueryParams = JSON.parse(cookies[COOKIE_KEY] || "{}")
// If the login is successful, delete the cookie.
- var callback = function() {
+ function callback() {
deleteCookie(COOKIE_KEY);
}
submitLogin(
"m.login.token",
{token: loginToken},
- original_query_params,
+ originalQueryParams,
callback);
return true;
-};
+}
diff --git a/synapse/static/client/login/style.css b/synapse/static/client/login/style.css
index 1cce5ed950..83e4f6abc8 100644
--- a/synapse/static/client/login/style.css
+++ b/synapse/static/client/login/style.css
@@ -31,20 +31,44 @@ form {
margin: 10px 0 0 0;
}
+/*
+ * Add some padding to the viewport.
+ */
+#container {
+ padding: 10px;
+}
+/*
+ * Center all direct children of the main form.
+ */
+#container > * {
+ display: block;
+ margin-left: auto;
+ margin-right: auto;
+ text-align: center;
+}
+
+/*
+ * A wrapper around each login flow.
+ */
.login_flow {
width: 300px;
text-align: left;
padding: 10px;
margin-bottom: 40px;
- -webkit-border-radius: 10px;
- -moz-border-radius: 10px;
border-radius: 10px;
-
- -webkit-box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
- -moz-box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
background-color: #f8f8f8;
border: 1px #ccc solid;
}
+
+/*
+ * Used to show error content.
+ */
+#feedback {
+ /* Red text. */
+ color: #ff0000;
+ /* A little space to not overlap the box-shadow. */
+ margin-bottom: 20px;
+}
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 71f8d43a76..995d4764a9 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -15,8 +15,6 @@
import logging
-from six import iteritems
-
from twisted.internet import defer
from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -421,7 +419,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
):
self.database_engine.lock_table(txn, "user_ips")
- for entry in iteritems(to_update):
+ for entry in to_update.items():
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
try:
@@ -530,7 +528,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
"user_agent": user_agent,
"last_seen": last_seen,
}
- for (access_token, ip), (user_agent, last_seen) in iteritems(results)
+ for (access_token, ip), (user_agent, last_seen) in results.items()
]
@wrap_as_background_process("prune_old_user_ips")
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index fb9f798e29..0ff0542453 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -17,8 +17,6 @@
import logging
from typing import List, Optional, Set, Tuple
-from six import iteritems
-
from canonicaljson import json
from twisted.internet import defer
@@ -208,7 +206,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
# add the updated cross-signing keys to the results list
- for user_id, result in iteritems(cross_signing_keys_by_user):
+ for user_id, result in cross_signing_keys_by_user.items():
result["user_id"] = user_id
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
results.append(("org.matrix.signing_key_update", result))
@@ -269,7 +267,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
results = []
- for user_id, user_devices in iteritems(devices):
+ for user_id, user_devices in devices.items():
# The prev_id for the first row is always the last row before
# `from_stream_id`
prev_id = yield self._get_last_device_update_for_remote_user(
@@ -493,7 +491,7 @@ class DeviceWorkerStore(SQLBaseStore):
if devices:
user_devices = devices[user_id]
results = []
- for device_id, device in iteritems(user_devices):
+ for device_id, device in user_devices.items():
result = {"device_id": device_id}
key_json = device.get("key_json", None)
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 20698bfd16..1a0842d4b0 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -16,8 +16,6 @@
# limitations under the License.
from typing import Dict, List
-from six import iteritems
-
from canonicaljson import encode_canonical_json, json
from twisted.enterprise.adbapi import Connection
@@ -64,9 +62,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
rv = {}
- for user_id, device_keys in iteritems(results):
+ for user_id, device_keys in results.items():
rv[user_id] = {}
- for device_id, device_info in iteritems(device_keys):
+ for device_id, device_info in device_keys.items():
r = db_to_json(device_info.pop("key_json"))
r["unsigned"] = {}
display_name = device_info["device_display_name"]
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 0321274de2..bc9f4f08ea 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -16,8 +16,6 @@
import logging
-from six import iteritems
-
from canonicaljson import json
from twisted.internet import defer
@@ -455,7 +453,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
sql,
(
_gen_entry(user_id, actions)
- for user_id, actions in iteritems(user_id_actions)
+ for user_id, actions in user_id_actions.items()
),
)
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index a6572571b4..8a13101f1d 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -21,7 +21,7 @@ from collections import OrderedDict, namedtuple
from functools import wraps
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
-from six import integer_types, iteritems, text_type
+from six import integer_types, text_type
from six.moves import range
import attr
@@ -232,10 +232,10 @@ class PersistEventsStore:
event_counter.labels(event.type, origin_type, origin_entity).inc()
- for room_id, new_state in iteritems(current_state_for_room):
+ for room_id, new_state in current_state_for_room.items():
self.store.get_current_state_ids.prefill((room_id,), new_state)
- for room_id, latest_event_ids in iteritems(new_forward_extremeties):
+ for room_id, latest_event_ids in new_forward_extremeties.items():
self.store.get_latest_event_ids_in_room.prefill(
(room_id,), list(latest_event_ids)
)
@@ -461,7 +461,7 @@ class PersistEventsStore:
state_delta_by_room: Dict[str, DeltaState],
stream_id: int,
):
- for room_id, delta_state in iteritems(state_delta_by_room):
+ for room_id, delta_state in state_delta_by_room.items():
to_delete = delta_state.to_delete
to_insert = delta_state.to_insert
@@ -545,7 +545,7 @@ class PersistEventsStore:
""",
[
(room_id, key[0], key[1], ev_id, ev_id)
- for key, ev_id in iteritems(to_insert)
+ for key, ev_id in to_insert.items()
],
)
@@ -642,7 +642,7 @@ class PersistEventsStore:
def _update_forward_extremities_txn(
self, txn, new_forward_extremities, max_stream_order
):
- for room_id, new_extrem in iteritems(new_forward_extremities):
+ for room_id, new_extrem in new_forward_extremities.items():
self.db.simple_delete_txn(
txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
)
@@ -655,7 +655,7 @@ class PersistEventsStore:
table="event_forward_extremities",
values=[
{"event_id": ev_id, "room_id": room_id}
- for room_id, new_extrem in iteritems(new_forward_extremities)
+ for room_id, new_extrem in new_forward_extremities.items()
for ev_id in new_extrem
],
)
@@ -672,7 +672,7 @@ class PersistEventsStore:
"event_id": event_id,
"stream_ordering": max_stream_order,
}
- for room_id, new_extrem in iteritems(new_forward_extremities)
+ for room_id, new_extrem in new_forward_extremities.items()
for event_id in new_extrem
],
)
@@ -727,7 +727,7 @@ class PersistEventsStore:
event.depth, depth_updates.get(event.room_id, event.depth)
)
- for room_id, depth in iteritems(depth_updates):
+ for room_id, depth in depth_updates.items():
self._update_min_depth_for_room_txn(txn, room_id, depth)
def _update_outliers_txn(self, txn, events_and_contexts):
@@ -1497,11 +1497,11 @@ class PersistEventsStore:
table="event_to_state_groups",
values=[
{"state_group": state_group_id, "event_id": event_id}
- for event_id, state_group_id in iteritems(state_groups)
+ for event_id, state_group_id in state_groups.items()
],
)
- for event_id, state_group_id in iteritems(state_groups):
+ for event_id, state_group_id in state_groups.items():
txn.call_after(
self.store._get_state_group_for_event.prefill,
(event_id,),
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index ab70776977..efb1a4fb4c 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -19,8 +19,6 @@ import logging
import re
from typing import List, Optional
-from six import iterkeys
-
from twisted.internet import defer
from twisted.internet.defer import Deferred
@@ -833,7 +831,7 @@ class RegistrationWorkerStore(SQLBaseStore):
last_send_attempt, validated_at
FROM threepid_validation_session WHERE %s
""" % (
- " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),
+ " AND ".join("%s = ?" % k for k in keyvalues.keys()),
)
if validated is not None:
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 137ebac833..44bab65eac 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -17,8 +17,6 @@
import logging
from typing import Iterable, List, Set
-from six import iteritems, itervalues
-
from canonicaljson import json
from twisted.internet import defer
@@ -544,7 +542,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
users_in_room = {}
member_event_ids = [
e_id
- for key, e_id in iteritems(current_state_ids)
+ for key, e_id in current_state_ids.items()
if key[0] == EventTypes.Member
]
@@ -561,7 +559,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
users_in_room = dict(prev_res)
member_event_ids = [
e_id
- for key, e_id in iteritems(context.delta_ids)
+ for key, e_id in context.delta_ids.items()
if key[0] == EventTypes.Member
]
for etype, state_key in context.delta_ids:
@@ -1101,7 +1099,7 @@ class _JoinedHostsCache(object):
if state_entry.state_group == self.state_group:
pass
elif state_entry.prev_group == self.state_group:
- for (typ, state_key), event_id in iteritems(state_entry.delta_ids):
+ for (typ, state_key), event_id in state_entry.delta_ids.items():
if typ != EventTypes.Member:
continue
@@ -1131,7 +1129,7 @@ class _JoinedHostsCache(object):
self.state_group = state_entry.state_group
else:
self.state_group = object()
- self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
+ self._len = sum(len(v) for v in self.hosts_to_joined_users.values())
return frozenset(self.hosts_to_joined_users)
def __len__(self):
diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py
index 1d8ee22fb1..ec2f38c373 100644
--- a/synapse/storage/data_stores/main/ui_auth.py
+++ b/synapse/storage/data_stores/main/ui_auth.py
@@ -186,7 +186,7 @@ class UIAuthWorkerStore(SQLBaseStore):
# The clientdict gets stored as JSON.
clientdict_json = json.dumps(clientdict)
- self.db.simple_update_one(
+ await self.db.simple_update_one(
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
updatevalues={"clientdict": clientdict_json},
diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/data_stores/state/bg_updates.py
index ff000bc9ec..be1fe97d79 100644
--- a/synapse/storage/data_stores/state/bg_updates.py
+++ b/synapse/storage/data_stores/state/bg_updates.py
@@ -15,8 +15,6 @@
import logging
-from six import iteritems
-
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
@@ -280,7 +278,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
delta_state = {
key: value
- for key, value in iteritems(curr_state)
+ for key, value in curr_state.items()
if prev_state.get(key, None) != value
}
@@ -316,7 +314,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in iteritems(delta_state)
+ for key, state_id in delta_state.items()
],
)
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index f3ad1e4369..b720212e55 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/data_stores/state/store.py
@@ -17,7 +17,6 @@ import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Set, Tuple
-from six import iteritems
from six.moves import range
from twisted.internet import defer
@@ -263,7 +262,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# And finally update the result dict, by filtering out any extra
# stuff we pulled out of the database.
- for group, group_state_dict in iteritems(group_to_state_dict):
+ for group, group_state_dict in group_to_state_dict.items():
# We just replace any existing entries, as we will have loaded
# everything we need from the database anyway.
state[group] = state_filter.filter_state(group_state_dict)
@@ -341,11 +340,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
else:
non_member_types = non_member_filter.concrete_types()
- for group, group_state_dict in iteritems(group_to_state_dict):
+ for group, group_state_dict in group_to_state_dict.items():
state_dict_members = {}
state_dict_non_members = {}
- for k, v in iteritems(group_state_dict):
+ for k, v in group_state_dict.items():
if k[0] == EventTypes.Member:
state_dict_members[k] = v
else:
@@ -432,7 +431,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in iteritems(delta_ids)
+ for key, state_id in delta_ids.items()
],
)
else:
@@ -447,7 +446,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in iteritems(current_state_ids)
+ for key, state_id in current_state_ids.items()
],
)
@@ -458,7 +457,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
current_member_state_ids = {
s: ev
- for (s, ev) in iteritems(current_state_ids)
+ for (s, ev) in current_state_ids.items()
if s[0] == EventTypes.Member
}
txn.call_after(
@@ -470,7 +469,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
current_non_member_state_ids = {
s: ev
- for (s, ev) in iteritems(current_state_ids)
+ for (s, ev) in current_state_ids.items()
if s[0] != EventTypes.Member
}
txn.call_after(
@@ -555,7 +554,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in iteritems(curr_state)
+ for key, state_id in curr_state.items()
],
)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index b112ff3df2..645a70934c 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -29,7 +29,6 @@ from typing import (
TypeVar,
)
-from six import iteritems, iterkeys, itervalues
from six.moves import intern, range
from prometheus_client import Histogram
@@ -259,7 +258,7 @@ class PerformanceCounters(object):
def interval(self, interval_duration_secs, limit=3):
counters = []
- for name, (count, cum_time) in iteritems(self.current_counters):
+ for name, (count, cum_time) in self.current_counters.items():
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append(
(
@@ -1053,7 +1052,7 @@ class Database(object):
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
if keyvalues:
- sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+ sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
txn.execute(sql, list(keyvalues.values()))
else:
txn.execute(sql)
@@ -1191,7 +1190,7 @@ class Database(object):
clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
clauses = [clause]
- for key, value in iteritems(keyvalues):
+ for key, value in keyvalues.items():
clauses.append("%s = ?" % (key,))
values.append(value)
@@ -1212,7 +1211,7 @@ class Database(object):
@staticmethod
def simple_update_txn(txn, table, keyvalues, updatevalues):
if keyvalues:
- where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+ where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
else:
where = ""
@@ -1351,7 +1350,7 @@ class Database(object):
clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
clauses = [clause]
- for key, value in iteritems(keyvalues):
+ for key, value in keyvalues.items():
clauses.append("%s = ?" % (key,))
values.append(value)
@@ -1388,7 +1387,7 @@ class Database(object):
txn.close()
if cache:
- min_val = min(itervalues(cache))
+ min_val = min(cache.values())
else:
min_val = max_value
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index f159400a87..92dfd709bc 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -20,7 +20,6 @@ import logging
from collections import deque, namedtuple
from typing import Iterable, List, Optional, Set, Tuple
-from six import iteritems
from six.moves import range
from prometheus_client import Counter, Histogram
@@ -218,7 +217,7 @@ class EventsPersistenceStorage(object):
partitioned.setdefault(event.room_id, []).append((event, ctx))
deferreds = []
- for room_id, evs_ctxs in iteritems(partitioned):
+ for room_id, evs_ctxs in partitioned.items():
d = self._event_persist_queue.add_to_queue(
room_id, evs_ctxs, backfilled=backfilled
)
@@ -319,7 +318,7 @@ class EventsPersistenceStorage(object):
(event, context)
)
- for room_id, ev_ctx_rm in iteritems(events_by_room):
+ for room_id, ev_ctx_rm in events_by_room.items():
latest_event_ids = await self.main_store.get_latest_event_ids_in_room(
room_id
)
@@ -674,7 +673,7 @@ class EventsPersistenceStorage(object):
to_insert = {
key: ev_id
- for key, ev_id in iteritems(current_state)
+ for key, ev_id in current_state.items()
if ev_id != existing_state.get(key)
}
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index c522c80922..dc568476f4 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -16,8 +16,6 @@
import logging
from typing import Iterable, List, TypeVar
-from six import iteritems, itervalues
-
import attr
from twisted.internet import defer
@@ -51,7 +49,7 @@ class StateFilter(object):
# If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary
if self.include_others:
- self.types = {k: v for k, v in iteritems(self.types) if v is not None}
+ self.types = {k: v for k, v in self.types.items() if v is not None}
@staticmethod
def all():
@@ -150,7 +148,7 @@ class StateFilter(object):
has_non_member_wildcard = self.include_others or any(
state_keys is None
- for t, state_keys in iteritems(self.types)
+ for t, state_keys in self.types.items()
if t != EventTypes.Member
)
@@ -199,7 +197,7 @@ class StateFilter(object):
# First we build up a lost of clauses for each type/state_key combo
clauses = []
- for etype, state_keys in iteritems(self.types):
+ for etype, state_keys in self.types.items():
if state_keys is None:
clauses.append("(type = ?)")
where_args.append(etype)
@@ -251,7 +249,7 @@ class StateFilter(object):
return dict(state_dict)
filtered_state = {}
- for k, v in iteritems(state_dict):
+ for k, v in state_dict.items():
typ, state_key = k
if typ in self.types:
state_keys = self.types[typ]
@@ -279,7 +277,7 @@ class StateFilter(object):
"""
return self.include_others or any(
- state_keys is None for state_keys in itervalues(self.types)
+ state_keys is None for state_keys in self.types.values()
)
def concrete_types(self):
@@ -292,7 +290,7 @@ class StateFilter(object):
"""
return [
(t, s)
- for t, state_keys in iteritems(self.types)
+ for t, state_keys in self.types.items()
if state_keys is not None
for s in state_keys
]
@@ -324,7 +322,7 @@ class StateFilter(object):
member_filter = StateFilter.none()
non_member_filter = StateFilter(
- types={k: v for k, v in iteritems(self.types) if k != EventTypes.Member},
+ types={k: v for k, v in self.types.items() if k != EventTypes.Member},
include_others=self.include_others,
)
@@ -366,7 +364,7 @@ class StateGroupStorage(object):
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
- groups = set(itervalues(event_to_groups))
+ groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(groups)
return group_to_state
@@ -400,8 +398,8 @@ class StateGroupStorage(object):
state_event_map = yield self.stores.main.get_events(
[
ev_id
- for group_ids in itervalues(group_to_ids)
- for ev_id in itervalues(group_ids)
+ for group_ids in group_to_ids.values()
+ for ev_id in group_ids.values()
],
get_prev_content=False,
)
@@ -409,10 +407,10 @@ class StateGroupStorage(object):
return {
group: [
state_event_map[v]
- for v in itervalues(event_id_map)
+ for v in event_id_map.values()
if v in state_event_map
]
- for group, event_id_map in iteritems(group_to_ids)
+ for group, event_id_map in group_to_ids.items()
}
def _get_state_groups_from_groups(
@@ -444,23 +442,23 @@ class StateGroupStorage(object):
"""
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
- groups = set(itervalues(event_to_groups))
+ groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(
groups, state_filter
)
state_event_map = yield self.stores.main.get_events(
- [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
+ [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
get_prev_content=False,
)
event_to_state = {
event_id: {
k: state_event_map[v]
- for k, v in iteritems(group_to_state[group])
+ for k, v in group_to_state[group].items()
if v in state_event_map
}
- for event_id, group in iteritems(event_to_groups)
+ for event_id, group in event_to_groups.items()
}
return {event: event_to_state[event] for event in event_ids}
@@ -481,14 +479,14 @@ class StateGroupStorage(object):
"""
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
- groups = set(itervalues(event_to_groups))
+ groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(
groups, state_filter
)
event_to_state = {
event_id: group_to_state[group]
- for event_id, group in iteritems(event_to_groups)
+ for event_id, group in event_to_groups.items()
}
return {event: event_to_state[event] for event in event_ids}
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index cd48262420..64f35fc288 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -21,8 +21,6 @@ import threading
from typing import Any, Tuple, Union, cast
from weakref import WeakValueDictionary
-from six import itervalues
-
from prometheus_client import Gauge
from typing_extensions import Protocol
@@ -281,7 +279,7 @@ class Cache(object):
def invalidate_all(self):
self.check_thread()
self.cache.clear()
- for entry in itervalues(self._pending_deferred_cache):
+ for entry in self._pending_deferred_cache.values():
entry.invalidate()
self._pending_deferred_cache.clear()
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 2726b67b6d..89a3420f92 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -16,8 +16,6 @@
import logging
from collections import OrderedDict
-from six import iteritems, itervalues
-
from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches import register_cache
@@ -150,7 +148,7 @@ class ExpiringCache(object):
keys_to_delete = set()
- for key, cache_entry in iteritems(self._cache):
+ for key, cache_entry in self._cache.items():
if now - cache_entry.time > self._expiry_ms:
keys_to_delete.add(key)
@@ -170,7 +168,7 @@ class ExpiringCache(object):
def __len__(self):
if self.iterable:
- return sum(len(entry.value) for entry in itervalues(self._cache))
+ return sum(len(entry.value) for entry in self._cache.values())
else:
return len(self._cache)
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 2ea4e4e911..ecd9948e79 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -1,7 +1,5 @@
from typing import Dict
-from six import itervalues
-
SENTINEL = object()
@@ -81,7 +79,7 @@ def iterate_tree_cache_entry(d):
can contain dicts.
"""
if isinstance(d, dict):
- for value_d in itervalues(d):
+ for value_d in d.values():
for value in iterate_tree_cache_entry(value_d):
yield value
else:
diff --git a/synapse/visibility.py b/synapse/visibility.py
index bab41182b9..780927cda1 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -16,7 +16,6 @@
import logging
import operator
-from six import iteritems, itervalues
from six.moves import map
from twisted.internet import defer
@@ -298,7 +297,7 @@ def filter_events_for_server(
# membership states for the requesting server to determine
# if the server is either in the room or has been invited
# into the room.
- for ev in itervalues(state):
+ for ev in state.values():
if ev.type != EventTypes.Member:
continue
try:
@@ -332,7 +331,7 @@ def filter_events_for_server(
)
visibility_ids = set()
- for sids in itervalues(event_to_state_ids):
+ for sids in event_to_state_ids.values():
hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
if hist:
visibility_ids.add(hist)
@@ -345,7 +344,7 @@ def filter_events_for_server(
event_map = yield storage.main.get_events(visibility_ids)
all_open = all(
e.content.get("history_visibility") in (None, "shared", "world_readable")
- for e in itervalues(event_map)
+ for e in event_map.values()
)
if not check_history_visibility_only:
@@ -394,8 +393,8 @@ def filter_events_for_server(
#
event_id_to_state_key = {
event_id: key
- for key_to_eid in itervalues(event_to_state_ids)
- for key, event_id in iteritems(key_to_eid)
+ for key_to_eid in event_to_state_ids.values()
+ for key, event_id in key_to_eid.items()
}
def include(typ, state_key):
@@ -409,20 +408,16 @@ def filter_events_for_server(
return state_key[idx + 1 :] == server_name
event_map = yield storage.main.get_events(
- [
- e_id
- for e_id, key in iteritems(event_id_to_state_key)
- if include(key[0], key[1])
- ]
+ [e_id for e_id, key in event_id_to_state_key.items() if include(key[0], key[1])]
)
event_to_state = {
e_id: {
key: event_map[inner_e_id]
- for key, inner_e_id in iteritems(key_to_eid)
+ for key, inner_e_id in key_to_eid.items()
if inner_e_id in event_map
}
- for e_id, key_to_eid in iteritems(event_to_state_ids)
+ for e_id, key_to_eid in event_to_state_ids.items()
}
to_return = []
diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index be20a89682..641093d349 100644
--- a/tests/app/test_frontend_proxy.py
+++ b/tests/app/test_frontend_proxy.py
@@ -30,6 +30,16 @@ class FrontendProxyTests(HomeserverTestCase):
def default_config(self):
c = super().default_config()
c["worker_app"] = "synapse.app.frontend_proxy"
+
+ c["worker_listeners"] = [
+ {
+ "type": "http",
+ "port": 8080,
+ "bind_addresses": ["0.0.0.0"],
+ "resources": [{"names": ["client"]}],
+ }
+ ]
+
return c
def test_listen_http_with_presence_enabled(self):
@@ -39,14 +49,8 @@ class FrontendProxyTests(HomeserverTestCase):
# Presence is on
self.hs.config.use_presence = True
- config = {
- "port": 8080,
- "bind_addresses": ["0.0.0.0"],
- "resources": [{"names": ["client"]}],
- }
-
# Listen with the config
- self.hs._listen_http(config)
+ self.hs._listen_http(self.hs.config.worker.worker_listeners[0])
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
@@ -67,14 +71,8 @@ class FrontendProxyTests(HomeserverTestCase):
# Presence is off
self.hs.config.use_presence = False
- config = {
- "port": 8080,
- "bind_addresses": ["0.0.0.0"],
- "resources": [{"names": ["client"]}],
- }
-
# Listen with the config
- self.hs._listen_http(config)
+ self.hs._listen_http(self.hs.config.worker.worker_listeners[0])
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 7364f9f1ec..0f016c32eb 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -18,6 +18,7 @@ from parameterized import parameterized
from synapse.app.generic_worker import GenericWorkerServer
from synapse.app.homeserver import SynapseHomeServer
+from synapse.config.server import parse_listener_def
from tests.unittest import HomeserverTestCase
@@ -35,6 +36,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
# have to tell the FederationHandler not to try to access stuff that is only
# in the primary store.
conf["worker_app"] = "yes"
+
return conf
@parameterized.expand(
@@ -53,12 +55,13 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
"""
config = {
"port": 8080,
+ "type": "http",
"bind_addresses": ["0.0.0.0"],
"resources": [{"names": names}],
}
# Listen with the config
- self.hs._listen_http(config)
+ self.hs._listen_http(parse_listener_def(config))
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
@@ -101,12 +104,13 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
"""
config = {
"port": 8080,
+ "type": "http",
"bind_addresses": ["0.0.0.0"],
"resources": [{"names": names}],
}
# Listen with the config
- self.hs._listener_http(config, config)
+ self.hs._listener_http(self.hs.get_config(), parse_listener_def(config))
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index ff12539041..1a9bd5f37d 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -21,6 +21,7 @@ from signedjson.types import BaseKey, SigningKey
from twisted.internet import defer
+from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.rest import admin
from synapse.rest.client.v1 import login
from synapse.types import JsonDict, ReadReceipt
@@ -536,7 +537,10 @@ def build_device_dict(user_id: str, device_id: str, sk: SigningKey):
return {
"user_id": user_id,
"device_id": device_id,
- "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
"keys": {
"curve25519:" + device_id: "curve25519+key",
key_id(sk): encode_pubkey(sk),
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index e1e144b2e7..6c1dc72bd1 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -25,6 +25,7 @@ from twisted.internet import defer
import synapse.handlers.e2e_keys
import synapse.storage
from synapse.api import errors
+from synapse.api.constants import RoomEncryptionAlgorithms
from tests import unittest, utils
@@ -222,7 +223,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_key_1 = {
"user_id": local_user,
"device_id": "abc",
- "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
"keys": {
"ed25519:abc": "base64+ed25519+key",
"curve25519:abc": "base64+curve25519+key",
@@ -232,7 +236,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_key_2 = {
"user_id": local_user,
"device_id": "def",
- "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
"keys": {
"ed25519:def": "base64+ed25519+key",
"curve25519:def": "base64+curve25519+key",
@@ -315,7 +322,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_key = {
"user_id": local_user,
"device_id": device_id,
- "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
"keys": {"curve25519:xyz": "curve25519+key", "ed25519:xyz": device_pubkey},
"signatures": {local_user: {"ed25519:xyz": "something"}},
}
@@ -392,7 +402,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"device_id": device_id,
"algorithms": [
"m.olm.curve25519-aes-sha2",
- "m.megolm.v1.aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
],
"keys": {
"curve25519:xyz": "curve25519+key",
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 70f172eb02..822ea42dde 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -96,6 +96,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
version_etag = res["etag"]
+ self.assertIsInstance(version_etag, str)
del res["etag"]
self.assertDictEqual(
res,
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 8e6b0b7536..07092f026a 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -24,8 +24,10 @@ EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM = 5
# The expected number of state events in a fresh private room.
#
-# Note: we increase this by 1 on the dinsic branch as we send
-# a "im.vector.room.access_rules" state event into new private rooms
+# Note: we increase this by 2 on the dinsic branch as we send
+# a "im.vector.room.access_rules" state event into new private rooms,
+# and an encryption state event as all private rooms are encrypted
+# by default
EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 7
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 0c5cdbd33a..ddee8d9e3a 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -17,12 +17,13 @@ from mock import Mock
from twisted.internet import defer
import synapse.rest.admin
-from synapse.api.constants import UserTypes
+from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, account_validity, user_directory
from synapse.storage.roommember import ProfileInfo
from tests import unittest
+from tests.unittest import override_config
class UserDirectoryTestCase(unittest.HomeserverTestCase):
@@ -147,6 +148,94 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user3", 10))
self.assertEqual(len(s["results"]), 0)
+ @override_config({"encryption_enabled_by_default_for_room_type": "all"})
+ def test_encrypted_by_default_config_option_all(self):
+ """Tests that invite-only and non-invite-only rooms have encryption enabled by
+ default when the config option encryption_enabled_by_default_for_room_type is "all".
+ """
+ # Create a user
+ user = self.register_user("user", "pass")
+ user_token = self.login(user, "pass")
+
+ # Create an invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
+
+ # Check that the room has an encryption state event
+ event_content = self.helper.get_state(
+ room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+ )
+ self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
+
+ # Create a non invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+ # Check that the room has an encryption state event
+ event_content = self.helper.get_state(
+ room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+ )
+ self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
+
+ @override_config({"encryption_enabled_by_default_for_room_type": "invite"})
+ def test_encrypted_by_default_config_option_invite(self):
+ """Tests that only new, invite-only rooms have encryption enabled by default when
+ the config option encryption_enabled_by_default_for_room_type is "invite".
+ """
+ # Create a user
+ user = self.register_user("user", "pass")
+ user_token = self.login(user, "pass")
+
+ # Create an invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
+
+ # Check that the room has an encryption state event
+ event_content = self.helper.get_state(
+ room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+ )
+ self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
+
+ # Create a non invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+ # Check that the room does not have an encryption state event
+ self.helper.get_state(
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
+ expect_code=404,
+ )
+
+ @override_config({"encryption_enabled_by_default_for_room_type": "off"})
+ def test_encrypted_by_default_config_option_off(self):
+ """Tests that neither new invite-only nor non-invite-only rooms have encryption
+ enabled by default when the config option
+ encryption_enabled_by_default_for_room_type is "off".
+ """
+ # Create a user
+ user = self.register_user("user", "pass")
+ user_token = self.login(user, "pass")
+
+ # Create an invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
+
+ # Check that the room does not have an encryption state event
+ self.helper.get_state(
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
+ expect_code=404,
+ )
+
+ # Create a non invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+ # Check that the room does not have an encryption state event
+ self.helper.get_state(
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
+ expect_code=404,
+ )
+
def test_spam_checker(self):
"""
A user which fails to the spam checks will not appear in search results.
diff --git a/tests/module_api/__init__.py b/tests/module_api/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/module_api/__init__.py
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
new file mode 100644
index 0000000000..807cd65dd6
--- /dev/null
+++ b/tests/module_api/test_api.py
@@ -0,0 +1,54 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.module_api import ModuleApi
+
+from tests.unittest import HomeserverTestCase
+
+
+class ModuleApiTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, homeserver):
+ self.store = homeserver.get_datastore()
+ self.module_api = ModuleApi(homeserver, homeserver.get_auth_handler())
+
+ def test_can_register_user(self):
+ """Tests that an external module can register a user"""
+ # Register a new user
+ user_id, access_token = self.get_success(
+ self.module_api.register(
+ "bob", displayname="Bobberino", emails=["bob@bobinator.bob"]
+ )
+ )
+
+ # Check that the new user exists with all provided attributes
+ self.assertEqual(user_id, "@bob:test")
+ self.assertTrue(access_token)
+ self.assertTrue(self.store.get_user_by_id(user_id))
+
+ # Check that the email was assigned
+ emails = self.get_success(self.store.user_get_threepids(user_id))
+ self.assertEqual(len(emails), 1)
+
+ email = emails[0]
+ self.assertEqual(email["medium"], "email")
+ self.assertEqual(email["address"], "bob@bobinator.bob")
+
+ # Should these be 0?
+ self.assertEqual(email["validated_at"], 0)
+ self.assertEqual(email["added_at"], 0)
+
+ # Check that the displayname was assigned
+ displayname = self.get_success(self.store.get_profile_displayname("bob"))
+ self.assertEqual(displayname, "Bobberino")
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 9ae6a87d7b..af35d23aea 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -21,7 +21,7 @@ from tests import unittest
class PushRuleEvaluatorTestCase(unittest.TestCase):
- def setUp(self):
+ def _get_evaluator(self, content):
event = FrozenEvent(
{
"event_id": "$event_id",
@@ -29,37 +29,58 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"sender": "@user:test",
"state_key": "",
"room_id": "@room:test",
- "content": {"body": "foo bar baz"},
+ "content": content,
},
RoomVersions.V1,
)
room_member_count = 0
sender_power_level = 0
power_levels = {}
- self.evaluator = PushRuleEvaluatorForEvent(
+ return PushRuleEvaluatorForEvent(
event, room_member_count, sender_power_level, power_levels
)
def test_display_name(self):
"""Check for a matching display name in the body of the event."""
+ evaluator = self._get_evaluator({"body": "foo bar baz"})
+
condition = {
"kind": "contains_display_name",
}
# Blank names are skipped.
- self.assertFalse(self.evaluator.matches(condition, "@user:test", ""))
+ self.assertFalse(evaluator.matches(condition, "@user:test", ""))
# Check a display name that doesn't match.
- self.assertFalse(self.evaluator.matches(condition, "@user:test", "not found"))
+ self.assertFalse(evaluator.matches(condition, "@user:test", "not found"))
# Check a display name which matches.
- self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo"))
+ self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
# A display name that matches, but not a full word does not result in a match.
- self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba"))
+ self.assertFalse(evaluator.matches(condition, "@user:test", "ba"))
# A display name should not be interpreted as a regular expression.
- self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba[rz]"))
+ self.assertFalse(evaluator.matches(condition, "@user:test", "ba[rz]"))
# A display name with spaces should work fine.
- self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo bar"))
+ self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
+
+ def test_no_body(self):
+ """Not having a body shouldn't break the evaluator."""
+ evaluator = self._get_evaluator({})
+
+ condition = {
+ "kind": "contains_display_name",
+ }
+ self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
+
+ def test_invalid_body(self):
+ """A non-string body should not break the evaluator."""
+ condition = {
+ "kind": "contains_display_name",
+ }
+
+ for body in (1, True, {"foo": "bar"}):
+ evaluator = self._get_evaluator({"body": body})
+ self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 51bf0ef4e9..097e1653b4 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -17,6 +17,7 @@ from typing import List, Optional
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
+from synapse.replication.tcp.commands import RdataCommand
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
from synapse.replication.tcp.streams.events import (
EventsStreamCurrentStateRow,
@@ -66,11 +67,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
# also one state event
state_event = self._inject_state_event()
- # tell the notifier to catch up to avoid duplicate rows.
- # workaround for https://github.com/matrix-org/synapse/issues/7360
- # FIXME remove this when the above is fixed
- self.replicate()
-
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -174,11 +170,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
# one more bit of state that doesn't get rolled back
state2 = self._inject_state_event()
- # tell the notifier to catch up to avoid duplicate rows.
- # workaround for https://github.com/matrix-org/synapse/issues/7360
- # FIXME remove this when the above is fixed
- self.replicate()
-
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -327,11 +318,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
prev_events = [e.event_id]
pl_events.append(e)
- # tell the notifier to catch up to avoid duplicate rows.
- # workaround for https://github.com/matrix-org/synapse/issues/7360
- # FIXME remove this when the above is fixed
- self.replicate()
-
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -378,6 +364,64 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows)
+ def test_backwards_stream_id(self):
+ """
+ Test that RDATA that comes after the current position should be discarded.
+ """
+ # disconnect, so that we can stack up some changes
+ self.disconnect()
+
+ # Generate an events. We inject them using inject_event so that they are
+ # not send out over replication until we call self.replicate().
+ event = self._inject_test_event()
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # We should have received the expected single row (as well as various
+ # cache invalidation updates which we ignore).
+ received_rows = [
+ row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+ ]
+
+ # There should be a single received row.
+ self.assertEqual(len(received_rows), 1)
+
+ stream_name, token, row = received_rows[0]
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "ev")
+ self.assertIsInstance(row.data, EventsStreamEventRow)
+ self.assertEqual(row.data.event_id, event.event_id)
+
+ # Reset the data.
+ self.test_handler.received_rdata_rows = []
+
+ # Save the current token for later.
+ worker_events_stream = self.worker_hs.get_replication_streams()["events"]
+ prev_token = worker_events_stream.current_token("master")
+
+ # Manually send an old RDATA command, which should get dropped. This
+ # re-uses the row from above, but with an earlier stream token.
+ self.hs.get_tcp_replication().send_command(
+ RdataCommand("events", "master", 1, row)
+ )
+
+ # No updates have been received (because it was discard as old).
+ received_rows = [
+ row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+ ]
+ self.assertEqual(len(received_rows), 0)
+
+ # Ensure the stream has not gone backwards.
+ current_token = worker_events_stream.current_token("master")
+ self.assertGreaterEqual(current_token, prev_token)
+
event_count = 0
def _inject_test_event(
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index fd62b26356..5acfb3e53e 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -16,10 +16,15 @@ from mock import Mock
from synapse.handlers.typing import RoomMember
from synapse.replication.tcp.streams import TypingStream
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from tests.replication._base import BaseStreamTestCase
USER_ID = "@feeling:blue"
+USER_ID_2 = "@da-ba-dee:blue"
+
+ROOM_ID = "!bar:blue"
+ROOM_ID_2 = "!foo:blue"
class TypingStreamTestCase(BaseStreamTestCase):
@@ -29,11 +34,9 @@ class TypingStreamTestCase(BaseStreamTestCase):
def test_typing(self):
typing = self.hs.get_typing_handler()
- room_id = "!bar:blue"
-
self.reconnect()
- typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
+ typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
self.reactor.advance(0)
@@ -46,7 +49,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
- self.assertEqual(room_id, row.room_id)
+ self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([USER_ID], row.user_ids)
# Now let's disconnect and insert some data.
@@ -54,7 +57,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.test_handler.on_rdata.reset_mock()
- typing._push_update(member=RoomMember(room_id, USER_ID), typing=False)
+ typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
self.test_handler.on_rdata.assert_not_called()
@@ -73,5 +76,78 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0]
- self.assertEqual(room_id, row.room_id)
+ self.assertEqual(ROOM_ID, row.room_id)
+ self.assertEqual([], row.user_ids)
+
+ def test_reset(self):
+ """
+ Test what happens when a typing stream resets.
+
+ This is emulated by jumping the stream ahead, then reconnecting (which
+ sends the proper position and RDATA).
+ """
+ typing = self.hs.get_typing_handler()
+
+ self.reconnect()
+
+ typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
+
+ self.reactor.advance(0)
+
+ # We should now see an attempt to connect to the master
+ request = self.handle_http_replication_attempt()
+ self.assert_request_is_get_repl_stream_updates(request, "typing")
+
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "typing")
+ self.assertEqual(1, len(rdata_rows))
+ row = rdata_rows[0] # type: TypingStream.TypingStreamRow
+ self.assertEqual(ROOM_ID, row.room_id)
+ self.assertEqual([USER_ID], row.user_ids)
+
+ # Push the stream forward a bunch so it can be reset.
+ for i in range(100):
+ typing._push_update(
+ member=RoomMember(ROOM_ID, "@test%s:blue" % i), typing=True
+ )
+ self.reactor.advance(0)
+
+ # Disconnect.
+ self.disconnect()
+
+ # Reset the typing handler
+ self.hs.get_replication_streams()["typing"].last_token = 0
+ self.hs.get_tcp_replication()._streams["typing"].last_token = 0
+ typing._latest_room_serial = 0
+ typing._typing_stream_change_cache = StreamChangeCache(
+ "TypingStreamChangeCache", typing._latest_room_serial
+ )
+ typing._reset()
+
+ # Reconnect.
+ self.reconnect()
+ self.pump(0.1)
+
+ # We should now see an attempt to connect to the master
+ request = self.handle_http_replication_attempt()
+ self.assert_request_is_get_repl_stream_updates(request, "typing")
+
+ # Reset the test code.
+ self.test_handler.on_rdata.reset_mock()
+ self.test_handler.on_rdata.assert_not_called()
+
+ # Push additional data.
+ typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
+ self.reactor.advance(0)
+
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "typing")
+ self.assertEqual(1, len(rdata_rows))
+ row = rdata_rows[0]
+ self.assertEqual(ROOM_ID_2, row.room_id)
self.assertEqual([], row.user_ids)
+
+ # The token should have been reset.
+ self.assertEqual(token, 1)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index c662195eec..89dcc58b99 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -30,7 +30,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
room_creator = self.homeserver.get_room_creation_handler()
room_deferred = ensureDeferred(
room_creator.create_room(
- our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
+ our_user, room_creator._presets_dict["public_chat"], ratelimit=False
)
)
self.reactor.advance(0.1)
diff --git a/tests/test_server.py b/tests/test_server.py
index e9a43b1e45..adae3c6e08 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -24,6 +24,7 @@ from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.config.server import parse_listener_def
from synapse.http.server import (
DirectServeResource,
JsonResource,
@@ -189,7 +190,13 @@ class OptionsResourceTests(unittest.TestCase):
request.prepath = [] # This doesn't get set properly by make_request.
# Create a site and query for the resource.
- site = SynapseSite("test", "site_tag", {}, self.resource, "1.0")
+ site = SynapseSite(
+ "test",
+ "site_tag",
+ parse_listener_def({"type": "http", "port": 0}),
+ self.resource,
+ "1.0",
+ )
request.site = site
resource = site.getResourceFor(request)
@@ -348,7 +355,9 @@ class SiteTestCase(unittest.HomeserverTestCase):
# time out the request while it's 'processing'
base_resource = Resource()
base_resource.putChild(b"", HangingResource())
- site = SynapseSite("test", "site_tag", {}, base_resource, "1.0")
+ site = SynapseSite(
+ "test", "site_tag", self.hs.config.listeners[0], base_resource, "1.0"
+ )
server = site.buildProtocol(None)
client = AccumulatingProtocol()
diff --git a/tests/unittest.py b/tests/unittest.py
index 6b6f224e9c..3175a3fa02 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -229,7 +229,7 @@ class HomeserverTestCase(TestCase):
self.site = SynapseSite(
logger_name="synapse.access.http.fake",
site_tag="test",
- config={},
+ config=self.hs.config.server.listeners[0],
resource=self.resource,
server_version_string="1",
)
diff --git a/tests/utils.py b/tests/utils.py
index 59c020a051..f5b118df1c 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -168,6 +168,9 @@ def default_config(name, parse=False):
# background, which upsets the test runner.
"update_user_directory": False,
"caches": {"global_factor": 1},
+ "listeners": [{"port": 0, "type": "http"}],
+ # Enable encryption by default in private rooms
+ "encryption_enabled_by_default_for_room_type": "invite",
}
if parse:
|