diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index d528450c78..55cce2db22 100644
--- a/synapse/_scripts/register_new_matrix_user.py
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -23,8 +23,6 @@ import hmac
import logging
import sys
-from six.moves import input
-
import requests as _requests
import yaml
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 06ade25674..e615533ea3 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -16,8 +16,6 @@
import logging
from typing import Optional
-from six import itervalues
-
import pymacaroons
from netaddr import IPAddress
@@ -90,7 +88,7 @@ class Auth(object):
event, prev_state_ids, for_verification=True
)
auth_events = yield self.store.get_events(auth_events_ids)
- auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)}
+ auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
event_auth.check(
@@ -196,6 +194,7 @@ class Auth(object):
access_token = self.get_access_token_from_request(request)
user_id, app_service = yield self._get_appservice_user_id(request)
+
if user_id:
request.authenticated_entity = user_id
opentracing.set_tag("authenticated_entity", user_id)
@@ -262,11 +261,11 @@ class Auth(object):
except KeyError:
raise MissingClientTokenError()
- @defer.inlineCallbacks
def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request)
)
+
if app_service is None:
return None, None
@@ -284,8 +283,12 @@ class Auth(object):
if not app_service.is_interested_in_user(user_id):
raise AuthError(403, "Application service cannot masquerade as this user.")
- if not (yield self.store.get_user_by_id(user_id)):
- raise AuthError(403, "Application service has not registered this user")
+ # Let ASes manipulate nonexistent users (e.g. to shadow-register them)
+ # if not (yield self.store.get_user_by_id(user_id)):
+ # raise AuthError(
+ # 403,
+ # "Application service has not registered this user"
+ # )
return user_id, app_service
@defer.inlineCallbacks
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 5ec4a77ccd..6a6d32c302 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -150,3 +150,8 @@ class EventContentFields(object):
# Timestamp to delete the event after
# cf https://github.com/matrix-org/matrix-doc/pull/2228
SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after"
+
+
+class RoomEncryptionAlgorithms(object):
+ MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2"
+ DEFAULT = MEGOLM_V1_AES_SHA2
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index d54dfb385d..20aff11c81 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,11 +18,9 @@
"""Contains exceptions and error codes."""
import logging
+from http import HTTPStatus
from typing import Dict, List
-from six import iteritems
-from six.moves import http_client
-
from canonicaljson import json
from twisted.web import http
@@ -74,6 +73,13 @@ class Codes(object):
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
USER_DEACTIVATED = "M_USER_DEACTIVATED"
BAD_ALIAS = "M_BAD_ALIAS"
+ PASSWORD_TOO_SHORT = "M_PASSWORD_TOO_SHORT"
+ PASSWORD_NO_DIGIT = "M_PASSWORD_NO_DIGIT"
+ PASSWORD_NO_UPPERCASE = "M_PASSWORD_NO_UPPERCASE"
+ PASSWORD_NO_LOWERCASE = "M_PASSWORD_NO_LOWERCASE"
+ PASSWORD_NO_SYMBOL = "M_PASSWORD_NO_SYMBOL"
+ PASSWORD_IN_DICTIONARY = "M_PASSWORD_IN_DICTIONARY"
+ WEAK_PASSWORD = "M_WEAK_PASSWORD"
class CodeMessageException(RuntimeError):
@@ -174,7 +180,7 @@ class ConsentNotGivenError(SynapseError):
consent_url (str): The URL where the user can give their consent
"""
super(ConsentNotGivenError, self).__init__(
- code=http_client.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN
+ code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN
)
self._consent_uri = consent_uri
@@ -194,7 +200,7 @@ class UserDeactivatedError(SynapseError):
msg (str): The human-readable error message
"""
super(UserDeactivatedError, self).__init__(
- code=http_client.FORBIDDEN, msg=msg, errcode=Codes.USER_DEACTIVATED
+ code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.USER_DEACTIVATED
)
@@ -497,7 +503,7 @@ def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
A dict representing the error response JSON.
"""
err = {"error": msg, "errcode": code}
- for key, value in iteritems(kwargs):
+ for key, value in kwargs.items():
err[key] = value
return err
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 8b64d0a285..f988f62a1e 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -17,8 +17,6 @@
# limitations under the License.
from typing import List
-from six import text_type
-
import jsonschema
from canonicaljson import json
from jsonschema import FormatChecker
@@ -313,7 +311,7 @@ class Filter(object):
content = event.get("content", {})
# check if there is a string url field in the content for filtering purposes
- contains_url = isinstance(content.get("url"), text_type)
+ contains_url = isinstance(content.get("url"), str)
labels = content.get(EventContentFields.LABELS, [])
return self.check_fields(room_id, sender, ev_type, labels, contains_url)
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index f34434bd67..bd03ebca5a 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -17,8 +17,7 @@
"""Contains the URL paths to prefix various aspects of the server with. """
import hmac
from hashlib import sha256
-
-from six.moves.urllib.parse import urlencode
+from urllib.parse import urlencode
from synapse.config import ConfigError
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index dedff81af3..373a80a4a7 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -20,6 +20,7 @@ import signal
import socket
import sys
import traceback
+from typing import Iterable
from daemonize import Daemonize
from typing_extensions import NoReturn
@@ -29,6 +30,7 @@ from twisted.protocols.tls import TLSMemoryBIOFactory
import synapse
from synapse.app import check_bind_error
+from synapse.config.server import ListenerConfig
from synapse.crypto import context_factory
from synapse.logging.context import PreserveLoggingContext
from synapse.util.async_helpers import Linearizer
@@ -234,7 +236,7 @@ def refresh_certificate(hs):
logger.info("Context factories updated.")
-def start(hs, listeners=None):
+def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
"""
Start a Synapse server or worker.
@@ -245,8 +247,8 @@ def start(hs, listeners=None):
notify systemd.
Args:
- hs (synapse.server.HomeServer)
- listeners (list[dict]): Listener configuration ('listeners' in homeserver.yaml)
+ hs: homeserver instance
+ listeners: Listener configuration ('listeners' in homeserver.yaml)
"""
try:
# Set up the SIGHUP machinery.
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index f3ec2a34ec..27a3fc9ed6 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -37,6 +37,7 @@ from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
+from synapse.config.server import ListenerConfig
from synapse.federation import send_queue
from synapse.federation.transport.server import TransportLayerServer
from synapse.handlers.presence import (
@@ -514,13 +515,18 @@ class GenericWorkerSlavedStore(
class GenericWorkerServer(HomeServer):
DATASTORE_CLASS = GenericWorkerSlavedStore
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
+ def _listen_http(self, listener_config: ListenerConfig):
+ port = listener_config.port
+ bind_addresses = listener_config.bind_addresses
+
+ assert listener_config.http_options is not None
+
+ site_tag = listener_config.http_options.tag
+ if site_tag is None:
+ site_tag = port
resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
+ for res in listener_config.http_options.resources:
+ for name in res.names:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
elif name == "client":
@@ -590,7 +596,7 @@ class GenericWorkerServer(HomeServer):
" repository is disabled. Ignoring."
)
- if name == "openid" and "federation" not in res["names"]:
+ if name == "openid" and "federation" not in res.names:
# Only load the openid resource separately if federation resource
# is not specified since federation resource includes openid
# resource.
@@ -625,19 +631,19 @@ class GenericWorkerServer(HomeServer):
logger.info("Synapse worker now listening on port %d", port)
- def start_listening(self, listeners):
+ def start_listening(self, listeners: Iterable[ListenerConfig]):
for listener in listeners:
- if listener["type"] == "http":
+ if listener.type == "http":
self._listen_http(listener)
- elif listener["type"] == "manhole":
+ elif listener.type == "manhole":
_base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
+ listener.bind_addresses,
+ listener.port,
manhole(
username="matrix", password="rabbithole", globals={"hs": self}
),
)
- elif listener["type"] == "metrics":
+ elif listener.type == "metrics":
if not self.get_config().enable_metrics:
logger.warning(
(
@@ -646,9 +652,9 @@ class GenericWorkerServer(HomeServer):
)
)
else:
- _base.listen_metrics(listener["bind_addresses"], listener["port"])
+ _base.listen_metrics(listener.bind_addresses, listener.port)
else:
- logger.warning("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unsupported listener type: %s", listener.type)
self.get_tcp_replication().start_replication(self)
@@ -738,6 +744,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
except Exception:
logger.exception("Error processing replication")
+ async def on_position(self, stream_name: str, instance_name: str, token: int):
+ await super().on_position(stream_name, instance_name, token)
+ # Also call on_rdata to ensure that stream positions are properly reset.
+ await self.on_rdata(stream_name, instance_name, token, [])
+
def stop_pusher(self, user_id, app_id, pushkey):
if not self.notify_pushers:
return
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 41994dc14b..09291d86ad 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -23,8 +23,7 @@ import math
import os
import resource
import sys
-
-from six import iteritems
+from typing import Iterable
from prometheus_client import Gauge
@@ -50,6 +49,7 @@ from synapse.app import _base
from synapse.app._base import listen_ssl, listen_tcp, quit_with_error
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.server import ListenerConfig
from synapse.federation.transport.server import TransportLayerServer
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import (
@@ -90,24 +90,24 @@ def gz_wrap(r):
class SynapseHomeServer(HomeServer):
DATASTORE_CLASS = DataStore
- def _listener_http(self, config, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- tls = listener_config.get("tls", False)
- site_tag = listener_config.get("tag", port)
+ def _listener_http(self, config: HomeServerConfig, listener_config: ListenerConfig):
+ port = listener_config.port
+ bind_addresses = listener_config.bind_addresses
+ tls = listener_config.tls
+ site_tag = listener_config.http_options.tag
+ if site_tag is None:
+ site_tag = port
resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "openid" and "federation" in res["names"]:
+ for res in listener_config.http_options.resources:
+ for name in res.names:
+ if name == "openid" and "federation" in res.names:
# Skip loading openid resource if federation is defined
# since federation resource will include openid
continue
- resources.update(
- self._configure_named_resource(name, res.get("compress", False))
- )
+ resources.update(self._configure_named_resource(name, res.compress))
- additional_resources = listener_config.get("additional_resources", {})
+ additional_resources = listener_config.http_options.additional_resources
logger.debug("Configuring additional resources: %r", additional_resources)
module_api = ModuleApi(self, self.get_auth_handler())
for path, resmodule in additional_resources.items():
@@ -279,7 +279,7 @@ class SynapseHomeServer(HomeServer):
return resources
- def start_listening(self, listeners):
+ def start_listening(self, listeners: Iterable[ListenerConfig]):
config = self.get_config()
if config.redis_enabled:
@@ -289,25 +289,25 @@ class SynapseHomeServer(HomeServer):
self.get_tcp_replication().start_replication(self)
for listener in listeners:
- if listener["type"] == "http":
+ if listener.type == "http":
self._listening_services.extend(self._listener_http(config, listener))
- elif listener["type"] == "manhole":
+ elif listener.type == "manhole":
listen_tcp(
- listener["bind_addresses"],
- listener["port"],
+ listener.bind_addresses,
+ listener.port,
manhole(
username="matrix", password="rabbithole", globals={"hs": self}
),
)
- elif listener["type"] == "replication":
+ elif listener.type == "replication":
services = listen_tcp(
- listener["bind_addresses"],
- listener["port"],
+ listener.bind_addresses,
+ listener.port,
ReplicationStreamProtocolFactory(self),
)
for s in services:
reactor.addSystemEventTrigger("before", "shutdown", s.stopListening)
- elif listener["type"] == "metrics":
+ elif listener.type == "metrics":
if not self.get_config().enable_metrics:
logger.warning(
(
@@ -316,9 +316,11 @@ class SynapseHomeServer(HomeServer):
)
)
else:
- _base.listen_metrics(listener["bind_addresses"], listener["port"])
+ _base.listen_metrics(listener.bind_addresses, listener.port)
else:
- logger.warning("Unrecognized listener type: %s", listener["type"])
+ # this shouldn't happen, as the listener type should have been checked
+ # during parsing
+ logger.warning("Unrecognized listener type: %s", listener.type)
# Gauges to expose monthly active user control metrics
@@ -526,7 +528,7 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
stats["total_nonbridged_users"] = total_nonbridged_users
daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
- for name, count in iteritems(daily_user_type_results):
+ for name, count in daily_user_type_results.items():
stats["daily_user_type_" + name] = count
room_count = yield hs.get_datastore().get_room_count()
@@ -538,7 +540,7 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
r30_results = yield hs.get_datastore().count_r30_users()
- for name, count in iteritems(r30_results):
+ for name, count in r30_results.items():
stats["r30_users_" + name] = count
daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 1b13e84425..0323256472 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -15,8 +15,6 @@
import logging
import re
-from six import string_types
-
from twisted.internet import defer
from synapse.api.constants import EventTypes
@@ -156,7 +154,7 @@ class ApplicationService(object):
)
regex = regex_obj.get("regex")
- if isinstance(regex, string_types):
+ if isinstance(regex, str):
regex_obj["regex"] = re.compile(regex) # Pre-compile regex
else:
raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 57174da021..da9a5e86d4 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-
-from six.moves import urllib
+import urllib
from prometheus_client import Counter
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 30d1050a91..f2830c609d 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -19,11 +19,10 @@ import argparse
import errno
import os
from collections import OrderedDict
+from io import open as io_open
from textwrap import dedent
from typing import Any, MutableMapping, Optional
-from six import integer_types
-
import yaml
@@ -117,7 +116,7 @@ class Config(object):
@staticmethod
def parse_size(value):
- if isinstance(value, integer_types):
+ if isinstance(value, int):
return value
sizes = {"K": 1024, "M": 1024 * 1024}
size = 1
@@ -129,7 +128,7 @@ class Config(object):
@staticmethod
def parse_duration(value):
- if isinstance(value, integer_types):
+ if isinstance(value, int):
return value
second = 1000
minute = 60 * second
@@ -181,7 +180,7 @@ class Config(object):
@classmethod
def read_file(cls, file_path, config_name):
cls.check_file(file_path, config_name)
- with open(file_path) as file_stream:
+ with io_open(file_path, encoding="utf-8") as file_stream:
return file_stream.read()
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index ca43e96bd1..8ed3e24258 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -14,9 +14,7 @@
import logging
from typing import Dict
-
-from six import string_types
-from six.moves.urllib import parse as urlparse
+from urllib import parse as urlparse
import yaml
from netaddr import IPSet
@@ -98,17 +96,14 @@ def load_appservices(hostname, config_files):
def _load_appservice(hostname, as_info, config_filename):
required_string_fields = ["id", "as_token", "hs_token", "sender_localpart"]
for field in required_string_fields:
- if not isinstance(as_info.get(field), string_types):
+ if not isinstance(as_info.get(field), str):
raise KeyError(
"Required string field: '%s' (%s)" % (field, config_filename)
)
# 'url' must either be a string or explicitly null, not missing
# to avoid accidentally turning off push for ASes.
- if (
- not isinstance(as_info.get("url"), string_types)
- and as_info.get("url", "") is not None
- ):
+ if not isinstance(as_info.get("url"), str) and as_info.get("url", "") is not None:
raise KeyError(
"Required string field or explicit null: 'url' (%s)" % (config_filename,)
)
@@ -138,7 +133,7 @@ def _load_appservice(hostname, as_info, config_filename):
ns,
regex_obj,
)
- if not isinstance(regex_obj.get("regex"), string_types):
+ if not isinstance(regex_obj.get("regex"), str):
raise ValueError("Missing/bad type 'regex' key in %s", regex_obj)
if not isinstance(regex_obj.get("exclusive"), bool):
raise ValueError(
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index 0672538796..aff5b21ab2 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -15,6 +15,7 @@
import os
import re
+import threading
from typing import Callable, Dict
from ._base import Config, ConfigError
@@ -25,6 +26,9 @@ _CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR"
# Map from canonicalised cache name to cache.
_CACHES = {}
+# a lock on the contents of _CACHES
+_CACHES_LOCK = threading.Lock()
+
_DEFAULT_FACTOR_SIZE = 0.5
_DEFAULT_EVENT_CACHE_SIZE = "10K"
@@ -66,7 +70,10 @@ def add_resizable_cache(cache_name: str, cache_resize_callback: Callable):
# Some caches have '*' in them which we strip out.
cache_name = _canonicalise_cache_name(cache_name)
- _CACHES[cache_name] = cache_resize_callback
+ # sometimes caches are initialised from background threads, so we need to make
+ # sure we don't conflict with another thread running a resize operation
+ with _CACHES_LOCK:
+ _CACHES[cache_name] = cache_resize_callback
# Ensure all loaded caches are sized appropriately
#
@@ -87,7 +94,8 @@ class CacheConfig(Config):
os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
)
properties.resize_all_caches_func = None
- _CACHES.clear()
+ with _CACHES_LOCK:
+ _CACHES.clear()
def generate_config_section(self, **kwargs):
return """\
@@ -193,6 +201,8 @@ class CacheConfig(Config):
For each cache, run the mapped callback function with either
a specific cache factor or the default, global one.
"""
- for cache_name, callback in _CACHES.items():
- new_factor = self.cache_factors.get(cache_name, self.global_factor)
- callback(new_factor)
+ # block other threads from modifying _CACHES while we iterate it.
+ with _CACHES_LOCK:
+ for cache_name, callback in _CACHES.items():
+ new_factor = self.cache_factors.get(cache_name, self.global_factor)
+ callback(new_factor)
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 2c7b3a699f..264c274c52 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -36,6 +36,7 @@ from .ratelimiting import RatelimitConfig
from .redis import RedisConfig
from .registration import RegistrationConfig
from .repository import ContentRepositoryConfig
+from .room import RoomConfig
from .room_directory import RoomDirectoryConfig
from .saml2_config import SAML2Config
from .server import ServerConfig
@@ -79,6 +80,7 @@ class HomeServerConfig(RootConfig):
PasswordAuthProviderConfig,
PushConfig,
SpamCheckerConfig,
+ RoomConfig,
GroupsConfig,
UserDirectoryConfig,
ConsentConfig,
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index e24dd637bc..e0939bce84 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -89,7 +89,7 @@ class OIDCConfig(Config):
# use an OpenID Connect Provider for authentication, instead of its internal
# password database.
#
- # See https://github.com/matrix-org/synapse/blob/master/openid.md.
+ # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md.
#
oidc_config:
# Uncomment the following to enable authorization against an OpenID Connect
diff --git a/synapse/config/password.py b/synapse/config/password.py
index 9c0ea8c30a..6b2dae78b0 100644
--- a/synapse/config/password.py
+++ b/synapse/config/password.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2015-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 2dd94bae2b..b1981d4d15 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -76,6 +76,9 @@ class RatelimitConfig(Config):
)
self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
+ self.rc_third_party_invite = RateLimitConfig(
+ config.get("rc_third_party_invite", {})
+ )
rc_login_config = config.get("rc_login", {})
self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
@@ -115,6 +118,8 @@ class RatelimitConfig(Config):
# - one for login that ratelimits login requests based on the account the
# client is attempting to log into, based on the amount of failed login
# attempts for this account.
+ # - one that ratelimits third-party invites requests based on the account
+ # that's making the requests.
# - one for ratelimiting redactions by room admins. If this is not explicitly
# set then it uses the same ratelimiting as per rc_message. This is useful
# to allow room admins to deal with abuse quickly.
@@ -140,6 +145,10 @@ class RatelimitConfig(Config):
# per_second: 0.17
# burst_count: 3
#
+ #rc_third_party_invite:
+ # per_second: 0.2
+ # burst_count: 10
+ #
#rc_admin_redaction:
# per_second: 1
# burst_count: 50
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index fecced2d57..080b9bc445 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -18,8 +18,9 @@ from distutils.util import strtobool
import pkg_resources
+from synapse.api.constants import RoomCreationPreset
from synapse.config._base import Config, ConfigError
-from synapse.types import RoomAlias
+from synapse.types import RoomAlias, UserID
from synapse.util.stringutils import random_string_with_symbols
@@ -99,8 +100,19 @@ class RegistrationConfig(Config):
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
+ self.check_is_for_allowed_local_3pids = config.get(
+ "check_is_for_allowed_local_3pids", None
+ )
+ self.allow_invited_3pids = config.get("allow_invited_3pids", False)
+
+ self.disable_3pid_changes = config.get("disable_3pid_changes", False)
+
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
self.registration_shared_secret = config.get("registration_shared_secret")
+ self.register_mxid_from_3pid = config.get("register_mxid_from_3pid")
+ self.register_just_use_email_for_display_name = config.get(
+ "register_just_use_email_for_display_name", False
+ )
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config.get(
@@ -108,7 +120,21 @@ class RegistrationConfig(Config):
)
account_threepid_delegates = config.get("account_threepid_delegates") or {}
self.account_threepid_delegate_email = account_threepid_delegates.get("email")
+ if (
+ self.account_threepid_delegate_email
+ and not self.account_threepid_delegate_email.startswith("http")
+ ):
+ raise ConfigError(
+ "account_threepid_delegates.email must begin with http:// or https://"
+ )
self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
+ if (
+ self.account_threepid_delegate_msisdn
+ and not self.account_threepid_delegate_msisdn.startswith("http")
+ ):
+ raise ConfigError(
+ "account_threepid_delegates.msisdn must begin with http:// or https://"
+ )
if self.account_threepid_delegate_msisdn and not self.public_baseurl:
raise ConfigError(
"The configuration option `public_baseurl` is required if "
@@ -127,13 +153,65 @@ class RegistrationConfig(Config):
for room_alias in self.auto_join_rooms:
if not RoomAlias.is_valid(room_alias):
raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,))
+
+ # Options for creating auto-join rooms if they do not exist yet.
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
+ self.autocreate_auto_join_rooms_federated = config.get(
+ "autocreate_auto_join_rooms_federated", True
+ )
+ self.autocreate_auto_join_room_preset = (
+ config.get("autocreate_auto_join_room_preset")
+ or RoomCreationPreset.PUBLIC_CHAT
+ )
+ self.auto_join_room_requires_invite = self.autocreate_auto_join_room_preset in {
+ RoomCreationPreset.PRIVATE_CHAT,
+ RoomCreationPreset.TRUSTED_PRIVATE_CHAT,
+ }
+
+ # Pull the creater/inviter from the configuration, this gets used to
+ # send invites for invite-only rooms.
+ mxid_localpart = config.get("auto_join_mxid_localpart")
+ self.auto_join_user_id = None
+ if mxid_localpart:
+ # Convert the localpart to a full mxid.
+ self.auto_join_user_id = UserID(
+ mxid_localpart, self.server_name
+ ).to_string()
+
+ if self.autocreate_auto_join_rooms:
+ # Ensure the preset is a known value.
+ if self.autocreate_auto_join_room_preset not in {
+ RoomCreationPreset.PUBLIC_CHAT,
+ RoomCreationPreset.PRIVATE_CHAT,
+ RoomCreationPreset.TRUSTED_PRIVATE_CHAT,
+ }:
+ raise ConfigError("Invalid value for autocreate_auto_join_room_preset")
+ # If the preset requires invitations to be sent, ensure there's a
+ # configured user to send them from.
+ if self.auto_join_room_requires_invite:
+ if not mxid_localpart:
+ raise ConfigError(
+ "The configuration option `auto_join_mxid_localpart` is required if "
+ "`autocreate_auto_join_room_preset` is set to private_chat or trusted_private_chat, such that "
+ "Synapse knows who to send invitations from. Please "
+ "configure `auto_join_mxid_localpart`."
+ )
+
self.auto_join_rooms_for_guests = config.get("auto_join_rooms_for_guests", True)
self.enable_set_displayname = config.get("enable_set_displayname", True)
self.enable_set_avatar_url = config.get("enable_set_avatar_url", True)
self.enable_3pid_changes = config.get("enable_3pid_changes", True)
+ self.replicate_user_profiles_to = config.get("replicate_user_profiles_to", [])
+ if not isinstance(self.replicate_user_profiles_to, list):
+ self.replicate_user_profiles_to = [self.replicate_user_profiles_to]
+
+ self.shadow_server = config.get("shadow_server", None)
+ self.rewrite_identity_server_urls = (
+ config.get("rewrite_identity_server_urls") or {}
+ )
+
self.disable_msisdn_registration = config.get(
"disable_msisdn_registration", False
)
@@ -143,6 +221,23 @@ class RegistrationConfig(Config):
session_lifetime = self.parse_duration(session_lifetime)
self.session_lifetime = session_lifetime
+ self.bind_new_user_emails_to_sydent = config.get(
+ "bind_new_user_emails_to_sydent"
+ )
+
+ if self.bind_new_user_emails_to_sydent:
+ if not isinstance(
+ self.bind_new_user_emails_to_sydent, str
+ ) or not self.bind_new_user_emails_to_sydent.startswith("http"):
+ raise ConfigError(
+ "Option bind_new_user_emails_to_sydent has invalid value"
+ )
+
+ # Remove trailing slashes
+ self.bind_new_user_emails_to_sydent = self.bind_new_user_emails_to_sydent.strip(
+ "/"
+ )
+
def generate_config_section(self, generate_secrets=False, **kwargs):
if generate_secrets:
registration_shared_secret = 'registration_shared_secret: "%s"' % (
@@ -247,9 +342,32 @@ class RegistrationConfig(Config):
#
#disable_msisdn_registration: true
+ # Derive the user's matrix ID from a type of 3PID used when registering.
+ # This overrides any matrix ID the user proposes when calling /register
+ # The 3PID type should be present in registrations_require_3pid to avoid
+ # users failing to register if they don't specify the right kind of 3pid.
+ #
+ #register_mxid_from_3pid: email
+
+ # Uncomment to set the display name of new users to their email address,
+ # rather than using the default heuristic.
+ #
+ #register_just_use_email_for_display_name: true
+
# Mandate that users are only allowed to associate certain formats of
# 3PIDs with accounts on this server.
#
+ # Use an Identity Server to establish which 3PIDs are allowed to register?
+ # Overrides allowed_local_3pids below.
+ #
+ #check_is_for_allowed_local_3pids: matrix.org
+ #
+ # If you are using an IS you can also check whether that IS registers
+ # pending invites for the given 3PID (and then allow it to sign up on
+ # the platform):
+ #
+ #allow_invited_3pids: false
+ #
#allowed_local_3pids:
# - medium: email
# pattern: '.*@matrix\\.org'
@@ -258,6 +376,11 @@ class RegistrationConfig(Config):
# - medium: msisdn
# pattern: '\\+44'
+ # If true, stop users from trying to change the 3PIDs associated with
+ # their accounts.
+ #
+ #disable_3pid_changes: false
+
# Enable 3PIDs lookup requests to identity servers from this server.
#
#enable_3pid_lookup: true
@@ -307,6 +430,30 @@ class RegistrationConfig(Config):
# - matrix.org
# - vector.im
+ # If enabled, user IDs, display names and avatar URLs will be replicated
+ # to this server whenever they change.
+ # This is an experimental API currently implemented by sydent to support
+ # cross-homeserver user directories.
+ #
+ #replicate_user_profiles_to: example.com
+
+ # If specified, attempt to replay registrations, profile changes & 3pid
+ # bindings on the given target homeserver via the AS API. The HS is authed
+ # via a given AS token.
+ #
+ #shadow_server:
+ # hs_url: https://shadow.example.com
+ # hs: shadow.example.com
+ # as_token: 12u394refgbdhivsia
+
+ # If enabled, don't let users set their own display names/avatars
+ # other than for the very first time (unless they are a server admin).
+ # Useful when provisioning users based on the contents of a 3rd party
+ # directory and to avoid ambiguities.
+ #
+ #disable_set_displayname: false
+ #disable_set_avatar_url: false
+
# Handle threepid (email/phone etc) registration and password resets through a set of
# *trusted* identity servers. Note that this allows the configured identity server to
# reset passwords for accounts!
@@ -357,7 +504,11 @@ class RegistrationConfig(Config):
#enable_3pid_changes: false
# Users who register on this homeserver will automatically be joined
- # to these rooms
+ # to these rooms.
+ #
+ # By default, any room aliases included in this list will be created
+ # as a publicly joinable room when the first user registers for the
+ # homeserver. This behaviour can be customised with the settings below.
#
#auto_join_rooms:
# - "#example:example.com"
@@ -365,10 +516,62 @@ class RegistrationConfig(Config):
# Where auto_join_rooms are specified, setting this flag ensures that the
# the rooms exist by creating them when the first user on the
# homeserver registers.
+ #
+ # By default the auto-created rooms are publicly joinable from any federated
+ # server. Use the autocreate_auto_join_rooms_federated and
+ # autocreate_auto_join_room_preset settings below to customise this behaviour.
+ #
# Setting to false means that if the rooms are not manually created,
# users cannot be auto-joined since they do not exist.
#
- #autocreate_auto_join_rooms: true
+ # Defaults to true. Uncomment the following line to disable automatically
+ # creating auto-join rooms.
+ #
+ #autocreate_auto_join_rooms: false
+
+ # Whether the auto_join_rooms that are auto-created are available via
+ # federation. Only has an effect if autocreate_auto_join_rooms is true.
+ #
+ # Note that whether a room is federated cannot be modified after
+ # creation.
+ #
+ # Defaults to true: the room will be joinable from other servers.
+ # Uncomment the following to prevent users from other homeservers from
+ # joining these rooms.
+ #
+ #autocreate_auto_join_rooms_federated: false
+
+ # The room preset to use when auto-creating one of auto_join_rooms. Only has an
+ # effect if autocreate_auto_join_rooms is true.
+ #
+ # This can be one of "public_chat", "private_chat", or "trusted_private_chat".
+ # If a value of "private_chat" or "trusted_private_chat" is used then
+ # auto_join_mxid_localpart must also be configured.
+ #
+ # Defaults to "public_chat", meaning that the room is joinable by anyone, including
+ # federated servers if autocreate_auto_join_rooms_federated is true (the default).
+ # Uncomment the following to require an invitation to join these rooms.
+ #
+ #autocreate_auto_join_room_preset: private_chat
+
+ # The local part of the user id which is used to create auto_join_rooms if
+ # autocreate_auto_join_rooms is true. If this is not provided then the
+ # initial user account that registers will be used to create the rooms.
+ #
+ # The user id is also used to invite new users to any auto-join rooms which
+ # are set to invite-only.
+ #
+ # It *must* be configured if autocreate_auto_join_room_preset is set to
+ # "private_chat" or "trusted_private_chat".
+ #
+ # Note that this must be specified in order for new users to be correctly
+ # invited to any auto-join rooms which have been set to invite-only (either
+ # at the time of creation or subsequently).
+ #
+ # Note that, if the room already exists, this user must be joined and
+ # have the appropriate permissions to invite new members.
+ #
+ #auto_join_mxid_localpart: system
# When auto_join_rooms is specified, setting this flag to false prevents
# guest accounts from being automatically joined to the rooms.
@@ -376,6 +579,31 @@ class RegistrationConfig(Config):
# Defaults to true.
#
#auto_join_rooms_for_guests: false
+
+ # Rewrite identity server URLs with a map from one URL to another. Applies to URLs
+ # provided by clients (which have https:// prepended) and those specified
+ # in `account_threepid_delegates`. URLs should not feature a trailing slash.
+ #
+ #rewrite_identity_server_urls:
+ # "https://somewhere.example.com": "https://somewhereelse.example.com"
+
+ # When a user registers an account with an email address, it can be useful to
+ # bind that email address to their mxid on an identity server. Typically, this
+ # requires the user to validate their email address with the identity server.
+ # However if Synapse itself is handling email validation on registration, the
+ # user ends up needing to validate their email twice, which leads to poor UX.
+ #
+ # It is possible to force Sydent, one identity server implementation, to bind
+ # threepids using its internal, unauthenticated bind API:
+ # https://github.com/matrix-org/sydent/#internal-bind-and-unbind-api
+ #
+ # Configure the address of a Sydent server here to have Synapse attempt
+ # to automatically bind users' emails following registration. The
+ # internal bind API must be reachable from Synapse, but should NOT be
+ # exposed to any third party, as it allows the creation of bindings
+ # without validation.
+ #
+ #bind_new_user_emails_to_sydent: https://example.com:8091
"""
% locals()
)
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index b751d02d37..54f565ad5b 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -94,10 +94,22 @@ class ContentRepositoryConfig(Config):
else:
self.can_load_media_repo = True
+ # Whether this instance should be the one to run the background jobs to
+ # e.g clean up old URL previews.
+ self.media_instance_running_background_jobs = config.get(
+ "media_instance_running_background_jobs",
+ )
+
self.max_upload_size = self.parse_size(config.get("max_upload_size", "10M"))
self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M"))
self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M"))
+ self.max_avatar_size = config.get("max_avatar_size")
+ if self.max_avatar_size:
+ self.max_avatar_size = self.parse_size(self.max_avatar_size)
+
+ self.allowed_avatar_mimetypes = config.get("allowed_avatar_mimetypes", [])
+
self.media_store_path = self.ensure_directory(
config.get("media_store_path", "media_store")
)
@@ -238,6 +250,30 @@ class ContentRepositoryConfig(Config):
#
#max_upload_size: 10M
+ # The largest allowed size for a user avatar. If not defined, no
+ # restriction will be imposed.
+ #
+ # Note that this only applies when an avatar is changed globally.
+ # Per-room avatar changes are not affected. See allow_per_room_profiles
+ # for disabling that functionality.
+ #
+ # Note that user avatar changes will not work if this is set without
+ # using Synapse's local media repo.
+ #
+ #max_avatar_size: 10M
+
+ # Allow mimetypes for a user avatar. If not defined, no restriction will
+ # be imposed.
+ #
+ # Note that this only applies when an avatar is changed globally.
+ # Per-room avatar changes are not affected. See allow_per_room_profiles
+ # for disabling that functionality.
+ #
+ # Note that user avatar changes will not work if this is set without
+ # using Synapse's local media repo.
+ #
+ #allowed_avatar_mimetypes: ["image/png", "image/jpeg", "image/gif"]
+
# Maximum number of pixels that will be thumbnailed
#
#max_image_pixels: 32M
diff --git a/synapse/config/room.py b/synapse/config/room.py
new file mode 100644
index 0000000000..6aa4de0672
--- /dev/null
+++ b/synapse/config/room.py
@@ -0,0 +1,80 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.api.constants import RoomCreationPreset
+
+from ._base import Config, ConfigError
+
+logger = logging.Logger(__name__)
+
+
+class RoomDefaultEncryptionTypes(object):
+ """Possible values for the encryption_enabled_by_default_for_room_type config option"""
+
+ ALL = "all"
+ INVITE = "invite"
+ OFF = "off"
+
+
+class RoomConfig(Config):
+ section = "room"
+
+ def read_config(self, config, **kwargs):
+ # Whether new, locally-created rooms should have encryption enabled
+ encryption_for_room_type = config.get(
+ "encryption_enabled_by_default_for_room_type",
+ RoomDefaultEncryptionTypes.OFF,
+ )
+ if encryption_for_room_type == RoomDefaultEncryptionTypes.ALL:
+ self.encryption_enabled_by_default_for_room_presets = [
+ RoomCreationPreset.PRIVATE_CHAT,
+ RoomCreationPreset.TRUSTED_PRIVATE_CHAT,
+ RoomCreationPreset.PUBLIC_CHAT,
+ ]
+ elif encryption_for_room_type == RoomDefaultEncryptionTypes.INVITE:
+ self.encryption_enabled_by_default_for_room_presets = [
+ RoomCreationPreset.PRIVATE_CHAT,
+ RoomCreationPreset.TRUSTED_PRIVATE_CHAT,
+ ]
+ elif encryption_for_room_type == RoomDefaultEncryptionTypes.OFF:
+ self.encryption_enabled_by_default_for_room_presets = []
+ else:
+ raise ConfigError(
+ "Invalid value for encryption_enabled_by_default_for_room_type"
+ )
+
+ def generate_config_section(self, **kwargs):
+ return """\
+ ## Rooms ##
+
+ # Controls whether locally-created rooms should be end-to-end encrypted by
+ # default.
+ #
+ # Possible options are "all", "invite", and "off". They are defined as:
+ #
+ # * "all": any locally-created room
+ # * "invite": any room created with the "private_chat" or "trusted_private_chat"
+ # room creation presets
+ # * "off": this option will take no effect
+ #
+ # The default value is "off".
+ #
+ # Note that this option will only affect rooms created after it is set. It
+ # will also not affect rooms created by other servers.
+ #
+ #encryption_enabled_by_default_for_room_type: invite
+ """
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index d0a19751e8..293643b2de 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -160,7 +160,7 @@ class SAML2Config(Config):
# session lifetime: in milliseconds
self.saml2_session_lifetime = self.parse_duration(
- saml2_config.get("saml_session_lifetime", "5m")
+ saml2_config.get("saml_session_lifetime", "15m")
)
template_dir = saml2_config.get("template_dir")
@@ -286,7 +286,7 @@ class SAML2Config(Config):
# The lifetime of a SAML session. This defines how long a user has to
# complete the authentication process, if allow_unsolicited is unset.
- # The default is 5 minutes.
+ # The default is 15 minutes.
#
#saml_session_lifetime: 5m
diff --git a/synapse/config/server.py b/synapse/config/server.py
index f57eefc99c..43ab5d62d6 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -19,7 +19,7 @@ import logging
import os.path
import re
from textwrap import indent
-from typing import Dict, List, Optional
+from typing import Any, Dict, Iterable, List, Optional
import attr
import yaml
@@ -57,6 +57,64 @@ on how to configure the new listener.
--------------------------------------------------------------------------------"""
+KNOWN_LISTENER_TYPES = {
+ "http",
+ "metrics",
+ "manhole",
+ "replication",
+}
+
+KNOWN_RESOURCES = {
+ "client",
+ "consent",
+ "federation",
+ "keys",
+ "media",
+ "metrics",
+ "openid",
+ "replication",
+ "static",
+ "webclient",
+}
+
+
+@attr.s(frozen=True)
+class HttpResourceConfig:
+ names = attr.ib(
+ type=List[str],
+ factory=list,
+ validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)), # type: ignore
+ )
+ compress = attr.ib(
+ type=bool,
+ default=False,
+ validator=attr.validators.optional(attr.validators.instance_of(bool)), # type: ignore[arg-type]
+ )
+
+
+@attr.s(frozen=True)
+class HttpListenerConfig:
+ """Object describing the http-specific parts of the config of a listener"""
+
+ x_forwarded = attr.ib(type=bool, default=False)
+ resources = attr.ib(type=List[HttpResourceConfig], factory=list)
+ additional_resources = attr.ib(type=Dict[str, dict], factory=dict)
+ tag = attr.ib(type=str, default=None)
+
+
+@attr.s(frozen=True)
+class ListenerConfig:
+ """Object describing the configuration of a single listener."""
+
+ port = attr.ib(type=int, validator=attr.validators.instance_of(int))
+ bind_addresses = attr.ib(type=List[str])
+ type = attr.ib(type=str, validator=attr.validators.in_(KNOWN_LISTENER_TYPES))
+ tls = attr.ib(type=bool, default=False)
+
+ # http_options is only populated if type=http
+ http_options = attr.ib(type=Optional[HttpListenerConfig], default=None)
+
+
class ServerConfig(Config):
section = "server"
@@ -253,6 +311,12 @@ class ServerConfig(Config):
# events with profile information that differ from the target's global profile.
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
+ # Whether to show the users on this homeserver in the user directory. Defaults to
+ # True.
+ self.show_users_in_user_directory = config.get(
+ "show_users_in_user_directory", True
+ )
+
retention_config = config.get("retention")
if retention_config is None:
retention_config = {}
@@ -379,38 +443,21 @@ class ServerConfig(Config):
}
]
- self.listeners = [] # type: List[dict]
- for listener in config.get("listeners", []):
- if not isinstance(listener.get("port", None), int):
- raise ConfigError(
- "Listener configuration is lacking a valid 'port' option"
- )
+ self.listeners = [parse_listener_def(x) for x in config.get("listeners", [])]
- if listener.setdefault("tls", False):
- # no_tls is not really supported any more, but let's grandfather it in
- # here.
- if config.get("no_tls", False):
+ # no_tls is not really supported any more, but let's grandfather it in
+ # here.
+ if config.get("no_tls", False):
+ l2 = []
+ for listener in self.listeners:
+ if listener.tls:
logger.info(
- "Ignoring TLS-enabled listener on port %i due to no_tls"
+ "Ignoring TLS-enabled listener on port %i due to no_tls",
+ listener.port,
)
- continue
-
- bind_address = listener.pop("bind_address", None)
- bind_addresses = listener.setdefault("bind_addresses", [])
-
- # if bind_address was specified, add it to the list of addresses
- if bind_address:
- bind_addresses.append(bind_address)
-
- # if we still have an empty list of addresses, use the default list
- if not bind_addresses:
- if listener["type"] == "metrics":
- # the metrics listener doesn't support IPv6
- bind_addresses.append("0.0.0.0")
else:
- bind_addresses.extend(DEFAULT_BIND_ADDRESSES)
-
- self.listeners.append(listener)
+ l2.append(listener)
+ self.listeners = l2
if not self.web_client_location:
_warn_if_webclient_configured(self.listeners)
@@ -446,43 +493,41 @@ class ServerConfig(Config):
bind_host = config.get("bind_host", "")
gzip_responses = config.get("gzip_responses", True)
+ http_options = HttpListenerConfig(
+ resources=[
+ HttpResourceConfig(names=["client"], compress=gzip_responses),
+ HttpResourceConfig(names=["federation"]),
+ ],
+ )
+
self.listeners.append(
- {
- "port": bind_port,
- "bind_addresses": [bind_host],
- "tls": True,
- "type": "http",
- "resources": [
- {"names": ["client"], "compress": gzip_responses},
- {"names": ["federation"], "compress": False},
- ],
- }
+ ListenerConfig(
+ port=bind_port,
+ bind_addresses=[bind_host],
+ tls=True,
+ type="http",
+ http_options=http_options,
+ )
)
unsecure_port = config.get("unsecure_port", bind_port - 400)
if unsecure_port:
self.listeners.append(
- {
- "port": unsecure_port,
- "bind_addresses": [bind_host],
- "tls": False,
- "type": "http",
- "resources": [
- {"names": ["client"], "compress": gzip_responses},
- {"names": ["federation"], "compress": False},
- ],
- }
+ ListenerConfig(
+ port=unsecure_port,
+ bind_addresses=[bind_host],
+ tls=False,
+ type="http",
+ http_options=http_options,
+ )
)
manhole = config.get("manhole")
if manhole:
self.listeners.append(
- {
- "port": manhole,
- "bind_addresses": ["127.0.0.1"],
- "type": "manhole",
- "tls": False,
- }
+ ListenerConfig(
+ port=manhole, bind_addresses=["127.0.0.1"], type="manhole",
+ )
)
metrics_port = config.get("metrics_port")
@@ -490,13 +535,14 @@ class ServerConfig(Config):
logger.warning(METRICS_PORT_WARNING)
self.listeners.append(
- {
- "port": metrics_port,
- "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")],
- "tls": False,
- "type": "http",
- "resources": [{"names": ["metrics"], "compress": False}],
- }
+ ListenerConfig(
+ port=metrics_port,
+ bind_addresses=[config.get("metrics_bind_host", "127.0.0.1")],
+ type="http",
+ http_options=HttpListenerConfig(
+ resources=[HttpResourceConfig(names=["metrics"])]
+ ),
+ )
)
_check_resource_config(self.listeners)
@@ -522,7 +568,7 @@ class ServerConfig(Config):
)
def has_tls_listener(self) -> bool:
- return any(listener["tls"] for listener in self.listeners)
+ return any(listener.tls for listener in self.listeners)
def generate_config_section(
self, server_name, data_dir_path, open_private_ports, listeners, **kwargs
@@ -856,7 +902,7 @@ class ServerConfig(Config):
# number of monthly active users.
#
# 'limit_usage_by_mau' disables/enables monthly active user blocking. When
- # anabled and a limit is reached the server returns a 'ResourceLimitError'
+ # enabled and a limit is reached the server returns a 'ResourceLimitError'
# with error type Codes.RESOURCE_LIMIT_EXCEEDED
#
# 'max_mau_value' is the hard limit of monthly active users above which
@@ -928,6 +974,74 @@ class ServerConfig(Config):
#
#allow_per_room_profiles: false
+ # Whether to show the users on this homeserver in the user directory. Defaults to
+ # 'true'.
+ #
+ #show_users_in_user_directory: false
+
+ # Message retention policy at the server level.
+ #
+ # Room admins and mods can define a retention period for their rooms using the
+ # 'm.room.retention' state event, and server admins can cap this period by setting
+ # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options.
+ #
+ # If this feature is enabled, Synapse will regularly look for and purge events
+ # which are older than the room's maximum retention period. Synapse will also
+ # filter events received over federation so that events that should have been
+ # purged are ignored and not stored again.
+ #
+ retention:
+ # The message retention policies feature is disabled by default. Uncomment the
+ # following line to enable it.
+ #
+ #enabled: true
+
+ # Default retention policy. If set, Synapse will apply it to rooms that lack the
+ # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't
+ # matter much because Synapse doesn't take it into account yet.
+ #
+ #default_policy:
+ # min_lifetime: 1d
+ # max_lifetime: 1y
+
+ # Retention policy limits. If set, a user won't be able to send a
+ # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
+ # that's not within this range. This is especially useful in closed federations,
+ # in which server admins can make sure every federating server applies the same
+ # rules.
+ #
+ #allowed_lifetime_min: 1d
+ #allowed_lifetime_max: 1y
+
+ # Server admins can define the settings of the background jobs purging the
+ # events which lifetime has expired under the 'purge_jobs' section.
+ #
+ # If no configuration is provided, a single job will be set up to delete expired
+ # events in every room daily.
+ #
+ # Each job's configuration defines which range of message lifetimes the job
+ # takes care of. For example, if 'shortest_max_lifetime' is '2d' and
+ # 'longest_max_lifetime' is '3d', the job will handle purging expired events in
+ # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and
+ # lower than or equal to 3 days. Both the minimum and the maximum value of a
+ # range are optional, e.g. a job with no 'shortest_max_lifetime' and a
+ # 'longest_max_lifetime' of '3d' will handle every room with a retention policy
+ # which 'max_lifetime' is lower than or equal to three days.
+ #
+ # The rationale for this per-job configuration is that some rooms might have a
+ # retention policy with a low 'max_lifetime', where history needs to be purged
+ # of outdated messages on a very frequent basis (e.g. every 5min), but not want
+ # that purge to be performed by a job that's iterating over every room it knows,
+ # which would be quite heavy on the server.
+ #
+ #purge_jobs:
+ # - shortest_max_lifetime: 1d
+ # longest_max_lifetime: 3d
+ # interval: 5m:
+ # - shortest_max_lifetime: 3d
+ # longest_max_lifetime: 1y
+ # interval: 24h
+
# How long to keep redacted events in unredacted form in the database. After
# this period redacted events get replaced with their redacted form in the DB.
#
@@ -1081,6 +1195,44 @@ def read_gc_thresholds(thresholds):
)
+def parse_listener_def(listener: Any) -> ListenerConfig:
+ """parse a listener config from the config file"""
+ listener_type = listener["type"]
+
+ port = listener.get("port")
+ if not isinstance(port, int):
+ raise ConfigError("Listener configuration is lacking a valid 'port' option")
+
+ tls = listener.get("tls", False)
+
+ bind_addresses = listener.get("bind_addresses", [])
+ bind_address = listener.get("bind_address")
+ # if bind_address was specified, add it to the list of addresses
+ if bind_address:
+ bind_addresses.append(bind_address)
+
+ # if we still have an empty list of addresses, use the default list
+ if not bind_addresses:
+ if listener_type == "metrics":
+ # the metrics listener doesn't support IPv6
+ bind_addresses.append("0.0.0.0")
+ else:
+ bind_addresses.extend(DEFAULT_BIND_ADDRESSES)
+
+ http_config = None
+ if listener_type == "http":
+ http_config = HttpListenerConfig(
+ x_forwarded=listener.get("x_forwarded", False),
+ resources=[
+ HttpResourceConfig(**res) for res in listener.get("resources", [])
+ ],
+ additional_resources=listener.get("additional_resources", {}),
+ tag=listener.get("tag"),
+ )
+
+ return ListenerConfig(port, bind_addresses, listener_type, tls, http_config)
+
+
NO_MORE_WEB_CLIENT_WARNING = """
Synapse no longer includes a web client. To enable a web client, configure
web_client_location. To remove this warning, remove 'webclient' from the 'listeners'
@@ -1088,40 +1240,27 @@ configuration.
"""
-def _warn_if_webclient_configured(listeners):
+def _warn_if_webclient_configured(listeners: Iterable[ListenerConfig]) -> None:
for listener in listeners:
- for res in listener.get("resources", []):
- for name in res.get("names", []):
+ if not listener.http_options:
+ continue
+ for res in listener.http_options.resources:
+ for name in res.names:
if name == "webclient":
logger.warning(NO_MORE_WEB_CLIENT_WARNING)
return
-KNOWN_RESOURCES = (
- "client",
- "consent",
- "federation",
- "keys",
- "media",
- "metrics",
- "openid",
- "replication",
- "static",
- "webclient",
-)
-
-
-def _check_resource_config(listeners):
+def _check_resource_config(listeners: Iterable[ListenerConfig]) -> None:
resource_names = {
res_name
for listener in listeners
- for res in listener.get("resources", [])
- for res_name in res.get("names", [])
+ if listener.http_options
+ for res in listener.http_options.resources
+ for res_name in res.names
}
for resource in resource_names:
- if resource not in KNOWN_RESOURCES:
- raise ConfigError("Unknown listener resource '%s'" % (resource,))
if resource == "consent":
try:
check_requirements("resources.consent")
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index a65538562b..e368ea564d 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -20,8 +20,6 @@ from datetime import datetime
from hashlib import sha256
from typing import List
-import six
-
from unpaddedbase64 import encode_base64
from OpenSSL import SSL, crypto
@@ -59,7 +57,7 @@ class TlsConfig(Config):
logger.warning(ACME_SUPPORT_ENABLED_WARN)
# hyperlink complains on py2 if this is not a Unicode
- self.acme_url = six.text_type(
+ self.acme_url = str(
acme_config.get("url", "https://acme-v01.api.letsencrypt.org/directory")
)
self.acme_port = acme_config.get("port", 80)
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index c8d19c5d6b..43b6c40456 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -26,6 +26,7 @@ class UserDirectoryConfig(Config):
def read_config(self, config, **kwargs):
self.user_directory_search_enabled = True
self.user_directory_search_all_users = False
+ self.user_directory_defer_to_id_server = None
user_directory_config = config.get("user_directory", None)
if user_directory_config:
self.user_directory_search_enabled = user_directory_config.get(
@@ -34,6 +35,9 @@ class UserDirectoryConfig(Config):
self.user_directory_search_all_users = user_directory_config.get(
"search_all_users", False
)
+ self.user_directory_defer_to_id_server = user_directory_config.get(
+ "defer_to_id_server", None
+ )
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
@@ -52,4 +56,9 @@ class UserDirectoryConfig(Config):
#user_directory:
# enabled: true
# search_all_users: false
+ #
+ # # If this is set, user search will be delegated to this ID server instead
+ # # of synapse performing the search itself.
+ # # This is an experimental API.
+ # defer_to_id_server: https://id.example.com
"""
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index ed06b91a54..dbc661630c 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -16,6 +16,7 @@
import attr
from ._base import Config, ConfigError
+from .server import ListenerConfig, parse_listener_def
@attr.s
@@ -52,7 +53,9 @@ class WorkerConfig(Config):
if self.worker_app == "synapse.app.homeserver":
self.worker_app = None
- self.worker_listeners = config.get("worker_listeners", [])
+ self.worker_listeners = [
+ parse_listener_def(x) for x in config.get("worker_listeners", [])
+ ]
self.worker_daemonize = config.get("worker_daemonize")
self.worker_pid_file = config.get("worker_pid_file")
self.worker_log_config = config.get("worker_log_config")
@@ -75,24 +78,11 @@ class WorkerConfig(Config):
manhole = config.get("worker_manhole")
if manhole:
self.worker_listeners.append(
- {
- "port": manhole,
- "bind_addresses": ["127.0.0.1"],
- "type": "manhole",
- "tls": False,
- }
+ ListenerConfig(
+ port=manhole, bind_addresses=["127.0.0.1"], type="manhole",
+ )
)
- if self.worker_listeners:
- for listener in self.worker_listeners:
- bind_address = listener.pop("bind_address", None)
- bind_addresses = listener.setdefault("bind_addresses", [])
-
- if bind_address:
- bind_addresses.append(bind_address)
- elif not bind_addresses:
- bind_addresses.append("")
-
# A map from instance name to host/port of their HTTP replication endpoint.
instance_map = config.get("instance_map") or {}
self.instance_map = {
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index a9f4025bfe..dbfc3e8972 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -15,11 +15,9 @@
# limitations under the License.
import logging
+import urllib
from collections import defaultdict
-import six
-from six.moves import urllib
-
import attr
from signedjson.key import (
decode_verify_key_bytes,
@@ -661,7 +659,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
for response in query_response["server_keys"]:
# do this first, so that we can give useful errors thereafter
server_name = response.get("server_name")
- if not isinstance(server_name, six.string_types):
+ if not isinstance(server_name, str):
raise KeyLookupError(
"Malformed response from key notary server %s: invalid server_name"
% (perspective_name,)
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 533ba327f5..cc5deca75b 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -20,8 +20,6 @@ import os
from distutils.util import strtobool
from typing import Dict, Optional, Type
-import six
-
from unpaddedbase64 import encode_base64
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
@@ -290,7 +288,7 @@ class EventBase(metaclass=abc.ABCMeta):
return list(self._dict.items())
def keys(self):
- return six.iterkeys(self._dict)
+ return self._dict.keys()
def prev_event_ids(self):
"""Returns the list of prev event IDs. The order matches the order
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 7c5f620d09..f94cdcbaba 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -14,8 +14,6 @@
# limitations under the License.
from typing import Optional, Union
-from six import iteritems
-
import attr
from frozendict import frozendict
@@ -341,7 +339,7 @@ def _encode_state_dict(state_dict):
if state_dict is None:
return None
- return [(etype, state_key, v) for (etype, state_key), v in iteritems(state_dict)]
+ return [(etype, state_key, v) for (etype, state_key), v in state_dict.items()]
def _decode_state_dict(input):
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 1ffc9525d1..3c11e317fd 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,7 +15,7 @@
# limitations under the License.
import inspect
-from typing import Any, Dict, List
+from typing import Any, Dict, List, Optional
from synapse.spam_checker_api import SpamCheckerApi
@@ -58,42 +58,82 @@ class SpamChecker(object):
return False
def user_may_invite(
- self, inviter_userid: str, invitee_userid: str, room_id: str
+ self,
+ inviter_userid: str,
+ invitee_userid: str,
+ third_party_invite: Optional[Dict],
+ room_id: str,
+ new_room: bool,
+ published_room: bool,
) -> bool:
"""Checks if a given user may send an invite
If this method returns false, the invite will be rejected.
Args:
- inviter_userid: The user ID of the sender of the invitation
- invitee_userid: The user ID targeted in the invitation
- room_id: The room ID
+ inviter_userid:
+ invitee_userid: The user ID of the invitee. Is None
+ if this is a third party invite and the 3PID is not bound to a
+ user ID.
+ third_party_invite: If a third party invite then is a
+ dict containing the medium and address of the invitee.
+ room_id:
+ new_room: Whether the user is being invited to the room as
+ part of a room creation, if so the invitee would have been
+ included in the call to `user_may_create_room`.
+ published_room: Whether the room the user is being invited
+ to has been published in the local homeserver's public room
+ directory.
Returns:
True if the user may send an invite, otherwise False
"""
for spam_checker in self.spam_checkers:
if (
- spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
+ spam_checker.user_may_invite(
+ inviter_userid,
+ invitee_userid,
+ third_party_invite,
+ room_id,
+ new_room,
+ published_room,
+ )
is False
):
return False
return True
- def user_may_create_room(self, userid: str) -> bool:
+ def user_may_create_room(
+ self,
+ userid: str,
+ invite_list: List[str],
+ third_party_invite_list: List[Dict],
+ cloning: bool,
+ ) -> bool:
"""Checks if a given user may create a room
If this method returns false, the creation request will be rejected.
Args:
userid: The ID of the user attempting to create a room
+ invite_list: List of user IDs that would be invited to
+ the new room.
+ third_party_invite_list: List of third party invites
+ for the new room.
+ cloning: Whether the user is cloning an existing room, e.g.
+ upgrading a room.
Returns:
True if the user may create a room, otherwise False
"""
for spam_checker in self.spam_checkers:
- if spam_checker.user_may_create_room(userid) is False:
+ if (
+ spam_checker.user_may_create_room(
+ userid, invite_list, third_party_invite_list, cloning
+ )
+ is False
+ ):
return False
return True
@@ -134,6 +174,25 @@ class SpamChecker(object):
return True
+ def user_may_join_room(self, userid: str, room_id: str, is_invited: bool):
+ """Checks if a given users is allowed to join a room.
+
+ Not called when a user creates a room.
+
+ Args:
+ userid:
+ room_id:
+ is_invited: Whether the user is invited into the room
+
+ Returns:
+ bool: Whether the user may join the room
+ """
+ for spam_checker in self.spam_checkers:
+ if spam_checker.user_may_join_room(userid, room_id, is_invited) is False:
+ return False
+
+ return True
+
def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server.
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index dd340be9a7..f6b507977f 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -16,8 +16,6 @@ import collections
import re
from typing import Any, Mapping, Union
-from six import string_types
-
from frozendict import frozendict
from twisted.internet import defer
@@ -318,7 +316,7 @@ def serialize_event(
if only_event_fields:
if not isinstance(only_event_fields, list) or not all(
- isinstance(f, string_types) for f in only_event_fields
+ isinstance(f, str) for f in only_event_fields
):
raise TypeError("only_event_fields must be a list of strings")
d = only_fields(d, only_event_fields)
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index b001c64bb4..588d222f36 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from six import integer_types, string_types
-
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import EventFormatVersions
@@ -53,7 +51,7 @@ class EventValidator(object):
event_strings = ["origin"]
for s in event_strings:
- if not isinstance(getattr(event, s), string_types):
+ if not isinstance(getattr(event, s), str):
raise SynapseError(400, "'%s' not a string type" % (s,))
# Depending on the room version, ensure the data is spec compliant JSON.
@@ -90,7 +88,7 @@ class EventValidator(object):
max_lifetime = event.content.get("max_lifetime")
if min_lifetime is not None:
- if not isinstance(min_lifetime, integer_types):
+ if not isinstance(min_lifetime, int):
raise SynapseError(
code=400,
msg="'min_lifetime' must be an integer",
@@ -124,7 +122,7 @@ class EventValidator(object):
)
if max_lifetime is not None:
- if not isinstance(max_lifetime, integer_types):
+ if not isinstance(max_lifetime, int):
raise SynapseError(
code=400,
msg="'max_lifetime' must be an integer",
@@ -183,7 +181,7 @@ class EventValidator(object):
strings.append("state_key")
for s in strings:
- if not isinstance(getattr(event, s), string_types):
+ if not isinstance(getattr(event, s), str):
raise SynapseError(400, "Not '%s' a string type" % (s,))
RoomID.from_string(event.room_id)
@@ -223,7 +221,7 @@ class EventValidator(object):
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))
- if not isinstance(d[s], string_types):
+ if not isinstance(d[s], str):
raise SynapseError(400, "'%s' not a string type" % (s,))
def _ensure_state_event(self, event):
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index c0012c6872..420df2385f 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -17,8 +17,6 @@ import logging
from collections import namedtuple
from typing import Iterable, List
-import six
-
from twisted.internet import defer
from twisted.internet.defer import Deferred, DeferredList
from twisted.python.failure import Failure
@@ -93,8 +91,8 @@ class FederationBase(object):
# *actual* redacted copy to be on the safe side.)
redacted_event = prune_event(pdu)
if set(redacted_event.keys()) == set(pdu.keys()) and set(
- six.iterkeys(redacted_event.content)
- ) == set(six.iterkeys(pdu.content)):
+ redacted_event.content.keys()
+ ) == set(pdu.content.keys()):
logger.info(
"Event %s seems to have been redacted; using our redacted "
"copy",
@@ -294,7 +292,7 @@ def event_from_pdu_json(
assert_params_in_dict(pdu_json, ("type", "depth"))
depth = pdu_json["depth"]
- if not isinstance(depth, six.integer_types):
+ if not isinstance(depth, int):
raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON)
if depth < 0:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 32a8a2ee46..e704cf2f44 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -17,11 +17,8 @@
import logging
from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union
-import six
-from six import iteritems
-
from canonicaljson import json
-from prometheus_client import Counter
+from prometheus_client import Counter, Histogram
from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
@@ -73,6 +70,10 @@ received_queries_counter = Counter(
"synapse_federation_server_received_queries", "", ["type"]
)
+pdu_process_time = Histogram(
+ "synapse_federation_server_pdu_process_time", "Time taken to process an event",
+)
+
class FederationServer(FederationBase):
def __init__(self, hs):
@@ -274,21 +275,22 @@ class FederationServer(FederationBase):
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
- with nested_logging_context(event_id):
- try:
- await self._handle_received_pdu(origin, pdu)
- pdu_results[event_id] = {}
- except FederationError as e:
- logger.warning("Error handling PDU %s: %s", event_id, e)
- pdu_results[event_id] = {"error": str(e)}
- except Exception as e:
- f = failure.Failure()
- pdu_results[event_id] = {"error": str(e)}
- logger.error(
- "Failed to handle PDU %s",
- event_id,
- exc_info=(f.type, f.value, f.getTracebackObject()),
- )
+ with pdu_process_time.time():
+ with nested_logging_context(event_id):
+ try:
+ await self._handle_received_pdu(origin, pdu)
+ pdu_results[event_id] = {}
+ except FederationError as e:
+ logger.warning("Error handling PDU %s: %s", event_id, e)
+ pdu_results[event_id] = {"error": str(e)}
+ except Exception as e:
+ f = failure.Failure()
+ pdu_results[event_id] = {"error": str(e)}
+ logger.error(
+ "Failed to handle PDU %s",
+ event_id,
+ exc_info=(f.type, f.value, f.getTracebackObject()),
+ )
await concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
@@ -534,9 +536,9 @@ class FederationServer(FederationBase):
",".join(
(
"%s for %s:%s" % (key_id, user_id, device_id)
- for user_id, user_keys in iteritems(json_result)
- for device_id, device_keys in iteritems(user_keys)
- for key_id, _ in iteritems(device_keys)
+ for user_id, user_keys in json_result.items()
+ for device_id, device_keys in user_keys.items()
+ for key_id, _ in device_keys.items()
)
),
)
@@ -752,7 +754,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
def _acl_entry_matches(server_name: str, acl_entry: str) -> Match:
- if not isinstance(acl_entry, six.string_types):
+ if not isinstance(acl_entry, str):
logger.warning(
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
)
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 52f4f54215..6bbd762681 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -33,8 +33,6 @@ import logging
from collections import namedtuple
from typing import Dict, List, Tuple, Type
-from six import iteritems
-
from sortedcontainers import SortedDict
from twisted.internet import defer
@@ -327,7 +325,7 @@ class FederationRemoteSendQueue(object):
# stream position.
keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]}
- for ((destination, edu_key), pos) in iteritems(keyed_edus):
+ for ((destination, edu_key), pos) in keyed_edus.items():
rows.append(
(
pos,
@@ -530,10 +528,10 @@ def process_rows_for_federation(transaction_queue, rows):
states=[state], destinations=destinations
)
- for destination, edu_map in iteritems(buff.keyed_edus):
+ for destination, edu_map in buff.keyed_edus.items():
for key, edu in edu_map.items():
transaction_queue.send_edu(edu, key)
- for destination, edu_list in iteritems(buff.edus):
+ for destination, edu_list in buff.edus.items():
for edu in edu_list:
transaction_queue.send_edu(edu, None)
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index d473576902..23fb515683 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -16,8 +16,6 @@
import logging
from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple
-from six import itervalues
-
from prometheus_client import Counter
from twisted.internet import defer
@@ -203,7 +201,15 @@ class FederationSender(object):
logger.debug("Sending %s to %r", event, destinations)
- self._send_pdu(event, destinations)
+ if destinations:
+ self._send_pdu(event, destinations)
+
+ now = self.clock.time_msec()
+ ts = await self.store.get_received_ts(event.event_id)
+
+ synapse.metrics.event_processing_lag_by_event.labels(
+ "federation_sender"
+ ).observe(now - ts)
async def handle_room_events(events: Iterable[EventBase]) -> None:
with Measure(self.clock, "handle_room_events"):
@@ -218,7 +224,7 @@ class FederationSender(object):
defer.gatherResults(
[
run_in_background(handle_room_events, evs)
- for evs in itervalues(events_by_room)
+ for evs in events_by_room.values()
],
consumeErrors=True,
)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 060bf07197..1ac522c8a1 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -15,9 +15,8 @@
# limitations under the License.
import logging
-from typing import Any, Dict, Optional
-
-from six.moves import urllib
+import urllib
+from typing import Any, Dict, List, Optional
from twisted.internet import defer
@@ -1021,6 +1020,20 @@ class TransportLayerClient(object):
return self.client.get_json(destination=destination, path=path)
+ def get_info_of_users(self, destination: str, user_ids: List[str]):
+ """
+ Args:
+ destination: The remote server
+ user_ids: A list of user IDs to query info about
+
+ Returns:
+ Deferred[List]: A dictionary of User ID to information about that user.
+ """
+ path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/users/info")
+ data = {"user_ids": user_ids}
+
+ return self.client.post_json(destination=destination, path=path, data=data)
+
def _create_path(federation_prefix, path, *args):
"""
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index af4595498c..cb6331d613 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -33,6 +33,7 @@ from synapse.api.urls import (
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
+ assert_params_in_dict,
parse_boolean_from_args,
parse_integer_from_args,
parse_json_object_from_request,
@@ -849,6 +850,57 @@ class PublicRoomList(BaseFederationServlet):
return 200, data
+class FederationUserInfoServlet(BaseFederationServlet):
+ """
+ Return information about a set of users.
+
+ This API returns expiration and deactivation information about a set of
+ users. Requested users not local to this homeserver will be ignored.
+
+ Example request:
+ POST /users/info
+
+ {
+ "user_ids": [
+ "@alice:example.com",
+ "@bob:example.com"
+ ]
+ }
+
+ Example response
+ {
+ "@alice:example.com": {
+ "expired": false,
+ "deactivated": true
+ }
+ }
+ """
+
+ PATH = "/users/info"
+ PREFIX = FEDERATION_UNSTABLE_PREFIX
+
+ def __init__(self, handler, authenticator, ratelimiter, server_name):
+ super(FederationUserInfoServlet, self).__init__(
+ handler, authenticator, ratelimiter, server_name
+ )
+ self.handler = handler
+
+ async def on_POST(self, origin, content, query):
+ assert_params_in_dict(content, required=["user_ids"])
+
+ user_ids = content.get("user_ids", [])
+
+ if not isinstance(user_ids, list):
+ raise SynapseError(
+ 400,
+ "'user_ids' must be a list of user ID strings",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ data = await self.handler.store.get_info_for_users(user_ids)
+ return 200, data
+
+
class FederationVersionServlet(BaseFederationServlet):
PATH = "/version"
@@ -1410,6 +1462,7 @@ FEDERATION_SERVLET_CLASSES = (
On3pidBindServlet,
FederationVersionServlet,
RoomComplexityServlet,
+ FederationUserInfoServlet,
) # type: Tuple[Type[BaseFederationServlet], ...]
OPENID_SERVLET_CLASSES = (
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 8a9de913b3..8db8ab1b7b 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -17,8 +17,6 @@
import logging
-from six import string_types
-
from synapse.api.errors import Codes, SynapseError
from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
@@ -513,7 +511,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
for keyname in ("name", "avatar_url", "short_description", "long_description"):
if keyname in content:
value = content[keyname]
- if not isinstance(value, string_types):
+ if not isinstance(value, str):
raise SynapseError(400, "%r value is not a string" % (keyname,))
profile[keyname] = value
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 590135d19c..0c2bcda4d0 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -20,6 +20,8 @@ from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import List
+from twisted.internet import defer
+
from synapse.api.errors import StoreError
from synapse.logging.context import make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -43,6 +45,8 @@ class AccountValidityHandler(object):
self.clock = self.hs.get_clock()
self._account_validity = self.hs.config.account_validity
+ self._show_users_in_user_directory = self.hs.config.show_users_in_user_directory
+ self.profile_handler = self.hs.get_profile_handler()
if (
self._account_validity.enabled
@@ -86,6 +90,13 @@ class AccountValidityHandler(object):
self.clock.looping_call(send_emails, 30 * 60 * 1000)
+ # If account_validity is enabled,check every hour to remove expired users from
+ # the user directory
+ if self._account_validity.enabled:
+ self.clock.looping_call(
+ self._mark_expired_users_as_inactive, 60 * 60 * 1000
+ )
+
async def _send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time
configured in the ``renew_at`` parameter from the ``account_validity``
@@ -266,4 +277,25 @@ class AccountValidityHandler(object):
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
)
+ # Check if renewed users should be reintroduced to the user directory
+ if self._show_users_in_user_directory:
+ # Show the user in the directory again by setting them to active
+ await self.profile_handler.set_active(
+ [UserID.from_string(user_id)], True, True
+ )
+
return expiration_ts
+
+ @defer.inlineCallbacks
+ def _mark_expired_users_as_inactive(self):
+ """Iterate over active, expired users. Mark them as inactive in order to hide them
+ from the user directory.
+
+ Returns:
+ Deferred
+ """
+ # Get active, expired users
+ active_expired_users = yield self.store.get_expired_users()
+
+ # Mark each as non-active
+ yield self.profile_handler.set_active(active_expired_users, False, True)
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index fe62f78e67..f7d9fd621e 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -15,8 +15,6 @@
import logging
-from six import itervalues
-
from prometheus_client import Counter
from twisted.internet import defer
@@ -116,6 +114,12 @@ class ApplicationServicesHandler(object):
for service in services:
self.scheduler.submit_event_for_as(service, event)
+ now = self.clock.time_msec()
+ ts = yield self.store.get_received_ts(event.event_id)
+ synapse.metrics.event_processing_lag_by_event.labels(
+ "appservice_sender"
+ ).observe(now - ts)
+
@defer.inlineCallbacks
def handle_room_events(events):
for event in events:
@@ -125,7 +129,7 @@ class ApplicationServicesHandler(object):
defer.gatherResults(
[
run_in_background(handle_room_events, evs)
- for evs in itervalues(events_by_room)
+ for evs in events_by_room.values()
],
consumeErrors=True,
)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index bb3b43d5ae..c3f86e7414 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -297,7 +297,7 @@ class AuthHandler(BaseHandler):
# Convert the URI and method to strings.
uri = request.uri.decode("utf-8")
- method = request.uri.decode("utf-8")
+ method = request.method.decode("utf-8")
# If there's no session ID, create a new session.
if not sid:
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 64aaa1335c..76f213723a 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -14,11 +14,10 @@
# limitations under the License.
import logging
+import urllib
import xml.etree.ElementTree as ET
from typing import Dict, Optional, Tuple
-from six.moves import urllib
-
from twisted.web.client import PartialDownloadError
from synapse.api.errors import Codes, LoginError
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 2afb390a92..fe62c3f973 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -33,6 +33,7 @@ class DeactivateAccountHandler(BaseHandler):
self._device_handler = hs.get_device_handler()
self._room_member_handler = hs.get_room_member_handler()
self._identity_handler = hs.get_handlers().identity_handler
+ self._profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
# Flag that indicates whether the process to part users from rooms is running
@@ -104,6 +105,9 @@ class DeactivateAccountHandler(BaseHandler):
await self.store.user_set_password_hash(user_id, None)
+ user = UserID.from_string(user_id)
+ await self._profile_handler.set_active([user], False, False)
+
# Add the user to a table of users pending deactivation (ie.
# removal from all the rooms they're a member of)
await self.store.add_user_pending_deactivation(user_id)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 230d170258..31346b56c3 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -17,8 +17,6 @@
import logging
from typing import Any, Dict, Optional
-from six import iteritems, itervalues
-
from twisted.internet import defer
from synapse.api import errors
@@ -159,7 +157,7 @@ class DeviceWorkerHandler(BaseHandler):
# The user may have left the room
# TODO: Check if they actually did or if we were just invited.
if room_id not in room_ids:
- for key, event_id in iteritems(current_state_ids):
+ for key, event_id in current_state_ids.items():
etype, state_key = key
if etype != EventTypes.Member:
continue
@@ -182,7 +180,7 @@ class DeviceWorkerHandler(BaseHandler):
log_kv(
{"event": "encountered empty previous state", "room_id": room_id}
)
- for key, event_id in iteritems(current_state_ids):
+ for key, event_id in current_state_ids.items():
etype, state_key = key
if etype != EventTypes.Member:
continue
@@ -198,10 +196,10 @@ class DeviceWorkerHandler(BaseHandler):
# Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users.
- for state_dict in itervalues(prev_state_ids):
+ for state_dict in prev_state_ids.values():
member_event = state_dict.get((EventTypes.Member, user_id), None)
if not member_event or member_event != current_member_id:
- for key, event_id in iteritems(current_state_ids):
+ for key, event_id in current_state_ids.items():
etype, state_key = key
if etype != EventTypes.Member:
continue
@@ -211,14 +209,14 @@ class DeviceWorkerHandler(BaseHandler):
# If there has been any change in membership, include them in the
# possibly changed list. We'll check if they are joined below,
# and we're not toooo worried about spuriously adding users.
- for key, event_id in iteritems(current_state_ids):
+ for key, event_id in current_state_ids.items():
etype, state_key = key
if etype != EventTypes.Member:
continue
# check if this member has changed since any of the extremities
# at the stream_ordering, and add them to the list if so.
- for state_dict in itervalues(prev_state_ids):
+ for state_dict in prev_state_ids.values():
prev_event_id = state_dict.get(key, None)
if not prev_event_id or prev_event_id != event_id:
if state_key != user_id:
@@ -693,6 +691,7 @@ class DeviceListUpdater(object):
return False
+ @trace
@defer.inlineCallbacks
def _maybe_retry_device_resync(self):
"""Retry to resync device lists that are out of sync, except if another retry is
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 05c4b3eec0..610b08d00b 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -18,8 +18,6 @@ from typing import Any, Dict
from canonicaljson import json
-from twisted.internet import defer
-
from synapse.api.errors import SynapseError
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
@@ -51,8 +49,7 @@ class DeviceMessageHandler(object):
self._device_list_updater = hs.get_device_handler().device_list_updater
- @defer.inlineCallbacks
- def on_direct_to_device_edu(self, origin, content):
+ async def on_direct_to_device_edu(self, origin, content):
local_messages = {}
sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id):
@@ -82,11 +79,11 @@ class DeviceMessageHandler(object):
}
local_messages[user_id] = messages_by_device
- yield self._check_for_unknown_devices(
+ await self._check_for_unknown_devices(
message_type, sender_user_id, by_device
)
- stream_id = yield self.store.add_messages_from_remote_to_device_inbox(
+ stream_id = await self.store.add_messages_from_remote_to_device_inbox(
origin, message_id, local_messages
)
@@ -94,14 +91,13 @@ class DeviceMessageHandler(object):
"to_device_key", stream_id, users=local_messages.keys()
)
- @defer.inlineCallbacks
- def _check_for_unknown_devices(
+ async def _check_for_unknown_devices(
self,
message_type: str,
sender_user_id: str,
by_device: Dict[str, Dict[str, Any]],
):
- """Checks inbound device messages for unkown remote devices, and if
+ """Checks inbound device messages for unknown remote devices, and if
found marks the remote cache for the user as stale.
"""
@@ -115,7 +111,7 @@ class DeviceMessageHandler(object):
requesting_device_ids.add(device_id)
# Check if we are tracking the devices of the remote user.
- room_ids = yield self.store.get_rooms_for_user(sender_user_id)
+ room_ids = await self.store.get_rooms_for_user(sender_user_id)
if not room_ids:
logger.info(
"Received device message from remote device we don't"
@@ -127,7 +123,7 @@ class DeviceMessageHandler(object):
# If we are tracking check that we know about the sending
# devices.
- cached_devices = yield self.store.get_cached_devices_for_user(sender_user_id)
+ cached_devices = await self.store.get_cached_devices_for_user(sender_user_id)
unknown_devices = requesting_device_ids - set(cached_devices)
if unknown_devices:
@@ -136,15 +132,14 @@ class DeviceMessageHandler(object):
sender_user_id,
unknown_devices,
)
- yield self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
+ await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
# Immediately attempt a resync in the background
run_in_background(
self._device_list_updater.user_device_resync, sender_user_id
)
- @defer.inlineCallbacks
- def send_device_message(self, sender_user_id, message_type, messages):
+ async def send_device_message(self, sender_user_id, message_type, messages):
set_tag("number_of_messages", len(messages))
set_tag("sender", sender_user_id)
local_messages = {}
@@ -183,7 +178,7 @@ class DeviceMessageHandler(object):
}
log_kv({"local_messages": local_messages})
- stream_id = yield self.store.add_messages_to_device_inbox(
+ stream_id = await self.store.add_messages_to_device_inbox(
local_messages, remote_edu_contents
)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index f2f16b1e43..79a2df6201 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -17,8 +17,6 @@ import logging
import string
from typing import Iterable, List, Optional
-from twisted.internet import defer
-
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes
from synapse.api.errors import (
AuthError,
@@ -55,8 +53,7 @@ class DirectoryHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
- @defer.inlineCallbacks
- def _create_association(
+ async def _create_association(
self,
room_alias: RoomAlias,
room_id: str,
@@ -76,13 +73,13 @@ class DirectoryHandler(BaseHandler):
# TODO(erikj): Add transactions.
# TODO(erikj): Check if there is a current association.
if not servers:
- users = yield self.state.get_current_users_in_room(room_id)
+ users = await self.state.get_current_users_in_room(room_id)
servers = {get_domain_from_id(u) for u in users}
if not servers:
raise SynapseError(400, "Failed to get server list")
- yield self.store.create_room_alias_association(
+ await self.store.create_room_alias_association(
room_alias, room_id, servers, creator=creator
)
@@ -93,7 +90,7 @@ class DirectoryHandler(BaseHandler):
room_id: str,
servers: Optional[List[str]] = None,
check_membership: bool = True,
- ):
+ ) -> None:
"""Attempt to create a new alias
Args:
@@ -103,9 +100,6 @@ class DirectoryHandler(BaseHandler):
servers: Iterable of servers that others servers should try and join via
check_membership: Whether to check if the user is in the room
before the alias can be set (if the server's config requires it).
-
- Returns:
- Deferred
"""
user_id = requester.user.to_string()
@@ -148,7 +142,7 @@ class DirectoryHandler(BaseHandler):
# per alias creation rule?
raise SynapseError(403, "Not allowed to create alias")
- can_create = await self.can_modify_alias(room_alias, user_id=user_id)
+ can_create = self.can_modify_alias(room_alias, user_id=user_id)
if not can_create:
raise AuthError(
400,
@@ -158,7 +152,9 @@ class DirectoryHandler(BaseHandler):
await self._create_association(room_alias, room_id, servers, creator=user_id)
- async def delete_association(self, requester: Requester, room_alias: RoomAlias):
+ async def delete_association(
+ self, requester: Requester, room_alias: RoomAlias
+ ) -> str:
"""Remove an alias from the directory
(this is only meant for human users; AS users should call
@@ -169,7 +165,7 @@ class DirectoryHandler(BaseHandler):
room_alias
Returns:
- Deferred[unicode]: room id that the alias used to point to
+ room id that the alias used to point to
Raises:
NotFoundError: if the alias doesn't exist
@@ -191,7 +187,7 @@ class DirectoryHandler(BaseHandler):
if not can_delete:
raise AuthError(403, "You don't have permission to delete the alias.")
- can_delete = await self.can_modify_alias(room_alias, user_id=user_id)
+ can_delete = self.can_modify_alias(room_alias, user_id=user_id)
if not can_delete:
raise SynapseError(
400,
@@ -208,8 +204,7 @@ class DirectoryHandler(BaseHandler):
return room_id
- @defer.inlineCallbacks
- def delete_appservice_association(
+ async def delete_appservice_association(
self, service: ApplicationService, room_alias: RoomAlias
):
if not service.is_interested_in_alias(room_alias.to_string()):
@@ -218,29 +213,27 @@ class DirectoryHandler(BaseHandler):
"This application service has not reserved this kind of alias",
errcode=Codes.EXCLUSIVE,
)
- yield self._delete_association(room_alias)
+ await self._delete_association(room_alias)
- @defer.inlineCallbacks
- def _delete_association(self, room_alias: RoomAlias):
+ async def _delete_association(self, room_alias: RoomAlias):
if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local")
- room_id = yield self.store.delete_room_alias(room_alias)
+ room_id = await self.store.delete_room_alias(room_alias)
return room_id
- @defer.inlineCallbacks
- def get_association(self, room_alias: RoomAlias):
+ async def get_association(self, room_alias: RoomAlias):
room_id = None
if self.hs.is_mine(room_alias):
- result = yield self.get_association_from_room_alias(room_alias)
+ result = await self.get_association_from_room_alias(room_alias)
if result:
room_id = result.room_id
servers = result.servers
else:
try:
- result = yield self.federation.make_query(
+ result = await self.federation.make_query(
destination=room_alias.domain,
query_type="directory",
args={"room_alias": room_alias.to_string()},
@@ -265,7 +258,7 @@ class DirectoryHandler(BaseHandler):
Codes.NOT_FOUND,
)
- users = yield self.state.get_current_users_in_room(room_id)
+ users = await self.state.get_current_users_in_room(room_id)
extra_servers = {get_domain_from_id(u) for u in users}
servers = set(extra_servers) | set(servers)
@@ -277,13 +270,12 @@ class DirectoryHandler(BaseHandler):
return {"room_id": room_id, "servers": servers}
- @defer.inlineCallbacks
- def on_directory_query(self, args):
+ async def on_directory_query(self, args):
room_alias = RoomAlias.from_string(args["room_alias"])
if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room Alias is not hosted on this homeserver")
- result = yield self.get_association_from_room_alias(room_alias)
+ result = await self.get_association_from_room_alias(room_alias)
if result is not None:
return {"room_id": result.room_id, "servers": result.servers}
@@ -344,16 +336,15 @@ class DirectoryHandler(BaseHandler):
ratelimit=False,
)
- @defer.inlineCallbacks
- def get_association_from_room_alias(self, room_alias: RoomAlias):
- result = yield self.store.get_association_from_room_alias(room_alias)
+ async def get_association_from_room_alias(self, room_alias: RoomAlias):
+ result = await self.store.get_association_from_room_alias(room_alias)
if not result:
# Query AS to see if it exists
as_handler = self.appservice_handler
- result = yield as_handler.query_room_alias_exists(room_alias)
+ result = await as_handler.query_room_alias_exists(room_alias)
return result
- def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None):
+ def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None) -> bool:
# Any application service "interested" in an alias they are regexing on
# can modify the alias.
# Users can only modify the alias if ALL the interested services have
@@ -366,12 +357,12 @@ class DirectoryHandler(BaseHandler):
for service in interested_services:
if user_id == service.sender:
# this user IS the app service so they can do whatever they like
- return defer.succeed(True)
+ return True
elif service.is_exclusive_alias(alias.to_string()):
# another service has an exclusive lock on this alias.
- return defer.succeed(False)
+ return False
# either no interested services, or no service with an exclusive lock
- return defer.succeed(True)
+ return True
async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
"""Determine whether a user can delete an alias.
@@ -459,8 +450,7 @@ class DirectoryHandler(BaseHandler):
await self.store.set_room_is_public(room_id, making_public)
- @defer.inlineCallbacks
- def edit_published_appservice_room_list(
+ async def edit_published_appservice_room_list(
self, appservice_id: str, network_id: str, room_id: str, visibility: str
):
"""Add or remove a room from the appservice/network specific public
@@ -475,7 +465,7 @@ class DirectoryHandler(BaseHandler):
if visibility not in ["public", "private"]:
raise SynapseError(400, "Invalid visibility setting")
- yield self.store.set_room_is_public_appservice(
+ await self.store.set_room_is_public_appservice(
room_id, appservice_id, network_id, visibility == "public"
)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 774a252619..a7e60cbc26 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -17,8 +17,6 @@
import logging
-from six import iteritems
-
import attr
from canonicaljson import encode_canonical_json, json
from signedjson.key import decode_verify_key_bytes
@@ -135,7 +133,7 @@ class E2eKeysHandler(object):
remote_queries_not_in_cache = {}
if remote_queries:
query_list = []
- for user_id, device_ids in iteritems(remote_queries):
+ for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
else:
@@ -145,9 +143,9 @@ class E2eKeysHandler(object):
user_ids_not_in_cache,
remote_results,
) = yield self.store.get_user_devices_from_cache(query_list)
- for user_id, devices in iteritems(remote_results):
+ for user_id, devices in remote_results.items():
user_devices = results.setdefault(user_id, {})
- for device_id, device in iteritems(devices):
+ for device_id, device in devices.items():
keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None)
if keys:
@@ -446,9 +444,9 @@ class E2eKeysHandler(object):
",".join(
(
"%s for %s:%s" % (key_id, user_id, device_id)
- for user_id, user_keys in iteritems(json_result)
- for device_id, device_keys in iteritems(user_keys)
- for key_id, _ in iteritems(device_keys)
+ for user_id, user_keys in json_result.items()
+ for device_id, device_keys in user_keys.items()
+ for key_id, _ in device_keys.items()
)
),
)
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 9abaf13b8f..f55470a707 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -16,8 +16,6 @@
import logging
-from six import iteritems
-
from twisted.internet import defer
from synapse.api.errors import (
@@ -205,8 +203,8 @@ class E2eRoomKeysHandler(object):
)
to_insert = [] # batch the inserts together
changed = False # if anything has changed, we need to update the etag
- for room_id, room in iteritems(room_keys["rooms"]):
- for session_id, room_key in iteritems(room["sessions"]):
+ for room_id, room in room_keys["rooms"].items():
+ for session_id, room_key in room["sessions"].items():
if not isinstance(room_key["is_verified"], bool):
msg = (
"is_verified must be a boolean in keys for session %s in"
@@ -351,6 +349,7 @@ class E2eRoomKeysHandler(object):
raise
res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"])
+ res["etag"] = str(res["etag"])
return res
@trace
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index d0b62f4cf2..8a552ca7b5 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,12 +19,9 @@
import itertools
import logging
+from http import HTTPStatus
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
-import six
-from six import iteritems, itervalues
-from six.moves import http_client, zip
-
import attr
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
@@ -33,7 +30,12 @@ from unpaddedbase64 import decode_base64
from twisted.internet import defer
from synapse import event_auth
-from synapse.api.constants import EventTypes, Membership, RejectedReason
+from synapse.api.constants import (
+ EventTypes,
+ Membership,
+ RejectedReason,
+ RoomEncryptionAlgorithms,
+)
from synapse.api.errors import (
AuthError,
CodeMessageException,
@@ -171,7 +173,7 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id
event_id = pdu.event_id
- logger.info("handling received PDU: %s", pdu)
+ logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
# We reprocess pdus when we have seen them only as outliers
existing = await self.store.get_event(
@@ -286,6 +288,14 @@ class FederationHandler(BaseHandler):
room_id,
event_id,
)
+ elif missing_prevs:
+ logger.info(
+ "[%s %s] Not recursively fetching %d missing prev_events: %s",
+ room_id,
+ event_id,
+ len(missing_prevs),
+ shortstr(missing_prevs),
+ )
if prevs - seen:
# We've still not been able to get all of the prev_events for this event.
@@ -330,12 +340,6 @@ class FederationHandler(BaseHandler):
affected=pdu.event_id,
)
- logger.info(
- "Event %s is missing prev_events: calculating state for a "
- "backwards extremity",
- event_id,
- )
-
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
event_map = {event_id: pdu}
@@ -353,7 +357,10 @@ class FederationHandler(BaseHandler):
# know about
for p in prevs - seen:
logger.info(
- "Requesting state at missing prev_event %s", event_id,
+ "[%s %s] Requesting state at missing prev_event %s",
+ room_id,
+ event_id,
+ p,
)
with nested_logging_context(p):
@@ -374,6 +381,7 @@ class FederationHandler(BaseHandler):
room_version = await self.store.get_room_version_id(room_id)
state_map = await resolve_events_with_store(
+ self.clock,
room_id,
room_version,
state_maps,
@@ -387,13 +395,11 @@ class FederationHandler(BaseHandler):
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
evs = await self.store.get_events(
- list(state_map.values()),
- get_prev_content=False,
- redact_behaviour=EventRedactBehaviour.AS_IS,
+ list(state_map.values()), get_prev_content=False,
)
event_map.update(evs)
- state = [event_map[e] for e in six.itervalues(state_map)]
+ state = [event_map[e] for e in state_map.values()]
except Exception:
logger.warning(
"[%s %s] Error attempting to resolve state at missing "
@@ -742,7 +748,10 @@ class FederationHandler(BaseHandler):
if device:
keys = device.get("keys", {}).get("keys", {})
- if event.content.get("algorithm") == "m.megolm.v1.aes-sha2":
+ if (
+ event.content.get("algorithm")
+ == RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2
+ ):
# For this algorithm we expect a curve25519 key.
key_name = "curve25519:%s" % (device_id,)
current_keys = [keys.get(key_name)]
@@ -1001,7 +1010,7 @@ class FederationHandler(BaseHandler):
"""
joined_users = [
(state_key, int(event.depth))
- for (e_type, state_key), event in iteritems(state)
+ for (e_type, state_key), event in state.items()
if e_type == EventTypes.Member and event.membership == Membership.JOIN
]
@@ -1091,16 +1100,16 @@ class FederationHandler(BaseHandler):
states = dict(zip(event_ids, [s.state for s in states]))
state_map = await self.store.get_events(
- [e_id for ids in itervalues(states) for e_id in itervalues(ids)],
+ [e_id for ids in states.values() for e_id in ids.values()],
get_prev_content=False,
)
states = {
key: {
k: state_map[e_id]
- for k, e_id in iteritems(state_dict)
+ for k, e_id in state_dict.items()
if e_id in state_map
}
- for key, state_dict in iteritems(states)
+ for key, state_dict in states.items()
}
for e_id, _ in sorted_extremeties_tuple:
@@ -1188,7 +1197,7 @@ class FederationHandler(BaseHandler):
ev.event_id,
len(ev.prev_event_ids()),
)
- raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many prev_events")
if len(ev.auth_event_ids()) > 10:
logger.warning(
@@ -1196,7 +1205,7 @@ class FederationHandler(BaseHandler):
ev.event_id,
len(ev.auth_event_ids()),
)
- raise SynapseError(http_client.BAD_REQUEST, "Too many auth_events")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
async def send_invite(self, target_host, event):
""" Sends the invite to the remote server for signing.
@@ -1517,8 +1526,15 @@ class FederationHandler(BaseHandler):
if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
+ is_published = await self.store.is_room_published(event.room_id)
+
if not self.spam_checker.user_may_invite(
- event.sender, event.state_key, event.room_id
+ event.sender,
+ event.state_key,
+ None,
+ room_id=event.room_id,
+ new_room=False,
+ published_room=is_published,
):
raise SynapseError(
403, "This user is not permitted to send invites to this server/user"
@@ -1539,7 +1555,7 @@ class FederationHandler(BaseHandler):
# block any attempts to invite the server notices mxid
if event.state_key == self._server_notices_mxid:
- raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user")
+ raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
# keep a record of the room version, if we don't yet know it.
# (this may get overwritten if we later get a different room version in a
@@ -1725,7 +1741,7 @@ class FederationHandler(BaseHandler):
state_groups = await self.state_store.get_state_groups(room_id, [event_id])
if state_groups:
- _, state = list(iteritems(state_groups)).pop()
+ _, state = list(state_groups.items()).pop()
results = {(e.type, e.state_key): e for e in state}
if event.is_state():
@@ -2088,7 +2104,7 @@ class FederationHandler(BaseHandler):
room_version, state_sets, event
)
current_state_ids = {
- k: e.event_id for k, e in iteritems(current_state_ids)
+ k: e.event_id for k, e in current_state_ids.items()
}
else:
current_state_ids = await self.state_handler.get_current_state_ids(
@@ -2104,7 +2120,7 @@ class FederationHandler(BaseHandler):
# Now check if event pass auth against said current state
auth_types = auth_types_for_event(event)
current_state_ids = [
- e for k, e in iteritems(current_state_ids) if k in auth_types
+ e for k, e in current_state_ids.items() if k in auth_types
]
current_auth_events = await self.store.get_events(current_state_ids)
@@ -2420,7 +2436,7 @@ class FederationHandler(BaseHandler):
else:
event_key = None
state_updates = {
- k: a.event_id for k, a in iteritems(auth_events) if k != event_key
+ k: a.event_id for k, a in auth_events.items() if k != event_key
}
current_state_ids = await context.get_current_state_ids()
@@ -2431,7 +2447,7 @@ class FederationHandler(BaseHandler):
prev_state_ids = await context.get_prev_state_ids()
prev_state_ids = dict(prev_state_ids)
- prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)})
+ prev_state_ids.update({k: a.event_id for k, a in auth_events.items()})
# create a new state group as a delta from the existing one.
prev_group = context.state_group
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index ebe8d25bd8..7cb106e365 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -16,8 +16,6 @@
import logging
-from six import iteritems
-
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.types import get_domain_from_id
@@ -227,7 +225,7 @@ class GroupsLocalWorkerHandler(object):
results = {}
failed_results = []
- for destination, dest_user_ids in iteritems(destinations):
+ for destination, dest_user_ids in destinations.items():
try:
r = await self.transport_client.bulk_get_publicised_groups(
destination, list(dest_user_ids)
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 4ba0042768..a77088e295 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018, 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,6 +25,7 @@ from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
+from twisted.internet import defer
from twisted.internet.error import TimeoutError
from synapse.api.errors import (
@@ -32,6 +33,7 @@ from synapse.api.errors import (
CodeMessageException,
Codes,
HttpResponseException,
+ ProxiedRequestError,
SynapseError,
)
from synapse.config.emailconfig import ThreepidBehaviour
@@ -43,29 +45,34 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
-id_server_scheme = "https://"
-
class IdentityHandler(BaseHandler):
def __init__(self, hs):
super(IdentityHandler, self).__init__(hs)
- self.http_client = SimpleHttpClient(hs)
+ self.hs = hs
+ self.http_client = hs.get_simple_http_client()
# We create a blacklisting instance of SimpleHttpClient for contacting identity
# servers specified by clients
self.blacklisting_http_client = SimpleHttpClient(
hs, ip_blacklist=hs.config.federation_ip_range_blacklist
)
self.federation_http_client = hs.get_http_client()
- self.hs = hs
- async def threepid_from_creds(self, id_server, creds):
+ self.trusted_id_servers = set(hs.config.trusted_third_party_id_servers)
+ self.trust_any_id_server_just_for_testing_do_not_use = (
+ hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
+ )
+ self.rewrite_identity_server_urls = hs.config.rewrite_identity_server_urls
+ self._enable_lookup = hs.config.enable_3pid_lookup
+
+ async def threepid_from_creds(self, id_server_url, creds):
"""
Retrieve and validate a threepid identifier from a "credentials" dictionary against a
given identity server
Args:
- id_server (str): The identity server to validate 3PIDs against. Must be a
+ id_server_url (str): The identity server to validate 3PIDs against. Must be a
complete URL including the protocol (http(s)://)
creds (dict[str, str]): Dictionary containing the following keys:
@@ -92,7 +99,14 @@ class IdentityHandler(BaseHandler):
query_params = {"sid": session_id, "client_secret": client_secret}
- url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid"
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ id_server_url = self.rewrite_id_server_url(id_server_url)
+
+ url = "%s%s" % (
+ id_server_url,
+ "/_matrix/identity/api/v1/3pid/getValidated3pid",
+ )
try:
data = await self.http_client.get_json(url, query_params)
@@ -101,7 +115,7 @@ class IdentityHandler(BaseHandler):
except HttpResponseException as e:
logger.info(
"%s returned %i for threepid validation for: %s",
- id_server,
+ id_server_url,
e.code,
creds,
)
@@ -115,7 +129,7 @@ class IdentityHandler(BaseHandler):
if "medium" in data:
return data
- logger.info("%s reported non-validated threepid: %s", id_server, creds)
+ logger.info("%s reported non-validated threepid: %s", id_server_url, creds)
return None
async def bind_threepid(
@@ -146,14 +160,19 @@ class IdentityHandler(BaseHandler):
if id_access_token is None:
use_v2 = False
+ # if we have a rewrite rule set for the identity server,
+ # apply it now, but only for sending the request (not
+ # storing in the database).
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
# Decide which API endpoint URLs to use
headers = {}
bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
if use_v2:
- bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
+ bind_url = "%s/_matrix/identity/v2/3pid/bind" % (id_server_url,)
headers["Authorization"] = create_id_access_token_header(id_access_token)
else:
- bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,)
+ bind_url = "%s/_matrix/identity/api/v1/3pid/bind" % (id_server_url,)
try:
# Use the blacklisting http client as this call is only to identity servers
@@ -238,9 +257,6 @@ class IdentityHandler(BaseHandler):
Deferred[bool]: True on success, otherwise False if the identity
server doesn't support unbinding
"""
- url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
- url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
-
content = {
"mxid": mxid,
"threepid": {"medium": threepid["medium"], "address": threepid["address"]},
@@ -249,6 +265,7 @@ class IdentityHandler(BaseHandler):
# we abuse the federation http client to sign the request, but we have to send it
# using the normal http client since we don't want the SRV lookup and want normal
# 'browser-like' HTTPS.
+ url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
auth_headers = self.federation_http_client.build_auth_headers(
destination=None,
method="POST",
@@ -258,6 +275,15 @@ class IdentityHandler(BaseHandler):
)
headers = {b"Authorization": auth_headers}
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ #
+ # Note that destination_is has to be the real id_server, not
+ # the server we connect to.
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ url = "%s/_matrix/identity/api/v1/3pid/unbind" % (id_server_url,)
+
try:
# Use the blacklisting http client as this call is only to identity servers
# provided by a client
@@ -371,15 +397,34 @@ class IdentityHandler(BaseHandler):
return session_id
+ def rewrite_id_server_url(self, url: str, add_https=False) -> str:
+ """Given an identity server URL, optionally add a protocol scheme
+ before rewriting it according to the rewrite_identity_server_urls
+ config option
+
+ Adds https:// to the URL if specified, then tries to rewrite the
+ url. Returns either the rewritten URL or the URL with optional
+ protocol scheme additions.
+ """
+ rewritten_url = url
+ if add_https:
+ rewritten_url = "https://" + rewritten_url
+
+ rewritten_url = self.rewrite_identity_server_urls.get(
+ rewritten_url, rewritten_url
+ )
+ logger.debug("Rewriting identity server rule from %s to %s", url, rewritten_url)
+ return rewritten_url
+
async def requestEmailToken(
- self, id_server, email, client_secret, send_attempt, next_link=None
+ self, id_server_url, email, client_secret, send_attempt, next_link=None
):
"""
Request an external server send an email on our behalf for the purposes of threepid
validation.
Args:
- id_server (str): The identity server to proxy to
+ id_server_url (str): The identity server to proxy to
email (str): The email to send the message to
client_secret (str): The unique client_secret sends by the user
send_attempt (int): Which attempt this is
@@ -393,6 +438,11 @@ class IdentityHandler(BaseHandler):
"client_secret": client_secret,
"send_attempt": send_attempt,
}
+
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ id_server_url = self.rewrite_id_server_url(id_server_url)
+
if next_link:
params["next_link"] = next_link
@@ -407,7 +457,8 @@ class IdentityHandler(BaseHandler):
try:
data = await self.http_client.post_json_get_json(
- id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
+ "%s/_matrix/identity/api/v1/validate/email/requestToken"
+ % (id_server_url,),
params,
)
return data
@@ -419,7 +470,7 @@ class IdentityHandler(BaseHandler):
async def requestMsisdnToken(
self,
- id_server,
+ id_server_url,
country,
phone_number,
client_secret,
@@ -430,7 +481,7 @@ class IdentityHandler(BaseHandler):
Request an external server send an SMS message on our behalf for the purposes of
threepid validation.
Args:
- id_server (str): The identity server to proxy to
+ id_server_url (str): The identity server to proxy to
country (str): The country code of the phone number
phone_number (str): The number to send the message to
client_secret (str): The unique client_secret sends by the user
@@ -458,9 +509,13 @@ class IdentityHandler(BaseHandler):
"details and update your config file."
)
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ id_server_url = self.rewrite_id_server_url(id_server_url)
try:
data = await self.http_client.post_json_get_json(
- id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
+ "%s/_matrix/identity/api/v1/validate/msisdn/requestToken"
+ % (id_server_url,),
params,
)
except HttpResponseException as e:
@@ -554,6 +609,89 @@ class IdentityHandler(BaseHandler):
logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
raise SynapseError(400, "Error contacting the identity server")
+ # TODO: The following two methods are used for proxying IS requests using
+ # the CS API. They should be consolidated with those in RoomMemberHandler
+ # https://github.com/matrix-org/synapse-dinsic/issues/25
+
+ @defer.inlineCallbacks
+ def proxy_lookup_3pid(self, id_server, medium, address):
+ """Looks up a 3pid in the passed identity server.
+
+ Args:
+ id_server (str): The server name (including port, if required)
+ of the identity server to use.
+ medium (str): The type of the third party identifier (e.g. "email").
+ address (str): The third party identifier (e.g. "foo@example.com").
+
+ Returns:
+ Deferred[dict]: The result of the lookup. See
+ https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
+ for details
+ """
+ if not self._enable_lookup:
+ raise AuthError(
+ 403, "Looking up third-party identifiers is denied from this server"
+ )
+
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ try:
+ data = yield self.http_client.get_json(
+ "%s/_matrix/identity/api/v1/lookup" % (id_server_url,),
+ {"medium": medium, "address": address},
+ )
+
+ if "mxid" in data:
+ if "signatures" not in data:
+ raise AuthError(401, "No signatures on 3pid binding")
+ yield self._verify_any_signature(data, id_server)
+
+ except HttpResponseException as e:
+ logger.info("Proxied lookup failed: %r", e)
+ raise e.to_synapse_error()
+ except IOError as e:
+ logger.info("Failed to contact %s: %s", id_server, e)
+ raise ProxiedRequestError(503, "Failed to contact identity server")
+
+ return data
+
+ @defer.inlineCallbacks
+ def proxy_bulk_lookup_3pid(self, id_server, threepids):
+ """Looks up given 3pids in the passed identity server.
+
+ Args:
+ id_server (str): The server name (including port, if required)
+ of the identity server to use.
+ threepids ([[str, str]]): The third party identifiers to lookup, as
+ a list of 2-string sized lists ([medium, address]).
+
+ Returns:
+ Deferred[dict]: The result of the lookup. See
+ https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
+ for details
+ """
+ if not self._enable_lookup:
+ raise AuthError(
+ 403, "Looking up third-party identifiers is denied from this server"
+ )
+
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ try:
+ data = yield self.http_client.post_json_get_json(
+ "%s/_matrix/identity/api/v1/bulk_lookup" % (id_server_url,),
+ {"threepids": threepids},
+ )
+
+ except HttpResponseException as e:
+ logger.info("Proxied lookup failed: %r", e)
+ raise e.to_synapse_error()
+ except IOError as e:
+ logger.info("Failed to contact %s: %s", id_server, e)
+ raise ProxiedRequestError(503, "Failed to contact identity server")
+
+ defer.returnValue(data)
+
async def lookup_3pid(self, id_server, medium, address, id_access_token=None):
"""Looks up a 3pid in the passed identity server.
@@ -568,10 +706,13 @@ class IdentityHandler(BaseHandler):
Returns:
str|None: the matrix ID of the 3pid, or None if it is not recognized.
"""
+ # Rewrite id_server URL if necessary
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
if id_access_token is not None:
try:
results = await self._lookup_3pid_v2(
- id_server, id_access_token, medium, address
+ id_server_url, id_access_token, medium, address
)
return results
@@ -589,14 +730,15 @@ class IdentityHandler(BaseHandler):
logger.warning("Error when looking up hashing details: %s", e)
return None
- return await self._lookup_3pid_v1(id_server, medium, address)
+ return await self._lookup_3pid_v1(id_server, id_server_url, medium, address)
- async def _lookup_3pid_v1(self, id_server, medium, address):
+ async def _lookup_3pid_v1(self, id_server, id_server_url, medium, address):
"""Looks up a 3pid in the passed identity server using v1 lookup.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
+ id_server_url (str): The actual, reachable domain of the id server
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
@@ -604,8 +746,8 @@ class IdentityHandler(BaseHandler):
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
- data = await self.blacklisting_http_client.get_json(
- "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
+ data = await self.http_client.get_json(
+ "%s/_matrix/identity/api/v1/lookup" % (id_server_url,),
{"medium": medium, "address": address},
)
@@ -621,12 +763,11 @@ class IdentityHandler(BaseHandler):
return None
- async def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
+ async def _lookup_3pid_v2(self, id_server_url, id_access_token, medium, address):
"""Looks up a 3pid in the passed identity server using v2 lookup.
Args:
- id_server (str): The server name (including port, if required)
- of the identity server to use.
+ id_server_url (str): The protocol scheme and domain of the id server
id_access_token (str): The access token to authenticate to the identity server with
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
@@ -636,8 +777,8 @@ class IdentityHandler(BaseHandler):
"""
# Check what hashing details are supported by this identity server
try:
- hash_details = await self.blacklisting_http_client.get_json(
- "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
+ hash_details = await self.http_client.get_json(
+ "%s/_matrix/identity/v2/hash_details" % (id_server_url,),
{"access_token": id_access_token},
)
except TimeoutError:
@@ -645,15 +786,14 @@ class IdentityHandler(BaseHandler):
if not isinstance(hash_details, dict):
logger.warning(
- "Got non-dict object when checking hash details of %s%s: %s",
- id_server_scheme,
- id_server,
+ "Got non-dict object when checking hash details of %s: %s",
+ id_server_url,
hash_details,
)
raise SynapseError(
400,
- "Non-dict object from %s%s during v2 hash_details request: %s"
- % (id_server_scheme, id_server, hash_details),
+ "Non-dict object from %s during v2 hash_details request: %s"
+ % (id_server_url, hash_details),
)
# Extract information from hash_details
@@ -667,8 +807,8 @@ class IdentityHandler(BaseHandler):
):
raise SynapseError(
400,
- "Invalid hash details received from identity server %s%s: %s"
- % (id_server_scheme, id_server, hash_details),
+ "Invalid hash details received from identity server %s: %s"
+ % (id_server_url, hash_details),
)
# Check if any of the supported lookup algorithms are present
@@ -690,7 +830,7 @@ class IdentityHandler(BaseHandler):
else:
logger.warning(
"None of the provided lookup algorithms of %s are supported: %s",
- id_server,
+ id_server_url,
supported_lookup_algorithms,
)
raise SynapseError(
@@ -703,8 +843,8 @@ class IdentityHandler(BaseHandler):
headers = {"Authorization": create_id_access_token_header(id_access_token)}
try:
- lookup_results = await self.blacklisting_http_client.post_json_get_json(
- "%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
+ lookup_results = await self.http_client.post_json_get_json(
+ "%s/_matrix/identity/v2/lookup" % (id_server_url,),
{
"addresses": [lookup_value],
"algorithm": lookup_algorithm,
@@ -731,30 +871,31 @@ class IdentityHandler(BaseHandler):
mxid = lookup_results["mappings"].get(lookup_value)
return mxid
- async def _verify_any_signature(self, data, server_hostname):
- if server_hostname not in data["signatures"]:
- raise AuthError(401, "No signature from server %s" % (server_hostname,))
- for key_name, signature in data["signatures"][server_hostname].items():
- try:
- key_data = await self.blacklisting_http_client.get_json(
- "%s%s/_matrix/identity/api/v1/pubkey/%s"
- % (id_server_scheme, server_hostname, key_name)
- )
- except TimeoutError:
- raise SynapseError(500, "Timed out contacting identity server")
+ async def _verify_any_signature(self, data, id_server):
+ if id_server not in data["signatures"]:
+ raise AuthError(401, "No signature from server %s" % (id_server,))
+
+ for key_name, signature in data["signatures"][id_server].items():
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ key_data = await self.http_client.get_json(
+ "%s/_matrix/identity/api/v1/pubkey/%s" % (id_server_url, key_name)
+ )
if "public_key" not in key_data:
raise AuthError(
- 401, "No public key named %s from %s" % (key_name, server_hostname)
+ 401, "No public key named %s from %s" % (key_name, id_server)
)
verify_signed_json(
data,
- server_hostname,
+ id_server,
decode_verify_key_bytes(
key_name, decode_base64(key_data["public_key"])
),
)
return
+ raise AuthError(401, "No signature from server %s" % (id_server,))
+
async def ask_id_server_for_third_party_invite(
self,
requester,
@@ -814,15 +955,17 @@ class IdentityHandler(BaseHandler):
"sender_avatar_url": inviter_avatar_url,
}
+ # Rewrite the identity server URL if necessary
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
# Add the identity service access token to the JSON body and use the v2
# Identity Service endpoints if id_access_token is present
data = None
- base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server)
+ base_url = "%s/_matrix/identity" % (id_server_url,)
if id_access_token:
- key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % (
- id_server_scheme,
- id_server,
+ key_validity_url = "%s/_matrix/identity/v2/pubkey/isvalid" % (
+ id_server_url,
)
# Attempt a v2 lookup
@@ -841,9 +984,8 @@ class IdentityHandler(BaseHandler):
raise e
if data is None:
- key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
- id_server_scheme,
- id_server,
+ key_validity_url = "%s/_matrix/identity/api/v1/pubkey/isvalid" % (
+ id_server_url,
)
url = base_url + "/api/v1/store-invite"
@@ -855,10 +997,7 @@ class IdentityHandler(BaseHandler):
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
logger.warning(
- "Error trying to call /store-invite on %s%s: %s",
- id_server_scheme,
- id_server,
- e,
+ "Error trying to call /store-invite on %s: %s", id_server_url, e,
)
if data is None:
@@ -871,10 +1010,9 @@ class IdentityHandler(BaseHandler):
)
except HttpResponseException as e:
logger.warning(
- "Error calling /store-invite on %s%s with fallback "
+ "Error calling /store-invite on %s with fallback "
"encoding: %s",
- id_server_scheme,
- id_server,
+ id_server_url,
e,
)
raise e
@@ -895,6 +1033,30 @@ class IdentityHandler(BaseHandler):
display_name = data["display_name"]
return token, public_keys, fallback_public_key, display_name
+ async def bind_email_using_internal_sydent_api(
+ self, id_server_url: str, email: str, user_id: str,
+ ):
+ """Bind an email to a fully qualified user ID using the internal API of an
+ instance of Sydent.
+
+ Args:
+ id_server_url: The URL of the Sydent instance
+ email: The email address to bind
+ user_id: The user ID to bind the email to
+
+ Raises:
+ HTTPResponseException: On a non-2xx HTTP response.
+ """
+ # id_server_url is assumed to have no trailing slashes
+ url = id_server_url + "/_matrix/identity/internal/bind"
+ body = {
+ "address": email,
+ "medium": "email",
+ "mxid": user_id,
+ }
+
+ await self.http_client.post_json_get_json(url, body)
+
def create_id_access_token_header(id_access_token):
"""Create an Authorization header for passing to SimpleHttpClient as the header value
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 649ca1f08a..665ad19b5d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -17,8 +17,6 @@
import logging
from typing import Optional, Tuple
-from six import iteritems, itervalues, string_types
-
from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
@@ -246,7 +244,7 @@ class MessageHandler(object):
"avatar_url": profile.avatar_url,
"display_name": profile.display_name,
}
- for user_id, profile in iteritems(users_with_profile)
+ for user_id, profile in users_with_profile.items()
}
def maybe_schedule_expiry(self, event):
@@ -715,7 +713,7 @@ class EventCreationHandler(object):
spam_error = self.spam_checker.check_event_for_spam(event)
if spam_error:
- if not isinstance(spam_error, string_types):
+ if not isinstance(spam_error, str):
spam_error = "Spam is not permitted here"
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
@@ -881,7 +879,9 @@ class EventCreationHandler(object):
"""
room_alias = RoomAlias.from_string(room_alias_str)
try:
- mapping = yield directory_handler.get_association(room_alias)
+ mapping = yield defer.ensureDeferred(
+ directory_handler.get_association(room_alias)
+ )
except SynapseError as e:
# Turn M_NOT_FOUND errors into M_BAD_ALIAS errors.
if e.errcode == Codes.NOT_FOUND:
@@ -988,7 +988,7 @@ class EventCreationHandler(object):
state_to_include_ids = [
e_id
- for k, e_id in iteritems(current_state_ids)
+ for k, e_id in current_state_ids.items()
if k[0] in self.room_invite_state_types
or k == (EventTypes.Member, event.sender)
]
@@ -1002,7 +1002,7 @@ class EventCreationHandler(object):
"content": e.content,
"sender": e.sender,
}
- for e in itervalues(state_to_include)
+ for e in state_to_include.values()
]
invitee = UserID.from_string(event.state_key)
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index d7442c62a7..da06582d4b 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -15,9 +15,6 @@
# limitations under the License.
import logging
-from six import iteritems
-
-from twisted.internet import defer
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
@@ -99,8 +96,7 @@ class PaginationHandler(object):
job["longest_max_lifetime"],
)
- @defer.inlineCallbacks
- def purge_history_for_rooms_in_range(self, min_ms, max_ms):
+ async def purge_history_for_rooms_in_range(self, min_ms, max_ms):
"""Purge outdated events from rooms within the given retention range.
If a default retention policy is defined in the server's configuration and its
@@ -139,13 +135,13 @@ class PaginationHandler(object):
include_null,
)
- rooms = yield self.store.get_rooms_for_retention_period_in_range(
+ rooms = await self.store.get_rooms_for_retention_period_in_range(
min_ms, max_ms, include_null
)
logger.debug("[purge] Rooms to purge: %s", rooms)
- for room_id, retention_policy in iteritems(rooms):
+ for room_id, retention_policy in rooms.items():
logger.info("[purge] Attempting to purge messages in room %s", room_id)
if room_id in self._purges_in_progress_by_room:
@@ -167,9 +163,9 @@ class PaginationHandler(object):
# Figure out what token we should start purging at.
ts = self.clock.time_msec() - max_lifetime
- stream_ordering = yield self.store.find_first_stream_ordering_after_ts(ts)
+ stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
- r = yield self.store.get_room_event_before_stream_ordering(
+ r = await self.store.get_room_event_before_stream_ordering(
room_id, stream_ordering,
)
if not r:
@@ -229,8 +225,7 @@ class PaginationHandler(object):
)
return purge_id
- @defer.inlineCallbacks
- def _purge_history(self, purge_id, room_id, token, delete_local_events):
+ async def _purge_history(self, purge_id, room_id, token, delete_local_events):
"""Carry out a history purge on a room.
Args:
@@ -239,14 +234,11 @@ class PaginationHandler(object):
token (str): topological token to delete events before
delete_local_events (bool): True to delete local events as well as
remote ones
-
- Returns:
- Deferred
"""
self._purges_in_progress_by_room.add(room_id)
try:
- with (yield self.pagination_lock.write(room_id)):
- yield self.storage.purge_events.purge_history(
+ with await self.pagination_lock.write(room_id):
+ await self.storage.purge_events.purge_history(
room_id, token, delete_local_events
)
logger.info("[purge] complete")
@@ -284,9 +276,7 @@ class PaginationHandler(object):
await self.store.get_room_version_id(room_id)
# first check that we have no users in this room
- joined = await defer.maybeDeferred(
- self.store.is_host_joined, room_id, self._server_name
- )
+ joined = await self.store.is_host_joined(room_id, self._server_name)
if joined:
raise SynapseError(400, "Users are still joined to this room")
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 3594f3b00f..d2f25ae12a 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -25,9 +25,7 @@ The methods that define policy are:
import abc
import logging
from contextlib import contextmanager
-from typing import Dict, Iterable, List, Set
-
-from six import iteritems, itervalues
+from typing import Dict, Iterable, List, Set, Tuple
from prometheus_client import Counter
from typing_extensions import ContextManager
@@ -170,14 +168,14 @@ class BasePresenceHandler(abc.ABC):
for user_id in user_ids
}
- missing = [user_id for user_id, state in iteritems(states) if not state]
+ missing = [user_id for user_id, state in states.items() if not state]
if missing:
# There are things not in our in memory cache. Lets pull them out of
# the database.
res = await self.store.get_presence_for_users(missing)
states.update(res)
- missing = [user_id for user_id, state in iteritems(states) if not state]
+ missing = [user_id for user_id, state in states.items() if not state]
if missing:
new = {
user_id: UserPresenceState.default(user_id) for user_id in missing
@@ -632,7 +630,7 @@ class PresenceHandler(BasePresenceHandler):
await self._update_states(
[
prev_state.copy_and_replace(last_user_sync_ts=time_now_ms)
- for prev_state in itervalues(prev_states)
+ for prev_state in prev_states.values()
]
)
self.external_process_last_updated_ms.pop(process_id, None)
@@ -775,7 +773,9 @@ class PresenceHandler(BasePresenceHandler):
return False
- async def get_all_presence_updates(self, last_id, current_id, limit):
+ async def get_all_presence_updates(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, list]], int, bool]:
"""
Gets a list of presence update rows from between the given stream ids.
Each row has:
@@ -787,10 +787,31 @@ class PresenceHandler(BasePresenceHandler):
- last_user_sync_ts(int)
- status_msg(int)
- currently_active(int)
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
"""
+
# TODO(markjh): replicate the unpersisted changes.
# This could use the in-memory stores for recent changes.
- rows = await self.store.get_all_presence_updates(last_id, current_id, limit)
+ rows = await self.store.get_all_presence_updates(
+ instance_name, last_id, current_id, limit
+ )
return rows
def notify_new_event(self):
@@ -1087,7 +1108,7 @@ class PresenceEventSource(object):
return (list(updates.values()), max_token)
else:
return (
- [s for s in itervalues(updates) if s.state != PresenceState.OFFLINE],
+ [s for s in updates.values() if s.state != PresenceState.OFFLINE],
max_token,
)
@@ -1323,11 +1344,11 @@ def get_interested_remotes(store, states, state_handler):
# hosts in those rooms.
room_ids_to_states, users_to_states = yield get_interested_parties(store, states)
- for room_id, states in iteritems(room_ids_to_states):
+ for room_id, states in room_ids_to_states.items():
hosts = yield state_handler.get_current_hosts_in_room(room_id)
hosts_and_states.append((hosts, states))
- for user_id, states in iteritems(users_to_states):
+ for user_id, states in users_to_states.items():
host = get_domain_from_id(user_id)
hosts_and_states.append(([host], states))
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 302efc1b9a..dd8979e750 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,10 +15,13 @@
# limitations under the License.
import logging
+from typing import List
-from six import raise_from
+from six.moves import range
-from twisted.internet import defer
+from signedjson.sign import sign_json
+
+from twisted.internet import defer, reactor
from synapse.api.errors import (
AuthError,
@@ -27,6 +31,7 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
+from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID, create_requester, get_domain_from_id
@@ -46,6 +51,8 @@ class BaseProfileHandler(BaseHandler):
subclass MasterProfileHandler
"""
+ PROFILE_REPLICATE_INTERVAL = 2 * 60 * 1000
+
def __init__(self, hs):
super(BaseProfileHandler, self).__init__(hs)
@@ -56,6 +63,88 @@ class BaseProfileHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
+
+ self.max_avatar_size = hs.config.max_avatar_size
+ self.allowed_avatar_mimetypes = hs.config.allowed_avatar_mimetypes
+ self.replicate_user_profiles_to = hs.config.replicate_user_profiles_to
+
+ if hs.config.worker_app is None:
+ self.clock.looping_call(
+ self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
+ )
+
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ reactor.callWhenRunning(self._assign_profile_replication_batches)
+ reactor.callWhenRunning(self._replicate_profiles)
+ # Add a looping call to replicate_profiles: this handles retries
+ # if the replication is unsuccessful when the user updated their
+ # profile.
+ self.clock.looping_call(
+ self._replicate_profiles, self.PROFILE_REPLICATE_INTERVAL
+ )
+
+ @defer.inlineCallbacks
+ def _assign_profile_replication_batches(self):
+ """If no profile replication has been done yet, allocate replication batch
+ numbers to each profile to start the replication process.
+ """
+ logger.info("Assigning profile batch numbers...")
+ total = 0
+ while True:
+ assigned = yield self.store.assign_profile_batch()
+ total += assigned
+ if assigned == 0:
+ break
+ logger.info("Assigned %d profile batch numbers", total)
+
+ @defer.inlineCallbacks
+ def _replicate_profiles(self):
+ """If any profile data has been updated and not pushed to the replication targets,
+ replicate it.
+ """
+ host_batches = yield self.store.get_replication_hosts()
+ latest_batch = yield self.store.get_latest_profile_replication_batch_number()
+ if latest_batch is None:
+ latest_batch = -1
+ for repl_host in self.hs.config.replicate_user_profiles_to:
+ if repl_host not in host_batches:
+ host_batches[repl_host] = -1
+ try:
+ for i in range(host_batches[repl_host] + 1, latest_batch + 1):
+ yield self._replicate_host_profile_batch(repl_host, i)
+ except Exception:
+ logger.exception(
+ "Exception while replicating to %s: aborting for now", repl_host
+ )
+
+ @defer.inlineCallbacks
+ def _replicate_host_profile_batch(self, host, batchnum):
+ logger.info("Replicating profile batch %d to %s", batchnum, host)
+ batch_rows = yield self.store.get_profile_batch(batchnum)
+ batch = {
+ UserID(r["user_id"], self.hs.hostname).to_string(): (
+ {"display_name": r["displayname"], "avatar_url": r["avatar_url"]}
+ if r["active"]
+ else None
+ )
+ for r in batch_rows
+ }
+
+ url = "https://%s/_matrix/identity/api/v1/replicate_profiles" % (host,)
+ body = {"batchnum": batchnum, "batch": batch, "origin_server": self.hs.hostname}
+ signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0])
+ try:
+ yield self.http_client.post_json_get_json(url, signed_body)
+ yield self.store.update_replication_batch_for_host(host, batchnum)
+ logger.info("Sucessfully replicated profile batch %d to %s", batchnum, host)
+ except Exception:
+ # This will get retried when the looping call next comes around
+ logger.exception(
+ "Failed to replicate profile batch %d to %s", batchnum, host
+ )
+ raise
+
@defer.inlineCallbacks
def get_profile(self, user_id):
target_user = UserID.from_string(user_id)
@@ -84,7 +173,7 @@ class BaseProfileHandler(BaseHandler):
)
return result
except RequestSendFailed as e:
- raise_from(SynapseError(502, "Failed to fetch profile"), e)
+ raise SynapseError(502, "Failed to fetch profile") from e
except HttpResponseException as e:
raise e.to_synapse_error()
@@ -135,7 +224,7 @@ class BaseProfileHandler(BaseHandler):
ignore_backoff=True,
)
except RequestSendFailed as e:
- raise_from(SynapseError(502, "Failed to fetch profile"), e)
+ raise SynapseError(502, "Failed to fetch profile") from e
except HttpResponseException as e:
raise e.to_synapse_error()
@@ -155,7 +244,7 @@ class BaseProfileHandler(BaseHandler):
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver")
- if not by_admin and target_user != requester.user:
+ if not by_admin and requester and target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname")
if not by_admin and not self.hs.config.enable_set_displayname:
@@ -175,13 +264,23 @@ class BaseProfileHandler(BaseHandler):
if new_displayname == "":
new_displayname = None
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ await self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+
# If the admin changes the display name of a user, the requesting user cannot send
# the join event to update the displayname in the rooms.
# This must be done by the target user himself.
if by_admin:
requester = create_requester(target_user)
- await self.store.set_profile_displayname(target_user.localpart, new_displayname)
+ await self.store.set_profile_displayname(
+ target_user.localpart, new_displayname, new_batchnum
+ )
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart)
@@ -191,6 +290,50 @@ class BaseProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user)
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
+ @defer.inlineCallbacks
+ def set_active(
+ self, users: List[UserID], active: bool, hide: bool,
+ ):
+ """
+ Sets the 'active' flag on a set of user profiles. If set to false, the
+ accounts are considered deactivated or hidden.
+
+ If 'hide' is true, then we interpret active=False as a request to try to
+ hide the users rather than deactivating them. This means withholding the
+ profiles from replication (and mark it as inactive) rather than clearing
+ the profile from the HS DB.
+
+ Note that unlike set_displayname and set_avatar_url, this does *not*
+ perform authorization checks! This is because the only place it's used
+ currently is in account deactivation where we've already done these
+ checks anyway.
+
+ Args:
+ users: The users to modify
+ active: Whether to set the user to active or inactive
+ hide: Whether to hide the user (withold from replication). If
+ False and active is False, user will have their profile
+ erased
+
+ Returns:
+ Deferred
+ """
+ if len(self.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ yield self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+
+ yield self.store.set_profiles_active(users, active, hide, new_batchnum)
+
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
@defer.inlineCallbacks
def get_avatar_url(self, target_user):
if self.hs.is_mine(target_user):
@@ -212,7 +355,7 @@ class BaseProfileHandler(BaseHandler):
ignore_backoff=True,
)
except RequestSendFailed as e:
- raise_from(SynapseError(502, "Failed to fetch profile"), e)
+ raise SynapseError(502, "Failed to fetch profile") from e
except HttpResponseException as e:
raise e.to_synapse_error()
@@ -241,11 +384,51 @@ class BaseProfileHandler(BaseHandler):
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
)
+ # Enforce a max avatar size if one is defined
+ if self.max_avatar_size or self.allowed_avatar_mimetypes:
+ media_id = self._validate_and_parse_media_id_from_avatar_url(new_avatar_url)
+
+ # Check that this media exists locally
+ media_info = await self.store.get_local_media(media_id)
+ if not media_info:
+ raise SynapseError(
+ 400, "Unknown media id supplied", errcode=Codes.NOT_FOUND
+ )
+
+ # Ensure avatar does not exceed max allowed avatar size
+ media_size = media_info["media_length"]
+ if self.max_avatar_size and media_size > self.max_avatar_size:
+ raise SynapseError(
+ 400,
+ "Avatars must be less than %s bytes in size"
+ % (self.max_avatar_size,),
+ errcode=Codes.TOO_LARGE,
+ )
+
+ # Ensure the avatar's file type is allowed
+ if (
+ self.allowed_avatar_mimetypes
+ and media_info["media_type"] not in self.allowed_avatar_mimetypes
+ ):
+ raise SynapseError(
+ 400, "Avatar file type '%s' not allowed" % media_info["media_type"]
+ )
+
# Same like set_displayname
if by_admin:
requester = create_requester(target_user)
- await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ await self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+
+ await self.store.set_profile_avatar_url(
+ target_user.localpart, new_avatar_url, new_batchnum
+ )
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart)
@@ -255,6 +438,23 @@ class BaseProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user)
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
+ def _validate_and_parse_media_id_from_avatar_url(self, mxc):
+ """Validate and parse a provided avatar url and return the local media id
+
+ Args:
+ mxc (str): A mxc URL
+
+ Returns:
+ str: The ID of the media
+ """
+ avatar_pieces = mxc.split("/")
+ if len(avatar_pieces) != 4 or avatar_pieces[0] != "mxc:":
+ raise SynapseError(400, "Invalid avatar URL '%s' supplied" % mxc)
+ return avatar_pieces[-1]
+
@defer.inlineCallbacks
def on_profile_query(self, args):
user = UserID.from_string(args["user_id"])
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 51979ea43e..f223630d43 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -17,7 +17,7 @@
import logging
from synapse import types
-from synapse.api.constants import MAX_USERID_LENGTH, LoginType
+from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import assert_params_in_dict
@@ -26,7 +26,8 @@ from synapse.replication.http.register import (
ReplicationPostRegisterActionsServlet,
ReplicationRegisterServlet,
)
-from synapse.types import RoomAlias, RoomID, UserID, create_requester
+from synapse.storage.state import StateFilter
+from synapse.types import RoomAlias, UserID, create_requester
from synapse.util.async_helpers import Linearizer
from ._base import BaseHandler
@@ -47,6 +48,7 @@ class RegistrationHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
self.identity_handler = self.hs.get_handlers().identity_handler
self.ratelimiter = hs.get_registration_ratelimiter()
@@ -59,6 +61,8 @@ class RegistrationHandler(BaseHandler):
)
self._server_notices_mxid = hs.config.server_notices_mxid
+ self._show_in_user_directory = self.hs.config.show_users_in_user_directory
+
if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = RegisterDeviceReplicationServlet.make_client(
@@ -74,8 +78,18 @@ class RegistrationHandler(BaseHandler):
self.session_lifetime = hs.config.session_lifetime
async def check_username(
- self, localpart, guest_access_token=None, assigned_user_id=None
+ self, localpart, guest_access_token=None, assigned_user_id=None,
):
+ """
+
+ Args:
+ localpart (str|None): The user's localpart
+ guest_access_token (str|None): A guest's access token
+ assigned_user_id (str|None): An existing User ID for this user if pre-calculated
+
+ Returns:
+ Deferred
+ """
if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
@@ -118,6 +132,8 @@ class RegistrationHandler(BaseHandler):
raise SynapseError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
+
+ # Retrieve guest user information from provided access token
user_data = await self.auth.get_user_by_access_token(guest_access_token)
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
raise AuthError(
@@ -203,6 +219,11 @@ class RegistrationHandler(BaseHandler):
address=address,
)
+ if default_display_name:
+ await self.profile_handler.set_displayname(
+ user, None, default_display_name, by_admin=True
+ )
+
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(localpart)
await self.user_directory_handler.handle_local_profile_change(
@@ -233,6 +254,10 @@ class RegistrationHandler(BaseHandler):
address=address,
)
+ await self.profile_handler.set_displayname(
+ user, None, default_display_name, by_admin=True
+ )
+
# Successfully registered
break
except SynapseError:
@@ -266,55 +291,169 @@ class RegistrationHandler(BaseHandler):
}
# Bind email to new account
- await self._register_email_threepid(user_id, threepid_dict, None)
+ await self.register_email_threepid(user_id, threepid_dict, None)
+
+ # Prevent the new user from showing up in the user directory if the server
+ # mandates it.
+ if not self._show_in_user_directory:
+ await self.store.add_account_data_for_user(
+ user_id, "im.vector.hide_profile", {"hide_profile": True}
+ )
+ await self.profile_handler.set_active([user], False, True)
return user_id
- async def _auto_join_rooms(self, user_id):
- """Automatically joins users to auto join rooms - creating the room in the first place
- if the user is the first to be created.
+ async def _create_and_join_rooms(self, user_id: str):
+ """
+ Create the auto-join rooms and join or invite the user to them.
+
+ This should only be called when the first "real" user registers.
Args:
- user_id(str): The user to join
+ user_id: The user to join
"""
- # auto-join the user to any rooms we're supposed to dump them into
- fake_requester = create_requester(user_id)
+ # Getting the handlers during init gives a dependency loop.
+ room_creation_handler = self.hs.get_room_creation_handler()
+ room_member_handler = self.hs.get_room_member_handler()
- # try to create the room if we're the first real user on the server. Note
- # that an auto-generated support or bot user is not a real user and will never be
- # the user to create the room
- should_auto_create_rooms = False
- is_real_user = await self.store.is_real_user(user_id)
- if self.hs.config.autocreate_auto_join_rooms and is_real_user:
- count = await self.store.count_real_users()
- should_auto_create_rooms = count == 1
- for r in self.hs.config.auto_join_rooms:
+ # Generate a stub for how the rooms will be configured.
+ stub_config = {
+ "preset": self.hs.config.registration.autocreate_auto_join_room_preset,
+ }
+
+ # If the configuration providers a user ID to create rooms with, use
+ # that instead of the first user registered.
+ requires_join = False
+ if self.hs.config.registration.auto_join_user_id:
+ fake_requester = create_requester(
+ self.hs.config.registration.auto_join_user_id
+ )
+
+ # If the room requires an invite, add the user to the list of invites.
+ if self.hs.config.registration.auto_join_room_requires_invite:
+ stub_config["invite"] = [user_id]
+
+ # If the room is being created by a different user, the first user
+ # registered needs to join it. Note that in the case of an invitation
+ # being necessary this will occur after the invite was sent.
+ requires_join = True
+ else:
+ fake_requester = create_requester(user_id)
+
+ # Choose whether to federate the new room.
+ if not self.hs.config.registration.autocreate_auto_join_rooms_federated:
+ stub_config["creation_content"] = {"m.federate": False}
+
+ for r in self.hs.config.registration.auto_join_rooms:
logger.info("Auto-joining %s to %s", user_id, r)
+
try:
- if should_auto_create_rooms:
- room_alias = RoomAlias.from_string(r)
- if self.hs.hostname != room_alias.domain:
- logger.warning(
- "Cannot create room alias %s, "
- "it does not match server domain",
- r,
- )
- else:
- # create room expects the localpart of the room alias
- room_alias_localpart = room_alias.localpart
-
- # getting the RoomCreationHandler during init gives a dependency
- # loop
- await self.hs.get_room_creation_handler().create_room(
- fake_requester,
- config={
- "preset": "public_chat",
- "room_alias_name": room_alias_localpart,
- },
+ room_alias = RoomAlias.from_string(r)
+
+ if self.hs.hostname != room_alias.domain:
+ logger.warning(
+ "Cannot create room alias %s, "
+ "it does not match server domain",
+ r,
+ )
+ else:
+ # A shallow copy is OK here since the only key that is
+ # modified is room_alias_name.
+ config = stub_config.copy()
+ # create room expects the localpart of the room alias
+ config["room_alias_name"] = room_alias.localpart
+
+ info, _ = await room_creation_handler.create_room(
+ fake_requester, config=config, ratelimit=False,
+ )
+
+ # If the room does not require an invite, but another user
+ # created it, then ensure the first user joins it.
+ if requires_join:
+ await room_member_handler.update_membership(
+ requester=create_requester(user_id),
+ target=UserID.from_string(user_id),
+ room_id=info["room_id"],
+ # Since it was just created, there are no remote hosts.
+ remote_room_hosts=[],
+ action="join",
ratelimit=False,
)
+
+ except ConsentNotGivenError as e:
+ # Technically not necessary to pull out this error though
+ # moving away from bare excepts is a good thing to do.
+ logger.error("Failed to join new user to %r: %r", r, e)
+ except Exception as e:
+ logger.error("Failed to join new user to %r: %r", r, e)
+
+ async def _join_rooms(self, user_id: str):
+ """
+ Join or invite the user to the auto-join rooms.
+
+ Args:
+ user_id: The user to join
+ """
+ room_member_handler = self.hs.get_room_member_handler()
+
+ for r in self.hs.config.registration.auto_join_rooms:
+ logger.info("Auto-joining %s to %s", user_id, r)
+
+ try:
+ room_alias = RoomAlias.from_string(r)
+
+ if RoomAlias.is_valid(r):
+ (
+ room_id,
+ remote_room_hosts,
+ ) = await room_member_handler.lookup_room_alias(room_alias)
+ room_id = room_id.to_string()
else:
- await self._join_user_to_room(fake_requester, r)
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (r,)
+ )
+
+ # Calculate whether the room requires an invite or can be
+ # joined directly. Note that unless a join rule of public exists,
+ # it is treated as requiring an invite.
+ requires_invite = True
+
+ state = await self.store.get_filtered_current_state_ids(
+ room_id, StateFilter.from_types([(EventTypes.JoinRules, "")])
+ )
+
+ event_id = state.get((EventTypes.JoinRules, ""))
+ if event_id:
+ join_rules_event = await self.store.get_event(
+ event_id, allow_none=True
+ )
+ if join_rules_event:
+ join_rule = join_rules_event.content.get("join_rule", None)
+ requires_invite = join_rule and join_rule != JoinRules.PUBLIC
+
+ # Send the invite, if necessary.
+ if requires_invite:
+ await room_member_handler.update_membership(
+ requester=create_requester(
+ self.hs.config.registration.auto_join_user_id
+ ),
+ target=UserID.from_string(user_id),
+ room_id=room_id,
+ remote_room_hosts=remote_room_hosts,
+ action="invite",
+ ratelimit=False,
+ )
+
+ # Send the join.
+ await room_member_handler.update_membership(
+ requester=create_requester(user_id),
+ target=UserID.from_string(user_id),
+ room_id=room_id,
+ remote_room_hosts=remote_room_hosts,
+ action="join",
+ ratelimit=False,
+ )
+
except ConsentNotGivenError as e:
# Technically not necessary to pull out this error though
# moving away from bare excepts is a good thing to do.
@@ -322,6 +461,29 @@ class RegistrationHandler(BaseHandler):
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
+ async def _auto_join_rooms(self, user_id: str):
+ """Automatically joins users to auto join rooms - creating the room in the first place
+ if the user is the first to be created.
+
+ Args:
+ user_id: The user to join
+ """
+ # auto-join the user to any rooms we're supposed to dump them into
+
+ # try to create the room if we're the first real user on the server. Note
+ # that an auto-generated support or bot user is not a real user and will never be
+ # the user to create the room
+ should_auto_create_rooms = False
+ is_real_user = await self.store.is_real_user(user_id)
+ if self.hs.config.registration.autocreate_auto_join_rooms and is_real_user:
+ count = await self.store.count_real_users()
+ should_auto_create_rooms = count == 1
+
+ if should_auto_create_rooms:
+ await self._create_and_join_rooms(user_id)
+ else:
+ await self._join_rooms(user_id)
+
async def post_consent_actions(self, user_id):
"""A series of registration actions that can only be carried out once consent
has been granted
@@ -331,7 +493,10 @@ class RegistrationHandler(BaseHandler):
"""
await self._auto_join_rooms(user_id)
- async def appservice_register(self, user_localpart, as_token):
+ async def appservice_register(
+ self, user_localpart, as_token, password_hash, display_name
+ ):
+ # FIXME: this should be factored out and merged with normal register()
user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token)
@@ -348,12 +513,25 @@ class RegistrationHandler(BaseHandler):
self.check_user_id_not_appservice_exclusive(user_id, allowed_appservice=service)
+ display_name = display_name or user.localpart
+
await self.register_with_store(
user_id=user_id,
- password_hash="",
+ password_hash=password_hash,
appservice_id=service_id,
- create_profile_with_displayname=user.localpart,
+ create_profile_with_displayname=display_name,
+ )
+
+ await self.profile_handler.set_displayname(
+ user, None, display_name, by_admin=True
)
+
+ if self.hs.config.user_directory_search_all_users:
+ profile = await self.store.get_profileinfo(user_localpart)
+ await self.user_directory_handler.handle_local_profile_change(
+ user_id, profile
+ )
+
return user_id
def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
@@ -380,6 +558,37 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
+ async def shadow_register(self, localpart, display_name, auth_result, params):
+ """Invokes the current registration on another server, using
+ shared secret registration, passing in any auth_results from
+ other registration UI auth flows (e.g. validated 3pids)
+ Useful for setting up shadow/backup accounts on a parallel deployment.
+ """
+
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ await self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/register?access_token=%s" % (shadow_hs_url, as_token),
+ {
+ # XXX: auth_result is an unspecified extension for shadow registration
+ "auth_result": auth_result,
+ # XXX: another unspecified extension for shadow registration to ensure
+ # that the displayname is correctly set by the masters erver
+ "display_name": display_name,
+ "username": localpart,
+ "password": params.get("password"),
+ "bind_msisdn": params.get("bind_msisdn"),
+ "device_id": params.get("device_id"),
+ "initial_device_display_name": params.get(
+ "initial_device_display_name"
+ ),
+ "inhibit_login": False,
+ "access_token": as_token,
+ },
+ )
+
async def _generate_user_id(self):
if self._next_generated_user_id is None:
with await self._generate_user_id_linearizer.queue(()):
@@ -392,30 +601,6 @@ class RegistrationHandler(BaseHandler):
self._next_generated_user_id += 1
return str(id)
- async def _join_user_to_room(self, requester, room_identifier):
- room_member_handler = self.hs.get_room_member_handler()
- if RoomID.is_valid(room_identifier):
- room_id = room_identifier
- elif RoomAlias.is_valid(room_identifier):
- room_alias = RoomAlias.from_string(room_identifier)
- room_id, remote_room_hosts = await room_member_handler.lookup_room_alias(
- room_alias
- )
- room_id = room_id.to_string()
- else:
- raise SynapseError(
- 400, "%s was not legal room ID or room alias" % (room_identifier,)
- )
-
- await room_member_handler.update_membership(
- requester=requester,
- target=requester.user,
- room_id=room_id,
- remote_room_hosts=remote_room_hosts,
- action="join",
- ratelimit=False,
- )
-
def check_registration_ratelimit(self, address):
"""A simple helper method to check whether the registration rate limit has been hit
for a given IP address
@@ -557,6 +742,7 @@ class RegistrationHandler(BaseHandler):
if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
threepid = auth_result[LoginType.EMAIL_IDENTITY]
+
# Necessary due to auth checks prior to the threepid being
# written to the db
if is_threepid_reserved(
@@ -564,7 +750,32 @@ class RegistrationHandler(BaseHandler):
):
await self.store.upsert_monthly_active_user(user_id)
- await self._register_email_threepid(user_id, threepid, access_token)
+ await self.register_email_threepid(user_id, threepid, access_token)
+
+ if self.hs.config.bind_new_user_emails_to_sydent:
+ # Attempt to call Sydent's internal bind API on the given identity server
+ # to bind this threepid
+ id_server_url = self.hs.config.bind_new_user_emails_to_sydent
+
+ logger.debug(
+ "Attempting the bind email of %s to identity server: %s using "
+ "internal Sydent bind API.",
+ user_id,
+ self.hs.config.bind_new_user_emails_to_sydent,
+ )
+
+ try:
+ await self.identity_handler.bind_email_using_internal_sydent_api(
+ id_server_url, threepid["address"], user_id
+ )
+ except Exception as e:
+ logger.warning(
+ "Failed to bind email of '%s' to Sydent instance '%s' ",
+ "using Sydent internal bind API: %s",
+ user_id,
+ id_server_url,
+ e,
+ )
if auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN]
@@ -585,7 +796,7 @@ class RegistrationHandler(BaseHandler):
await self.store.user_set_consent_version(user_id, consent_version)
await self.post_consent_actions(user_id)
- async def _register_email_threepid(self, user_id, threepid, token):
+ async def register_email_threepid(self, user_id, threepid, token):
"""Add an email address as a 3pid identifier
Also adds an email pusher for the email address, if configured in the
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 61db3ccc43..ac9a58e766 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -24,9 +24,12 @@ import string
from collections import OrderedDict
from typing import Tuple
-from six import iteritems, string_types
-
-from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
+from synapse.api.constants import (
+ EventTypes,
+ JoinRules,
+ RoomCreationPreset,
+ RoomEncryptionAlgorithms,
+)
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events.utils import copy_power_levels_contents
@@ -56,31 +59,6 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000
class RoomCreationHandler(BaseHandler):
-
- PRESETS_DICT = {
- RoomCreationPreset.PRIVATE_CHAT: {
- "join_rules": JoinRules.INVITE,
- "history_visibility": "shared",
- "original_invitees_have_ops": False,
- "guest_can_join": True,
- "power_level_content_override": {"invite": 0},
- },
- RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
- "join_rules": JoinRules.INVITE,
- "history_visibility": "shared",
- "original_invitees_have_ops": True,
- "guest_can_join": True,
- "power_level_content_override": {"invite": 0},
- },
- RoomCreationPreset.PUBLIC_CHAT: {
- "join_rules": JoinRules.PUBLIC,
- "history_visibility": "shared",
- "original_invitees_have_ops": False,
- "guest_can_join": False,
- "power_level_content_override": {},
- },
- }
-
def __init__(self, hs):
super(RoomCreationHandler, self).__init__(hs)
@@ -89,6 +67,39 @@ class RoomCreationHandler(BaseHandler):
self.room_member_handler = hs.get_room_member_handler()
self.config = hs.config
+ # Room state based off defined presets
+ self._presets_dict = {
+ RoomCreationPreset.PRIVATE_CHAT: {
+ "join_rules": JoinRules.INVITE,
+ "history_visibility": "shared",
+ "original_invitees_have_ops": False,
+ "guest_can_join": True,
+ "power_level_content_override": {"invite": 0},
+ },
+ RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
+ "join_rules": JoinRules.INVITE,
+ "history_visibility": "shared",
+ "original_invitees_have_ops": True,
+ "guest_can_join": True,
+ "power_level_content_override": {"invite": 0},
+ },
+ RoomCreationPreset.PUBLIC_CHAT: {
+ "join_rules": JoinRules.PUBLIC,
+ "history_visibility": "shared",
+ "original_invitees_have_ops": False,
+ "guest_can_join": False,
+ "power_level_content_override": {},
+ },
+ }
+
+ # Modify presets to selectively enable encryption by default per homeserver config
+ for preset_name, preset_config in self._presets_dict.items():
+ encrypted = (
+ preset_name
+ in self.config.encryption_enabled_by_default_for_room_presets
+ )
+ preset_config["encrypted"] = encrypted
+
self._replication = hs.get_replication_data_handler()
# linearizer to stop two upgrades happening at once
@@ -324,7 +335,19 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- if not self.spam_checker.user_may_create_room(user_id):
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to create rooms
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = await self.auth.is_server_admin(requester.user)
+
+ if not is_requester_admin and not self.spam_checker.user_may_create_room(
+ user_id, invite_list=[], third_party_invite_list=[], cloning=True
+ ):
raise SynapseError(403, "You are not permitted to create rooms")
creation_content = {
@@ -364,7 +387,7 @@ class RoomCreationHandler(BaseHandler):
# map from event_id to BaseEvent
old_room_state_events = await self.store.get_events(old_room_state_ids.values())
- for k, old_event_id in iteritems(old_room_state_ids):
+ for k, old_event_id in old_room_state_ids.items():
old_event = old_room_state_events.get(old_event_id)
if old_event:
initial_state[k] = old_event.content
@@ -417,7 +440,7 @@ class RoomCreationHandler(BaseHandler):
old_room_member_state_events = await self.store.get_events(
old_room_member_state_ids.values()
)
- for k, old_event in iteritems(old_room_member_state_events):
+ for k, old_event in old_room_member_state_events.items():
# Only transfer ban events
if (
"membership" in old_event.content
@@ -570,8 +593,14 @@ class RoomCreationHandler(BaseHandler):
403, "You are not permitted to create rooms", Codes.FORBIDDEN
)
+ invite_list = config.get("invite", [])
+ invite_3pid_list = config.get("invite_3pid", [])
+
if not is_requester_admin and not self.spam_checker.user_may_create_room(
- user_id
+ user_id,
+ invite_list=invite_list,
+ third_party_invite_list=invite_3pid_list,
+ cloning=False,
):
raise SynapseError(403, "You are not permitted to create rooms")
@@ -582,7 +611,7 @@ class RoomCreationHandler(BaseHandler):
"room_version", self.config.default_room_version.identifier
)
- if not isinstance(room_version_id, string_types):
+ if not isinstance(room_version_id, str):
raise SynapseError(400, "room_version must be a string", Codes.BAD_JSON)
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
@@ -606,7 +635,6 @@ class RoomCreationHandler(BaseHandler):
else:
room_alias = None
- invite_list = config.get("invite", [])
for i in invite_list:
try:
uid = UserID.from_string(i)
@@ -628,8 +656,6 @@ class RoomCreationHandler(BaseHandler):
% (user_id,),
)
- invite_3pid_list = config.get("invite_3pid", [])
-
visibility = config.get("visibility", None)
is_public = visibility == "public"
@@ -731,6 +757,7 @@ class RoomCreationHandler(BaseHandler):
"invite",
ratelimit=False,
content=content,
+ new_room=True,
)
for invite_3pid in invite_3pid_list:
@@ -746,6 +773,7 @@ class RoomCreationHandler(BaseHandler):
id_server,
requester,
txn_id=None,
+ new_room=True,
id_access_token=id_access_token,
)
@@ -798,7 +826,7 @@ class RoomCreationHandler(BaseHandler):
)
return last_stream_id
- config = RoomCreationHandler.PRESETS_DICT[preset_config]
+ config = self._presets_dict[preset_config]
creator_id = creator.user.to_string()
@@ -815,6 +843,7 @@ class RoomCreationHandler(BaseHandler):
"join",
ratelimit=False,
content=creator_join_profile,
+ new_room=True,
)
# We treat the power levels override specially as this needs to be one
@@ -888,6 +917,13 @@ class RoomCreationHandler(BaseHandler):
etype=etype, state_key=state_key, content=content
)
+ if config["encrypted"]:
+ last_sent_stream_id = await send(
+ etype=EventTypes.RoomEncryption,
+ state_key="",
+ content={"algorithm": RoomEncryptionAlgorithms.DEFAULT},
+ )
+
return last_sent_stream_id
async def _generate_room_id(
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 4cbc02b0d0..5e05be6181 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -17,8 +17,6 @@ import logging
from collections import namedtuple
from typing import Any, Dict, Optional
-from six import iteritems
-
import msgpack
from unpaddedbase64 import decode_base64, encode_base64
@@ -271,7 +269,7 @@ class RoomListHandler(BaseHandler):
event_map = yield self.store.get_events(
[
event_id
- for key, event_id in iteritems(current_state_ids)
+ for key, event_id in current_state_ids.items()
if key[0]
in (
EventTypes.Create,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 0f7af982f0..d585114f2d 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -17,10 +17,9 @@
import abc
import logging
+from http import HTTPStatus
from typing import Dict, Iterable, List, Optional, Tuple
-from six.moves import http_client
-
from synapse import types
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError
@@ -59,6 +58,7 @@ class RoomMemberHandler(object):
self.registration_handler = hs.get_registration_handler()
self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler()
+ self.identity_handler = hs.get_handlers().identity_handler
self.member_linearizer = Linearizer(name="member")
@@ -288,8 +288,10 @@ class RoomMemberHandler(object):
third_party_signed: Optional[dict] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
+ new_room: bool = False,
require_consent: bool = True,
) -> Tuple[Optional[str], int]:
+ """Update a user's membership in a room"""
key = (room_id,)
with (await self.member_linearizer.queue(key)):
@@ -303,6 +305,7 @@ class RoomMemberHandler(object):
third_party_signed=third_party_signed,
ratelimit=ratelimit,
content=content,
+ new_room=new_room,
require_consent=require_consent,
)
@@ -319,6 +322,7 @@ class RoomMemberHandler(object):
third_party_signed: Optional[dict] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
+ new_room: bool = False,
require_consent: bool = True,
) -> Tuple[Optional[str], int]:
content_specified = bool(content)
@@ -361,7 +365,7 @@ class RoomMemberHandler(object):
if effective_membership_state == Membership.INVITE:
# block any attempts to invite the server notices mxid
if target.to_string() == self._server_notices_mxid:
- raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user")
+ raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
block_invite = False
@@ -383,8 +387,15 @@ class RoomMemberHandler(object):
)
block_invite = True
+ is_published = await self.store.is_room_published(room_id)
+
if not self.spam_checker.user_may_invite(
- requester.user.to_string(), target.to_string(), room_id
+ requester.user.to_string(),
+ target.to_string(),
+ third_party_invite=None,
+ room_id=room_id,
+ new_room=new_room,
+ published_room=is_published,
):
logger.info("Blocking invite due to spam checker")
block_invite = True
@@ -444,7 +455,7 @@ class RoomMemberHandler(object):
is_blocked = await self._is_server_notice_room(room_id)
if is_blocked:
raise SynapseError(
- http_client.FORBIDDEN,
+ HTTPStatus.FORBIDDEN,
"You cannot reject this invite",
errcode=Codes.CANNOT_LEAVE_SERVER_NOTICE_ROOM,
)
@@ -462,6 +473,25 @@ class RoomMemberHandler(object):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to join rooms
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = await self.auth.is_server_admin(requester.user)
+
+ inviter = await self._get_inviter(target.to_string(), room_id)
+ if not is_requester_admin:
+ # We assume that if the spam checker allowed the user to create
+ # a room then they're allowed to join it.
+ if not new_room and not self.spam_checker.user_may_join_room(
+ target.to_string(), room_id, is_invited=inviter is not None
+ ):
+ raise SynapseError(403, "Not allowed to join this room")
+
if not is_host_in_room:
inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
@@ -736,6 +766,7 @@ class RoomMemberHandler(object):
id_server: str,
requester: Requester,
txn_id: Optional[str],
+ new_room: bool = False,
id_access_token: Optional[str] = None,
) -> int:
if self.config.block_non_admin_invites:
@@ -759,6 +790,16 @@ class RoomMemberHandler(object):
Codes.FORBIDDEN,
)
+ can_invite = await self.third_party_event_rules.check_threepid_can_be_invited(
+ medium, address, room_id
+ )
+ if not can_invite:
+ raise SynapseError(
+ 403,
+ "This third-party identifier can not be invited in this room",
+ Codes.FORBIDDEN,
+ )
+
if not self._enable_lookup:
raise SynapseError(
403, "Looking up third-party identifiers is denied from this server"
@@ -768,6 +809,19 @@ class RoomMemberHandler(object):
id_server, medium, address, id_access_token
)
+ is_published = await self.store.is_room_published(room_id)
+
+ if not self.spam_checker.user_may_invite(
+ requester.user.to_string(),
+ invitee,
+ third_party_invite={"medium": medium, "address": address},
+ room_id=room_id,
+ new_room=new_room,
+ published_room=is_published,
+ ):
+ logger.info("Blocking invite due to spam checker")
+ raise SynapseError(403, "Invites have been disabled on this server")
+
if invitee:
_, stream_id = await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 4d245b618b..5d34989f21 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2017 New Vector Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 6bdb24baff..4c7524493e 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -18,8 +18,6 @@ import itertools
import logging
from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple
-from six import iteritems, itervalues
-
import attr
from prometheus_client import Counter
@@ -390,7 +388,7 @@ class SyncHandler(object):
# result returned by the event source is poor form (it might cache
# the object)
room_id = event["room_id"]
- event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"}
+ event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
receipt_key = since_token.receipt_key if since_token else "0"
@@ -408,7 +406,7 @@ class SyncHandler(object):
for event in receipts:
room_id = event["room_id"]
# exclude room id, as above
- event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"}
+ event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
return now_token, ephemeral_by_room
@@ -454,7 +452,7 @@ class SyncHandler(object):
current_state_ids_map = await self.state.get_current_state_ids(
room_id
)
- current_state_ids = frozenset(itervalues(current_state_ids_map))
+ current_state_ids = frozenset(current_state_ids_map.values())
recents = await filter_events_for_client(
self.storage,
@@ -509,7 +507,7 @@ class SyncHandler(object):
current_state_ids_map = await self.state.get_current_state_ids(
room_id
)
- current_state_ids = frozenset(itervalues(current_state_ids_map))
+ current_state_ids = frozenset(current_state_ids_map.values())
loaded_recents = await filter_events_for_client(
self.storage,
@@ -909,7 +907,7 @@ class SyncHandler(object):
logger.debug("filtering state from %r...", state_ids)
state_ids = {
t: event_id
- for t, event_id in iteritems(state_ids)
+ for t, event_id in state_ids.items()
if cache.get(t[1]) != event_id
}
logger.debug("...to %r", state_ids)
@@ -1430,7 +1428,7 @@ class SyncHandler(object):
if since_token:
for joined_sync in sync_result_builder.joined:
it = itertools.chain(
- joined_sync.timeline.events, itervalues(joined_sync.state)
+ joined_sync.timeline.events, joined_sync.state.values()
)
for event in it:
if event.type == EventTypes.Member:
@@ -1505,7 +1503,7 @@ class SyncHandler(object):
newly_left_rooms = []
room_entries = []
invited = []
- for room_id, events in iteritems(mem_change_events_by_room_id):
+ for room_id, events in mem_change_events_by_room_id.items():
logger.debug(
"Membership changes in %s: [%s]",
room_id,
@@ -1993,17 +1991,17 @@ def _calculate_state(
event_id_to_key = {
e: key
for key, e in itertools.chain(
- iteritems(timeline_contains),
- iteritems(previous),
- iteritems(timeline_start),
- iteritems(current),
+ timeline_contains.items(),
+ previous.items(),
+ timeline_start.items(),
+ current.items(),
)
}
- c_ids = set(itervalues(current))
- ts_ids = set(itervalues(timeline_start))
- p_ids = set(itervalues(previous))
- tc_ids = set(itervalues(timeline_contains))
+ c_ids = set(current.values())
+ ts_ids = set(timeline_start.values())
+ p_ids = set(previous.values())
+ tc_ids = set(timeline_contains.values())
# If we are lazyloading room members, we explicitly add the membership events
# for the senders in the timeline into the state block returned by /sync,
@@ -2017,7 +2015,7 @@ def _calculate_state(
if lazy_load_members:
p_ids.difference_update(
- e for t, e in iteritems(timeline_start) if t[0] == EventTypes.Member
+ e for t, e in timeline_start.items() if t[0] == EventTypes.Member
)
state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index c7bc14c623..6c7abaa578 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,9 +15,7 @@
import logging
from collections import namedtuple
-from typing import List
-
-from twisted.internet import defer
+from typing import List, Tuple
from synapse.api.errors import AuthError, SynapseError
from synapse.logging.context import run_in_background
@@ -115,8 +113,7 @@ class TypingHandler(object):
def is_typing(self, member):
return member.user_id in self._room_typing.get(member.room_id, [])
- @defer.inlineCallbacks
- def started_typing(self, target_user, auth_user, room_id, timeout):
+ async def started_typing(self, target_user, auth_user, room_id, timeout):
target_user_id = target_user.to_string()
auth_user_id = auth_user.to_string()
@@ -126,7 +123,7 @@ class TypingHandler(object):
if target_user_id != auth_user_id:
raise AuthError(400, "Cannot set another user's typing state")
- yield self.auth.check_user_in_room(room_id, target_user_id)
+ await self.auth.check_user_in_room(room_id, target_user_id)
logger.debug("%s has started typing in %s", target_user_id, room_id)
@@ -145,8 +142,7 @@ class TypingHandler(object):
self._push_update(member=member, typing=True)
- @defer.inlineCallbacks
- def stopped_typing(self, target_user, auth_user, room_id):
+ async def stopped_typing(self, target_user, auth_user, room_id):
target_user_id = target_user.to_string()
auth_user_id = auth_user.to_string()
@@ -156,7 +152,7 @@ class TypingHandler(object):
if target_user_id != auth_user_id:
raise AuthError(400, "Cannot set another user's typing state")
- yield self.auth.check_user_in_room(room_id, target_user_id)
+ await self.auth.check_user_in_room(room_id, target_user_id)
logger.debug("%s has stopped typing in %s", target_user_id, room_id)
@@ -164,12 +160,11 @@ class TypingHandler(object):
self._stopped_typing(member)
- @defer.inlineCallbacks
def user_left_room(self, user, room_id):
user_id = user.to_string()
if self.is_mine_id(user_id):
member = RoomMember(room_id=room_id, user_id=user_id)
- yield self._stopped_typing(member)
+ self._stopped_typing(member)
def _stopped_typing(self, member):
if member.user_id not in self._room_typing.get(member.room_id, set()):
@@ -188,10 +183,9 @@ class TypingHandler(object):
self._push_update_local(member=member, typing=typing)
- @defer.inlineCallbacks
- def _push_remote(self, member, typing):
+ async def _push_remote(self, member, typing):
try:
- users = yield self.state.get_current_users_in_room(member.room_id)
+ users = await self.state.get_current_users_in_room(member.room_id)
self._member_last_federation_poke[member] = self.clock.time_msec()
now = self.clock.time_msec()
@@ -215,8 +209,7 @@ class TypingHandler(object):
except Exception:
logger.exception("Error pushing typing notif to remotes")
- @defer.inlineCallbacks
- def _recv_edu(self, origin, content):
+ async def _recv_edu(self, origin, content):
room_id = content["room_id"]
user_id = content["user_id"]
@@ -231,7 +224,7 @@ class TypingHandler(object):
)
return
- users = yield self.state.get_current_users_in_room(room_id)
+ users = await self.state.get_current_users_in_room(room_id)
domains = {get_domain_from_id(u) for u in users}
if self.server_name in domains:
@@ -259,14 +252,31 @@ class TypingHandler(object):
)
async def get_all_typing_updates(
- self, last_id: int, current_id: int, limit: int
- ) -> List[dict]:
- """Get up to `limit` typing updates between the given tokens, earliest
- updates first.
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, list]], int, bool]:
+ """Get updates for typing replication stream.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
"""
if last_id == current_id:
- return []
+ return [], current_id, False
changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
last_id
@@ -280,9 +290,16 @@ class TypingHandler(object):
serial = self._room_serials[room_id]
if last_id < serial <= current_id:
typing = self._room_typing[room_id]
- rows.append((serial, room_id, list(typing)))
+ rows.append((serial, [room_id, list(typing)]))
rows.sort()
- return rows[:limit]
+
+ limited = False
+ if len(rows) > limit:
+ rows = rows[:limit]
+ current_id = rows[-1][0]
+ limited = True
+
+ return rows, current_id, limited
def get_current_token(self):
return self._latest_room_serial
@@ -306,7 +323,7 @@ class TypingNotificationEventSource(object):
"content": {"user_ids": list(typing)},
}
- def get_new_events(self, from_key, room_ids, **kwargs):
+ async def get_new_events(self, from_key, room_ids, **kwargs):
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)
handler = self.get_typing_handler()
@@ -320,7 +337,7 @@ class TypingNotificationEventSource(object):
events.append(self._make_event_for(room_id))
- return defer.succeed((events, handler._latest_room_serial))
+ return (events, handler._latest_room_serial)
def get_current_key(self):
return self.get_typing_handler()._latest_room_serial
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 12423b909a..521b6d620d 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -15,8 +15,6 @@
import logging
-from six import iteritems, iterkeys
-
import synapse.metrics
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.handlers.state_deltas import StateDeltasHandler
@@ -289,7 +287,7 @@ class UserDirectoryHandler(StateDeltasHandler):
users_with_profile = await self.state.get_current_users_in_room(room_id)
# Remove every user from the sharing tables for that room.
- for user_id in iterkeys(users_with_profile):
+ for user_id in users_with_profile.keys():
await self.store.remove_user_who_share_room(user_id, room_id)
# Then, re-add them to the tables.
@@ -298,7 +296,7 @@ class UserDirectoryHandler(StateDeltasHandler):
# which when ran over an entire room, will result in the same values
# being added multiple times. The batching upserts shouldn't make this
# too bad, though.
- for user_id, profile in iteritems(users_with_profile):
+ for user_id, profile in users_with_profile.items():
await self._handle_new_user(room_id, user_id, profile)
async def _handle_new_user(self, room_id, user_id, profile):
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 3cef747a4d..8743e9839d 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -15,11 +15,9 @@
# limitations under the License.
import logging
+import urllib
from io import BytesIO
-from six import raise_from, text_type
-from six.moves import urllib
-
import treq
from canonicaljson import encode_canonical_json, json
from netaddr import IPAddress
@@ -577,7 +575,7 @@ class SimpleHttpClient(object):
# This can happen e.g. because the body is too large.
raise
except Exception as e:
- raise_from(SynapseError(502, ("Failed to download remote body: %s" % e)), e)
+ raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e
return (
length,
@@ -638,7 +636,7 @@ def encode_urlencode_args(args):
def encode_urlencode_arg(arg):
- if isinstance(arg, text_type):
+ if isinstance(arg, str):
return arg.encode("utf-8")
elif isinstance(arg, list):
return [encode_urlencode_arg(i) for i in arg]
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index f5f917f5ae..c5fc746f2f 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -48,6 +48,9 @@ class MatrixFederationAgent(object):
tls_client_options_factory (FederationPolicyForHTTPS|None):
factory to use for fetching client tls options, or none to disable TLS.
+ user_agent (bytes):
+ The user agent header to use for federation requests.
+
_srv_resolver (SrvResolver|None):
SRVResolver impl to use for looking up SRV records. None to use a default
implementation.
@@ -61,6 +64,7 @@ class MatrixFederationAgent(object):
self,
reactor,
tls_client_options_factory,
+ user_agent,
_srv_resolver=None,
_well_known_resolver=None,
):
@@ -78,6 +82,7 @@ class MatrixFederationAgent(object):
),
pool=self._pool,
)
+ self.user_agent = user_agent
if _well_known_resolver is None:
_well_known_resolver = WellKnownResolver(
@@ -87,6 +92,7 @@ class MatrixFederationAgent(object):
pool=self._pool,
contextFactory=tls_client_options_factory,
),
+ user_agent=self.user_agent,
)
self._well_known_resolver = _well_known_resolver
@@ -149,7 +155,7 @@ class MatrixFederationAgent(object):
parsed_uri = urllib.parse.urlparse(uri)
# We need to make sure the host header is set to the netloc of the
- # server.
+ # server and that a user-agent is provided.
if headers is None:
headers = Headers()
else:
@@ -157,6 +163,8 @@ class MatrixFederationAgent(object):
if not headers.hasHeader(b"host"):
headers.addRawHeader(b"host", parsed_uri.netloc)
+ if not headers.hasHeader(b"user-agent"):
+ headers.addRawHeader(b"user-agent", self.user_agent)
res = yield make_deferred_yieldable(
self._agent.request(method, uri, headers, bodyProducer)
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 7ddfad286d..89a3b041ce 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -23,6 +23,7 @@ import attr
from twisted.internet import defer
from twisted.web.client import RedirectAgent, readBody
from twisted.web.http import stringToDatetime
+from twisted.web.http_headers import Headers
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock
@@ -78,7 +79,12 @@ class WellKnownResolver(object):
"""
def __init__(
- self, reactor, agent, well_known_cache=None, had_well_known_cache=None
+ self,
+ reactor,
+ agent,
+ user_agent,
+ well_known_cache=None,
+ had_well_known_cache=None,
):
self._reactor = reactor
self._clock = Clock(reactor)
@@ -92,6 +98,7 @@ class WellKnownResolver(object):
self._well_known_cache = well_known_cache
self._had_valid_well_known_cache = had_well_known_cache
self._well_known_agent = RedirectAgent(agent)
+ self.user_agent = user_agent
@defer.inlineCallbacks
def get_well_known(self, server_name):
@@ -227,6 +234,10 @@ class WellKnownResolver(object):
uri = b"https://%s/.well-known/matrix/server" % (server_name,)
uri_str = uri.decode("ascii")
+ headers = {
+ b"User-Agent": [self.user_agent],
+ }
+
i = 0
while True:
i += 1
@@ -234,7 +245,9 @@ class WellKnownResolver(object):
logger.info("Fetching %s", uri_str)
try:
response = yield make_deferred_yieldable(
- self._well_known_agent.request(b"GET", uri)
+ self._well_known_agent.request(
+ b"GET", uri, headers=Headers(headers)
+ )
)
body = yield make_deferred_yieldable(readBody(response))
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 2d47b9ea00..18f6a8fd29 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -17,11 +17,9 @@ import cgi
import logging
import random
import sys
+import urllib
from io import BytesIO
-from six import raise_from, string_types
-from six.moves import urllib
-
import attr
import treq
from canonicaljson import encode_canonical_json
@@ -199,7 +197,14 @@ class MatrixFederationHttpClient(object):
self.reactor = Reactor()
- self.agent = MatrixFederationAgent(self.reactor, tls_client_options_factory)
+ user_agent = hs.version_string
+ if hs.config.user_agent_suffix:
+ user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
+ user_agent = user_agent.encode("ascii")
+
+ self.agent = MatrixFederationAgent(
+ self.reactor, tls_client_options_factory, user_agent
+ )
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
@@ -432,10 +437,10 @@ class MatrixFederationHttpClient(object):
except TimeoutError as e:
raise RequestSendFailed(e, can_retry=True) from e
except DNSLookupError as e:
- raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e)
+ raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e
except Exception as e:
logger.info("Failed to send request: %s", e)
- raise_from(RequestSendFailed(e, can_retry=True), e)
+ raise RequestSendFailed(e, can_retry=True) from e
incoming_responses_counter.labels(
request.method, response.code
@@ -487,7 +492,7 @@ class MatrixFederationHttpClient(object):
# Retry if the error is a 429 (Too Many Requests),
# otherwise just raise a standard HttpResponseException
if response.code == 429:
- raise_from(RequestSendFailed(e, can_retry=True), e)
+ raise RequestSendFailed(e, can_retry=True) from e
else:
raise e
@@ -998,7 +1003,7 @@ def encode_query_args(args):
encoded_args = {}
for k, vs in args.items():
- if isinstance(vs, string_types):
+ if isinstance(vs, str):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 2331a2a4b0..d192de7923 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -16,10 +16,10 @@
import collections
import html
-import http.client
import logging
import types
import urllib
+from http import HTTPStatus
from io import BytesIO
from typing import Awaitable, Callable, TypeVar, Union
@@ -188,7 +188,7 @@ def return_html_error(
exc_info=(f.type, f.value, f.getTracebackObject()),
)
else:
- code = http.HTTPStatus.INTERNAL_SERVER_ERROR
+ code = HTTPStatus.INTERNAL_SERVER_ERROR
msg = "Internal server error"
logger.error(
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 167293c46d..cbc37eac6e 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -19,6 +19,7 @@ from typing import Optional
from twisted.python.failure import Failure
from twisted.web.server import Request, Site
+from synapse.config.server import ListenerConfig
from synapse.http import redact_uri
from synapse.http.request_metrics import RequestMetrics, requests_counter
from synapse.logging.context import LoggingContext, PreserveLoggingContext
@@ -350,7 +351,7 @@ class SynapseSite(Site):
self,
logger_name,
site_tag,
- config,
+ config: ListenerConfig,
resource,
server_version_string,
*args,
@@ -360,7 +361,8 @@ class SynapseSite(Site):
self.site_tag = site_tag
- proxied = config.get("x_forwarded", False)
+ assert config.http_options is not None
+ proxied = config.http_options.x_forwarded
self.requestFactory = XForwardedForRequest if proxied else SynapseRequest
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii")
diff --git a/synapse/logging/formatter.py b/synapse/logging/formatter.py
index fbf570c756..d736ad5b9b 100644
--- a/synapse/logging/formatter.py
+++ b/synapse/logging/formatter.py
@@ -16,8 +16,7 @@
import logging
import traceback
-
-from six import StringIO
+from io import StringIO
class LogFormatter(logging.Formatter):
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 5dddf57008..73bef5e5ca 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -171,8 +171,9 @@ import logging
import re
import types
from functools import wraps
-from typing import TYPE_CHECKING, Dict
+from typing import TYPE_CHECKING, Dict, Optional, Type
+import attr
from canonicaljson import json
from twisted.internet import defer
@@ -232,6 +233,30 @@ except ImportError:
LogContextScopeManager = None # type: ignore
+try:
+ from rust_python_jaeger_reporter import Reporter
+
+ @attr.s(slots=True, frozen=True)
+ class _WrappedRustReporter:
+ """Wrap the reporter to ensure `report_span` never throws.
+ """
+
+ _reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter))
+
+ def set_process(self, *args, **kwargs):
+ return self._reporter.set_process(*args, **kwargs)
+
+ def report_span(self, span):
+ try:
+ return self._reporter.report_span(span)
+ except Exception:
+ logger.exception("Failed to report span")
+
+ RustReporter = _WrappedRustReporter # type: Optional[Type[_WrappedRustReporter]]
+except ImportError:
+ RustReporter = None
+
+
logger = logging.getLogger(__name__)
@@ -320,11 +345,19 @@ def init_tracer(hs: "HomeServer"):
set_homeserver_whitelist(hs.config.opentracer_whitelist)
- JaegerConfig(
+ config = JaegerConfig(
config=hs.config.jaeger_config,
service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()),
scope_manager=LogContextScopeManager(hs.config),
- ).initialize_tracer()
+ )
+
+ # If we have the rust jaeger reporter available let's use that.
+ if RustReporter:
+ logger.info("Using rust_python_jaeger_reporter library")
+ tracer = config.create_tracer(RustReporter(), config.sampler)
+ opentracing.set_global_tracer(tracer)
+ else:
+ config.initialize_tracer()
# Whitelisting
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 9cf31f96b3..6035672698 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -22,8 +22,6 @@ import threading
import time
from typing import Callable, Dict, Iterable, Optional, Tuple, Union
-import six
-
import attr
from prometheus_client import Counter, Gauge, Histogram
from prometheus_client.core import (
@@ -83,7 +81,7 @@ class LaterGauge(object):
return
if isinstance(calls, dict):
- for k, v in six.iteritems(calls):
+ for k, v in calls.items():
g.add_metric(k, v)
else:
g.add_metric([], calls)
@@ -194,7 +192,7 @@ class InFlightGauge(object):
gauge = GaugeMetricFamily(
"_".join([self.name, name]), "", labels=self.labels
)
- for key, metrics in six.iteritems(metrics_by_key):
+ for key, metrics in metrics_by_key.items():
gauge.add_metric(key, getattr(metrics, name))
yield gauge
@@ -465,6 +463,12 @@ event_processing_last_ts = Gauge("synapse_event_processing_last_ts", "", ["name"
# finished being processed.
event_processing_lag = Gauge("synapse_event_processing_lag", "", ["name"])
+event_processing_lag_by_event = Histogram(
+ "synapse_event_processing_lag_by_event",
+ "Time between an event being persisted and it being queued up to be sent to the relevant remote servers",
+ ["name"],
+)
+
# Build info of the running server.
build_info = Gauge(
"synapse_build_info", "Build information", ["pythonversion", "version", "osversion"]
diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py
index ab7f948ed4..4304c60d56 100644
--- a/synapse/metrics/_exposition.py
+++ b/synapse/metrics/_exposition.py
@@ -208,6 +208,7 @@ class MetricsHandler(BaseHTTPRequestHandler):
raise
self.send_response(200)
self.send_header("Content-Type", CONTENT_TYPE_LATEST)
+ self.send_header("Content-Length", str(len(output)))
self.end_headers()
self.wfile.write(output)
@@ -261,4 +262,6 @@ class MetricsResource(Resource):
def render_GET(self, request):
request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii"))
- return generate_latest(self.registry)
+ response = generate_latest(self.registry)
+ request.setHeader(b"Content-Length", str(len(response)))
+ return response
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index e75d964ac8..43ffe6faf0 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -17,8 +17,6 @@
import logging
from collections import namedtuple
-from six import iteritems, itervalues
-
from prometheus_client import Counter
from twisted.internet import defer
@@ -130,7 +128,7 @@ class BulkPushRuleEvaluator(object):
event, prev_state_ids, for_verification=False
)
auth_events = yield self.store.get_events(auth_events_ids)
- auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)}
+ auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
sender_level = get_user_power_level(event.sender, auth_events)
@@ -162,7 +160,7 @@ class BulkPushRuleEvaluator(object):
condition_cache = {}
- for uid, rules in iteritems(rules_by_user):
+ for uid, rules in rules_by_user.items():
if event.sender == uid:
continue
@@ -395,7 +393,7 @@ class RulesForRoom(object):
# If the event is a join event then it will be in current state evnts
# map but not in the DB, so we have to explicitly insert it.
if event.type == EventTypes.Member:
- for event_id in itervalues(member_event_ids):
+ for event_id in member_event_ids.values():
if event_id == event.event_id:
members[event_id] = (event.state_key, event.membership)
@@ -404,7 +402,7 @@ class RulesForRoom(object):
interested_in_user_ids = {
user_id
- for user_id, membership in itervalues(members)
+ for user_id, membership in members.values()
if membership == Membership.JOIN
}
@@ -415,7 +413,7 @@ class RulesForRoom(object):
)
user_ids = {
- uid for uid, have_pusher in iteritems(if_users_with_pushers) if have_pusher
+ uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
}
logger.debug("With pushers: %r", user_ids)
@@ -436,7 +434,7 @@ class RulesForRoom(object):
)
ret_rules_by_user.update(
- item for item in iteritems(rules_by_user) if item[0] is not None
+ item for item in rules_by_user.items() if item[0] is not None
)
self.update_cache(sequence, members, ret_rules_by_user, state_group)
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index eaaa7afc91..ed60dbc1bf 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -129,6 +129,8 @@ class HttpPusher(object):
@defer.inlineCallbacks
def _update_badge(self):
+ # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
+ # to be largely redundant. perhaps we can remove it.
badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
yield self._send_badge(badge)
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index d57a66a697..dda560b2c2 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -17,12 +17,11 @@ import email.mime.multipart
import email.utils
import logging
import time
+import urllib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Iterable, List, TypeVar
-from six.moves import urllib
-
import bleach
import jinja2
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 11032491af..8e0d3a416d 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -18,8 +18,6 @@ import logging
import re
from typing import Pattern
-from six import string_types
-
from synapse.events import EventBase
from synapse.types import UserID
from synapse.util.caches import register_cache
@@ -131,7 +129,7 @@ class PushRuleEvaluatorForEvent(object):
# XXX: optimisation: cache our pattern regexps
if condition["key"] == "content.body":
body = self._event.content.get("body", None)
- if not body:
+ if not body or not isinstance(body, str):
return False
return _glob_matches(pattern, body, word_boundary=True)
@@ -147,7 +145,7 @@ class PushRuleEvaluatorForEvent(object):
return False
body = self._event.content.get("body", None)
- if not body:
+ if not body or not isinstance(body, str):
return False
# Similar to _glob_matches, but do not treat display_name as a glob.
@@ -244,7 +242,7 @@ def _flatten_dict(d, prefix=[], result=None):
if result is None:
result = {}
for key, value in d.items():
- if isinstance(value, string_types):
+ if isinstance(value, str):
result[".".join(prefix + [key])] = value.lower()
elif hasattr(value, "items"):
_flatten_dict(value, prefix=(prefix + [key]), result=result)
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 88d203aa44..f6a5458681 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -215,11 +215,9 @@ class PusherPool:
try:
# Need to subtract 1 from the minimum because the lower bound here
# is not inclusive
- updated_receipts = yield self.store.get_all_updated_receipts(
+ users_affected = yield self.store.get_users_sent_receipts_between(
min_stream_id - 1, max_stream_id
)
- # This returns a tuple, user_id is at index 3
- users_affected = {r[3] for r in updated_receipts}
for u in users_affected:
if u in self.pushers:
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 8b4312e5a3..b1cac901eb 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -66,11 +66,9 @@ REQUIREMENTS = [
"pymacaroons>=0.13.0",
"msgpack>=0.5.2",
"phonenumbers>=8.2.0",
- "six>=1.10",
"prometheus_client>=0.0.18,<0.8.0",
- # we use attr.s(slots), which arrived in 16.0.0
- # Twisted 18.7.0 requires attrs>=17.4.0
- "attrs>=17.4.0",
+ # we use attr.validators.deep_iterable, which arrived in 19.1.0
+ "attrs>=19.1.0",
"netaddr>=0.7.18",
"Jinja2>=2.9",
"bleach>=1.4.3",
@@ -95,7 +93,12 @@ CONDITIONAL_REQUIREMENTS = {
"oidc": ["authlib>=0.14.0"],
"systemd": ["systemd-python>=231"],
"url_preview": ["lxml>=3.5.0"],
- "test": ["mock>=2.0", "parameterized"],
+ # Dependencies which are exclusively required by unit test code. This is
+ # NOT a list of all modules that are necessary to run the unit tests.
+ # Tests assume that all optional dependencies are installed.
+ #
+ # parameterized_class decorator was introduced in parameterized 0.7.0
+ "test": ["mock>=2.0", "parameterized>=0.7.0"],
"sentry": ["sentry-sdk>=0.7.2"],
"opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"],
"jwt": ["pyjwt>=1.6.4"],
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 793cef6c26..9caf1e80c1 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -16,12 +16,10 @@
import abc
import logging
import re
+import urllib
from inspect import signature
from typing import Dict, List, Tuple
-from six import raise_from
-from six.moves import urllib
-
from twisted.internet import defer
from synapse.api.errors import (
@@ -220,7 +218,7 @@ class ReplicationEndpoint(object):
# importantly, not stack traces everywhere)
raise e.to_synapse_error()
except RequestSendFailed as e:
- raise_from(SynapseError(502, "Failed to talk to master"), e)
+ raise SynapseError(502, "Failed to talk to master") from e
return result
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index c04f622816..ea5937a20c 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -149,7 +149,7 @@ class RdataCommand(Command):
class PositionCommand(Command):
- """Sent by the server to tell the client the stream postition without
+ """Sent by the server to tell the client the stream position without
needing to send an RDATA.
Format::
@@ -188,7 +188,7 @@ class ErrorCommand(_SimpleCommand):
class PingCommand(_SimpleCommand):
- """Sent by either side as a keep alive. The data is arbitary (often timestamp)
+ """Sent by either side as a keep alive. The data is arbitrary (often timestamp)
"""
NAME = "PING"
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index cbcf46f3ae..e6a2e2598b 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -112,8 +112,8 @@ class ReplicationCommandHandler:
"replication_position", clock=self._clock
)
- # Map of stream to batched updates. See RdataCommand for info on how
- # batching works.
+ # Map of stream name to batched updates. See RdataCommand for info on
+ # how batching works.
self._pending_batches = {} # type: Dict[str, List[Any]]
# The factory used to create connections.
@@ -123,7 +123,8 @@ class ReplicationCommandHandler:
# outgoing replication commands to.)
self._connections = [] # type: List[AbstractConnection]
- # For each connection, the incoming streams that are coming from that connection
+ # For each connection, the incoming stream names that are coming from
+ # that connection.
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
LaterGauge(
@@ -310,7 +311,28 @@ class ReplicationCommandHandler:
# Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, [])
rows.append(row)
- await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
+
+ stream = self._streams.get(stream_name)
+ if not stream:
+ logger.error("Got RDATA for unknown stream: %s", stream_name)
+ return
+
+ # Find where we previously streamed up to.
+ current_token = stream.current_token(cmd.instance_name)
+
+ # Discard this data if this token is earlier than the current
+ # position. Note that streams can be reset (in which case you
+ # expect an earlier token), but that must be preceded by a
+ # POSITION command.
+ if cmd.token <= current_token:
+ logger.debug(
+ "Discarding RDATA from stream %s at position %s before previous position %s",
+ stream_name,
+ cmd.token,
+ current_token,
+ )
+ else:
+ await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 4acefc8a96..f196eff072 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -264,7 +264,7 @@ class BackfillStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_current_backfill_token),
- db_query_to_update_function(store.get_all_new_backfill_event_rows),
+ store.get_all_new_backfill_event_rows,
)
@@ -291,9 +291,7 @@ class PresenceStream(Stream):
if hs.config.worker_app is None:
# on the master, query the presence handler
presence_handler = hs.get_presence_handler()
- update_function = db_query_to_update_function(
- presence_handler.get_all_presence_updates
- )
+ update_function = presence_handler.get_all_presence_updates
else:
# Query master process
update_function = make_http_update_function(hs, self.NAME)
@@ -318,9 +316,7 @@ class TypingStream(Stream):
if hs.config.worker_app is None:
# on the master, query the typing handler
- update_function = db_query_to_update_function(
- typing_handler.get_all_typing_updates
- )
+ update_function = typing_handler.get_all_typing_updates
else:
# Query master process
update_function = make_http_update_function(hs, self.NAME)
@@ -352,7 +348,7 @@ class ReceiptsStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_max_receipt_stream_id),
- db_query_to_update_function(store.get_all_updated_receipts),
+ store.get_all_updated_receipts,
)
@@ -367,26 +363,17 @@ class PushRulesStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
+
super(PushRulesStream, self).__init__(
- hs.get_instance_name(), self._current_token, self._update_function
+ hs.get_instance_name(),
+ self._current_token,
+ self.store.get_all_push_rule_updates,
)
def _current_token(self, instance_name: str) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
- async def _update_function(
- self, instance_name: str, from_token: Token, to_token: Token, limit: int
- ):
- rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
-
- limited = False
- if len(rows) == limit:
- to_token = rows[-1][0]
- limited = True
-
- return [(row[0], (row[2],)) for row in rows], to_token, limited
-
class PushersStream(Stream):
"""A user has added/changed/removed a pusher
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 46e458e95b..2e81eeff65 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -118,6 +118,7 @@ class ClientRestResource(JsonResource):
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
capabilities.register_servlets(hs, client_resource)
account_validity.register_servlets(hs, client_resource)
+ password_policy.register_servlets(hs, client_resource)
relations.register_servlets(hs, client_resource)
password_policy.register_servlets(hs, client_resource)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 8173baef8f..e07c32118d 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -15,7 +15,7 @@
import logging
from typing import List, Optional
-from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
RestServlet,
@@ -77,7 +77,7 @@ class ShutdownRoomRestServlet(RestServlet):
info, stream_id = await self._room_creation_handler.create_room(
room_creator_requester,
config={
- "preset": "public_chat",
+ "preset": RoomCreationPreset.PUBLIC_CHAT,
"name": room_name,
"power_level_content_override": {"users_default": -10},
},
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index fefc8f71fa..e4330c39d6 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -16,9 +16,7 @@ import hashlib
import hmac
import logging
import re
-
-from six import text_type
-from six.moves import http_client
+from http import HTTPStatus
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -215,10 +213,7 @@ class UserRestServletV2(RestServlet):
await self.store.set_server_admin(target_user, set_admin_to)
if "password" in body:
- if (
- not isinstance(body["password"], text_type)
- or len(body["password"]) > 512
- ):
+ if not isinstance(body["password"], str) or len(body["password"]) > 512:
raise SynapseError(400, "Invalid password")
else:
new_password = body["password"]
@@ -252,7 +247,7 @@ class UserRestServletV2(RestServlet):
password = body.get("password")
password_hash = None
if password is not None:
- if not isinstance(password, text_type) or len(password) > 512:
+ if not isinstance(password, str) or len(password) > 512:
raise SynapseError(400, "Invalid password")
password_hash = await self.auth_handler.hash(password)
@@ -370,10 +365,7 @@ class UserRegisterServlet(RestServlet):
400, "username must be specified", errcode=Codes.BAD_JSON
)
else:
- if (
- not isinstance(body["username"], text_type)
- or len(body["username"]) > 512
- ):
+ if not isinstance(body["username"], str) or len(body["username"]) > 512:
raise SynapseError(400, "Invalid username")
username = body["username"].encode("utf-8")
@@ -386,7 +378,7 @@ class UserRegisterServlet(RestServlet):
)
else:
password = body["password"]
- if not isinstance(password, text_type) or len(password) > 512:
+ if not isinstance(password, str) or len(password) > 512:
raise SynapseError(400, "Invalid password")
password_bytes = password.encode("utf-8")
@@ -477,7 +469,7 @@ class DeactivateAccountRestServlet(RestServlet):
erase = body.get("erase", False)
if not isinstance(erase, bool):
raise SynapseError(
- http_client.BAD_REQUEST,
+ HTTPStatus.BAD_REQUEST,
"Param 'erase' must be a boolean, if given",
Codes.BAD_JSON,
)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index dceb2792fa..bf0f9bd077 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -60,10 +60,18 @@ def login_id_thirdparty_from_phone(identifier):
Returns: Login identifier dict of type 'm.id.threepid'
"""
- if "country" not in identifier or "number" not in identifier:
+ if "country" not in identifier or (
+ # The specification requires a "phone" field, while Synapse used to require a "number"
+ # field. Accept both for backwards compatibility.
+ "phone" not in identifier
+ and "number" not in identifier
+ ):
raise SynapseError(400, "Invalid phone-type identifier")
- msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"])
+ # Accept both "phone" and "number" as valid keys in m.id.phone
+ phone_number = identifier.get("phone", identifier["number"])
+
+ msisdn = phone_number_to_msisdn(identifier["country"], phone_number)
return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
@@ -73,7 +81,8 @@ class LoginRestServlet(RestServlet):
CAS_TYPE = "m.login.cas"
SSO_TYPE = "m.login.sso"
TOKEN_TYPE = "m.login.token"
- JWT_TYPE = "m.login.jwt"
+ JWT_TYPE = "org.matrix.login.jwt"
+ JWT_TYPE_DEPRECATED = "m.login.jwt"
def __init__(self, hs):
super(LoginRestServlet, self).__init__()
@@ -108,6 +117,7 @@ class LoginRestServlet(RestServlet):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
+ flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED})
if self.cas_enabled:
# we advertise CAS for backwards compat, though MSC1721 renamed it
@@ -141,6 +151,7 @@ class LoginRestServlet(RestServlet):
try:
if self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
+ or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
result = await self.do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index eec16f8ad8..970fdd5834 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -17,8 +17,6 @@
"""
import logging
-from six import string_types
-
from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.presence import format_user_presence_state
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -51,7 +49,9 @@ class PresenceStatusRestServlet(RestServlet):
raise AuthError(403, "You are not allowed to see their presence.")
state = await self.presence_handler.get_state(target_user=user)
- state = format_user_presence_state(state, self.clock.time_msec())
+ state = format_user_presence_state(
+ state, self.clock.time_msec(), include_user_id=False
+ )
return 200, state
@@ -71,7 +71,7 @@ class PresenceStatusRestServlet(RestServlet):
if "status_msg" in content:
state["status_msg"] = content.pop("status_msg")
- if not isinstance(state["status_msg"], string_types):
+ if not isinstance(state["status_msg"], str):
raise SynapseError(400, "status_msg must be a string.")
if content:
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index e7fe50ed72..165313b572 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -14,6 +14,7 @@
# limitations under the License.
""" This module contains REST servlets to do with profile: /profile/<paths> """
+from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -28,6 +29,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
super(ProfileDisplaynameRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
@@ -63,11 +65,27 @@ class ProfileDisplaynameRestServlet(RestServlet):
await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
+ self.shadow_displayname(shadow_user.to_string(), content)
+
return 200, {}
def on_OPTIONS(self, request, user_id):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_displayname(self, user_id, body):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.put_json(
+ "%s/_matrix/client/r0/profile/%s/displayname?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
+ )
+
class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
@@ -76,6 +94,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
super(ProfileAvatarURLRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
@@ -114,11 +133,27 @@ class ProfileAvatarURLRestServlet(RestServlet):
user, requester, new_avatar_url, is_admin
)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
+ self.shadow_avatar_url(shadow_user.to_string(), content)
+
return 200, {}
def on_OPTIONS(self, request, user_id):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_avatar_url(self, user_id, body):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.put_json(
+ "%s/_matrix/client/r0/profile/%s/avatar_url?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
+ )
+
class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 105e0cf4d2..bbf98a5cfe 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -18,8 +18,7 @@
import logging
import re
from typing import List, Optional
-
-from six.moves.urllib import parse as urlparse
+from urllib import parse as urlparse
from canonicaljson import json
@@ -726,7 +725,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content["id_server"],
requester,
txn_id,
- content.get("id_access_token"),
+ new_room=False,
+ id_access_token=content.get("id_access_token"),
)
return 200, {}
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index b58a77826f..aeaf322985 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018, 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,8 +15,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import re
-from six.moves import http_client
+from http import HTTPStatus
+
+from twisted.internet import defer
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, ThreepidValidationError
@@ -29,6 +32,7 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.push.mailer import Mailer, load_jinja2_templates
+from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import check_3pid_allowed
@@ -91,7 +95,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "email", email):
raise SynapseError(
403,
- "Your email domain is not authorized on this server",
+ "Your email is not authorized on this server",
Codes.THREEPID_DENIED,
)
@@ -221,6 +225,7 @@ class PasswordRestServlet(RestServlet):
self.datastore = self.hs.get_datastore()
self.password_policy_handler = hs.get_password_policy_handler()
self._set_password_handler = hs.get_set_password_handler()
+ self.http_client = hs.get_simple_http_client()
@interactive_auth_handler
async def on_POST(self, request):
@@ -252,13 +257,17 @@ class PasswordRestServlet(RestServlet):
if self.auth.has_access_token(request):
requester = await self.auth.get_user_by_req(request)
- params = await self.auth_handler.validate_user_via_ui_auth(
- requester,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "modify your account password",
- )
+ # blindly trust ASes without UI-authing them
+ if requester.app_service:
+ params = body
+ else:
+ params = await self.auth_handler.validate_user_via_ui_auth(
+ requester,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "modify your account password",
+ )
user_id = requester.user.to_string()
else:
requester = None
@@ -298,11 +307,29 @@ class PasswordRestServlet(RestServlet):
user_id, new_password_hash, logout_devices, requester
)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_password(params, shadow_user.to_string())
+
return 200, {}
def on_OPTIONS(self, _):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_password(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/password?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_patterns("/account/deactivate$")
@@ -320,7 +347,7 @@ class DeactivateAccountRestServlet(RestServlet):
erase = body.get("erase", False)
if not isinstance(erase, bool):
raise SynapseError(
- http_client.BAD_REQUEST,
+ HTTPStatus.BAD_REQUEST,
"Param 'erase' must be a boolean, if given",
Codes.BAD_JSON,
)
@@ -397,13 +424,15 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
- if not check_3pid_allowed(self.hs, "email", email):
+ if not (await check_3pid_allowed(self.hs, "email", email)):
raise SynapseError(
403,
- "Your email domain is not authorized on this server",
+ "Your email is not authorized on this server",
Codes.THREEPID_DENIED,
)
+ assert_valid_client_secret(body["client_secret"])
+
existing_user_id = await self.store.get_user_id_by_threepid(
"email", body["email"]
)
@@ -467,13 +496,15 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(country, phone_number)
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Account phone numbers are not authorized on this server",
Codes.THREEPID_DENIED,
)
+ assert_valid_client_secret(body["client_secret"])
+
existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
if existing_user_id is not None:
@@ -632,7 +663,8 @@ class ThreepidRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
- self.datastore = self.hs.get_datastore()
+ self.datastore = hs.get_datastore()
+ self.http_client = hs.get_simple_http_client()
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request)
@@ -651,6 +683,29 @@ class ThreepidRestServlet(RestServlet):
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
+ # skip validation if this is a shadow 3PID from an AS
+ if requester.app_service:
+ # XXX: ASes pass in a validated threepid directly to bypass the IS.
+ # This makes the API entirely change shape when we have an AS token;
+ # it really should be an entirely separate API - perhaps
+ # /account/3pid/replicate or something.
+ threepid = body.get("threepid")
+
+ await self.auth_handler.add_threepid(
+ user_id,
+ threepid["medium"],
+ threepid["address"],
+ threepid["validated_at"],
+ )
+
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+
+ return 200, {}
+
threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds")
if threepid_creds is None:
raise SynapseError(
@@ -672,12 +727,36 @@ class ThreepidRestServlet(RestServlet):
validation_session["address"],
validation_session["validated_at"],
)
+
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ threepid = {
+ "medium": validation_session["medium"],
+ "address": validation_session["address"],
+ "validated_at": validation_session["validated_at"],
+ }
+ self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+
return 200, {}
raise SynapseError(
400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
)
+ @defer.inlineCallbacks
+ def shadow_3pid(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class ThreepidAddRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/add$")
@@ -688,6 +767,7 @@ class ThreepidAddRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ self.http_client = hs.get_simple_http_client()
@interactive_auth_handler
async def on_POST(self, request):
@@ -723,12 +803,34 @@ class ThreepidAddRestServlet(RestServlet):
validation_session["address"],
validation_session["validated_at"],
)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ threepid = {
+ "medium": validation_session["medium"],
+ "address": validation_session["address"],
+ "validated_at": validation_session["validated_at"],
+ }
+ self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
return 200, {}
raise SynapseError(
400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
)
+ @defer.inlineCallbacks
+ def shadow_3pid(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class ThreepidBindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/bind$")
@@ -798,6 +900,7 @@ class ThreepidDeleteRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ self.http_client = hs.get_simple_http_client()
async def on_POST(self, request):
if not self.hs.config.enable_3pid_changes:
@@ -822,6 +925,12 @@ class ThreepidDeleteRestServlet(RestServlet):
logger.exception("Failed to remove threepid")
raise SynapseError(500, "Failed to remove threepid")
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_3pid_delete(body, shadow_user.to_string())
+
if ret:
id_server_unbind_result = "success"
else:
@@ -829,6 +938,77 @@ class ThreepidDeleteRestServlet(RestServlet):
return 200, {"id_server_unbind_result": id_server_unbind_result}
+ @defer.inlineCallbacks
+ def shadow_3pid_delete(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid/delete?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
+
+class ThreepidLookupRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/lookup$")]
+
+ def __init__(self, hs):
+ super(ThreepidLookupRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ """Proxy a /_matrix/identity/api/v1/lookup request to an identity
+ server
+ """
+ yield self.auth.get_user_by_req(request)
+
+ # Verify query parameters
+ query_params = request.args
+ assert_params_in_dict(query_params, [b"medium", b"address", b"id_server"])
+
+ # Retrieve needed information from query parameters
+ medium = parse_string(request, "medium")
+ address = parse_string(request, "address")
+ id_server = parse_string(request, "id_server")
+
+ # Proxy the request to the identity server. lookup_3pid handles checking
+ # if the lookup is allowed so we don't need to do it here.
+ ret = yield self.identity_handler.proxy_lookup_3pid(id_server, medium, address)
+
+ defer.returnValue((200, ret))
+
+
+class ThreepidBulkLookupRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/bulk_lookup$")]
+
+ def __init__(self, hs):
+ super(ThreepidBulkLookupRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ """Proxy a /_matrix/identity/api/v1/bulk_lookup request to an identity
+ server
+ """
+ yield self.auth.get_user_by_req(request)
+
+ body = parse_json_object_from_request(request)
+
+ assert_params_in_dict(body, ["threepids", "id_server"])
+
+ # Proxy the request to the identity server. lookup_3pid handles checking
+ # if the lookup is allowed so we don't need to do it here.
+ ret = yield self.identity_handler.proxy_bulk_lookup_3pid(
+ body["id_server"], body["threepids"]
+ )
+
+ defer.returnValue((200, ret))
+
class WhoamiRestServlet(RestServlet):
PATTERNS = client_patterns("/account/whoami$")
@@ -857,4 +1037,6 @@ def register_servlets(hs, http_server):
ThreepidBindRestServlet(hs).register(http_server)
ThreepidUnbindRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
+ ThreepidLookupRestServlet(hs).register(http_server)
+ ThreepidBulkLookupRestServlet(hs).register(http_server)
WhoamiRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index c1d4cd0caf..d31ec7c29d 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -17,6 +17,7 @@ import logging
from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import UserID
from ._base import client_patterns
@@ -39,6 +40,7 @@ class AccountDataServlet(RestServlet):
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
+ self._profile_handler = hs.get_profile_handler()
async def on_PUT(self, request, user_id, account_data_type):
if self._is_worker:
@@ -50,6 +52,11 @@ class AccountDataServlet(RestServlet):
body = parse_json_object_from_request(request)
+ if account_data_type == "im.vector.hide_profile":
+ user = UserID.from_string(user_id)
+ hide_profile = body.get("hide_profile")
+ await self._profile_handler.set_active([user], not hide_profile, True)
+
max_id = await self.store.add_account_data_for_user(
user_id, account_data_type, body
)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index c8d2de7b54..6255cd2e21 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2015 - 2016 OpenMarket Ltd
-# Copyright 2017 Vector Creations Ltd
+# Copyright 2015-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,10 +17,9 @@
import hmac
import logging
+import re
from typing import List, Union
-from six import string_types
-
import synapse
import synapse.api.auth
import synapse.types
@@ -122,10 +122,10 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
- if not check_3pid_allowed(self.hs, "email", email):
+ if not (await check_3pid_allowed(self.hs, "email", body["email"])):
raise SynapseError(
403,
- "Your email domain is not authorized to register on this server",
+ "Your email is not authorized to register on this server",
Codes.THREEPID_DENIED,
)
@@ -194,7 +194,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(country, phone_number)
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ assert_valid_client_secret(body["client_secret"])
+
+ if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Phone numbers are not authorized to register on this server",
@@ -350,15 +352,9 @@ class UsernameAvailabilityRestServlet(RestServlet):
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
)
- ip = self.hs.get_ip_from_request(request)
- with self.ratelimiter.ratelimit(ip) as wait_deferred:
- await wait_deferred
-
- username = parse_string(request, "username", required=True)
-
- await self.registration_handler.check_username(username)
-
- return 200, {"available": True}
+ # We are not interested in logging in via a username in this deployment.
+ # Simply allow anything here as it won't be used later.
+ return 200, {"available": True}
class RegisterRestServlet(RestServlet):
@@ -409,9 +405,10 @@ class RegisterRestServlet(RestServlet):
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the username/password provided to us.
+ desired_password_hash = None
if "password" in body:
password = body.pop("password")
- if not isinstance(password, string_types) or len(password) > 512:
+ if not isinstance(password, str) or len(password) > 512:
raise SynapseError(400, "Invalid password")
self.password_policy_handler.validate_password(password)
@@ -420,15 +417,19 @@ class RegisterRestServlet(RestServlet):
if "password_hash" in body:
raise SynapseError(400, "Unexpected property: password_hash")
body["password_hash"] = await self.auth_handler.hash(password)
+ desired_password_hash = body["password_hash"]
+ # We don't care about usernames for this deployment. In fact, the act
+ # of checking whether they exist already can leak metadata about
+ # which users are already registered.
+ #
+ # Usernames are already derived via the provided email.
+ # So, if they're not necessary, just ignore them.
+ #
+ # (we do still allow appservices to set them below)
desired_username = None
- if "username" in body:
- if (
- not isinstance(body["username"], string_types)
- or len(body["username"]) > 512
- ):
- raise SynapseError(400, "Invalid username")
- desired_username = body["username"]
+
+ desired_display_name = body.get("display_name")
appservice = None
if self.auth.has_access_token(request):
@@ -442,7 +443,7 @@ class RegisterRestServlet(RestServlet):
# Set the desired user according to the AS API (which uses the
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
- desired_username = body.get("user", desired_username)
+ desired_username = body.get("user", body.get("username"))
# XXX we should check that desired_username is valid. Currently
# we give appservices carte blanche for any insanity in mxids,
@@ -451,21 +452,16 @@ class RegisterRestServlet(RestServlet):
access_token = self.auth.get_access_token_from_request(request)
- if isinstance(desired_username, string_types):
+ if isinstance(desired_username, str):
result = await self._do_appservice_registration(
- desired_username, access_token, body
+ desired_username,
+ desired_password_hash,
+ desired_display_name,
+ access_token,
+ body,
)
return 200, result # we throw for non 200 responses
- # for regular registration, downcase the provided username before
- # attempting to register it. This should mean
- # that people who try to register with upper-case in their usernames
- # don't get a nasty surprise. (Note that we treat username
- # case-insenstively in login, so they are free to carry on imagining
- # that their username is CrAzYh4cKeR if that keeps them happy)
- if desired_username is not None:
- desired_username = desired_username.lower()
-
# == Normal User Registration == (everyone else)
if not self.hs.config.enable_registration:
raise SynapseError(403, "Registration has been disabled")
@@ -491,13 +487,6 @@ class RegisterRestServlet(RestServlet):
session_id, "registered_user_id", None
)
- if desired_username is not None:
- await self.registration_handler.check_username(
- desired_username,
- guest_access_token=guest_access_token,
- assigned_user_id=registered_user_id,
- )
-
auth_result, params, session_id = await self.auth_handler.check_auth(
self._registration_flows,
request,
@@ -518,7 +507,7 @@ class RegisterRestServlet(RestServlet):
medium = auth_result[login_type]["medium"]
address = auth_result[login_type]["address"]
- if not check_3pid_allowed(self.hs, medium, address):
+ if not (await check_3pid_allowed(self.hs, medium, address)):
raise SynapseError(
403,
"Third party identifiers (email/phone numbers)"
@@ -526,6 +515,80 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ existingUid = await self.store.get_user_id_by_threepid(
+ medium, address
+ )
+
+ if existingUid is not None:
+ raise SynapseError(
+ 400, "%s is already in use" % medium, Codes.THREEPID_IN_USE
+ )
+
+ if self.hs.config.register_mxid_from_3pid:
+ # override the desired_username based on the 3PID if any.
+ # reset it first to avoid folks picking their own username.
+ desired_username = None
+
+ # we should have an auth_result at this point if we're going to progress
+ # to register the user (i.e. we haven't picked up a registered_user_id
+ # from our session store), in which case get ready and gen the
+ # desired_username
+ if auth_result:
+ if (
+ self.hs.config.register_mxid_from_3pid == "email"
+ and LoginType.EMAIL_IDENTITY in auth_result
+ ):
+ address = auth_result[LoginType.EMAIL_IDENTITY]["address"]
+ desired_username = synapse.types.strip_invalid_mxid_characters(
+ address.replace("@", "-").lower()
+ )
+
+ # find a unique mxid for the account, suffixing numbers
+ # if needed
+ while True:
+ try:
+ await self.registration_handler.check_username(
+ desired_username,
+ guest_access_token=guest_access_token,
+ assigned_user_id=registered_user_id,
+ )
+ # if we got this far we passed the check.
+ break
+ except SynapseError as e:
+ if e.errcode == Codes.USER_IN_USE:
+ m = re.match(r"^(.*?)(\d+)$", desired_username)
+ if m:
+ desired_username = m.group(1) + str(
+ int(m.group(2)) + 1
+ )
+ else:
+ desired_username += "1"
+ else:
+ # something else went wrong.
+ break
+
+ if self.hs.config.register_just_use_email_for_display_name:
+ desired_display_name = address
+ else:
+ # Custom mapping between email address and display name
+ desired_display_name = _map_email_to_displayname(address)
+ elif (
+ self.hs.config.register_mxid_from_3pid == "msisdn"
+ and LoginType.MSISDN in auth_result
+ ):
+ desired_username = auth_result[LoginType.MSISDN]["address"]
+ else:
+ raise SynapseError(
+ 400, "Cannot derive mxid from 3pid; no recognised 3pid"
+ )
+
+ if desired_username is not None:
+ await self.registration_handler.check_username(
+ desired_username,
+ guest_access_token=guest_access_token,
+ assigned_user_id=registered_user_id,
+ )
+
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session", registered_user_id
@@ -536,7 +599,12 @@ class RegisterRestServlet(RestServlet):
# NB: This may be from the auth handler and NOT from the POST
assert_params_in_dict(params, ["password_hash"])
- desired_username = params.get("username", None)
+ if not self.hs.config.register_mxid_from_3pid:
+ desired_username = params.get("username", None)
+ else:
+ # we keep the original desired_username derived from the 3pid above
+ pass
+
guest_access_token = params.get("guest_access_token", None)
new_password_hash = params.get("password_hash", None)
@@ -573,6 +641,7 @@ class RegisterRestServlet(RestServlet):
localpart=desired_username,
password_hash=new_password_hash,
guest_access_token=guest_access_token,
+ default_display_name=desired_display_name,
threepid=threepid,
address=client_addr,
)
@@ -584,6 +653,14 @@ class RegisterRestServlet(RestServlet):
):
await self.store.upsert_monthly_active_user(registered_user_id)
+ if self.hs.config.shadow_server:
+ await self.registration_handler.shadow_register(
+ localpart=desired_username,
+ display_name=desired_display_name,
+ auth_result=auth_result,
+ params=params,
+ )
+
# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
await self.auth_handler.set_session_data(
@@ -608,11 +685,30 @@ class RegisterRestServlet(RestServlet):
def on_OPTIONS(self, _):
return 200, {}
- async def _do_appservice_registration(self, username, as_token, body):
+ async def _do_appservice_registration(
+ self, username, password_hash, display_name, as_token, body
+ ):
+ # FIXME: appservice_register() is horribly duplicated with register()
+ # and they should probably just be combined together with a config flag.
user_id = await self.registration_handler.appservice_register(
- username, as_token
+ username, as_token, password_hash, display_name
)
- return await self._create_registration_details(user_id, body)
+ result = await self._create_registration_details(user_id, body)
+
+ auth_result = body.get("auth_result")
+ if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
+ threepid = auth_result[LoginType.EMAIL_IDENTITY]
+ await self.registration_handler.register_email_threepid(
+ user_id, threepid, result["access_token"]
+ )
+
+ if auth_result and LoginType.MSISDN in auth_result:
+ threepid = auth_result[LoginType.MSISDN]
+ await self.registration_handler.register_msisdn_threepid(
+ user_id, threepid, result["access_token"]
+ )
+
+ return result
async def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user
@@ -663,6 +759,60 @@ class RegisterRestServlet(RestServlet):
)
+def cap(name):
+ """Capitalise parts of a name containing different words, including those
+ separated by hyphens.
+ For example, 'John-Doe'
+
+ Args:
+ name (str): The name to parse
+ """
+ if not name:
+ return name
+
+ # Split the name by whitespace then hyphens, capitalizing each part then
+ # joining it back together.
+ capatilized_name = " ".join(
+ "-".join(part.capitalize() for part in space_part.split("-"))
+ for space_part in name.split()
+ )
+ return capatilized_name
+
+
+def _map_email_to_displayname(address):
+ """Custom mapping from an email address to a user displayname
+
+ Args:
+ address (str): The email address to process
+ Returns:
+ str: The new displayname
+ """
+ # Split the part before and after the @ in the email.
+ # Replace all . with spaces in the first part
+ parts = address.replace(".", " ").split("@")
+
+ # Figure out which org this email address belongs to
+ org_parts = parts[1].split(" ")
+
+ # If this is a ...matrix.org email, mark them as an Admin
+ if org_parts[-2] == "matrix" and org_parts[-1] == "org":
+ org = "Tchap Admin"
+
+ # Is this is a ...gouv.fr address, set the org to whatever is before
+ # gouv.fr. If there isn't anything (a @gouv.fr email) simply mark their
+ # org as "gouv"
+ elif org_parts[-2] == "gouv" and org_parts[-1] == "fr":
+ org = org_parts[-3] if len(org_parts) > 2 else org_parts[-2]
+
+ # Otherwise, mark their org as the email's second-level domain name
+ else:
+ org = org_parts[-2]
+
+ desired_display_name = cap(parts[0]) + " [" + cap(org) + "]"
+
+ return desired_display_name
+
+
def _calculate_registration_flows(
# technically `config` has to provide *all* of these interfaces, not just one
config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig],
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index f067b5edac..e15927c4ea 100644
--- a/synapse/rest/client/v2_alpha/report_event.py
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -14,9 +14,7 @@
# limitations under the License.
import logging
-
-from six import string_types
-from six.moves import http_client
+from http import HTTPStatus
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
@@ -47,15 +45,15 @@ class ReportEventRestServlet(RestServlet):
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ("reason", "score"))
- if not isinstance(body["reason"], string_types):
+ if not isinstance(body["reason"], str):
raise SynapseError(
- http_client.BAD_REQUEST,
+ HTTPStatus.BAD_REQUEST,
"Param 'reason' must be a string",
Codes.BAD_JSON,
)
if not isinstance(body["score"], int):
raise SynapseError(
- http_client.BAD_REQUEST,
+ HTTPStatus.BAD_REQUEST,
"Param 'score' must be an integer",
Codes.BAD_JSON,
)
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index bef91a2d3e..6e8300d6a5 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -14,9 +14,17 @@
# limitations under the License.
import logging
+from typing import Dict
-from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from signedjson.sign import sign_json
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+from synapse.types import UserID
from ._base import client_patterns
@@ -35,6 +43,7 @@ class UserDirectorySearchRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
async def on_POST(self, request):
"""Searches for users in directory
@@ -61,6 +70,16 @@ class UserDirectorySearchRestServlet(RestServlet):
body = parse_json_object_from_request(request)
+ if self.hs.config.user_directory_defer_to_id_server:
+ signed_body = sign_json(
+ body, self.hs.hostname, self.hs.config.signing_key[0]
+ )
+ url = "%s/_matrix/identity/api/v1/user_directory/search" % (
+ self.hs.config.user_directory_defer_to_id_server,
+ )
+ resp = await self.http_client.post_json_get_json(url, signed_body)
+ return 200, resp
+
limit = body.get("limit", 10)
limit = min(limit, 50)
@@ -76,5 +95,125 @@ class UserDirectorySearchRestServlet(RestServlet):
return 200, results
+class SingleUserInfoServlet(RestServlet):
+ """
+ Deprecated and replaced by `/users/info`
+
+ GET /user/{user_id}/info HTTP/1.1
+ """
+
+ PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/info$")
+
+ def __init__(self, hs):
+ super(SingleUserInfoServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.transport_layer = hs.get_federation_transport_client()
+ registry = hs.get_federation_registry()
+
+ if not registry.query_handlers.get("user_info"):
+ registry.register_query_handler("user_info", self._on_federation_query)
+
+ async def on_GET(self, request, user_id):
+ # Ensure the user is authenticated
+ await self.auth.get_user_by_req(request)
+
+ user = UserID.from_string(user_id)
+ if not self.hs.is_mine(user):
+ # Attempt to make a federation request to the server that owns this user
+ args = {"user_id": user_id}
+ res = await self.transport_layer.make_query(
+ user.domain, "user_info", args, retry_on_dns_fail=True
+ )
+ return 200, res
+
+ user_id_to_info = await self.store.get_info_for_users([user_id])
+ return 200, user_id_to_info[user_id]
+
+ async def _on_federation_query(self, args):
+ """Called when a request for user information appears over federation
+
+ Args:
+ args (dict): Dictionary of query arguments provided by the request
+
+ Returns:
+ Deferred[dict]: Deactivation and expiration information for a given user
+ """
+ user_id = args.get("user_id")
+ if not user_id:
+ raise SynapseError(400, "user_id not provided")
+
+ user = UserID.from_string(user_id)
+ if not self.hs.is_mine(user):
+ raise SynapseError(400, "User is not hosted on this homeserver")
+
+ user_ids_to_info_dict = await self.store.get_info_for_users([user_id])
+ return user_ids_to_info_dict[user_id]
+
+
+class UserInfoServlet(RestServlet):
+ """Bulk version of `/user/{user_id}/info` endpoint
+
+ GET /users/info HTTP/1.1
+
+ Returns a dictionary of user_id to info dictionary. Supports remote users
+ """
+
+ PATTERNS = client_patterns("/users/info$", unstable=True, releases=())
+
+ def __init__(self, hs):
+ super(UserInfoServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.transport_layer = hs.get_federation_transport_client()
+
+ async def on_POST(self, request):
+ # Ensure the user is authenticated
+ await self.auth.get_user_by_req(request)
+
+ # Extract the user_ids from the request
+ body = parse_json_object_from_request(request)
+ assert_params_in_dict(body, required=["user_ids"])
+
+ user_ids = body["user_ids"]
+ if not isinstance(user_ids, list):
+ raise SynapseError(
+ 400,
+ "'user_ids' must be a list of user ID strings",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ # Separate local and remote users
+ local_user_ids = set()
+ remote_server_to_user_ids = {} # type: Dict[str, set]
+ for user_id in user_ids:
+ user = UserID.from_string(user_id)
+
+ if self.hs.is_mine(user):
+ local_user_ids.add(user_id)
+ else:
+ remote_server_to_user_ids.setdefault(user.domain, set())
+ remote_server_to_user_ids[user.domain].add(user_id)
+
+ # Retrieve info of all local users
+ user_id_to_info_dict = await self.store.get_info_for_users(local_user_ids)
+
+ # Request info of each remote user from their remote homeserver
+ for server_name, user_id_set in remote_server_to_user_ids.items():
+ # Make a request to the given server about their own users
+ res = await self.transport_layer.get_info_of_users(
+ server_name, list(user_id_set)
+ )
+
+ for user_id, info in res:
+ user_id_to_info_dict[user_id] = info
+
+ return 200, user_id_to_info_dict
+
+
def register_servlets(hs, http_server):
UserDirectorySearchRestServlet(hs).register(http_server)
+ SingleUserInfoServlet(hs).register(http_server)
+ UserInfoServlet(hs).register(http_server)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 0d668df0b6..b1999d051b 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -57,9 +57,12 @@ class VersionsRestServlet(RestServlet):
# MSC2326.
"org.matrix.label_based_filtering": True,
# Implements support for cross signing as described in MSC1756
- "org.matrix.e2e_cross_signing": True,
+ # "org.matrix.e2e_cross_signing": True,
# Implements additional endpoints as described in MSC2432
"org.matrix.msc2432": True,
+ # Tchap does not currently assume this rule for r0.5.0
+ # XXX: Remove this when it does
+ "m.lazy_load_members": True,
},
},
)
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 4a20282d1b..0a890c98cb 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -16,10 +16,9 @@
import hmac
import logging
from hashlib import sha256
+from http import HTTPStatus
from os import path
-from six.moves import http_client
-
import jinja2
from jinja2 import TemplateNotFound
@@ -219,4 +218,4 @@ class ConsentResource(DirectServeResource):
)
if not compare_digest(want_mac, userhmac):
- raise SynapseError(http_client.FORBIDDEN, "HMAC incorrect")
+ raise SynapseError(HTTPStatus.FORBIDDEN, "HMAC incorrect")
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 3689777266..595849f9d5 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -16,8 +16,7 @@
import logging
import os
-
-from six.moves import urllib
+import urllib
from twisted.internet import defer
from twisted.protocols.basic import FileSender
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index fd10d42f2f..45628c07b4 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -20,8 +20,6 @@ import os
import shutil
from typing import Dict, Tuple
-from six import iteritems
-
import twisted.internet.error
import twisted.web.http
from twisted.web.resource import Resource
@@ -340,7 +338,7 @@ class MediaRepository(object):
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
request_path = "/".join(
- ("/_matrix/media/v1/download", server_name, media_id)
+ ("/_matrix/media/r0/download", server_name, media_id)
)
try:
length, headers = await self.client.get_file(
@@ -606,7 +604,7 @@ class MediaRepository(object):
thumbnails[(t_width, t_height, r_type)] = r_method
# Now we generate the thumbnails for each dimension, store it
- for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
+ for (t_width, t_height, t_type), t_method in thumbnails.items():
# Generate the thumbnail
if t_method == "crop":
t_byte_source = await defer_to_thread(
@@ -705,7 +703,7 @@ class MediaRepositoryResource(Resource):
Uploads are POSTed to a resource which returns a token which is used to GET
the download::
- => POST /_matrix/media/v1/upload HTTP/1.1
+ => POST /_matrix/media/r0/upload HTTP/1.1
Content-Type: <media-type>
Content-Length: <content-length>
@@ -716,7 +714,7 @@ class MediaRepositoryResource(Resource):
{ "content_uri": "mxc://<server-name>/<media-id>" }
- => GET /_matrix/media/v1/download/<server-name>/<media-id> HTTP/1.1
+ => GET /_matrix/media/r0/download/<server-name>/<media-id> HTTP/1.1
<= HTTP/1.1 200 OK
Content-Type: <media-type>
@@ -727,7 +725,7 @@ class MediaRepositoryResource(Resource):
Clients can get thumbnails by supplying a desired width and height and
thumbnailing method::
- => GET /_matrix/media/v1/thumbnail/<server_name>
+ => GET /_matrix/media/r0/thumbnail/<server_name>
/<media-id>?width=<w>&height=<h>&method=<m> HTTP/1.1
<= HTTP/1.1 200 OK
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 683a79c966..79cb0dddbe 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -17,9 +17,6 @@ import contextlib
import logging
import os
import shutil
-import sys
-
-import six
from twisted.internet import defer
from twisted.protocols.basic import FileSender
@@ -117,12 +114,11 @@ class MediaStorage(object):
with open(fname, "wb") as f:
yield f, fname, finish
except Exception:
- t, v, tb = sys.exc_info()
try:
os.remove(fname)
except Exception:
pass
- six.reraise(t, v, tb)
+ raise
if not finished_called:
raise Exception("Finished callback not called")
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index f206605727..b4645cd608 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -24,10 +24,7 @@ import shutil
import sys
import traceback
from typing import Dict, Optional
-
-import six
-from six import string_types
-from six.moves import urllib_parse as urlparse
+from urllib import parse as urlparse
from canonicaljson import json
@@ -85,6 +82,15 @@ class PreviewUrlResource(DirectServeResource):
self.primary_base_path = media_repo.primary_base_path
self.media_storage = media_storage
+ # We run the background jobs if we're the instance specified (or no
+ # instance is specified, where we assume there is only one instance
+ # serving media).
+ instance_running_jobs = hs.config.media.media_instance_running_background_jobs
+ self._worker_run_media_background_jobs = (
+ instance_running_jobs is None
+ or instance_running_jobs == hs.get_instance_name()
+ )
+
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
self.url_preview_accept_language = hs.config.url_preview_accept_language
@@ -97,9 +103,10 @@ class PreviewUrlResource(DirectServeResource):
expiry_ms=60 * 60 * 1000,
)
- self._cleaner_loop = self.clock.looping_call(
- self._start_expire_url_cache_data, 10 * 1000
- )
+ if self._worker_run_media_background_jobs:
+ self._cleaner_loop = self.clock.looping_call(
+ self._start_expire_url_cache_data, 10 * 1000
+ )
def render_OPTIONS(self, request):
request.setHeader(b"Allow", b"OPTIONS, GET")
@@ -188,7 +195,7 @@ class PreviewUrlResource(DirectServeResource):
# It may be stored as text in the database, not as bytes (such as
# PostgreSQL). If so, encode it back before handing it on.
og = cache_result["og"]
- if isinstance(og, six.text_type):
+ if isinstance(og, str):
og = og.encode("utf8")
return og
@@ -400,6 +407,8 @@ class PreviewUrlResource(DirectServeResource):
"""
# TODO: Delete from backup media store
+ assert self._worker_run_media_background_jobs
+
now = self.clock.time_msec()
logger.debug("Running url preview cache expiry")
@@ -631,7 +640,7 @@ def _iterate_over_text(tree, *tags_to_ignore):
if el is None:
return
- if isinstance(el, string_types):
+ if isinstance(el, str):
yield el
elif el.tag not in tags_to_ignore:
# el.text is the text before the first child, so we can immediately
diff --git a/synapse/rulecheck/__init__.py b/synapse/rulecheck/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/synapse/rulecheck/__init__.py
diff --git a/synapse/rulecheck/domain_rule_checker.py b/synapse/rulecheck/domain_rule_checker.py
new file mode 100644
index 0000000000..6f2a1931c5
--- /dev/null
+++ b/synapse/rulecheck/domain_rule_checker.py
@@ -0,0 +1,181 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.config._base import ConfigError
+
+logger = logging.getLogger(__name__)
+
+
+class DomainRuleChecker(object):
+ """
+ A re-implementation of the SpamChecker that prevents users in one domain from
+ inviting users in other domains to rooms, based on a configuration.
+
+ Takes a config in the format:
+
+ spam_checker:
+ module: "rulecheck.DomainRuleChecker"
+ config:
+ domain_mapping:
+ "inviter_domain": [ "invitee_domain_permitted", "other_domain_permitted" ]
+ "other_inviter_domain": [ "invitee_domain_permitted" ]
+ default: False
+
+ # Only let local users join rooms if they were explicitly invited.
+ can_only_join_rooms_with_invite: false
+
+ # Only let local users create rooms if they are inviting only one
+ # other user, and that user matches the rules above.
+ can_only_create_one_to_one_rooms: false
+
+ # Only let local users invite during room creation, regardless of the
+ # domain mapping rules above.
+ can_only_invite_during_room_creation: false
+
+ # Prevent local users from inviting users from certain domains to
+ # rooms published in the room directory.
+ domains_prevented_from_being_invited_to_published_rooms: []
+
+ # Allow third party invites
+ can_invite_by_third_party_id: true
+
+ Don't forget to consider if you can invite users from your own domain.
+ """
+
+ def __init__(self, config):
+ self.domain_mapping = config["domain_mapping"] or {}
+ self.default = config["default"]
+
+ self.can_only_join_rooms_with_invite = config.get(
+ "can_only_join_rooms_with_invite", False
+ )
+ self.can_only_create_one_to_one_rooms = config.get(
+ "can_only_create_one_to_one_rooms", False
+ )
+ self.can_only_invite_during_room_creation = config.get(
+ "can_only_invite_during_room_creation", False
+ )
+ self.can_invite_by_third_party_id = config.get(
+ "can_invite_by_third_party_id", True
+ )
+ self.domains_prevented_from_being_invited_to_published_rooms = config.get(
+ "domains_prevented_from_being_invited_to_published_rooms", []
+ )
+
+ def check_event_for_spam(self, event):
+ """Implements synapse.events.SpamChecker.check_event_for_spam
+ """
+ return False
+
+ def user_may_invite(
+ self,
+ inviter_userid,
+ invitee_userid,
+ third_party_invite,
+ room_id,
+ new_room,
+ published_room=False,
+ ):
+ """Implements synapse.events.SpamChecker.user_may_invite
+ """
+ if self.can_only_invite_during_room_creation and not new_room:
+ return False
+
+ if not self.can_invite_by_third_party_id and third_party_invite:
+ return False
+
+ # This is a third party invite (without a bound mxid), so unless we have
+ # banned all third party invites (above) we allow it.
+ if not invitee_userid:
+ return True
+
+ inviter_domain = self._get_domain_from_id(inviter_userid)
+ invitee_domain = self._get_domain_from_id(invitee_userid)
+
+ if inviter_domain not in self.domain_mapping:
+ return self.default
+
+ if (
+ published_room
+ and invitee_domain
+ in self.domains_prevented_from_being_invited_to_published_rooms
+ ):
+ return False
+
+ return invitee_domain in self.domain_mapping[inviter_domain]
+
+ def user_may_create_room(
+ self, userid, invite_list, third_party_invite_list, cloning
+ ):
+ """Implements synapse.events.SpamChecker.user_may_create_room
+ """
+
+ if cloning:
+ return True
+
+ if not self.can_invite_by_third_party_id and third_party_invite_list:
+ return False
+
+ number_of_invites = len(invite_list) + len(third_party_invite_list)
+
+ if self.can_only_create_one_to_one_rooms and number_of_invites != 1:
+ return False
+
+ return True
+
+ def user_may_create_room_alias(self, userid, room_alias):
+ """Implements synapse.events.SpamChecker.user_may_create_room_alias
+ """
+ return True
+
+ def user_may_publish_room(self, userid, room_id):
+ """Implements synapse.events.SpamChecker.user_may_publish_room
+ """
+ return True
+
+ def user_may_join_room(self, userid, room_id, is_invited):
+ """Implements synapse.events.SpamChecker.user_may_join_room
+ """
+ if self.can_only_join_rooms_with_invite and not is_invited:
+ return False
+
+ return True
+
+ @staticmethod
+ def parse_config(config):
+ """Implements synapse.events.SpamChecker.parse_config
+ """
+ if "default" in config:
+ return config
+ else:
+ raise ConfigError("No default set for spam_config DomainRuleChecker")
+
+ @staticmethod
+ def _get_domain_from_id(mxid):
+ """Parses a string and returns the domain part of the mxid.
+
+ Args:
+ mxid (str): a valid mxid
+
+ Returns:
+ str: the domain part of the mxid
+
+ """
+ idx = mxid.find(":")
+ if idx == -1:
+ raise Exception("Invalid ID: %r" % (mxid,))
+ return mxid[idx + 1 :]
diff --git a/synapse/server.py b/synapse/server.py
index fe94836a2c..d14b6b722c 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -174,6 +174,7 @@ class HomeServer(object):
"event_builder_factory",
"filtering",
"http_client_context_factory",
+ "proxied_http_client",
"simple_http_client",
"proxied_http_client",
"media_repository",
@@ -212,6 +213,7 @@ class HomeServer(object):
"replication_streamer",
"replication_data_handler",
"replication_streams",
+ "password_policy_handler",
]
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
@@ -570,9 +572,6 @@ class HomeServer(object):
def build_event_client_serializer(self):
return EventClientSerializer(self)
- def build_password_policy_handler(self):
- return PasswordPolicyHandler(self)
-
def build_storage(self) -> Storage:
return Storage(self, self.datastores)
@@ -585,6 +584,9 @@ class HomeServer(object):
def build_replication_streams(self):
return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()}
+ def build_password_policy_handler(self):
+ return PasswordPolicyHandler(self)
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
index 3bf330da49..3bfc8d7278 100644
--- a/synapse/server_notices/consent_server_notices.py
+++ b/synapse/server_notices/consent_server_notices.py
@@ -14,8 +14,6 @@
# limitations under the License.
import logging
-from six import iteritems, string_types
-
from synapse.api.errors import SynapseError
from synapse.api.urls import ConsentURIBuilder
from synapse.config import ConfigError
@@ -118,10 +116,10 @@ def copy_with_str_subst(x, substitutions):
Returns:
copy of x
"""
- if isinstance(x, string_types):
+ if isinstance(x, str):
return x % substitutions
if isinstance(x, dict):
- return {k: copy_with_str_subst(v, substitutions) for (k, v) in iteritems(x)}
+ return {k: copy_with_str_subst(v, substitutions) for (k, v) in x.items()}
if isinstance(x, (list, tuple)):
return [copy_with_str_subst(y) for y in x]
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index 73f2cedb5c..4404ceff93 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -14,8 +14,6 @@
# limitations under the License.
import logging
-from six import iteritems
-
from synapse.api.constants import (
EventTypes,
LimitBlockingTypes,
@@ -214,7 +212,7 @@ class ResourceLimitsServerNotices(object):
referenced_events = list(pinned_state_event.content.get("pinned", []))
events = await self._store.get_events(referenced_events)
- for event_id, event in iteritems(events):
+ for event_id, event in events.items():
if event.type != EventTypes.Message:
continue
if event.content.get("msgtype") == ServerNoticeMsgType:
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 2fa529fcd0..495d9f04c8 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -18,8 +18,6 @@ import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Set
-from six import iteritems, itervalues
-
import attr
from frozendict import frozendict
from prometheus_client import Histogram
@@ -34,6 +32,7 @@ from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.types import StateMap
+from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure, measure_func
@@ -144,7 +143,7 @@ class StateHandler(object):
list(state.values()), get_prev_content=False
)
state = {
- key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map
+ key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
}
return state
@@ -416,6 +415,7 @@ class StateHandler(object):
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_store(
+ self.clock,
event.room_id,
room_version,
state_set_ids,
@@ -423,7 +423,7 @@ class StateHandler(object):
state_res_store=StateResolutionStore(self.store),
)
- new_state = {key: state_map[ev_id] for key, ev_id in iteritems(new_state)}
+ new_state = {key: state_map[ev_id] for key, ev_id in new_state.items()}
return new_state
@@ -505,8 +505,8 @@ class StateResolutionHandler(object):
# resolve_events_with_store do it?
new_state = {}
conflicted_state = False
- for st in itervalues(state_groups_ids):
- for key, e_id in iteritems(st):
+ for st in state_groups_ids.values():
+ for key, e_id in st.items():
if key in new_state:
conflicted_state = True
break
@@ -518,9 +518,10 @@ class StateResolutionHandler(object):
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_store(
+ self.clock,
room_id,
room_version,
- list(itervalues(state_groups_ids)),
+ list(state_groups_ids.values()),
event_map=event_map,
state_res_store=state_res_store,
)
@@ -561,12 +562,12 @@ def _make_state_cache_entry(new_state, state_groups_ids):
# not get persisted.
# first look for exact matches
- new_state_event_ids = set(itervalues(new_state))
- for sg, state in iteritems(state_groups_ids):
+ new_state_event_ids = set(new_state.values())
+ for sg, state in state_groups_ids.items():
if len(new_state_event_ids) != len(state):
continue
- old_state_event_ids = set(itervalues(state))
+ old_state_event_ids = set(state.values())
if new_state_event_ids == old_state_event_ids:
# got an exact match.
return _StateCacheEntry(state=new_state, state_group=sg)
@@ -579,8 +580,8 @@ def _make_state_cache_entry(new_state, state_groups_ids):
prev_group = None
delta_ids = None
- for old_group, old_state in iteritems(state_groups_ids):
- n_delta_ids = {k: v for k, v in iteritems(new_state) if old_state.get(k) != v}
+ for old_group, old_state in state_groups_ids.items():
+ n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v}
if not delta_ids or len(n_delta_ids) < len(delta_ids):
prev_group = old_group
delta_ids = n_delta_ids
@@ -591,6 +592,7 @@ def _make_state_cache_entry(new_state, state_groups_ids):
def resolve_events_with_store(
+ clock: Clock,
room_id: str,
room_version: str,
state_sets: List[StateMap[str]],
@@ -627,7 +629,7 @@ def resolve_events_with_store(
)
else:
return v2.resolve_events_with_store(
- room_id, room_version, state_sets, event_map, state_res_store
+ clock, room_id, room_version, state_sets, event_map, state_res_store
)
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 9bf98d06f2..7b531a8337 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -17,8 +17,6 @@ import hashlib
import logging
from typing import Callable, Dict, List, Optional
-from six import iteritems, iterkeys, itervalues
-
from twisted.internet import defer
from synapse import event_auth
@@ -70,11 +68,11 @@ def resolve_events_with_store(
unconflicted_state, conflicted_state = _seperate(state_sets)
needed_events = {
- event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids
+ event_id for event_ids in conflicted_state.values() for event_id in event_ids
}
needed_event_count = len(needed_events)
if event_map is not None:
- needed_events -= set(iterkeys(event_map))
+ needed_events -= set(event_map.keys())
logger.info(
"Asking for %d/%d conflicted events", len(needed_events), needed_event_count
@@ -102,11 +100,11 @@ def resolve_events_with_store(
unconflicted_state, conflicted_state, state_map
)
- new_needed_events = set(itervalues(auth_events))
+ new_needed_events = set(auth_events.values())
new_needed_event_count = len(new_needed_events)
new_needed_events -= needed_events
if event_map is not None:
- new_needed_events -= set(iterkeys(event_map))
+ new_needed_events -= set(event_map.keys())
logger.info(
"Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count
@@ -152,7 +150,7 @@ def _seperate(state_sets):
conflicted_state = {}
for state_set in state_set_iterator:
- for key, value in iteritems(state_set):
+ for key, value in state_set.items():
# Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key)
if unconflicted_value is None:
@@ -178,7 +176,7 @@ def _seperate(state_sets):
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {}
- for event_ids in itervalues(conflicted_state):
+ for event_ids in conflicted_state.values():
for event_id in event_ids:
if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id])
@@ -194,7 +192,7 @@ def _resolve_with_state(
unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map
):
conflicted_state = {}
- for key, event_ids in iteritems(conflicted_state_ids):
+ for key, event_ids in conflicted_state_ids.items():
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1:
conflicted_state[key] = events
@@ -203,7 +201,7 @@ def _resolve_with_state(
auth_events = {
key: state_map[ev_id]
- for key, ev_id in iteritems(auth_event_ids)
+ for key, ev_id in auth_event_ids.items()
if ev_id in state_map
}
@@ -214,7 +212,7 @@ def _resolve_with_state(
raise
new_state = unconflicted_state_ids
- for key, event in iteritems(resolved_state):
+ for key, event in resolved_state.items():
new_state[key] = event.event_id
return new_state
@@ -238,21 +236,21 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state)
- for key, events in iteritems(conflicted_state):
+ for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events(events, auth_events)
auth_events.update(resolved_state)
- for key, events in iteritems(conflicted_state):
+ for key, events in conflicted_state.items():
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events(events, auth_events)
auth_events.update(resolved_state)
- for key, events in iteritems(conflicted_state):
+ for key, events in conflicted_state.items():
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = _resolve_normal_events(events, auth_events)
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 18484e2fa6..bf6caa0946 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -18,8 +18,6 @@ import itertools
import logging
from typing import Dict, List, Optional
-from six import iteritems, itervalues
-
from twisted.internet import defer
import synapse.state
@@ -29,12 +27,20 @@ from synapse.api.errors import AuthError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.types import StateMap
+from synapse.util import Clock
logger = logging.getLogger(__name__)
+# We want to yield to the reactor occasionally during state res when dealing
+# with large data sets, so that we don't exhaust the reactor. This is done by
+# yielding to reactor during loops every N iterations.
+_YIELD_AFTER_ITERATIONS = 100
+
+
@defer.inlineCallbacks
def resolve_events_with_store(
+ clock: Clock,
room_id: str,
room_version: str,
state_sets: List[StateMap[str]],
@@ -44,13 +50,11 @@ def resolve_events_with_store(
"""Resolves the state using the v2 state resolution algorithm
Args:
+ clock
room_id: the room we are working in
-
room_version: The room version
-
state_sets: List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
-
event_map:
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
@@ -87,7 +91,7 @@ def resolve_events_with_store(
full_conflicted_set = set(
itertools.chain(
- itertools.chain.from_iterable(itervalues(conflicted_state)), auth_diff
+ itertools.chain.from_iterable(conflicted_state.values()), auth_diff
)
)
@@ -115,13 +119,14 @@ def resolve_events_with_store(
)
sorted_power_events = yield _reverse_topological_power_sort(
- room_id, power_events, event_map, state_res_store, full_conflicted_set
+ clock, room_id, power_events, event_map, state_res_store, full_conflicted_set
)
logger.debug("sorted %d power events", len(sorted_power_events))
# Now sequentially auth each one
resolved_state = yield _iterative_auth_checks(
+ clock,
room_id,
room_version,
sorted_power_events,
@@ -135,20 +140,22 @@ def resolve_events_with_store(
# OK, so we've now resolved the power events. Now sort the remaining
# events using the mainline of the resolved power level.
+ set_power_events = set(sorted_power_events)
leftover_events = [
- ev_id for ev_id in full_conflicted_set if ev_id not in sorted_power_events
+ ev_id for ev_id in full_conflicted_set if ev_id not in set_power_events
]
logger.debug("sorting %d remaining events", len(leftover_events))
pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
leftover_events = yield _mainline_sort(
- room_id, leftover_events, pl, event_map, state_res_store
+ clock, room_id, leftover_events, pl, event_map, state_res_store
)
logger.debug("resolving remaining events")
resolved_state = yield _iterative_auth_checks(
+ clock,
room_id,
room_version,
leftover_events,
@@ -318,12 +325,13 @@ def _add_event_and_auth_chain_to_graph(
@defer.inlineCallbacks
def _reverse_topological_power_sort(
- room_id, event_ids, event_map, state_res_store, auth_diff
+ clock, room_id, event_ids, event_map, state_res_store, auth_diff
):
"""Returns a list of the event_ids sorted by reverse topological ordering,
and then by power level and origin_server_ts
Args:
+ clock (Clock)
room_id (str): the room we are working in
event_ids (list[str]): The events to sort
event_map (dict[str,FrozenEvent])
@@ -335,18 +343,28 @@ def _reverse_topological_power_sort(
"""
graph = {}
- for event_id in event_ids:
+ for idx, event_id in enumerate(event_ids, start=1):
yield _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff
)
+ # We yield occasionally when we're working with large data sets to
+ # ensure that we don't block the reactor loop for too long.
+ if idx % _YIELD_AFTER_ITERATIONS == 0:
+ yield clock.sleep(0)
+
event_to_pl = {}
- for event_id in graph:
+ for idx, event_id in enumerate(graph, start=1):
pl = yield _get_power_level_for_sender(
room_id, event_id, event_map, state_res_store
)
event_to_pl[event_id] = pl
+ # We yield occasionally when we're working with large data sets to
+ # ensure that we don't block the reactor loop for too long.
+ if idx % _YIELD_AFTER_ITERATIONS == 0:
+ yield clock.sleep(0)
+
def _get_power_order(event_id):
ev = event_map[event_id]
pl = event_to_pl[event_id]
@@ -362,12 +380,13 @@ def _reverse_topological_power_sort(
@defer.inlineCallbacks
def _iterative_auth_checks(
- room_id, room_version, event_ids, base_state, event_map, state_res_store
+ clock, room_id, room_version, event_ids, base_state, event_map, state_res_store
):
"""Sequentially apply auth checks to each event in given list, updating the
state as it goes along.
Args:
+ clock (Clock)
room_id (str)
room_version (str)
event_ids (list[str]): Ordered list of events to apply auth checks to
@@ -381,7 +400,7 @@ def _iterative_auth_checks(
resolved_state = base_state.copy()
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
- for event_id in event_ids:
+ for idx, event_id in enumerate(event_ids, start=1):
event = event_map[event_id]
auth_events = {}
@@ -419,17 +438,23 @@ def _iterative_auth_checks(
except AuthError:
pass
+ # We yield occasionally when we're working with large data sets to
+ # ensure that we don't block the reactor loop for too long.
+ if idx % _YIELD_AFTER_ITERATIONS == 0:
+ yield clock.sleep(0)
+
return resolved_state
@defer.inlineCallbacks
def _mainline_sort(
- room_id, event_ids, resolved_power_event_id, event_map, state_res_store
+ clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store
):
"""Returns a sorted list of event_ids sorted by mainline ordering based on
the given event resolved_power_event_id
Args:
+ clock (Clock)
room_id (str): room we're working in
event_ids (list[str]): Events to sort
resolved_power_event_id (str): The final resolved power level event ID
@@ -439,8 +464,14 @@ def _mainline_sort(
Returns:
Deferred[list[str]]: The sorted list
"""
+ if not event_ids:
+ # It's possible for there to be no event IDs here to sort, so we can
+ # skip calculating the mainline in that case.
+ return []
+
mainline = []
pl = resolved_power_event_id
+ idx = 0
while pl:
mainline.append(pl)
pl_ev = yield _get_event(room_id, pl, event_map, state_res_store)
@@ -454,17 +485,29 @@ def _mainline_sort(
pl = aid
break
+ # We yield occasionally when we're working with large data sets to
+ # ensure that we don't block the reactor loop for too long.
+ if idx != 0 and idx % _YIELD_AFTER_ITERATIONS == 0:
+ yield clock.sleep(0)
+
+ idx += 1
+
mainline_map = {ev_id: i + 1 for i, ev_id in enumerate(reversed(mainline))}
event_ids = list(event_ids)
order_map = {}
- for ev_id in event_ids:
+ for idx, ev_id in enumerate(event_ids, start=1):
depth = yield _get_mainline_depth_for_event(
event_map[ev_id], mainline_map, event_map, state_res_store
)
order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
+ # We yield occasionally when we're working with large data sets to
+ # ensure that we don't block the reactor loop for too long.
+ if idx % _YIELD_AFTER_ITERATIONS == 0:
+ yield clock.sleep(0)
+
event_ids.sort(key=lambda ev_id: order_map[ev_id])
return event_ids
@@ -572,7 +615,7 @@ def lexicographical_topological_sort(graph, key):
# `(key(node), node)` so that sorting does the right thing
zero_outdegree = []
- for node, edges in iteritems(graph):
+ for node, edges in graph.items():
if len(edges) == 0:
zero_outdegree.append((key(node), node))
diff --git a/synapse/static/client/login/index.html b/synapse/static/client/login/index.html
index 6fefdaaff7..9e6daf38ac 100644
--- a/synapse/static/client/login/index.html
+++ b/synapse/static/client/login/index.html
@@ -1,24 +1,24 @@
<!doctype html>
<html>
<head>
-<title> Login </title>
-<meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
-<link rel="stylesheet" href="style.css">
-<script src="js/jquery-3.4.1.min.js"></script>
-<script src="js/login.js"></script>
+ <meta http-equiv="Content-Type" content="text/html; charset=utf-8">
+ <title> Login </title>
+ <meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+ <link rel="stylesheet" href="style.css">
+ <script src="js/jquery-3.4.1.min.js"></script>
+ <script src="js/login.js"></script>
</head>
<body onload="matrixLogin.onLoad()">
- <center>
- <br/>
+ <div id="container">
<h1 id="title"></h1>
- <span id="feedback" style="color: #f00"></span>
+ <span id="feedback"></span>
<div id="loading">
<img src="spinner.gif" />
</div>
- <div id="sso_flow" class="login_flow" style="display:none">
+ <div id="sso_flow" class="login_flow" style="display: none;">
Single-sign on:
<form id="sso_form" action="/_matrix/client/r0/login/sso/redirect" method="get">
<input id="sso_redirect_url" type="hidden" name="redirectUrl" value=""/>
@@ -26,9 +26,9 @@
</form>
</div>
- <div id="password_flow" class="login_flow" style="display:none">
+ <div id="password_flow" class="login_flow" style="display: none;">
Password Authentication:
- <form onsubmit="matrixLogin.password_login(); return false;">
+ <form onsubmit="matrixLogin.passwordLogin(); return false;">
<input id="user_id" size="32" type="text" placeholder="Matrix ID (e.g. bob)" autocapitalize="off" autocorrect="off" />
<br/>
<input id="password" size="32" type="password" placeholder="Password"/>
@@ -38,9 +38,9 @@
</form>
</div>
- <div id="no_login_types" type="button" class="login_flow" style="display:none">
+ <div id="no_login_types" type="button" class="login_flow" style="display: none;">
Log in currently unavailable.
</div>
- </center>
+ </div>
</body>
</html>
diff --git a/synapse/static/client/login/js/login.js b/synapse/static/client/login/js/login.js
index ba8048b23f..3678670ec7 100644
--- a/synapse/static/client/login/js/login.js
+++ b/synapse/static/client/login/js/login.js
@@ -5,11 +5,11 @@ window.matrixLogin = {
};
// Titles get updated through the process to give users feedback.
-var TITLE_PRE_AUTH = "Log in with one of the following methods";
-var TITLE_POST_AUTH = "Logging in...";
+const TITLE_PRE_AUTH = "Log in with one of the following methods";
+const TITLE_POST_AUTH = "Logging in...";
// The cookie used to store the original query parameters when using SSO.
-var COOKIE_KEY = "synapse_login_fallback_qs";
+const COOKIE_KEY = "synapse_login_fallback_qs";
/*
* Submit a login request.
@@ -20,9 +20,9 @@ var COOKIE_KEY = "synapse_login_fallback_qs";
* login request, e.g. device_id.
* callback: (Optional) Function to call on successful login.
*/
-var submitLogin = function(type, data, extra, callback) {
+function submitLogin(type, data, extra, callback) {
console.log("Logging in with " + type);
- set_title(TITLE_POST_AUTH);
+ setTitle(TITLE_POST_AUTH);
// Add the login type.
data.type = type;
@@ -41,12 +41,15 @@ var submitLogin = function(type, data, extra, callback) {
}
matrixLogin.onLogin(response);
}).fail(errorFunc);
-};
+}
-var errorFunc = function(err) {
+/*
+ * Display an error to the user and show the login form again.
+ */
+function errorFunc(err) {
// We want to show the error to the user rather than redirecting immediately to the
// SSO portal (if SSO is the only login option), so we inhibit the redirect.
- show_login(true);
+ showLogin(true);
if (err.responseJSON && err.responseJSON.error) {
setFeedbackString(err.responseJSON.error + " (" + err.responseJSON.errcode + ")");
@@ -54,27 +57,42 @@ var errorFunc = function(err) {
else {
setFeedbackString("Request failed: " + err.status);
}
-};
+}
-var setFeedbackString = function(text) {
+/*
+ * Display an error to the user.
+ */
+function setFeedbackString(text) {
$("#feedback").text(text);
-};
+}
-var show_login = function(inhibit_redirect) {
- // Set the redirect to come back to this page, a login token will get added
- // and handled after the redirect.
- var this_page = window.location.origin + window.location.pathname;
- $("#sso_redirect_url").val(this_page);
+/*
+ * (Maybe) Show the login forms.
+ *
+ * This actually does a few unrelated functions:
+ *
+ * * Configures the SSO redirect URL to come back to this page.
+ * * Configures and shows the SSO form, if the server supports SSO.
+ * * Otherwise, shows the password form.
+ */
+function showLogin(inhibitRedirect) {
+ setTitle(TITLE_PRE_AUTH);
- // If inhibit_redirect is false, and SSO is the only supported login method,
+ // If inhibitRedirect is false, and SSO is the only supported login method,
// we can redirect straight to the SSO page.
if (matrixLogin.serverAcceptsSso) {
+ // Set the redirect to come back to this page, a login token will get
+ // added as a query parameter and handled after the redirect.
+ $("#sso_redirect_url").val(window.location.origin + window.location.pathname);
+
// Before submitting SSO, set the current query parameters into a cookie
// for retrieval later.
var qs = parseQsFromUrl();
setCookie(COOKIE_KEY, JSON.stringify(qs));
- if (!inhibit_redirect && !matrixLogin.serverAcceptsPassword) {
+ // If password is not supported and redirects are allowed, then submit
+ // the form (redirecting to the SSO provider).
+ if (!inhibitRedirect && !matrixLogin.serverAcceptsPassword) {
$("#sso_form").submit();
return;
}
@@ -87,30 +105,39 @@ var show_login = function(inhibit_redirect) {
$("#password_flow").show();
}
+ // If neither password or SSO are supported, show an error to the user.
if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsSso) {
$("#no_login_types").show();
}
- set_title(TITLE_PRE_AUTH);
-
$("#loading").hide();
-};
+}
-var show_spinner = function() {
+/*
+ * Hides the forms and shows a loading throbber.
+ */
+function showSpinner() {
$("#password_flow").hide();
$("#sso_flow").hide();
$("#no_login_types").hide();
$("#loading").show();
-};
+}
-var set_title = function(title) {
+/*
+ * Helper to show the page's main title.
+ */
+function setTitle(title) {
$("#title").text(title);
-};
+}
-var fetch_info = function(cb) {
+/*
+ * Query the login endpoint for the homeserver's supported flows.
+ *
+ * This populates matrixLogin.serverAccepts* variables.
+ */
+function fetchLoginFlows(cb) {
$.get(matrixLogin.endpoint, function(response) {
- var serverAcceptsPassword = false;
- for (var i=0; i<response.flows.length; i++) {
+ for (var i = 0; i < response.flows.length; i++) {
var flow = response.flows[i];
if ("m.login.sso" === flow.type) {
matrixLogin.serverAcceptsSso = true;
@@ -126,27 +153,41 @@ var fetch_info = function(cb) {
}).fail(errorFunc);
}
+/*
+ * Called on load to fetch login flows and attempt SSO login (if a token is available).
+ */
matrixLogin.onLoad = function() {
- fetch_info(function() {
- if (!try_token()) {
- show_login(false);
+ fetchLoginFlows(function() {
+ // (Maybe) attempt logging in via SSO if a token is available.
+ if (!tryTokenLogin()) {
+ showLogin(false);
}
});
};
-matrixLogin.password_login = function() {
+/*
+ * Submit simple user & password login.
+ */
+matrixLogin.passwordLogin = function() {
var user = $("#user_id").val();
var pwd = $("#password").val();
setFeedbackString("");
- show_spinner();
+ showSpinner();
submitLogin(
"m.login.password",
{user: user, password: pwd},
parseQsFromUrl());
};
+/*
+ * The onLogin function gets called after a succesful login.
+ *
+ * It is expected that implementations override this to be notified when the
+ * login is complete. The response to the login call is provided as the single
+ * parameter.
+ */
matrixLogin.onLogin = function(response) {
// clobber this function
console.warn("onLogin - This function should be replaced to proceed.");
@@ -155,7 +196,7 @@ matrixLogin.onLogin = function(response) {
/*
* Process the query parameters from the current URL into an object.
*/
-var parseQsFromUrl = function() {
+function parseQsFromUrl() {
var pos = window.location.href.indexOf("?");
if (pos == -1) {
return {};
@@ -174,12 +215,12 @@ var parseQsFromUrl = function() {
result[key] = val;
});
return result;
-};
+}
/*
* Process the cookies and return an object.
*/
-var parseCookies = function() {
+function parseCookies() {
var allCookies = document.cookie;
var result = {};
allCookies.split(";").forEach(function(part) {
@@ -196,32 +237,32 @@ var parseCookies = function() {
result[key] = val;
});
return result;
-};
+}
/*
* Set a cookie that is valid for 1 hour.
*/
-var setCookie = function(key, value) {
+function setCookie(key, value) {
// The maximum age is set in seconds.
var maxAge = 60 * 60;
// Set the cookie, this defaults to the current domain and path.
document.cookie = key + "=" + encodeURIComponent(value) + ";max-age=" + maxAge + ";sameSite=lax";
-};
+}
/*
* Removes a cookie by key.
*/
-var deleteCookie = function(key) {
+function deleteCookie(key) {
// Delete a cookie by setting the expiration to 0. (Note that the value
// doesn't matter.)
document.cookie = key + "=deleted;expires=0";
-};
+}
/*
* Submits the login token if one is found in the query parameters. Returns a
* boolean of whether the login token was found or not.
*/
-var try_token = function() {
+function tryTokenLogin() {
// Check if the login token is in the query parameters.
var qs = parseQsFromUrl();
@@ -233,18 +274,18 @@ var try_token = function() {
// Retrieve the original query parameters (from before the SSO redirect).
// They are stored as JSON in a cookie.
var cookies = parseCookies();
- var original_query_params = JSON.parse(cookies[COOKIE_KEY] || "{}")
+ var originalQueryParams = JSON.parse(cookies[COOKIE_KEY] || "{}")
// If the login is successful, delete the cookie.
- var callback = function() {
+ function callback() {
deleteCookie(COOKIE_KEY);
}
submitLogin(
"m.login.token",
{token: loginToken},
- original_query_params,
+ originalQueryParams,
callback);
return true;
-};
+}
diff --git a/synapse/static/client/login/style.css b/synapse/static/client/login/style.css
index 1cce5ed950..83e4f6abc8 100644
--- a/synapse/static/client/login/style.css
+++ b/synapse/static/client/login/style.css
@@ -31,20 +31,44 @@ form {
margin: 10px 0 0 0;
}
+/*
+ * Add some padding to the viewport.
+ */
+#container {
+ padding: 10px;
+}
+/*
+ * Center all direct children of the main form.
+ */
+#container > * {
+ display: block;
+ margin-left: auto;
+ margin-right: auto;
+ text-align: center;
+}
+
+/*
+ * A wrapper around each login flow.
+ */
.login_flow {
width: 300px;
text-align: left;
padding: 10px;
margin-bottom: 40px;
- -webkit-border-radius: 10px;
- -moz-border-radius: 10px;
border-radius: 10px;
-
- -webkit-box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
- -moz-box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
background-color: #f8f8f8;
border: 1px #ccc solid;
}
+
+/*
+ * Used to show error content.
+ */
+#feedback {
+ /* Red text. */
+ color: #ff0000;
+ /* A little space to not overlap the box-shadow. */
+ margin-bottom: 20px;
+}
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 71f8d43a76..995d4764a9 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -15,8 +15,6 @@
import logging
-from six import iteritems
-
from twisted.internet import defer
from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -421,7 +419,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
):
self.database_engine.lock_table(txn, "user_ips")
- for entry in iteritems(to_update):
+ for entry in to_update.items():
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
try:
@@ -530,7 +528,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
"user_agent": user_agent,
"last_seen": last_seen,
}
- for (access_token, ip), (user_agent, last_seen) in iteritems(results)
+ for (access_token, ip), (user_agent, last_seen) in results.items()
]
@wrap_as_background_process("prune_old_user_ips")
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index fb9f798e29..0ff0542453 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -17,8 +17,6 @@
import logging
from typing import List, Optional, Set, Tuple
-from six import iteritems
-
from canonicaljson import json
from twisted.internet import defer
@@ -208,7 +206,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
# add the updated cross-signing keys to the results list
- for user_id, result in iteritems(cross_signing_keys_by_user):
+ for user_id, result in cross_signing_keys_by_user.items():
result["user_id"] = user_id
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
results.append(("org.matrix.signing_key_update", result))
@@ -269,7 +267,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
results = []
- for user_id, user_devices in iteritems(devices):
+ for user_id, user_devices in devices.items():
# The prev_id for the first row is always the last row before
# `from_stream_id`
prev_id = yield self._get_last_device_update_for_remote_user(
@@ -493,7 +491,7 @@ class DeviceWorkerStore(SQLBaseStore):
if devices:
user_devices = devices[user_id]
results = []
- for device_id, device in iteritems(user_devices):
+ for device_id, device in user_devices.items():
result = {"device_id": device_id}
key_json = device.get("key_json", None)
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 20698bfd16..1a0842d4b0 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -16,8 +16,6 @@
# limitations under the License.
from typing import Dict, List
-from six import iteritems
-
from canonicaljson import encode_canonical_json, json
from twisted.enterprise.adbapi import Connection
@@ -64,9 +62,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
rv = {}
- for user_id, device_keys in iteritems(results):
+ for user_id, device_keys in results.items():
rv[user_id] = {}
- for device_id, device_info in iteritems(device_keys):
+ for device_id, device_info in device_keys.items():
r = db_to_json(device_info.pop("key_json"))
r["unsigned"] = {}
display_name = device_info["device_display_name"]
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 24ce8c4330..a6bb3221ff 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -14,10 +14,9 @@
# limitations under the License.
import itertools
import logging
+from queue import Empty, PriorityQueue
from typing import Dict, List, Optional, Set, Tuple
-from six.moves.queue import Empty, PriorityQueue
-
from twisted.internet import defer
from synapse.api.errors import StoreError
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 0321274de2..bc9f4f08ea 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -16,8 +16,6 @@
import logging
-from six import iteritems
-
from canonicaljson import json
from twisted.internet import defer
@@ -455,7 +453,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
sql,
(
_gen_entry(user_id, actions)
- for user_id, actions in iteritems(user_id_actions)
+ for user_id, actions in user_id_actions.items()
),
)
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index a6572571b4..cfd24d2f06 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -21,9 +21,6 @@ from collections import OrderedDict, namedtuple
from functools import wraps
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
-from six import integer_types, iteritems, text_type
-from six.moves import range
-
import attr
from canonicaljson import json
from prometheus_client import Counter
@@ -232,10 +229,10 @@ class PersistEventsStore:
event_counter.labels(event.type, origin_type, origin_entity).inc()
- for room_id, new_state in iteritems(current_state_for_room):
+ for room_id, new_state in current_state_for_room.items():
self.store.get_current_state_ids.prefill((room_id,), new_state)
- for room_id, latest_event_ids in iteritems(new_forward_extremeties):
+ for room_id, latest_event_ids in new_forward_extremeties.items():
self.store.get_latest_event_ids_in_room.prefill(
(room_id,), list(latest_event_ids)
)
@@ -461,7 +458,7 @@ class PersistEventsStore:
state_delta_by_room: Dict[str, DeltaState],
stream_id: int,
):
- for room_id, delta_state in iteritems(state_delta_by_room):
+ for room_id, delta_state in state_delta_by_room.items():
to_delete = delta_state.to_delete
to_insert = delta_state.to_insert
@@ -545,7 +542,7 @@ class PersistEventsStore:
""",
[
(room_id, key[0], key[1], ev_id, ev_id)
- for key, ev_id in iteritems(to_insert)
+ for key, ev_id in to_insert.items()
],
)
@@ -642,7 +639,7 @@ class PersistEventsStore:
def _update_forward_extremities_txn(
self, txn, new_forward_extremities, max_stream_order
):
- for room_id, new_extrem in iteritems(new_forward_extremities):
+ for room_id, new_extrem in new_forward_extremities.items():
self.db.simple_delete_txn(
txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
)
@@ -655,7 +652,7 @@ class PersistEventsStore:
table="event_forward_extremities",
values=[
{"event_id": ev_id, "room_id": room_id}
- for room_id, new_extrem in iteritems(new_forward_extremities)
+ for room_id, new_extrem in new_forward_extremities.items()
for ev_id in new_extrem
],
)
@@ -672,7 +669,7 @@ class PersistEventsStore:
"event_id": event_id,
"stream_ordering": max_stream_order,
}
- for room_id, new_extrem in iteritems(new_forward_extremities)
+ for room_id, new_extrem in new_forward_extremities.items()
for event_id in new_extrem
],
)
@@ -727,7 +724,7 @@ class PersistEventsStore:
event.depth, depth_updates.get(event.room_id, event.depth)
)
- for room_id, depth in iteritems(depth_updates):
+ for room_id, depth in depth_updates.items():
self._update_min_depth_for_room_txn(txn, room_id, depth)
def _update_outliers_txn(self, txn, events_and_contexts):
@@ -893,8 +890,7 @@ class PersistEventsStore:
"received_ts": self._clock.time_msec(),
"sender": event.sender,
"contains_url": (
- "url" in event.content
- and isinstance(event.content["url"], text_type)
+ "url" in event.content and isinstance(event.content["url"], str)
),
}
for event, _ in events_and_contexts
@@ -1345,10 +1341,10 @@ class PersistEventsStore:
):
if (
"min_lifetime" in event.content
- and not isinstance(event.content.get("min_lifetime"), integer_types)
+ and not isinstance(event.content.get("min_lifetime"), int)
) or (
"max_lifetime" in event.content
- and not isinstance(event.content.get("max_lifetime"), integer_types)
+ and not isinstance(event.content.get("max_lifetime"), int)
):
# Ignore the event if one of the value isn't an integer.
return
@@ -1497,11 +1493,11 @@ class PersistEventsStore:
table="event_to_state_groups",
values=[
{"state_group": state_group_id, "event_id": event_id}
- for event_id, state_group_id in iteritems(state_groups)
+ for event_id, state_group_id in state_groups.items()
],
)
- for event_id, state_group_id in iteritems(state_groups):
+ for event_id, state_group_id in state_groups.items():
txn.call_after(
self.store._get_state_group_for_event.prefill,
(event_id,),
diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index f54c8b1ee0..62d28f44dc 100644
--- a/synapse/storage/data_stores/main/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -15,8 +15,6 @@
import logging
-from six import text_type
-
from canonicaljson import json
from twisted.internet import defer
@@ -133,7 +131,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
contains_url = "url" in content
if contains_url:
- contains_url &= isinstance(content["url"], text_type)
+ contains_url &= isinstance(content["url"], str)
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 213d69100a..a48c7a96ca 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -1077,9 +1077,32 @@ class EventsWorkerStore(SQLBaseStore):
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
)
- def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
+ async def get_all_new_backfill_event_rows(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, list]], int, bool]:
+ """Get updates for backfill replication stream, including all new
+ backfilled events and events that have gone from being outliers to not.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
+ """
if last_id == current_id:
- return defer.succeed([])
+ return [], current_id, False
def get_all_new_backfill_event_rows(txn):
sql = (
@@ -1094,10 +1117,12 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (-last_id, -current_id, limit))
- new_event_updates = txn.fetchall()
+ new_event_updates = [(row[0], row[1:]) for row in txn]
+ limited = False
if len(new_event_updates) == limit:
upper_bound = new_event_updates[-1][0]
+ limited = True
else:
upper_bound = current_id
@@ -1114,11 +1139,15 @@ class EventsWorkerStore(SQLBaseStore):
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_id, -upper_bound))
- new_event_updates.extend(txn.fetchall())
+ new_event_updates.extend((row[0], row[1:]) for row in txn)
- return new_event_updates
+ if len(new_event_updates) >= limit:
+ upper_bound = new_event_updates[-1][0]
+ limited = True
- return self.db.runInteraction(
+ return new_event_updates, upper_bound, limited
+
+ return await self.db.runInteraction(
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)
diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py
index 8aecd414c2..15bc13cbd0 100644
--- a/synapse/storage/data_stores/main/media_repository.py
+++ b/synapse/storage/data_stores/main/media_repository.py
@@ -81,6 +81,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_local_media",
)
+ def mark_local_media_as_safe(self, media_id: str):
+ """Mark a local media as safe from quarantining."""
+ return self.db.simple_update_one(
+ table="local_media_repository",
+ keyvalues={"media_id": media_id},
+ updatevalues={"safe_from_quarantine": True},
+ desc="mark_local_media_as_safe",
+ )
+
def get_url_cache(self, url, ts):
"""Get the media_id and ts for a cached URL as of the given timestamp
Returns:
diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py
index dab31e0c2d..7574612619 100644
--- a/synapse/storage/data_stores/main/presence.py
+++ b/synapse/storage/data_stores/main/presence.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Tuple
+
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
@@ -73,9 +75,32 @@ class PresenceStore(SQLBaseStore):
)
txn.execute(sql + clause, [stream_id] + list(args))
- def get_all_presence_updates(self, last_id, current_id, limit):
+ async def get_all_presence_updates(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, list]], int, bool]:
+ """Get updates for presence replication stream.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
+ """
+
if last_id == current_id:
- return defer.succeed([])
+ return [], current_id, False
def get_all_presence_updates_txn(txn):
sql = """
@@ -89,9 +114,17 @@ class PresenceStore(SQLBaseStore):
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
- return txn.fetchall()
+ updates = [(row[0], row[1:]) for row in txn]
+
+ upper_bound = current_id
+ limited = False
+ if len(updates) >= limit:
+ upper_bound = updates[-1][0]
+ limited = True
+
+ return updates, upper_bound, limited
- return self.db.runInteraction(
+ return await self.db.runInteraction(
"get_all_presence_updates", get_all_presence_updates_txn
)
diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py
index bfc9369f0b..cd2472feb0 100644
--- a/synapse/storage/data_stores/main/profile.py
+++ b/synapse/storage/data_stores/main/profile.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,11 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Tuple
+
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.roommember import ProfileInfo
+from synapse.types import UserID
+from synapse.util.caches.descriptors import cached
+
+BATCH_SIZE = 100
class ProfileWorkerStore(SQLBaseStore):
@@ -41,6 +48,7 @@ class ProfileWorkerStore(SQLBaseStore):
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
+ @cached(max_entries=5000)
def get_profile_displayname(self, user_localpart):
return self.db.simple_select_one_onecol(
table="profiles",
@@ -49,6 +57,7 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_displayname",
)
+ @cached(max_entries=5000)
def get_profile_avatar_url(self, user_localpart):
return self.db.simple_select_one_onecol(
table="profiles",
@@ -57,6 +66,54 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_avatar_url",
)
+ def get_latest_profile_replication_batch_number(self):
+ def f(txn):
+ txn.execute("SELECT MAX(batch) as maxbatch FROM profiles")
+ rows = self.db.cursor_to_dict(txn)
+ return rows[0]["maxbatch"]
+
+ return self.db.runInteraction("get_latest_profile_replication_batch_number", f)
+
+ def get_profile_batch(self, batchnum):
+ return self.db.simple_select_list(
+ table="profiles",
+ keyvalues={"batch": batchnum},
+ retcols=("user_id", "displayname", "avatar_url", "active"),
+ desc="get_profile_batch",
+ )
+
+ def assign_profile_batch(self):
+ def f(txn):
+ sql = (
+ "UPDATE profiles SET batch = "
+ "(SELECT COALESCE(MAX(batch), -1) + 1 FROM profiles) "
+ "WHERE user_id in ("
+ " SELECT user_id FROM profiles WHERE batch is NULL limit ?"
+ ")"
+ )
+ txn.execute(sql, (BATCH_SIZE,))
+ return txn.rowcount
+
+ return self.db.runInteraction("assign_profile_batch", f)
+
+ def get_replication_hosts(self):
+ def f(txn):
+ txn.execute(
+ "SELECT host, last_synced_batch FROM profile_replication_status"
+ )
+ rows = self.db.cursor_to_dict(txn)
+ return {r["host"]: r["last_synced_batch"] for r in rows}
+
+ return self.db.runInteraction("get_replication_hosts", f)
+
+ def update_replication_batch_for_host(self, host, last_synced_batch):
+ return self.db.simple_upsert(
+ table="profile_replication_status",
+ keyvalues={"host": host},
+ values={"last_synced_batch": last_synced_batch},
+ desc="update_replication_batch_for_host",
+ )
+
def get_from_remote_profile_cache(self, user_id):
return self.db.simple_select_one(
table="remote_profile_cache",
@@ -71,24 +128,83 @@ class ProfileWorkerStore(SQLBaseStore):
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
)
- def set_profile_displayname(self, user_localpart, new_displayname):
- return self.db.simple_update_one(
+ def set_profile_displayname(self, user_localpart, new_displayname, batchnum):
+ # Invalidate the read cache for this user
+ self.get_profile_displayname.invalidate((user_localpart,))
+
+ return self.db.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"displayname": new_displayname},
+ values={"displayname": new_displayname, "batch": batchnum},
desc="set_profile_displayname",
+ lock=False, # we can do this because user_id has a unique index
)
- def set_profile_avatar_url(self, user_localpart, new_avatar_url):
- return self.db.simple_update_one(
+ def set_profile_avatar_url(self, user_localpart, new_avatar_url, batchnum):
+ # Invalidate the read cache for this user
+ self.get_profile_avatar_url.invalidate((user_localpart,))
+
+ return self.db.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"avatar_url": new_avatar_url},
+ values={"avatar_url": new_avatar_url, "batch": batchnum},
desc="set_profile_avatar_url",
+ lock=False, # we can do this because user_id has a unique index
+ )
+
+ def set_profiles_active(
+ self, users: List[UserID], active: bool, hide: bool, batchnum: int,
+ ):
+ """Given a set of users, set active and hidden flags on them.
+
+ Args:
+ users: A list of UserIDs
+ active: Whether to set the users to active or inactive
+ hide: Whether to hide the users (withold from replication). If
+ False and active is False, users will have their profiles
+ erased
+ batchnum: The batch number, used for profile replication
+
+ Returns:
+ Deferred
+ """
+ # Convert list of localparts to list of tuples containing localparts
+ user_localparts = [(user.localpart,) for user in users]
+
+ # Generate list of value tuples for each user
+ value_names = ("active", "batch")
+ values = [(int(active), batchnum) for _ in user_localparts] # type: List[Tuple]
+
+ if not active and not hide:
+ # we are deactivating for real (not in hide mode)
+ # so clear the profile information
+ value_names += ("avatar_url", "displayname")
+ values = [v + (None, None) for v in values]
+
+ return self.db.runInteraction(
+ "set_profiles_active",
+ self.db.simple_upsert_many_txn,
+ table="profiles",
+ key_names=("user_id",),
+ key_values=user_localparts,
+ value_names=value_names,
+ value_values=values,
)
class ProfileStore(ProfileWorkerStore):
+ def __init__(self, database, db_conn, hs):
+
+ super(ProfileStore, self).__init__(database, db_conn, hs)
+
+ self.db.updates.register_background_index_update(
+ "profile_replication_status_host_index",
+ index_name="profile_replication_status_idx",
+ table="profile_replication_status",
+ columns=["host"],
+ unique=True,
+ )
+
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
"""Ensure we are caching the remote user's profiles.
@@ -107,7 +223,7 @@ class ProfileStore(ProfileWorkerStore):
)
def update_remote_profile_cache(self, user_id, displayname, avatar_url):
- return self.db.simple_update(
+ return self.db.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
updatevalues={
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index ef8f40959f..f6e78ca590 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -16,7 +16,7 @@
import abc
import logging
-from typing import Union
+from typing import List, Tuple, Union
from canonicaljson import json
@@ -348,23 +348,53 @@ class PushRulesWorkerStore(
results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
return results
- def get_all_push_rule_updates(self, last_id, current_id, limit):
- """Get all the push rules changes that have happend on the server"""
+ async def get_all_push_rule_updates(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for push_rules replication stream.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
+ """
+
if last_id == current_id:
- return defer.succeed([])
+ return [], current_id, False
def get_all_push_rule_updates_txn(txn):
- sql = (
- "SELECT stream_id, event_stream_ordering, user_id, rule_id,"
- " op, priority_class, priority, conditions, actions"
- " FROM push_rules_stream"
- " WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC LIMIT ?"
- )
+ sql = """
+ SELECT stream_id, user_id
+ FROM push_rules_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
txn.execute(sql, (last_id, current_id, limit))
- return txn.fetchall()
+ updates = [(stream_id, (user_id,)) for stream_id, user_id in txn]
+
+ limited = False
+ upper_bound = current_id
+ if len(updates) == limit:
+ limited = True
+ upper_bound = updates[-1][0]
+
+ return updates, upper_bound, limited
- return self.db.runInteraction(
+ return await self.db.runInteraction(
"get_all_push_rule_updates", get_all_push_rule_updates_txn
)
diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py
index cebdcd409f..8f5505bd67 100644
--- a/synapse/storage/data_stores/main/receipts.py
+++ b/synapse/storage/data_stores/main/receipts.py
@@ -16,6 +16,7 @@
import abc
import logging
+from typing import List, Tuple
from canonicaljson import json
@@ -24,6 +25,7 @@ from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -266,26 +268,79 @@ class ReceiptsWorkerStore(SQLBaseStore):
}
return results
- def get_all_updated_receipts(self, last_id, current_id, limit=None):
+ def get_users_sent_receipts_between(self, last_id: int, current_id: int):
+ """Get all users who sent receipts between `last_id` exclusive and
+ `current_id` inclusive.
+
+ Returns:
+ Deferred[List[str]]
+ """
+
if last_id == current_id:
return defer.succeed([])
- def get_all_updated_receipts_txn(txn):
- sql = (
- "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
- " FROM receipts_linearized"
- " WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC"
- )
- args = [last_id, current_id]
- if limit is not None:
- sql += " LIMIT ?"
- args.append(limit)
- txn.execute(sql, args)
+ def _get_users_sent_receipts_between_txn(txn):
+ sql = """
+ SELECT DISTINCT user_id FROM receipts_linearized
+ WHERE ? < stream_id AND stream_id <= ?
+ """
+ txn.execute(sql, (last_id, current_id))
- return [r[0:5] + (json.loads(r[5]),) for r in txn]
+ return [r[0] for r in txn]
return self.db.runInteraction(
+ "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn
+ )
+
+ async def get_all_updated_receipts(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, list]], int, bool]:
+ """Get updates for receipts replication stream.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
+ """
+
+ if last_id == current_id:
+ return [], current_id, False
+
+ def get_all_updated_receipts_txn(txn):
+ sql = """
+ SELECT stream_id, room_id, receipt_type, user_id, event_id, data
+ FROM receipts_linearized
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (last_id, current_id, limit))
+
+ updates = [(r[0], r[1:5] + (json.loads(r[5]),)) for r in txn]
+
+ limited = False
+ upper_bound = current_id
+
+ if len(updates) == limit:
+ limited = True
+ upper_bound = updates[-1][0]
+
+ return updates, upper_bound, limited
+
+ return await self.db.runInteraction(
"get_all_updated_receipts", get_all_updated_receipts_txn
)
@@ -300,10 +355,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
room_id, None, update_metrics=False
)
- # first handle the Deferred case
- if isinstance(res, defer.Deferred):
- if res.called:
- res = res.result
+ # first handle the ObservableDeferred case
+ if isinstance(res, ObservableDeferred):
+ if res.has_called():
+ res = res.get_result()
else:
res = None
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index 9768981891..efb1a4fb4c 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -17,9 +17,7 @@
import logging
import re
-from typing import Optional
-
-from six import iterkeys
+from typing import List, Optional
from twisted.internet import defer
from twisted.internet.defer import Deferred
@@ -159,6 +157,37 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@defer.inlineCallbacks
+ def get_expired_users(self):
+ """Get UserIDs of all expired users.
+
+ Users who are not active, or do not have profile information, are
+ excluded from the results.
+
+ Returns:
+ Deferred[List[UserID]]: List of expired user IDs
+ """
+
+ def get_expired_users_txn(txn, now_ms):
+ # We need to use pattern matching as profiles.user_id is confusingly just the
+ # user's localpart, whereas account_validity.user_id is a full user ID
+ sql = """
+ SELECT av.user_id from account_validity AS av
+ LEFT JOIN profiles as p
+ ON av.user_id LIKE '%%' || p.user_id || ':%%'
+ WHERE expiration_ts_ms <= ?
+ AND p.active = 1
+ """
+ txn.execute(sql, (now_ms,))
+ rows = txn.fetchall()
+
+ return [UserID.from_string(row[0]) for row in rows]
+
+ res = yield self.db.runInteraction(
+ "get_expired_users", get_expired_users_txn, self.clock.time_msec()
+ )
+ return res
+
+ @defer.inlineCallbacks
def set_renewal_token_for_user(self, user_id, renewal_token):
"""Defines a renewal token for a given user.
@@ -274,6 +303,55 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="delete_account_validity_for_user",
)
+ @defer.inlineCallbacks
+ def get_info_for_users(
+ self, user_ids: List[str],
+ ):
+ """Return the user info for a given set of users
+
+ Args:
+ user_ids: A list of users to return information about
+
+ Returns:
+ Deferred[Dict[str, bool]]: A dictionary mapping each user ID to
+ a dict with the following keys:
+ * expired - whether this is an expired user
+ * deactivated - whether this is a deactivated user
+ """
+ # Get information of all our local users
+ def _get_info_for_users_txn(txn):
+ rows = []
+
+ for user_id in user_ids:
+ sql = """
+ SELECT u.name, u.deactivated, av.expiration_ts_ms
+ FROM users as u
+ LEFT JOIN account_validity as av
+ ON av.user_id = u.name
+ WHERE u.name = ?
+ """
+
+ txn.execute(sql, (user_id,))
+ row = txn.fetchone()
+ if row:
+ rows.append(row)
+
+ return rows
+
+ info_rows = yield self.db.runInteraction(
+ "get_info_for_users", _get_info_for_users_txn
+ )
+
+ return {
+ user_id: {
+ "expired": (
+ expiration is not None and self.clock.time_msec() >= expiration
+ ),
+ "deactivated": deactivated == 1,
+ }
+ for user_id, deactivated, expiration in info_rows
+ }
+
async def is_server_admin(self, user):
"""Determines if a user is an admin of this homeserver.
@@ -753,7 +831,7 @@ class RegistrationWorkerStore(SQLBaseStore):
last_send_attempt, validated_at
FROM threepid_validation_session WHERE %s
""" % (
- " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),
+ " AND ".join("%s = ?" % k for k in keyvalues.keys()),
)
if validated is not None:
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index 46f643c6b9..d72d6affb8 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -340,6 +340,24 @@ class RoomWorkerStore(SQLBaseStore):
desc="is_room_blocked",
)
+ @defer.inlineCallbacks
+ def is_room_published(self, room_id):
+ """Check whether a room has been published in the local public room
+ directory.
+
+ Args:
+ room_id (str)
+ Returns:
+ bool: Whether the room is currently published in the room directory
+ """
+ # Get room information
+ room_info = yield self.get_room(room_id)
+ if not room_info:
+ defer.returnValue(False)
+
+ # Check the is_public value
+ defer.returnValue(room_info.get("is_public", False))
+
async def get_rooms_paginate(
self,
start: int,
@@ -548,6 +566,11 @@ class RoomWorkerStore(SQLBaseStore):
Returns:
dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
"""
+ # If the room retention feature is disabled, return a policy with no minimum nor
+ # maximum, in order not to filter out events we should filter out when sending to
+ # the client.
+ if not self.config.retention_enabled:
+ defer.returnValue({"min_lifetime": None, "max_lifetime": None})
def get_retention_policy_for_room_txn(txn):
txn.execute(
@@ -626,36 +649,10 @@ class RoomWorkerStore(SQLBaseStore):
def _quarantine_media_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
- total_media_quarantined = 0
-
- # Now update all the tables to set the quarantined_by flag
-
- txn.executemany(
- """
- UPDATE local_media_repository
- SET quarantined_by = ?
- WHERE media_id = ?
- """,
- ((quarantined_by, media_id) for media_id in local_mxcs),
- )
-
- txn.executemany(
- """
- UPDATE remote_media_cache
- SET quarantined_by = ?
- WHERE media_origin = ? AND media_id = ?
- """,
- (
- (quarantined_by, origin, media_id)
- for origin, media_id in remote_mxcs
- ),
+ return self._quarantine_media_txn(
+ txn, local_mxcs, remote_mxcs, quarantined_by
)
- total_media_quarantined += len(local_mxcs)
- total_media_quarantined += len(remote_mxcs)
-
- return total_media_quarantined
-
return self.db.runInteraction(
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
@@ -805,17 +802,17 @@ class RoomWorkerStore(SQLBaseStore):
Returns:
The total number of media items quarantined
"""
- total_media_quarantined = 0
-
# Update all the tables to set the quarantined_by flag
txn.executemany(
"""
UPDATE local_media_repository
SET quarantined_by = ?
- WHERE media_id = ?
+ WHERE media_id = ? AND safe_from_quarantine = ?
""",
- ((quarantined_by, media_id) for media_id in local_mxcs),
+ ((quarantined_by, media_id, False) for media_id in local_mxcs),
)
+ # Note that a rowcount of -1 can be used to indicate no rows were affected.
+ total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
txn.executemany(
"""
@@ -825,9 +822,7 @@ class RoomWorkerStore(SQLBaseStore):
""",
((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs),
)
-
- total_media_quarantined += len(local_mxcs)
- total_media_quarantined += len(remote_mxcs)
+ total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
return total_media_quarantined
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 137ebac833..44bab65eac 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -17,8 +17,6 @@
import logging
from typing import Iterable, List, Set
-from six import iteritems, itervalues
-
from canonicaljson import json
from twisted.internet import defer
@@ -544,7 +542,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
users_in_room = {}
member_event_ids = [
e_id
- for key, e_id in iteritems(current_state_ids)
+ for key, e_id in current_state_ids.items()
if key[0] == EventTypes.Member
]
@@ -561,7 +559,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
users_in_room = dict(prev_res)
member_event_ids = [
e_id
- for key, e_id in iteritems(context.delta_ids)
+ for key, e_id in context.delta_ids.items()
if key[0] == EventTypes.Member
]
for etype, state_key in context.delta_ids:
@@ -1101,7 +1099,7 @@ class _JoinedHostsCache(object):
if state_entry.state_group == self.state_group:
pass
elif state_entry.prev_group == self.state_group:
- for (typ, state_key), event_id in iteritems(state_entry.delta_ids):
+ for (typ, state_key), event_id in state_entry.delta_ids.items():
if typ != EventTypes.Member:
continue
@@ -1131,7 +1129,7 @@ class _JoinedHostsCache(object):
self.state_group = state_entry.state_group
else:
self.state_group = object()
- self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
+ self._len = sum(len(v) for v in self.hosts_to_joined_users.values())
return frozenset(self.hosts_to_joined_users)
def __len__(self):
diff --git a/synapse/storage/data_stores/main/schema/delta/30/as_users.py b/synapse/storage/data_stores/main/schema/delta/30/as_users.py
index 9b95411fb6..b42c02710a 100644
--- a/synapse/storage/data_stores/main/schema/delta/30/as_users.py
+++ b/synapse/storage/data_stores/main/schema/delta/30/as_users.py
@@ -13,8 +13,6 @@
# limitations under the License.
import logging
-from six.moves import range
-
from synapse.config.appservice import load_appservices
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/data_stores/main/schema/delta/48/profiles_batch.sql b/synapse/storage/data_stores/main/schema/delta/48/profiles_batch.sql
new file mode 100644
index 0000000000..e744c02fe8
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/48/profiles_batch.sql
@@ -0,0 +1,36 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Add a batch number to track changes to profiles and the
+ * order they're made in so we can replicate user profiles
+ * to other hosts as they change
+ */
+ALTER TABLE profiles ADD COLUMN batch BIGINT DEFAULT NULL;
+
+/*
+ * Index on the batch number so we can get profiles
+ * by their batch
+ */
+CREATE INDEX profiles_batch_idx ON profiles(batch);
+
+/*
+ * A table to track what batch of user profiles has been
+ * synced to what profile replication target.
+ */
+CREATE TABLE profile_replication_status (
+ host TEXT NOT NULL,
+ last_synced_batch BIGINT NOT NULL
+);
diff --git a/synapse/storage/data_stores/main/schema/delta/50/profiles_deactivated_users.sql b/synapse/storage/data_stores/main/schema/delta/50/profiles_deactivated_users.sql
new file mode 100644
index 0000000000..c8893ecbe8
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/50/profiles_deactivated_users.sql
@@ -0,0 +1,23 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * A flag saying whether the user owning the profile has been deactivated
+ * This really belongs on the users table, not here, but the users table
+ * stores users by their full user_id and profiles stores them by localpart,
+ * so we can't easily join between the two tables. Plus, the batch number
+ * realy ought to represent data in this table that has changed.
+ */
+ALTER TABLE profiles ADD COLUMN active SMALLINT DEFAULT 1 NOT NULL;
diff --git a/synapse/storage/data_stores/main/schema/delta/55/profile_replication_status_index.sql b/synapse/storage/data_stores/main/schema/delta/55/profile_replication_status_index.sql
new file mode 100644
index 0000000000..18a0f7e10c
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/55/profile_replication_status_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('profile_replication_status_host_index', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql b/synapse/storage/data_stores/main/schema/delta/55/room_retention.sql
index ee6cdf7a14..ee6cdf7a14 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql
+++ b/synapse/storage/data_stores/main/schema/delta/55/room_retention.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres b/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres
new file mode 100644
index 0000000000..597f2ffd3d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- The local_media_repository should have files which do not get quarantined,
+-- e.g. files from sticker packs.
+ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT FALSE;
diff --git a/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite
new file mode 100644
index 0000000000..69db89ac0e
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- The local_media_repository should have files which do not get quarantined,
+-- e.g. files from sticker packs.
+ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT 0;
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
index 889a9a0ce4..20c5af2eb7 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
@@ -658,10 +658,19 @@ CREATE TABLE presence_stream (
+CREATE TABLE profile_replication_status (
+ host text NOT NULL,
+ last_synced_batch bigint NOT NULL
+);
+
+
+
CREATE TABLE profiles (
user_id text NOT NULL,
displayname text,
- avatar_url text
+ avatar_url text,
+ batch bigint,
+ active smallint DEFAULT 1 NOT NULL
);
@@ -1788,6 +1797,10 @@ CREATE INDEX presence_stream_user_id ON presence_stream USING btree (user_id);
+CREATE INDEX profiles_batch_idx ON profiles USING btree (batch);
+
+
+
CREATE INDEX public_room_index ON rooms USING btree (is_public);
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
index a0411ede7e..e28ec3fa45 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
@@ -6,7 +6,7 @@ CREATE TABLE presence_allow_inbound( observed_user_id TEXT NOT NULL, observer_us
CREATE TABLE users( name TEXT, password_hash TEXT, creation_ts BIGINT, admin SMALLINT DEFAULT 0 NOT NULL, upgrade_ts BIGINT, is_guest SMALLINT DEFAULT 0 NOT NULL, appservice_id TEXT, consent_version TEXT, consent_server_notice_sent TEXT, user_type TEXT DEFAULT NULL, UNIQUE(name) );
CREATE TABLE access_tokens( id BIGINT PRIMARY KEY, user_id TEXT NOT NULL, device_id TEXT, token TEXT NOT NULL, last_used BIGINT, UNIQUE(token) );
CREATE TABLE user_ips ( user_id TEXT NOT NULL, access_token TEXT NOT NULL, device_id TEXT, ip TEXT NOT NULL, user_agent TEXT NOT NULL, last_seen BIGINT NOT NULL );
-CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, UNIQUE(user_id) );
+CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, batch BIGINT DEFAULT NULL, active SMALLINT DEFAULT 1 NOT NULL, UNIQUE(user_id) );
CREATE TABLE received_transactions( transaction_id TEXT, origin TEXT, ts BIGINT, response_code INTEGER, response_json bytea, has_been_referenced smallint default 0, UNIQUE (transaction_id, origin) );
CREATE TABLE destinations( destination TEXT PRIMARY KEY, retry_last_ts BIGINT, retry_interval INTEGER );
CREATE TABLE events( stream_ordering INTEGER PRIMARY KEY, topological_ordering BIGINT NOT NULL, event_id TEXT NOT NULL, type TEXT NOT NULL, room_id TEXT NOT NULL, content TEXT, unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, depth BIGINT DEFAULT 0 NOT NULL, origin_server_ts BIGINT, received_ts BIGINT, sender TEXT, contains_url BOOLEAN, UNIQUE (event_id) );
@@ -202,6 +202,8 @@ CREATE INDEX group_users_u_idx ON group_users(user_id);
CREATE INDEX group_invites_u_idx ON group_invites(user_id);
CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id);
CREATE INDEX group_rooms_r_idx ON group_rooms(room_id);
+CREATE INDEX profiles_batch_idx ON profiles(batch);
+CREATE TABLE profile_replication_status ( host TEXT NOT NULL, last_synced_batch BIGINT NOT NULL );
CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, device_id TEXT, timestamp BIGINT NOT NULL );
CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp);
CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp);
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index 13f49d8060..a8381dc577 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -17,8 +17,6 @@ import logging
import re
from collections import namedtuple
-from six import string_types
-
from canonicaljson import json
from twisted.internet import defer
@@ -180,7 +178,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
# skip over it.
continue
- if not isinstance(value, string_types):
+ if not isinstance(value, str):
# If the event body, name or topic isn't a string
# then skip over it
continue
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index e89f0bffb5..379d758b5d 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -40,8 +40,6 @@ import abc
import logging
from collections import namedtuple
-from six.moves import range
-
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py
index 4219018302..f8c776be3f 100644
--- a/synapse/storage/data_stores/main/tags.py
+++ b/synapse/storage/data_stores/main/tags.py
@@ -16,8 +16,6 @@
import logging
-from six.moves import range
-
from canonicaljson import json
from twisted.internet import defer
diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py
index 1d8ee22fb1..ec2f38c373 100644
--- a/synapse/storage/data_stores/main/ui_auth.py
+++ b/synapse/storage/data_stores/main/ui_auth.py
@@ -186,7 +186,7 @@ class UIAuthWorkerStore(SQLBaseStore):
# The clientdict gets stored as JSON.
clientdict_json = json.dumps(clientdict)
- self.db.simple_update_one(
+ await self.db.simple_update_one(
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
updatevalues={"clientdict": clientdict_json},
diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/data_stores/state/bg_updates.py
index ff000bc9ec..be1fe97d79 100644
--- a/synapse/storage/data_stores/state/bg_updates.py
+++ b/synapse/storage/data_stores/state/bg_updates.py
@@ -15,8 +15,6 @@
import logging
-from six import iteritems
-
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
@@ -280,7 +278,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
delta_state = {
key: value
- for key, value in iteritems(curr_state)
+ for key, value in curr_state.items()
if prev_state.get(key, None) != value
}
@@ -316,7 +314,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in iteritems(delta_state)
+ for key, state_id in delta_state.items()
],
)
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index f3ad1e4369..5db9f20135 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/data_stores/state/store.py
@@ -17,9 +17,6 @@ import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Set, Tuple
-from six import iteritems
-from six.moves import range
-
from twisted.internet import defer
from synapse.api.constants import EventTypes
@@ -263,7 +260,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# And finally update the result dict, by filtering out any extra
# stuff we pulled out of the database.
- for group, group_state_dict in iteritems(group_to_state_dict):
+ for group, group_state_dict in group_to_state_dict.items():
# We just replace any existing entries, as we will have loaded
# everything we need from the database anyway.
state[group] = state_filter.filter_state(group_state_dict)
@@ -341,11 +338,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
else:
non_member_types = non_member_filter.concrete_types()
- for group, group_state_dict in iteritems(group_to_state_dict):
+ for group, group_state_dict in group_to_state_dict.items():
state_dict_members = {}
state_dict_non_members = {}
- for k, v in iteritems(group_state_dict):
+ for k, v in group_state_dict.items():
if k[0] == EventTypes.Member:
state_dict_members[k] = v
else:
@@ -432,7 +429,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in iteritems(delta_ids)
+ for key, state_id in delta_ids.items()
],
)
else:
@@ -447,7 +444,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in iteritems(current_state_ids)
+ for key, state_id in current_state_ids.items()
],
)
@@ -458,7 +455,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
current_member_state_ids = {
s: ev
- for (s, ev) in iteritems(current_state_ids)
+ for (s, ev) in current_state_ids.items()
if s[0] == EventTypes.Member
}
txn.call_after(
@@ -470,7 +467,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
current_non_member_state_ids = {
s: ev
- for (s, ev) in iteritems(current_state_ids)
+ for (s, ev) in current_state_ids.items()
if s[0] != EventTypes.Member
}
txn.call_after(
@@ -555,7 +552,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in iteritems(curr_state)
+ for key, state_id in curr_state.items()
],
)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index b112ff3df2..3be20c866a 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -16,6 +16,7 @@
# limitations under the License.
import logging
import time
+from sys import intern
from time import monotonic as monotonic_time
from typing import (
Any,
@@ -29,9 +30,6 @@ from typing import (
TypeVar,
)
-from six import iteritems, iterkeys, itervalues
-from six.moves import intern, range
-
from prometheus_client import Histogram
from twisted.enterprise import adbapi
@@ -259,7 +257,7 @@ class PerformanceCounters(object):
def interval(self, interval_duration_secs, limit=3):
counters = []
- for name, (count, cum_time) in iteritems(self.current_counters):
+ for name, (count, cum_time) in self.current_counters.items():
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append(
(
@@ -1053,7 +1051,7 @@ class Database(object):
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
if keyvalues:
- sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+ sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
txn.execute(sql, list(keyvalues.values()))
else:
txn.execute(sql)
@@ -1191,7 +1189,7 @@ class Database(object):
clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
clauses = [clause]
- for key, value in iteritems(keyvalues):
+ for key, value in keyvalues.items():
clauses.append("%s = ?" % (key,))
values.append(value)
@@ -1212,7 +1210,7 @@ class Database(object):
@staticmethod
def simple_update_txn(txn, table, keyvalues, updatevalues):
if keyvalues:
- where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+ where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
else:
where = ""
@@ -1351,7 +1349,7 @@ class Database(object):
clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
clauses = [clause]
- for key, value in iteritems(keyvalues):
+ for key, value in keyvalues.items():
clauses.append("%s = ?" % (key,))
values.append(value)
@@ -1388,7 +1386,7 @@ class Database(object):
txn.close()
if cache:
- min_val = min(itervalues(cache))
+ min_val = min(cache.values())
else:
min_val = max_value
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index f159400a87..ec894a91cb 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -20,9 +20,6 @@ import logging
from collections import deque, namedtuple
from typing import Iterable, List, Optional, Set, Tuple
-from six import iteritems
-from six.moves import range
-
from prometheus_client import Counter, Histogram
from twisted.internet import defer
@@ -218,7 +215,7 @@ class EventsPersistenceStorage(object):
partitioned.setdefault(event.room_id, []).append((event, ctx))
deferreds = []
- for room_id, evs_ctxs in iteritems(partitioned):
+ for room_id, evs_ctxs in partitioned.items():
d = self._event_persist_queue.add_to_queue(
room_id, evs_ctxs, backfilled=backfilled
)
@@ -319,7 +316,7 @@ class EventsPersistenceStorage(object):
(event, context)
)
- for room_id, ev_ctx_rm in iteritems(events_by_room):
+ for room_id, ev_ctx_rm in events_by_room.items():
latest_event_ids = await self.main_store.get_latest_event_ids_in_room(
room_id
)
@@ -674,7 +671,7 @@ class EventsPersistenceStorage(object):
to_insert = {
key: ev_id
- for key, ev_id in iteritems(current_state)
+ for key, ev_id in current_state.items()
if ev_id != existing_state.get(key)
}
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index c522c80922..dc568476f4 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -16,8 +16,6 @@
import logging
from typing import Iterable, List, TypeVar
-from six import iteritems, itervalues
-
import attr
from twisted.internet import defer
@@ -51,7 +49,7 @@ class StateFilter(object):
# If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary
if self.include_others:
- self.types = {k: v for k, v in iteritems(self.types) if v is not None}
+ self.types = {k: v for k, v in self.types.items() if v is not None}
@staticmethod
def all():
@@ -150,7 +148,7 @@ class StateFilter(object):
has_non_member_wildcard = self.include_others or any(
state_keys is None
- for t, state_keys in iteritems(self.types)
+ for t, state_keys in self.types.items()
if t != EventTypes.Member
)
@@ -199,7 +197,7 @@ class StateFilter(object):
# First we build up a lost of clauses for each type/state_key combo
clauses = []
- for etype, state_keys in iteritems(self.types):
+ for etype, state_keys in self.types.items():
if state_keys is None:
clauses.append("(type = ?)")
where_args.append(etype)
@@ -251,7 +249,7 @@ class StateFilter(object):
return dict(state_dict)
filtered_state = {}
- for k, v in iteritems(state_dict):
+ for k, v in state_dict.items():
typ, state_key = k
if typ in self.types:
state_keys = self.types[typ]
@@ -279,7 +277,7 @@ class StateFilter(object):
"""
return self.include_others or any(
- state_keys is None for state_keys in itervalues(self.types)
+ state_keys is None for state_keys in self.types.values()
)
def concrete_types(self):
@@ -292,7 +290,7 @@ class StateFilter(object):
"""
return [
(t, s)
- for t, state_keys in iteritems(self.types)
+ for t, state_keys in self.types.items()
if state_keys is not None
for s in state_keys
]
@@ -324,7 +322,7 @@ class StateFilter(object):
member_filter = StateFilter.none()
non_member_filter = StateFilter(
- types={k: v for k, v in iteritems(self.types) if k != EventTypes.Member},
+ types={k: v for k, v in self.types.items() if k != EventTypes.Member},
include_others=self.include_others,
)
@@ -366,7 +364,7 @@ class StateGroupStorage(object):
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
- groups = set(itervalues(event_to_groups))
+ groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(groups)
return group_to_state
@@ -400,8 +398,8 @@ class StateGroupStorage(object):
state_event_map = yield self.stores.main.get_events(
[
ev_id
- for group_ids in itervalues(group_to_ids)
- for ev_id in itervalues(group_ids)
+ for group_ids in group_to_ids.values()
+ for ev_id in group_ids.values()
],
get_prev_content=False,
)
@@ -409,10 +407,10 @@ class StateGroupStorage(object):
return {
group: [
state_event_map[v]
- for v in itervalues(event_id_map)
+ for v in event_id_map.values()
if v in state_event_map
]
- for group, event_id_map in iteritems(group_to_ids)
+ for group, event_id_map in group_to_ids.items()
}
def _get_state_groups_from_groups(
@@ -444,23 +442,23 @@ class StateGroupStorage(object):
"""
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
- groups = set(itervalues(event_to_groups))
+ groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(
groups, state_filter
)
state_event_map = yield self.stores.main.get_events(
- [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
+ [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
get_prev_content=False,
)
event_to_state = {
event_id: {
k: state_event_map[v]
- for k, v in iteritems(group_to_state[group])
+ for k, v in group_to_state[group].items()
if v in state_event_map
}
- for event_id, group in iteritems(event_to_groups)
+ for event_id, group in event_to_groups.items()
}
return {event: event_to_state[event] for event in event_ids}
@@ -481,14 +479,14 @@ class StateGroupStorage(object):
"""
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
- groups = set(itervalues(event_to_groups))
+ groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(
groups, state_filter
)
event_to_state = {
event_id: group_to_state[group]
- for event_id, group in iteritems(event_to_groups)
+ for event_id, group in event_to_groups.items()
}
return {event: event_to_state[event] for event in event_ids}
diff --git a/synapse/third_party_rules/__init__.py b/synapse/third_party_rules/__init__.py
new file mode 100644
index 0000000000..1453d04571
--- /dev/null
+++ b/synapse/third_party_rules/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/third_party_rules/access_rules.py b/synapse/third_party_rules/access_rules.py
new file mode 100644
index 0000000000..2c9155d15c
--- /dev/null
+++ b/synapse/third_party_rules/access_rules.py
@@ -0,0 +1,588 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import email.utils
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset
+from synapse.api.errors import SynapseError
+from synapse.config._base import ConfigError
+from synapse.types import get_domain_from_id
+
+ACCESS_RULES_TYPE = "im.vector.room.access_rules"
+ACCESS_RULE_RESTRICTED = "restricted"
+ACCESS_RULE_UNRESTRICTED = "unrestricted"
+ACCESS_RULE_DIRECT = "direct"
+
+VALID_ACCESS_RULES = (
+ ACCESS_RULE_DIRECT,
+ ACCESS_RULE_RESTRICTED,
+ ACCESS_RULE_UNRESTRICTED,
+)
+
+# Rules to which we need to apply the power levels restrictions.
+#
+# These are all of the rules that neither:
+# * forbid users from joining based on a server blacklist (which means that there
+# is no need to apply power level restrictions), nor
+# * target direct chats (since we allow both users to be room admins in this case).
+#
+# The power-level restrictions, when they are applied, prevent the following:
+# * the default power level for users (users_default) being set to anything other than 0.
+# * a non-default power level being assigned to any user which would be forbidden from
+# joining a restricted room.
+RULES_WITH_RESTRICTED_POWER_LEVELS = (ACCESS_RULE_UNRESTRICTED,)
+
+
+class RoomAccessRules(object):
+ """Implementation of the ThirdPartyEventRules module API that allows federation admins
+ to define custom rules for specific events and actions.
+ Implements the custom behaviour for the "im.vector.room.access_rules" state event.
+
+ Takes a config in the format:
+
+ third_party_event_rules:
+ module: third_party_rules.RoomAccessRules
+ config:
+ # List of domains (server names) that can't be invited to rooms if the
+ # "restricted" rule is set. Defaults to an empty list.
+ domains_forbidden_when_restricted: []
+
+ # Identity server to use when checking the HS an email address belongs to
+ # using the /info endpoint. Required.
+ id_server: "vector.im"
+
+ Don't forget to consider if you can invite users from your own domain.
+ """
+
+ def __init__(self, config, http_client):
+ self.http_client = http_client
+
+ self.id_server = config["id_server"]
+
+ self.domains_forbidden_when_restricted = config.get(
+ "domains_forbidden_when_restricted", []
+ )
+
+ @staticmethod
+ def parse_config(config):
+ if "id_server" in config:
+ return config
+ else:
+ raise ConfigError("No IS for event rules TchapEventRules")
+
+ def on_create_room(self, requester, config, is_requester_admin) -> bool:
+ """Implements synapse.events.ThirdPartyEventRules.on_create_room
+
+ Checks if a im.vector.room.access_rules event is being set during room creation.
+ If yes, make sure the event is correct. Otherwise, append an event with the
+ default rule to the initial state.
+ """
+ is_direct = config.get("is_direct")
+ preset = config.get("preset")
+ access_rule = None
+ join_rule = None
+
+ # If there's a rules event in the initial state, check if it complies with the
+ # spec for im.vector.room.access_rules and deny the request if not.
+ for event in config.get("initial_state", []):
+ if event["type"] == ACCESS_RULES_TYPE:
+ access_rule = event["content"].get("rule")
+
+ # Make sure the event has a valid content.
+ if access_rule is None:
+ raise SynapseError(400, "Invalid access rule")
+
+ # Make sure the rule name is valid.
+ if access_rule not in VALID_ACCESS_RULES:
+ raise SynapseError(400, "Invalid access rule")
+
+ # Make sure the rule is "direct" if the room is a direct chat.
+ if (is_direct and access_rule != ACCESS_RULE_DIRECT) or (
+ access_rule == ACCESS_RULE_DIRECT and not is_direct
+ ):
+ raise SynapseError(400, "Invalid access rule")
+
+ if event["type"] == EventTypes.JoinRules:
+ join_rule = event["content"].get("join_rule")
+
+ if access_rule is None:
+ # If there's no access rules event in the initial state, create one with the
+ # default setting.
+ if is_direct:
+ default_rule = ACCESS_RULE_DIRECT
+ else:
+ # If the default value for non-direct chat changes, we should make another
+ # case here for rooms created with either a "public" join_rule or the
+ # "public_chat" preset to make sure those keep defaulting to "restricted"
+ default_rule = ACCESS_RULE_RESTRICTED
+
+ if not config.get("initial_state"):
+ config["initial_state"] = []
+
+ config["initial_state"].append(
+ {
+ "type": ACCESS_RULES_TYPE,
+ "state_key": "",
+ "content": {"rule": default_rule},
+ }
+ )
+
+ access_rule = default_rule
+
+ # Check that the preset or the join rule in use is compatible with the access
+ # rule, whether it's a user-defined one or the default one (i.e. if it involves
+ # a "public" join rule, the access rule must be "restricted").
+ if (
+ join_rule == JoinRules.PUBLIC or preset == RoomCreationPreset.PUBLIC_CHAT
+ ) and access_rule != ACCESS_RULE_RESTRICTED:
+ raise SynapseError(400, "Invalid access rule")
+
+ # Check if the creator can override values for the power levels.
+ allowed = self._is_power_level_content_allowed(
+ config.get("power_level_content_override", {}), access_rule
+ )
+ if not allowed:
+ raise SynapseError(400, "Invalid power levels content override")
+
+ # Second loop for events we need to know the current rule to process.
+ for event in config.get("initial_state", []):
+ if event["type"] == EventTypes.PowerLevels:
+ allowed = self._is_power_level_content_allowed(
+ event["content"], access_rule
+ )
+ if not allowed:
+ raise SynapseError(400, "Invalid power levels content")
+
+ return True
+
+ @defer.inlineCallbacks
+ def check_threepid_can_be_invited(self, medium, address, state_events):
+ """Implements synapse.events.ThirdPartyEventRules.check_threepid_can_be_invited
+
+ Check if a threepid can be invited to the room via a 3PID invite given the current
+ rules and the threepid's address, by retrieving the HS it's mapped to from the
+ configured identity server, and checking if we can invite users from it.
+ """
+ rule = self._get_rule_from_state(state_events)
+
+ if medium != "email":
+ defer.returnValue(False)
+
+ if rule != ACCESS_RULE_RESTRICTED:
+ # Only "restricted" requires filtering 3PID invites. We don't need to do
+ # anything for "direct" here, because only "restricted" requires filtering
+ # based on the HS the address is mapped to.
+ defer.returnValue(True)
+
+ parsed_address = email.utils.parseaddr(address)[1]
+ if parsed_address != address:
+ # Avoid reproducing the security issue described here:
+ # https://matrix.org/blog/2019/04/18/security-update-sydent-1-0-2
+ # It's probably not worth it but let's just be overly safe here.
+ defer.returnValue(False)
+
+ # Get the HS this address belongs to from the identity server.
+ res = yield self.http_client.get_json(
+ "https://%s/_matrix/identity/api/v1/info" % (self.id_server,),
+ {"medium": medium, "address": address},
+ )
+
+ # Look for a domain that's not forbidden from being invited.
+ if not res.get("hs"):
+ defer.returnValue(False)
+ if res.get("hs") in self.domains_forbidden_when_restricted:
+ defer.returnValue(False)
+
+ defer.returnValue(True)
+
+ def check_event_allowed(self, event, state_events):
+ """Implements synapse.events.ThirdPartyEventRules.check_event_allowed
+
+ Checks the event's type and the current rule and calls the right function to
+ determine whether the event can be allowed.
+ """
+ if event.type == ACCESS_RULES_TYPE:
+ return self._on_rules_change(event, state_events)
+
+ # We need to know the rule to apply when processing the event types below.
+ rule = self._get_rule_from_state(state_events)
+
+ if event.type == EventTypes.PowerLevels:
+ return self._is_power_level_content_allowed(event.content, rule)
+
+ if event.type == EventTypes.Member or event.type == EventTypes.ThirdPartyInvite:
+ return self._on_membership_or_invite(event, rule, state_events)
+
+ if event.type == EventTypes.JoinRules:
+ return self._on_join_rule_change(event, rule)
+
+ if event.type == EventTypes.RoomAvatar:
+ return self._on_room_avatar_change(event, rule)
+
+ if event.type == EventTypes.Name:
+ return self._on_room_name_change(event, rule)
+
+ if event.type == EventTypes.Topic:
+ return self._on_room_topic_change(event, rule)
+
+ return True
+
+ def _on_rules_change(self, event, state_events):
+ """Implement the checks and behaviour specified on allowing or forbidding a new
+ im.vector.room.access_rules event.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ state_events (dict[tuple[event type, state key], EventBase]): The state of the
+ room before the event was sent.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ new_rule = event.content.get("rule")
+
+ # Check for invalid values.
+ if new_rule not in VALID_ACCESS_RULES:
+ return False
+
+ # We must not allow rooms with the "public" join rule to be given any other access
+ # rule than "restricted".
+ join_rule = self._get_join_rule_from_state(state_events)
+ if join_rule == JoinRules.PUBLIC and new_rule != ACCESS_RULE_RESTRICTED:
+ return False
+
+ # Make sure we don't apply "direct" if the room has more than two members.
+ if new_rule == ACCESS_RULE_DIRECT:
+ existing_members, threepid_tokens = self._get_members_and_tokens_from_state(
+ state_events
+ )
+
+ if len(existing_members) > 2 or len(threepid_tokens) > 1:
+ return False
+
+ prev_rules_event = state_events.get((ACCESS_RULES_TYPE, ""))
+
+ # Now that we know the new rule doesn't break the "direct" case, we can allow any
+ # new rule in rooms that had none before.
+ if prev_rules_event is None:
+ return True
+
+ prev_rule = prev_rules_event.content.get("rule")
+
+ # Currently, we can only go from "restricted" to "unrestricted".
+ if prev_rule == ACCESS_RULE_RESTRICTED and new_rule == ACCESS_RULE_UNRESTRICTED:
+ return True
+
+ return False
+
+ def _on_membership_or_invite(self, event, rule, state_events):
+ """Applies the correct rule for incoming m.room.member and
+ m.room.third_party_invite events.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ state_events (dict[tuple[event type, state key], EventBase]): The state of the
+ room before the event was sent.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ if rule == ACCESS_RULE_RESTRICTED:
+ ret = self._on_membership_or_invite_restricted(event)
+ elif rule == ACCESS_RULE_UNRESTRICTED:
+ ret = self._on_membership_or_invite_unrestricted()
+ elif rule == ACCESS_RULE_DIRECT:
+ ret = self._on_membership_or_invite_direct(event, state_events)
+ else:
+ # We currently apply the default (restricted) if we don't know the rule, we
+ # might want to change that in the future.
+ ret = self._on_membership_or_invite_restricted(event)
+
+ return ret
+
+ def _on_membership_or_invite_restricted(self, event):
+ """Implements the checks and behaviour specified for the "restricted" rule.
+
+ "restricted" currently means that users can only invite users if their server is
+ included in a limited list of domains.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ # We're not applying the rules on m.room.third_party_member events here because
+ # the filtering on threepids is done in check_threepid_can_be_invited, which is
+ # called before check_event_allowed.
+ if event.type == EventTypes.ThirdPartyInvite:
+ return True
+
+ # We only need to process "join" and "invite" memberships, in order to be backward
+ # compatible, e.g. if a user from a blacklisted server joined a restricted room
+ # before the rules started being enforced on the server, that user must be able to
+ # leave it.
+ if event.membership not in [Membership.JOIN, Membership.INVITE]:
+ return True
+
+ invitee_domain = get_domain_from_id(event.state_key)
+ return invitee_domain not in self.domains_forbidden_when_restricted
+
+ def _on_membership_or_invite_unrestricted(self):
+ """Implements the checks and behaviour specified for the "unrestricted" rule.
+
+ "unrestricted" currently means that every event is allowed.
+
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ return True
+
+ def _on_membership_or_invite_direct(self, event, state_events):
+ """Implements the checks and behaviour specified for the "direct" rule.
+
+ "direct" currently means that no member is allowed apart from the two initial
+ members the room was created for (i.e. the room's creator and their first
+ invitee).
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ state_events (dict[tuple[event type, state key], EventBase]): The state of the
+ room before the event was sent.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ # Get the room memberships and 3PID invite tokens from the room's state.
+ existing_members, threepid_tokens = self._get_members_and_tokens_from_state(
+ state_events
+ )
+
+ # There should never be more than one 3PID invite in the room state: if the second
+ # original user came and left, and we're inviting them using their email address,
+ # given we know they have a Matrix account binded to the address (so they could
+ # join the first time), Synapse will successfully look it up before attempting to
+ # store an invite on the IS.
+ if len(threepid_tokens) == 1 and event.type == EventTypes.ThirdPartyInvite:
+ # If we already have a 3PID invite in flight, don't accept another one, unless
+ # the new one has the same invite token as its state key. This is because 3PID
+ # invite revocations must be allowed, and a revocation is basically a new 3PID
+ # invite event with an empty content and the same token as the invite it
+ # revokes.
+ return event.state_key in threepid_tokens
+
+ if len(existing_members) == 2:
+ # If the user was within the two initial user of the room, Synapse would have
+ # looked it up successfully and thus sent a m.room.member here instead of
+ # m.room.third_party_invite.
+ if event.type == EventTypes.ThirdPartyInvite:
+ return False
+
+ # We can only have m.room.member events here. The rule in this case is to only
+ # allow the event if its target is one of the initial two members in the room,
+ # i.e. the state key of one of the two m.room.member states in the room.
+ return event.state_key in existing_members
+
+ # We're alone in the room (and always have been) and there's one 3PID invite in
+ # flight.
+ if len(existing_members) == 1 and len(threepid_tokens) == 1:
+ # We can only have m.room.member events here. In this case, we can only allow
+ # the event if it's either a m.room.member from the joined user (we can assume
+ # that the only m.room.member event is a join otherwise we wouldn't be able to
+ # send an event to the room) or an an invite event which target is the invited
+ # user.
+ target = event.state_key
+ is_from_threepid_invite = self._is_invite_from_threepid(
+ event, threepid_tokens[0]
+ )
+ if is_from_threepid_invite or target == existing_members[0]:
+ return True
+
+ return False
+
+ return True
+
+ def _is_power_level_content_allowed(self, content, access_rule):
+ """Check if a given power levels event is permitted under the given access rule.
+
+ It shouldn't be allowed if it either changes the default PL to a non-0 value or
+ gives a non-0 PL to a user that would have been forbidden from joining the room
+ under a more restrictive access rule.
+
+ Args:
+ content (dict[]): The content of the m.room.power_levels event to check.
+ access_rule (str): The access rule in place in this room.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ # Check if we need to apply the restrictions with the current rule.
+ if access_rule not in RULES_WITH_RESTRICTED_POWER_LEVELS:
+ return True
+
+ # If users_default is explicitly set to a non-0 value, deny the event.
+ users_default = content.get("users_default", 0)
+ if users_default:
+ return False
+
+ users = content.get("users", {})
+ for user_id, power_level in users.items():
+ server_name = get_domain_from_id(user_id)
+ # Check the domain against the blacklist. If found, and the PL isn't 0, deny
+ # the event.
+ if (
+ server_name in self.domains_forbidden_when_restricted
+ and power_level != 0
+ ):
+ return False
+
+ return True
+
+ def _on_join_rule_change(self, event, rule):
+ """Check whether a join rule change is allowed. A join rule change is always
+ allowed unless the new join rule is "public" and the current access rule isn't
+ "restricted".
+ The rationale is that external users (those whose server would be denied access
+ to rooms enforcing the "restricted" access rule) should always rely on non-
+ external users for access to rooms, therefore they shouldn't be able to access
+ rooms that don't require an invite to be joined.
+
+ Note that we currently rely on the default access rule being "restricted": during
+ room creation, the m.room.join_rules event will be sent *before* the
+ im.vector.room.access_rules one, so the access rule that will be considered here
+ in this case will be the default "restricted" one. This is fine since the
+ "restricted" access rule allows any value for the join rule, but we should keep
+ that in mind if we need to change the default access rule in the future.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ if event.content.get("join_rule") == JoinRules.PUBLIC:
+ return rule == ACCESS_RULE_RESTRICTED
+
+ return True
+
+ def _on_room_avatar_change(self, event, rule):
+ """Check whether a change of room avatar is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ return rule != ACCESS_RULE_DIRECT
+
+ def _on_room_name_change(self, event, rule):
+ """Check whether a change of room name is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ return rule != ACCESS_RULE_DIRECT
+
+ def _on_room_topic_change(self, event, rule):
+ """Check whether a change of room topic is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ return rule != ACCESS_RULE_DIRECT
+
+ @staticmethod
+ def _get_rule_from_state(state_events):
+ """Extract the rule to be applied from the given set of state events.
+
+ Args:
+ state_events (dict[tuple[event type, state key], EventBase]): The set of state
+ events.
+ Returns:
+ str, the name of the rule (either "direct", "restricted" or "unrestricted")
+ """
+ access_rules = state_events.get((ACCESS_RULES_TYPE, ""))
+ if access_rules is None:
+ rule = ACCESS_RULE_RESTRICTED
+ else:
+ rule = access_rules.content.get("rule")
+ return rule
+
+ @staticmethod
+ def _get_join_rule_from_state(state_events):
+ """Extract the room's join rule from the given set of state events.
+
+ Args:
+ state_events (dict[tuple[event type, state key], EventBase]): The set of state
+ events.
+ Returns:
+ str, the name of the join rule (either "public", or "invite")
+ """
+ join_rule_event = state_events.get((EventTypes.JoinRules, ""))
+ if join_rule_event is None:
+ return None
+ return join_rule_event.content.get("join_rule")
+
+ @staticmethod
+ def _get_members_and_tokens_from_state(state_events):
+ """Retrieves from a list of state events the list of users that have a
+ m.room.member event in the room, and the tokens of 3PID invites in the room.
+
+ Args:
+ state_events (dict[tuple[event type, state key], EventBase]): The set of state
+ events.
+ Returns:
+ existing_members (list[str]): List of targets of the m.room.member events in
+ the state.
+ threepid_invite_tokens (list[str]): List of tokens of the 3PID invites in the
+ state.
+ """
+ existing_members = []
+ threepid_invite_tokens = []
+ for key, state_event in state_events.items():
+ if key[0] == EventTypes.Member and state_event.content:
+ existing_members.append(state_event.state_key)
+ if key[0] == EventTypes.ThirdPartyInvite and state_event.content:
+ # Don't include revoked invites.
+ threepid_invite_tokens.append(state_event.state_key)
+
+ return existing_members, threepid_invite_tokens
+
+ @staticmethod
+ def _is_invite_from_threepid(invite, threepid_invite_token):
+ """Checks whether the given invite follows the given 3PID invite.
+
+ Args:
+ invite (EventBase): The m.room.member event with "invite" membership.
+ threepid_invite_token (str): The state key from the 3PID invite.
+ """
+ token = (
+ invite.content.get("third_party_invite", {})
+ .get("signed", {})
+ .get("token", "")
+ )
+
+ return token == threepid_invite_token
diff --git a/synapse/types.py b/synapse/types.py
index acf60baddc..9f28f9a192 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -19,6 +19,8 @@ import sys
from collections import namedtuple
from typing import Any, Dict, Tuple, TypeVar
+from six.moves import filter
+
import attr
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
@@ -267,6 +269,19 @@ def contains_invalid_mxid_characters(localpart):
return any(c not in mxid_localpart_allowed_characters for c in localpart)
+def strip_invalid_mxid_characters(localpart):
+ """Removes any invalid characters from an mxid
+
+ Args:
+ localpart (basestring): the localpart to be stripped
+
+ Returns:
+ localpart (basestring): the localpart having been stripped
+ """
+ filtered = filter(lambda c: c in mxid_localpart_allowed_characters, localpart)
+ return "".join(filtered)
+
+
UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")
# the following is a pattern which matches '=', and bytes which are not allowed in a mxid
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index f7af2bca7f..65abf0846e 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -19,8 +19,6 @@ import logging
from contextlib import contextmanager
from typing import Dict, Sequence, Set, Union
-from six.moves import range
-
import attr
from twisted.internet import defer
@@ -95,7 +93,7 @@ class ObservableDeferred(object):
This returns a brand new deferred that is resolved when the underlying
deferred is resolved. Interacting with the returned deferred does not
- effect the underdlying deferred.
+ effect the underlying deferred.
"""
if not self._result:
d = defer.Deferred()
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index cd48262420..64f35fc288 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -21,8 +21,6 @@ import threading
from typing import Any, Tuple, Union, cast
from weakref import WeakValueDictionary
-from six import itervalues
-
from prometheus_client import Gauge
from typing_extensions import Protocol
@@ -281,7 +279,7 @@ class Cache(object):
def invalidate_all(self):
self.check_thread()
self.cache.clear()
- for entry in itervalues(self._pending_deferred_cache):
+ for entry in self._pending_deferred_cache.values():
entry.invalidate()
self._pending_deferred_cache.clear()
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 2726b67b6d..89a3420f92 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -16,8 +16,6 @@
import logging
from collections import OrderedDict
-from six import iteritems, itervalues
-
from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches import register_cache
@@ -150,7 +148,7 @@ class ExpiringCache(object):
keys_to_delete = set()
- for key, cache_entry in iteritems(self._cache):
+ for key, cache_entry in self._cache.items():
if now - cache_entry.time > self._expiry_ms:
keys_to_delete.add(key)
@@ -170,7 +168,7 @@ class ExpiringCache(object):
def __len__(self):
if self.iterable:
- return sum(len(entry.value) for entry in itervalues(self._cache))
+ return sum(len(entry.value) for entry in self._cache.values())
else:
return len(self._cache)
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 2a161bf244..c541bf4579 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -17,8 +17,6 @@ import logging
import math
from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union
-from six import integer_types
-
from sortedcontainers import SortedDict
from synapse.types import Collection
@@ -88,7 +86,7 @@ class StreamChangeCache:
def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool:
"""Returns True if the entity may have been updated since stream_pos
"""
- assert type(stream_pos) in integer_types
+ assert isinstance(stream_pos, int)
if stream_pos < self._earliest_known_stream_pos:
self.metrics.inc_misses()
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 2ea4e4e911..ecd9948e79 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -1,7 +1,5 @@
from typing import Dict
-from six import itervalues
-
SENTINEL = object()
@@ -81,7 +79,7 @@ def iterate_tree_cache_entry(d):
can contain dicts.
"""
if isinstance(d, dict):
- for value_d in itervalues(d):
+ for value_d in d.values():
for value in iterate_tree_cache_entry(value_d):
yield value
else:
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
index 8b17d1c8b8..6a3f6177b1 100644
--- a/synapse/util/file_consumer.py
+++ b/synapse/util/file_consumer.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from six.moves import queue
+import queue
from twisted.internet import threads
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 9815bb8667..eab78dd256 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from six import binary_type, text_type
-
from canonicaljson import json
from frozendict import frozendict
@@ -26,7 +24,7 @@ def freeze(o):
if isinstance(o, frozendict):
return o
- if isinstance(o, (binary_type, text_type)):
+ if isinstance(o, (bytes, str)):
return o
try:
@@ -41,7 +39,7 @@ def unfreeze(o):
if isinstance(o, (dict, frozendict)):
return dict({k: unfreeze(v) for k, v in o.items()})
- if isinstance(o, (binary_type, text_type)):
+ if isinstance(o, (bytes, str)):
return o
try:
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 3ec1dfb0c2..34ce7cac16 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -16,11 +16,14 @@
import logging
import re
+from twisted.internet import defer
+
logger = logging.getLogger(__name__)
+@defer.inlineCallbacks
def check_3pid_allowed(hs, medium, address):
- """Checks whether a given format of 3PID is allowed to be used on this HS
+ """Checks whether a given 3PID is allowed to be used on this HS
Args:
hs (synapse.server.HomeServer): server
@@ -28,9 +31,36 @@ def check_3pid_allowed(hs, medium, address):
address (str): address within that medium (e.g. "wotan@matrix.org")
msisdns need to first have been canonicalised
Returns:
- bool: whether the 3PID medium/address is allowed to be added to this HS
+ defered bool: whether the 3PID medium/address is allowed to be added to this HS
"""
+ if hs.config.check_is_for_allowed_local_3pids:
+ data = yield hs.get_simple_http_client().get_json(
+ "https://%s%s"
+ % (
+ hs.config.check_is_for_allowed_local_3pids,
+ "/_matrix/identity/api/v1/internal-info",
+ ),
+ {"medium": medium, "address": address},
+ )
+
+ # Check for invalid response
+ if "hs" not in data and "shadow_hs" not in data:
+ defer.returnValue(False)
+
+ # Check if this user is intended to register for this homeserver
+ if (
+ data.get("hs") != hs.config.server_name
+ and data.get("shadow_hs") != hs.config.server_name
+ ):
+ defer.returnValue(False)
+
+ if data.get("requires_invite", False) and not data.get("invited", False):
+ # Requires an invite but hasn't been invited
+ defer.returnValue(False)
+
+ defer.returnValue(True)
+
if hs.config.allowed_local_3pids:
for constraint in hs.config.allowed_local_3pids:
logger.debug(
@@ -43,8 +73,8 @@ def check_3pid_allowed(hs, medium, address):
if medium == constraint["medium"] and re.match(
constraint["pattern"], address
):
- return True
+ defer.returnValue(True)
else:
- return True
+ defer.returnValue(True)
- return False
+ defer.returnValue(False)
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index 9bf6a44f75..023beb5ede 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from six.moves import range
-
class _Entry(object):
__slots__ = ["end_key", "queue"]
diff --git a/synapse/visibility.py b/synapse/visibility.py
index bab41182b9..3dfd4af26c 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -16,9 +16,6 @@
import logging
import operator
-from six import iteritems, itervalues
-from six.moves import map
-
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
@@ -298,7 +295,7 @@ def filter_events_for_server(
# membership states for the requesting server to determine
# if the server is either in the room or has been invited
# into the room.
- for ev in itervalues(state):
+ for ev in state.values():
if ev.type != EventTypes.Member:
continue
try:
@@ -332,7 +329,7 @@ def filter_events_for_server(
)
visibility_ids = set()
- for sids in itervalues(event_to_state_ids):
+ for sids in event_to_state_ids.values():
hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
if hist:
visibility_ids.add(hist)
@@ -345,7 +342,7 @@ def filter_events_for_server(
event_map = yield storage.main.get_events(visibility_ids)
all_open = all(
e.content.get("history_visibility") in (None, "shared", "world_readable")
- for e in itervalues(event_map)
+ for e in event_map.values()
)
if not check_history_visibility_only:
@@ -394,8 +391,8 @@ def filter_events_for_server(
#
event_id_to_state_key = {
event_id: key
- for key_to_eid in itervalues(event_to_state_ids)
- for key, event_id in iteritems(key_to_eid)
+ for key_to_eid in event_to_state_ids.values()
+ for key, event_id in key_to_eid.items()
}
def include(typ, state_key):
@@ -409,20 +406,16 @@ def filter_events_for_server(
return state_key[idx + 1 :] == server_name
event_map = yield storage.main.get_events(
- [
- e_id
- for e_id, key in iteritems(event_id_to_state_key)
- if include(key[0], key[1])
- ]
+ [e_id for e_id, key in event_id_to_state_key.items() if include(key[0], key[1])]
)
event_to_state = {
e_id: {
key: event_map[inner_e_id]
- for key, inner_e_id in iteritems(key_to_eid)
+ for key, inner_e_id in key_to_eid.items()
if inner_e_id in event_map
}
- for e_id, key_to_eid in iteritems(event_to_state_ids)
+ for e_id, key_to_eid in event_to_state_ids.items()
}
to_return = []
|